Skip to content

Commit

Permalink
[wip,broken] Accelerate core env
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 19, 2025
1 parent 1c845e8 commit 20cbbce
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 1,004 deletions.
16 changes: 12 additions & 4 deletions benchmarks/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@

multi_drone_env_setup_code = """
import gymnasium
import jax
import lsy_drone_racing
env = gymnasium.make('MultiDroneRacing-v0',
n_envs=1000, # TODO: Remove this for single-world envs
n_envs=1, # TODO: Remove this for single-world envs
n_drones=config.env.n_drones,
freq=config.env.freq,
sim_config=config.sim,
Expand All @@ -87,11 +88,18 @@
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
device='gpu',
device='cpu',
)
env.reset()
env.step(env.action_space.sample()) # JIT compile
env.reset()
# JIT step
env.step(env.action_space.sample())
jax.block_until_ready(env.unwrapped.data)
# JIT masked reset (used in autoreset)
mask = env.unwrapped.data.marked_for_reset
mask = mask.at[0].set(True)
env.unwrapped.reset(mask=mask)
jax.block_until_ready(env.unwrapped.data)
env.action_space.seed(2)
"""

Expand Down
Loading

0 comments on commit 20cbbce

Please sign in to comment.