Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax
brings the power of jit
and vmap
/pmap
to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax
allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax
). We provide training & checkpoints for both PPO & ES in gymnax-blines
. Get started here 👉 .
* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit
compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines
documentation.
If you want to get the most recent commit, please install directly from the repository:
-
Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit
, vmap
, pmap
):
jit_step = jax.jit(env.step)
reset_rng = jax.vmap(env.reset, in_axes=(0, None))
step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))
reset_params = jax.vmap(env.reset, in_axes=(None, 0))
step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
For speed comparisons with standard vectorized NumPy environments check out gymnax-blines
.
-
Scan through entire episode rollouts: You can also lax.scan
through entire reset
, step
episode loops for fast compilation:
def rollout(rng_input, policy_params, env_params, steps_in_episode):
"""Rollout a jitted gymnax episode with lax.scan."""
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = env.reset(rng_reset, env_params)
def policy_step(state_input, tmp):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng = state_input
rng, rng_step, rng_net = jax.random.split(rng, 3)
action = model.apply(policy_params, obs)
next_obs, next_state, reward, done, _ = env.step(
rng_step, state, action, env_params
)
carry = [next_obs, next_state, policy_params, rng]
return carry, [obs, action, reward, next_obs, done]
_, scan_out = jax.lax.scan(
policy_step,
[obs, state, policy_params, rng_episode],
(),
steps_in_episode
)
obs, action, reward, next_obs, done = scan_out
return obs, action, reward, next_obs, done
-
Build-in visualization tools: You can also smoothly generate GIF animations using the Visualizer
tool, which covers all classic_control
, MinAtar
and most misc
environments:
from gymnax.visualize import Visualizer
state_seq, reward_seq = [], []
rng, rng_reset = jax.random.split(rng)
obs, env_state = env.reset(rng_reset, env_params)
while True:
state_seq.append(env_state)
rng, rng_act, rng_step = jax.random.split(rng, 3)
action = env.action_space(env_params).sample(rng_act)
next_obs, next_env_state, reward, done, info = env.step(
rng_step, env_state, action, env_params
)
reward_seq.append(reward)
if done:
break
else:
obs = next_obs
env_state = next_env_state
cum_rewards = jnp.cumsum(jnp.array(reward_seq))
vis = Visualizer(env, env_params, state_seq, cum_rewards)
vis.animate(f"docs/anim.gif")
-
Training pipelines & pretrained agents: Check out gymnax-blines
for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.
-
Simple batch agent evaluation: Work-in-progress.
from gymnax.experimental import RolloutWrapper
manager = RolloutWrapper(model.apply, env_name="Pendulum-v1")
obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
rng_batch = jax.random.split(rng, 10)
obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
rng_batch, policy_params
)
batch_params = jax.tree_map(
lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params
)
obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
rng_batch, batch_params
)