Module easyagents.env

This module contains support classes and methods to interact with OpenAI gym environments as well as a gym env implementation for unit tests (linewordl)

see https://github.com/openai/gym

View Source
"""This module contains support classes and methods to interact with OpenAI gym environments

    as well as a gym env implementation for unit tests (linewordl)

    see https://github.com/openai/gym

"""

import gym, gym.envs, gym.error, gym.spaces

import inspect

import math

import matplotlib as plt

import numpy as np

from typing import List, Optional, Tuple, Dict, Callable, Any

def _is_registered_with_gym(gym_env_name: str) -> bool:

    """Determines if a gym environment with the name id exists.

        Args:

            gym_env_name: gym id to test.

        Returns:

            True if it exists, false otherwise

    """

    result = False

    try:

        spec = gym.envs.registration.spec(gym_env_name)

        assert spec is not None

        result = True

    except gym.error.UnregisteredEnv:

        pass

    return result

# noinspection DuplicatedCode

def register_with_gym(gym_env_name: str, entry_point: type, max_episode_steps: int = 100000, **kwargs):

    """Registers the class entry_point in gym by the name gym_env_name allowing overriding registrations.

    Thus different implementations of the same class (and the same name) maybe registered consecutively.

    The latest registrated version is used for instantiation.

    This facilitates developing an environment in a jupyter notebook without haveing to

    reregister a modified class under a new name.

    limitation: the max_episode_steps value of the first registration holds for all registrations

        with the same gym_env_name

    Args:

        gym_env_name: the gym environment name to be used as argument with gym.make

        max_episode_steps: all episodes end latest after this number of steps

        entry_point: the class to be registed with gym id gym_env_name

        kwargs: the args passed to the entry_point constructor call

    """

    assert gym_env_name is not None, "None is not an admissible environment name"

    assert type(gym_env_name) is str, "gym_env_name is not a str"

    assert len(gym_env_name) > 0, "empty string is not an admissible environment name"

    assert inspect.isclass(entry_point), "entry_point not a class"

    assert issubclass(entry_point, gym.Env), "entry_point not a subclass of gym.Env"

    assert callable(entry_point), "entry_point not callable"

    if gym_env_name not in _ShimEnv._entry_points:

        gym.envs.registration.register(id=gym_env_name,

                                       entry_point=_ShimEnv,

                                       max_episode_steps=max_episode_steps,

                                       kwargs={_ShimEnv._KWARG_GYM_NAME: gym_env_name})

    _ShimEnv._entry_points[gym_env_name] = (entry_point, kwargs)

class _LineWorldEnv(gym.Env):

    """Simple environment for fast unittest, registered as 'LineWorld-v0'

        * an agent lives in a finite linear world of uneven elements

        * at each moment it is in a certain position

        * initial position is the middle

        * some positions gain rewards, some don't

        * rewards are between 0 and 15

        * agent can either move left or right

        * objective: maximize total reward = sum(rewards) + sum(steps)

        * Cost per step: -1

        * Done Condition: agent is at pos 0 or total reward <= -20

    """

    @staticmethod

    def register_with_gym():

        """Register this environment with gym and yields the gym environment name."""

        result = "LineWorld-v0"

        register_with_gym(result, _LineWorldEnv)

        return result

    def __init__(self, world: Optional[List[int]] = None):

        """Creates the lineword, size and rewards are given by the world arg.

        Args:

            world: list of rewards to collect in each position of the lineworld.

        """

        if world is None:

            world = [10, 0, 0, 5, 0, 2, 15]

        assert world, "world must not be None or empty."

        self.world: np.array = np.array(world)

        number_of_actions: int = 2

        self.action_space: gym.spaces.Discrete = gym.spaces.Discrete(number_of_actions)

        self.size_of_world: int = len(world)

        self.max_reward: int = max(world)

        self.min_reward: int = min(world)

        # the environment's current state is described by the position of the agent and the remaining rewards

        self.observation_size: int = 1 + self.size_of_world

        low: np.array = np.full(self.observation_size, self.min_reward)

        high: np.array = np.full(self.observation_size, self.max_reward)

        self.observation_space: gym.spaces.Box = gym.spaces.Box(low=low, high=high, dtype=np.float32)

        self.reward_range: Tuple[int, int] = (self.min_reward, self.max_reward)

        self.steps: int = 0

        self.done: bool = False

        self.pos: int = 0

        self._figure = None

        self.reset()

    def get_observation(self):

        return np.append([self.pos], self.remaining_rewards)

    def reset(self):

        self.total_reward = 0

        self.done = False

        self.pos = math.floor(len(self.world) / 2)

        self.steps = 0

        self.remaining_rewards = np.array(self.world, copy=True)

        return self.get_observation()

    def step(self, action):

        """perform action on this lineword.

        Args:

            action: 0 ==> move left, 1 ==> move right

        """

        if isinstance(action, np.ndarray):

            assert action.size == 1, "action of type numpy.array as invalid size"

            action = (int)(action)

        if isinstance(action, np.int32):

            action = (int)(action)

        assert isinstance(action, int)

        if action <= 0 and self.pos > 0:

            self.pos -= 1

        if action > 0 and self.pos < self.size_of_world - 1:

            self.pos += 1

        reward = self.remaining_rewards[self.pos] - 1

        self.total_reward += reward

        self.remaining_rewards[self.pos] = 0

        self.done = (self.pos == 0) or (self.total_reward <= -20)

        self.steps += 1

        observation = self.get_observation()

        info = None

        return observation, reward, self.done, info

    def _render_to_ansi(self):

        return f'position: {self.pos}, remaining rewards: {self.remaining_rewards},' + \

               f'total reward so far: {self.total_reward}, steps so far: {self.steps}, game done: {self.done}'

    def _render_to_figure(self):

        """ Renders the current state as a graph with matplotlib """

        if self._figure is not None:

            plt.close(self._figure)

        self._figure, ax = plt.subplots(1, figsize=(8, 4))

        ax.set_ylim(bottom=-1, top=self.max_reward + 1)

        x = np.arange(0, self.size_of_world, 1, dtype=np.uint8)

        y = self.remaining_rewards

        plt.plot([self.pos, self.pos], [0, 2], 'r^-')

        ax.scatter(x, y, s=75)

        self._figure.canvas.draw()

        return self._figure

    def _render_to_rgb(self):

        """ convert the output of render_to_figure to a rgb_array """

        self._render_to_figure()

        self._figure.canvas.draw()

        buf = self._figure.canvas.tostring_rgb()

        num_cols, num_rows = self._figure.canvas.get_width_height()

        plt.close(self._figure)

        self._figure = None

        result = np.fromstring(buf, dtype=np.uint8).reshape(num_rows, num_cols, 3)

        return result

    def render(self, mode='ansi'):

        if mode == 'ansi':

            return self._render_to_ansi()

        elif mode == 'human':

            return self._render_to_figure()

        elif mode == 'rgb_array':

            return self._render_to_rgb()

        else:

            super().render(mode=mode)

