ale-py
Advanced tools
@@ -0,1 +1,3 @@ | ||
| """Python module for interacting with ALE c++ interface and gymnasium wrapper.""" | ||
| import os | ||
@@ -2,0 +4,0 @@ import platform |
@@ -31,2 +31,3 @@ import os | ||
| """ | ||
| @property | ||
@@ -58,2 +59,3 @@ def value(self) -> int: | ||
| """ | ||
| @property | ||
@@ -107,5 +109,5 @@ def value(self) -> int: | ||
| @overload | ||
| def act(self, action: Action) -> int: ... | ||
| def act(self, action: Action, paddle_strength: float = 1.0) -> int: ... | ||
| @overload | ||
| def act(self, action: int) -> int: ... | ||
| def act(self, action: int, paddle_strength: float = 1.0) -> int: ... | ||
| def cloneState(self, *, include_rng: bool = False) -> ALEState: ... | ||
@@ -112,0 +114,0 @@ def cloneSystemState(self) -> ALEState: ... |
+154
-46
@@ -0,5 +1,9 @@ | ||
| """Gymnasium wrapper around the Arcade Learning Environment (ALE).""" | ||
| from __future__ import annotations | ||
| import sys | ||
| from functools import lru_cache | ||
| from typing import Any, Literal | ||
| from warnings import warn | ||
@@ -20,2 +24,4 @@ import ale_py | ||
| class AtariEnvStepMetadata(TypedDict): | ||
| """Step info options.""" | ||
| lives: int | ||
@@ -28,6 +34,3 @@ episode_frame_number: int | ||
| class AtariEnv(gymnasium.Env, utils.EzPickle): | ||
| """ | ||
| (A)rcade (L)earning (Gym) (Env)ironment. | ||
| A Gym wrapper around the Arcade Learning Environment (ALE). | ||
| """ | ||
| """Gymnasium wrapper around the Arcade Learning Environment (ALE).""" | ||
@@ -46,7 +49,9 @@ # FPS can differ per ROM, therefore, dynamically collect the fps once the game is loaded | ||
| full_action_space: bool = False, | ||
| continuous: bool = False, | ||
| continuous_action_threshold: float = 0.5, | ||
| max_num_frames_per_episode: int | None = None, | ||
| render_mode: Literal["human", "rgb_array"] | None = None, | ||
| ): | ||
| """ | ||
| Initialize the ALE for Gymnasium. | ||
| """Initialize the ALE for Gymnasium. | ||
| Default parameters are taken from Machado et al., 2018. | ||
@@ -64,2 +69,4 @@ | ||
| full_action_space: bool => Use full action space? | ||
| continuous: bool => Use continuous actions? | ||
| continuous_action_threshold: float => threshold used for continuous actions. | ||
| max_num_frames_per_episode: int => Max number of frame per epsiode. | ||
@@ -124,2 +131,4 @@ Once `max_num_frames_per_episode` is reached the episode is | ||
| full_action_space=full_action_space, | ||
| continuous=continuous, | ||
| continuous_action_threshold=continuous_action_threshold, | ||
| max_num_frames_per_episode=max_num_frames_per_episode, | ||
@@ -157,10 +166,28 @@ render_mode=render_mode, | ||
| # initialize action space | ||
| # get the set of legal actions | ||
| if continuous and not full_action_space: | ||
| warn( | ||
| "`continuous` is set to `True`, but `full_action_space` is set to `False`. " | ||
| "This will error out when the continuous actions are discretized to illegal action spaces. " | ||
| "Therefore, `full_action_space` has been automatically set to `True`." | ||
| ) | ||
| self._action_set = ( | ||
| self.ale.getLegalActionSet() | ||
| if full_action_space | ||
| if (full_action_space or continuous) | ||
| else self.ale.getMinimalActionSet() | ||
| ) | ||
| self.action_space = spaces.Discrete(len(self._action_set)) | ||
| # action space | ||
| self.continuous = continuous | ||
| self.continuous_action_threshold = continuous_action_threshold | ||
| if continuous: | ||
| # Actions are radius, theta, and fire, where first two are the | ||
| # parameters of polar coordinates. | ||
| self.action_space = spaces.Box( | ||
| np.array([0.0, -np.pi, 0.0]).astype(np.float32), | ||
| np.array([1.0, np.pi, 1.0]).astype(np.float32), | ||
| ) # radius, theta, fire. First two are polar coordinates. | ||
| else: | ||
| self.action_space = spaces.Discrete(len(self._action_set)) | ||
| # initialize observation space | ||
@@ -202,3 +229,3 @@ if self._obs_type == "ram": | ||
| def reset( # pyright: ignore[reportIncompatibleMethodOverride] | ||
| def reset( | ||
| self, | ||
@@ -209,3 +236,3 @@ *, | ||
| ) -> tuple[np.ndarray, AtariEnvStepMetadata]: | ||
| """Resets environment and returns initial observation.""" | ||
| """Resets environment and returns initial episode observation.""" | ||
| super().reset(seed=seed, options=options) | ||
@@ -231,9 +258,10 @@ | ||
| self, | ||
| action: int, | ||
| action: int | np.ndarray, | ||
| ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: | ||
| """ | ||
| Perform one agent step, i.e., repeats `action` frameskip # of steps. | ||
| """Perform one agent step, i.e., repeats `action` frameskip # of steps. | ||
| Args: | ||
| action_ind: int => Action index to execute | ||
| action: int | np.ndarray => | ||
| if `continuous=False` -> action index to execute | ||
| if `continuous=True` -> numpy array of r, theta, fire | ||
@@ -244,4 +272,3 @@ Returns: | ||
| Note: `metadata` contains the keys "lives" and "rgb" if | ||
| render_mode == 'rgb_array'. | ||
| Note: `metadata` contains the keys "lives". | ||
| """ | ||
@@ -257,6 +284,29 @@ # If frameskip is a length 2 tuple then it's stochastic | ||
| # action formatting | ||
| if self.continuous: | ||
| # compute the x, y, fire of the joystick | ||
| assert isinstance(action, np.ndarray) | ||
| x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) | ||
| action_idx = self.map_action_idx( | ||
| left_center_right=( | ||
| -int(x < self.continuous_action_threshold) | ||
| + int(x > self.continuous_action_threshold) | ||
| ), | ||
| down_center_up=( | ||
| -int(y < self.continuous_action_threshold) | ||
| + int(y > self.continuous_action_threshold) | ||
| ), | ||
| fire=(action[-1] > self.continuous_action_threshold), | ||
| ) | ||
| strength = action[0] | ||
| else: | ||
| action_idx = self._action_set[action] | ||
| strength = 1.0 | ||
| # Frameskip | ||
| reward = 0.0 | ||
| for _ in range(frameskip): | ||
| reward += self.ale.act(self._action_set[action]) | ||
| reward += self.ale.act(action_idx, strength) | ||
| is_terminal = self.ale.game_over(with_truncation=False) | ||
@@ -268,11 +318,3 @@ is_truncated = self.ale.game_truncated() | ||
| def render(self) -> np.ndarray | None: | ||
| """ | ||
| Render is not supported by ALE. We use a paradigm similar to | ||
| Gym3 which allows you to specify `render_mode` during construction. | ||
| For example, | ||
| gym.make("ale-py:Pong-v0", render_mode="human") | ||
| will display the ALE and maintain the proper interval to match the | ||
| FPS target set by the ROM. | ||
| """ | ||
| """Renders the ALE with `rgb_array` and `human` options.""" | ||
| if self.render_mode == "rgb_array": | ||
@@ -289,6 +331,3 @@ return self.ale.getScreenRGB() | ||
| def _get_obs(self) -> np.ndarray: | ||
| """ | ||
| Retrieves the current observation. | ||
| This is dependent on `self._obs_type`. | ||
| """ | ||
| """Retrieves the current observation using `obs_type`.""" | ||
| if self._obs_type == "ram": | ||
@@ -301,3 +340,5 @@ return self.ale.getRAM() | ||
| else: | ||
| raise error.Error(f"Unrecognized observation type: {self._obs_type}") | ||
| raise error.Error( | ||
| f"Unrecognized observation type: {self._obs_type}, expected: 'ram', 'rgb' and 'grayscale'." | ||
| ) | ||
@@ -311,6 +352,12 @@ def _get_info(self) -> AtariEnvStepMetadata: | ||
| @lru_cache(1) | ||
| def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: | ||
| """Return keymapping -> actions for human play. | ||
| Up, down, left and right are wasd keys with fire being space. | ||
| No op is 'e' | ||
| Returns: | ||
| Dictionary of key values to actions | ||
| """ | ||
| Return keymapping -> actions for human play. | ||
| """ | ||
| UP = ord("w") | ||
@@ -321,5 +368,6 @@ LEFT = ord("a") | ||
| FIRE = ord(" ") | ||
| NOOP = ord("e") | ||
| mapping = { | ||
| ale_py.Action.NOOP: (None,), | ||
| ale_py.Action.NOOP: (NOOP,), | ||
| ale_py.Action.UP: (UP,), | ||
@@ -348,13 +396,66 @@ ale_py.Action.FIRE: (FIRE,), | ||
| # | ||
| return dict( | ||
| zip( | ||
| map(lambda action: tuple(sorted(mapping[action])), self._action_set), | ||
| range(len(self._action_set)), | ||
| return { | ||
| tuple(sorted(mapping[act_idx])): act_idx for act_idx in self._action_set | ||
| } | ||
| @lru_cache(18) | ||
| def map_action_idx( | ||
| self, left_center_right: int, down_center_up: int, fire: bool | ||
| ) -> int: | ||
| """Return an action idx given unit actions for underlying env.""" | ||
| # no op and fire | ||
| if left_center_right == 0 and down_center_up == 0 and not fire: | ||
| return ale_py.Action.NOOP | ||
| elif left_center_right == 0 and down_center_up == 0 and fire: | ||
| return ale_py.Action.FIRE | ||
| # cardinal no fire | ||
| elif left_center_right == -1 and down_center_up == 0 and not fire: | ||
| return ale_py.Action.LEFT | ||
| elif left_center_right == 1 and down_center_up == 0 and not fire: | ||
| return ale_py.Action.RIGHT | ||
| elif left_center_right == 0 and down_center_up == -1 and not fire: | ||
| return ale_py.Action.DOWN | ||
| elif left_center_right == 0 and down_center_up == 1 and not fire: | ||
| return ale_py.Action.UP | ||
| # cardinal fire | ||
| if left_center_right == -1 and down_center_up == 0 and fire: | ||
| return ale_py.Action.LEFTFIRE | ||
| elif left_center_right == 1 and down_center_up == 0 and fire: | ||
| return ale_py.Action.RIGHTFIRE | ||
| elif left_center_right == 0 and down_center_up == -1 and fire: | ||
| return ale_py.Action.DOWNFIRE | ||
| elif left_center_right == 0 and down_center_up == 1 and fire: | ||
| return ale_py.Action.UPFIRE | ||
| # diagonal no fire | ||
| elif left_center_right == -1 and down_center_up == -1 and not fire: | ||
| return ale_py.Action.DOWNLEFT | ||
| elif left_center_right == 1 and down_center_up == -1 and not fire: | ||
| return ale_py.Action.DOWNRIGHT | ||
| elif left_center_right == -1 and down_center_up == 1 and not fire: | ||
| return ale_py.Action.UPLEFT | ||
| elif left_center_right == 1 and down_center_up == 1 and not fire: | ||
| return ale_py.Action.UPRIGHT | ||
| # diagonal fire | ||
| elif left_center_right == -1 and down_center_up == -1 and fire: | ||
| return ale_py.Action.DOWNLEFTFIRE | ||
| elif left_center_right == 1 and down_center_up == -1 and fire: | ||
| return ale_py.Action.DOWNRIGHTFIRE | ||
| elif left_center_right == -1 and down_center_up == 1 and fire: | ||
| return ale_py.Action.UPLEFTFIRE | ||
| elif left_center_right == 1 and down_center_up == 1 and fire: | ||
| return ale_py.Action.UPRIGHTFIRE | ||
| # just in case | ||
| else: | ||
| raise LookupError( | ||
| "Unexpected action mapping, expected `left_center_right` and `down_center_up` to be in {-1, 0, 1} and `fire` to only be `True` or `False`. " | ||
| f"Received {left_center_right=}, {down_center_up=} and {fire=}." | ||
| ) | ||
| ) | ||
| def get_action_meanings(self) -> list[str]: | ||
| """ | ||
| Return the meaning of each integer action. | ||
| """ | ||
| """Return the meaning of each action.""" | ||
| keys = ale_py.Action.__members__.values() | ||
@@ -366,5 +467,12 @@ values = ale_py.Action.__members__.keys() | ||
| def clone_state(self, include_rng: bool = False) -> ale_py.ALEState: | ||
| """Clone emulator state w/o system state. Restoring this state will | ||
| *not* give an identical environment. For complete cloning and restoring | ||
| of the full state, see `{clone,restore}_full_state()`.""" | ||
| """Clone emulator state. | ||
| To reproduce identical states, specify `include_rng` to `True`. | ||
| Args: | ||
| include_rng: If to include the system RNG within the state | ||
| Returns: | ||
| The cloned ALE state | ||
| """ | ||
| return self.ale.cloneState(include_rng=include_rng) | ||
@@ -371,0 +479,0 @@ |
@@ -0,1 +1,3 @@ | ||
| """Registration for Atari environments.""" | ||
| from __future__ import annotations | ||
@@ -11,2 +13,4 @@ | ||
| class EnvFlavour(NamedTuple): | ||
| """Environment flavour for env id suffix and kwargs.""" | ||
| suffix: str | ||
@@ -17,2 +21,4 @@ kwargs: Mapping[str, Any] | Callable[[str], Mapping[str, Any]] | ||
| class EnvConfig(NamedTuple): | ||
| """Environment config for version, kwargs and flavours.""" | ||
| version: str | ||
@@ -24,10 +30,5 @@ kwargs: Mapping[str, Any] | ||
| def _rom_id_to_name(rom: str) -> str: | ||
| """ | ||
| Let the ROM ID be the ROM identifier in snake_case. | ||
| For example, `space_invaders` | ||
| The ROM name is the ROM ID in pascalcase. | ||
| For example, `SpaceInvaders` | ||
| """Converts the Rom ID (snake_case) to ROM name in PascalCase. | ||
| This function converts the ROM ID to the ROM name. | ||
| i.e., snakecase -> pascalcase | ||
| For example, `space_invaders` to `SpaceInvaders` | ||
| """ | ||
@@ -78,2 +79,3 @@ return rom.title().replace("_", "") | ||
| def register_v0_v4_envs(): | ||
| """Registers all v0 and v4 environments.""" | ||
| legacy_games = [ | ||
@@ -184,2 +186,3 @@ "adventure", | ||
| def register_v5_envs(): | ||
| """Register all v5 environments.""" | ||
| all_games = roms.get_all_rom_ids() | ||
@@ -186,0 +189,0 @@ obs_types = ["rgb", "ram"] |
@@ -0,1 +1,3 @@ | ||
| """Rom module with functions for collecting individual and all ROMS files.""" | ||
| from __future__ import annotations | ||
@@ -2,0 +4,0 @@ |
+3
-3
| Metadata-Version: 2.1 | ||
| Name: ale-py | ||
| Version: 0.9.1 | ||
| Version: 0.10.0 | ||
| Summary: The Arcade Learning Environment (ALE) - a platform for AI research. | ||
@@ -26,3 +26,3 @@ Author: Marc G. Bellemare, Yavar Naddaf, Joel Veness, Michael Bowling | ||
| License-File: LICENSE.md | ||
| Requires-Dist: numpy | ||
| Requires-Dist: numpy <2.0 | ||
| Requires-Dist: importlib-metadata >=4.10.0 ; python_version < "3.10" | ||
@@ -37,3 +37,3 @@ Requires-Dist: typing-extensions ; python_version < "3.11" | ||
| <a href="#the-arcade-learning-environment"> | ||
| <img alt="Arcade Learning Environment" align="right" width=75 src="https://raw.githubusercontent.com/Farama-Foundation/Arcade-Learning-Environment/master/docs/static/ale.svg" /> | ||
| <img alt="Arcade Learning Environment" align="right" width=75 src="https://github.com/Farama-Foundation/Arcade-Learning-Environment/blob/master/docs/_static/img/ale.svg" /> | ||
| </a> | ||
@@ -40,0 +40,0 @@ =============================== |
+11
-11
@@ -1,6 +0,6 @@ | ||
| ale_py/registration.py,sha256=ZDsxO9E3geLnMJqCVuQE20C_0WTyyuoLoimELNCcGKc,5680 | ||
| ale_py/env.py,sha256=dy8xt9LaIV9y8rlNbvWQEQbgpp94oh3-2VwkVsOwib4,13603 | ||
| ale_py/__init__.pyi,sha256=NhS3zU9f3PULxCmT0HaBzCbF5kO7jmiX0kUO7k2A0DI,6711 | ||
| ale_py/_ale_py.cpython-310-darwin.so,sha256=kmAroj2txNEZ9he2fzvU-TAXFuuynW5HV13pNgT8Shg,2112264 | ||
| ale_py/__init__.py,sha256=JrUg_5FJoyEsV-I3gN_t-dHOnWnmFNoo8E4jg10VFxk,2028 | ||
| ale_py/registration.py,sha256=dQZET9V89CzTLhy-YM_D7aZzkYAtpg5wm2JD1iIKMAw,5779 | ||
| ale_py/env.py,sha256=nCLwE8svW58exSGkjCiAgK1vsfmWVB5Xd9nClgW4sO8,18386 | ||
| ale_py/__init__.pyi,sha256=5F6AAGzOv82CFCOn9F5nmgEcavo_Qk6F2THpps9irj4,6773 | ||
| ale_py/_ale_py.cpython-310-darwin.so,sha256=GPhmpyU9M9xh7eLxddipomUv9T-b0p7pljMqBXjYok0,2113848 | ||
| ale_py/__init__.py,sha256=46VRo5hRwmIRnvv-wY6CeTf7j3QMq7r-cJCpAizQAZA,2111 | ||
| ale_py/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 | ||
@@ -54,3 +54,3 @@ ale_py/libSDL2-2.0.dylib,sha256=KPp35ygdbwzrvaWj-F68-ceb2dNXhIL1dExeFmCKEIY,1489704 | ||
| ale_py/roms/othello.bin,sha256=Tj2eWSwPVRM3eUQAH3jye7oUyG4lgr_QvkxvHbo6xs8,2048 | ||
| ale_py/roms/__init__.py,sha256=bxjR-rdRbp5fyU6xmnyewPR96WbLo5ewrmu0fyr5Ayk,2102 | ||
| ale_py/roms/__init__.py,sha256=NlSx4chmhvCuXnsd_ljP5UDUUBh6LBTEpRTbjWMrEDM,2181 | ||
| ale_py/roms/atlantis.bin,sha256=zvS3uR6iRfeCdgShxiClyn-UKIwsa1TZCpMd6O61In0,4096 | ||
@@ -119,6 +119,6 @@ ale_py/roms/tic_tac_toe_3d.bin,sha256=EaSGzMNtgdQ12X1QMxzrvn1i48_N1YL-1uXFt8xeBgc,2048 | ||
| ale_py/roms/superman.bin,sha256=d9I_JpEhZmJyjrYwugk2OE6pSoHUEEdZLSf-4SIQsjM,4096 | ||
| ale_py-0.9.1.dist-info/LICENSE.md,sha256=MVVce6oSrBFQvJpI-22FtV2c6hM2rjr6k9ZT9o_km7I,17879 | ||
| ale_py-0.9.1.dist-info/RECORD,, | ||
| ale_py-0.9.1.dist-info/WHEEL,sha256=iRoY_TYTYgfo7UDBnTsHzZnvy1iAOsBao9G-bA7mwCk,111 | ||
| ale_py-0.9.1.dist-info/top_level.txt,sha256=CjHTcYlCUfmSaCGbLLpKEKIY5hZteuawsHZy3TWVyNk,7 | ||
| ale_py-0.9.1.dist-info/METADATA,sha256=v3yxzhuHZqzIRC3L9EG3ftydBjSQvbQhXaVownSrCWo,7592 | ||
| ale_py-0.10.0.dist-info/LICENSE.md,sha256=MVVce6oSrBFQvJpI-22FtV2c6hM2rjr6k9ZT9o_km7I,17879 | ||
| ale_py-0.10.0.dist-info/RECORD,, | ||
| ale_py-0.10.0.dist-info/WHEEL,sha256=TJezrjTCs8ObXv15Jt93b2-wgiyMkFb14ApgSyz8vFs,111 | ||
| ale_py-0.10.0.dist-info/top_level.txt,sha256=CjHTcYlCUfmSaCGbLLpKEKIY5hZteuawsHZy3TWVyNk,7 | ||
| ale_py-0.10.0.dist-info/METADATA,sha256=n5sZXl9SpnrR9bbEiEcSrIuCHF2b4k4WAWzFPCiAQbA,7593 |
+1
-1
| Wheel-Version: 1.0 | ||
| Generator: setuptools (72.1.0) | ||
| Generator: setuptools (75.1.0) | ||
| Root-Is-Purelib: false | ||
| Tag: cp310-cp310-macosx_10_15_x86_64 | ||
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.