class _ShimEnv(gym.Wrapper):

    """Wrapper to redirect the instantiation of a gym environment to its current implementation.

    """

    _KWARG_GYM_NAME = "shimenv_gym_name"

    _entry_points: Dict[str, Tuple[Callable, Dict]] = {}

    def __init__(self, **kwargs):

        assert _ShimEnv._KWARG_GYM_NAME in kwargs, f'{_ShimEnv._KWARG_GYM_NAME} missing from kwargs'

        self._gym_env_name = kwargs[_ShimEnv._KWARG_GYM_NAME]

        entry_point, gym_kwargs = _ShimEnv._entry_points[self._gym_env_name]

        self._gym_env = entry_point(**gym_kwargs)

        super().__init__(self._gym_env)

    def step(self, action):

        return self.env.step(action)

    def reset(self, **kwargs):

        return self.env.reset(**kwargs)

class _StepCountEnv(gym.core.Env):

    """Debug Env that runs forever, counting the calls to reset and step."""

    metadata = {'render.modes': ['ansi']}

    reward_range = (0, 1)

    max = 10 ** 7

    action_space = gym.spaces.discrete.Discrete(2)

    observation_space = gym.spaces.Box(low=0, high=max, shape=(1,))

    step_count: int = 0

    reset_count: int = 0

    @staticmethod

    def register_with_gym():

        """Register this environment with gym and yields the gym environment name."""

        result = "_StepCountEnv-v0"

        register_with_gym(result, _StepCountEnv)

        return result

    @staticmethod

    def clear():

        _StepCountEnv.step_count = 0

        _StepCountEnv.reset_count = 0

    def __str__(self):

        return f'reset_count={_StepCountEnv.reset_count} step_count={_StepCountEnv.step_count}'

    def step(self, action):

        _StepCountEnv.step_count += 1

        # noinspection PyRedundantParentheses

        return [_StepCountEnv.step_count], 1, False, None

    def reset(self):

        _StepCountEnv.reset_count += 1

        return [_StepCountEnv.step_count]

    def render(self, mode='ansi'):

        return str(self)

Functions

register_with_gym

def register_with_gym(
    gym_env_name: str,
    entry_point: type,
    max_episode_steps: int = 100000,
    **kwargs
)

Registers the class entry_point in gym by the name gym_env_name allowing overriding registrations.

Thus different implementations of the same class (and the same name) maybe registered consecutively. The latest registrated version is used for instantiation. This facilitates developing an environment in a jupyter notebook without haveing to reregister a modified class under a new name.

limitation: the max_episode_steps value of the first registration holds for all registrations with the same gym_env_name

Args: gym_env_name: the gym environment name to be used as argument with gym.make max_episode_steps: all episodes end latest after this number of steps entry_point: the class to be registed with gym id gym_env_name kwargs: the args passed to the entry_point constructor call

View Source
def register_with_gym(gym_env_name: str, entry_point: type, max_episode_steps: int = 100000, **kwargs):

    """Registers the class entry_point in gym by the name gym_env_name allowing overriding registrations.

    Thus different implementations of the same class (and the same name) maybe registered consecutively.

    The latest registrated version is used for instantiation.

    This facilitates developing an environment in a jupyter notebook without haveing to

    reregister a modified class under a new name.

    limitation: the max_episode_steps value of the first registration holds for all registrations

        with the same gym_env_name

    Args:

        gym_env_name: the gym environment name to be used as argument with gym.make

        max_episode_steps: all episodes end latest after this number of steps

        entry_point: the class to be registed with gym id gym_env_name

        kwargs: the args passed to the entry_point constructor call

    """

    assert gym_env_name is not None, "None is not an admissible environment name"

    assert type(gym_env_name) is str, "gym_env_name is not a str"

    assert len(gym_env_name) > 0, "empty string is not an admissible environment name"

    assert inspect.isclass(entry_point), "entry_point not a class"

    assert issubclass(entry_point, gym.Env), "entry_point not a subclass of gym.Env"

    assert callable(entry_point), "entry_point not callable"

    if gym_env_name not in _ShimEnv._entry_points:

        gym.envs.registration.register(id=gym_env_name,

                                       entry_point=_ShimEnv,

                                       max_episode_steps=max_episode_steps,

                                       kwargs={_ShimEnv._KWARG_GYM_NAME: gym_env_name})

    _ShimEnv._entry_points[gym_env_name] = (entry_point, kwargs)