Skip to content

Commit

Permalink
Fix name collisions between core env and gym envs. Improve benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 22, 2025
1 parent 53d8b36 commit 86f40de
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 55 deletions.
31 changes: 22 additions & 9 deletions benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from sim import time_multi_drone_reset, time_multi_drone_step, time_sim_reset, time_sim_step


def print_benchmark_results(name: str, timings: list[float]):
print(f"\nResults for {name}:")
def print_benchmark_results(name: str, timings: list[float], n_envs: int, device: str):
print(f"\nResults for {name} ({n_envs} envs, {device}):")
print(f"Mean/std: {np.mean(timings):.2e}s +- {np.std(timings):.2e}s")
print(f"Min time: {np.min(timings):.2e}s")
print(f"Max time: {np.max(timings):.2e}s")
print(f"FPS: {1 / np.mean(timings):.2f}")
print(f"FPS: {n_envs / np.mean(timings):.2f}")


def main(
Expand All @@ -19,18 +19,31 @@ def main(
multi_drone: bool = False,
reset: bool = True,
step: bool = True,
vec_size: int = 1,
device: str = "cpu",
):
reset_fn, step_fn = time_sim_reset, time_sim_step
if multi_drone:
reset_fn, step_fn = time_multi_drone_reset, time_multi_drone_step
if reset:
timings = reset_fn(n_tests=n_tests, number=number)
print_benchmark_results(name="Racing env reset", timings=timings / number)
timings = reset_fn(n_tests=n_tests, number=number, n_envs=vec_size, device=device)
print_benchmark_results(
name="Racing env reset", timings=timings / number, n_envs=vec_size, device=device
)
if step:
timings = step_fn(n_tests=n_tests, number=number)
print_benchmark_results(name="Racing env steps", timings=timings / number)
timings = step_fn(n_tests=n_tests, number=number, physics_mode="sys_id")
print_benchmark_results(name="Racing env steps (sys_id backend)", timings=timings / number)
timings = step_fn(n_tests=n_tests, number=number, n_envs=vec_size, device=device)
print_benchmark_results(
name="Racing env steps", timings=timings / number, n_envs=vec_size, device=device
)
timings = step_fn(
n_tests=n_tests, number=number, physics_mode="sys_id", n_envs=vec_size, device=device
)
print_benchmark_results(
name="Racing env steps (sys_id backend)",
timings=timings / number,
n_envs=vec_size,
device=device,
)
# timings = step_fn(n_tests=n_tests, number=number, physics_mode="mujoco")
# print_benchmark_results(name="Sim steps (mujoco backend)", timings=timings / number)

Expand Down
99 changes: 66 additions & 33 deletions benchmarks/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

env_setup_code = """
import gymnasium
import jax
import lsy_drone_racing
env = gymnasium.make(
env = gymnasium.make_vec(
config.env.id,
num_envs={num_envs},
freq=config.env.freq,
sim_config=config.sim,
sensor_range=config.env.sensor_range,
Expand All @@ -32,21 +34,29 @@
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
device='{device}',
)
# JIT compile the reset and step functions
env.reset()
env.step(env.action_space.sample()) # JIT compile
env.reset()
env.action_space.seed(42)
action = env.action_space.sample()
env.step(env.action_space.sample())
jax.block_until_ready(env.unwrapped.data)
# JIT masked reset (used in autoreset)
mask = env.unwrapped.data.marked_for_reset
mask = mask.at[0].set(True)
env.unwrapped._reset(mask=mask) # enforce masked reset compile
jax.block_until_ready(env.unwrapped.data)
env.action_space.seed(2)
"""

attitude_env_setup_code = """
import gymnasium
import jax
import lsy_drone_racing
env = gymnasium.make('DroneRacingAttitude-v0',
config.env.id,
env = gymnasium.make_vec('DroneRacingAttitude-v0',
num_envs={num_envs},
freq=config.env.freq,
sim_config=config.sim,
sensor_range=config.env.sensor_range,
Expand All @@ -55,12 +65,19 @@
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
device='{device}',
)
# JIT compile the reset and step functions
env.reset()
env.step(env.action_space.sample()) # JIT compile
env.reset()
env.action_space.seed(42)
action = env.action_space.sample()
env.step(env.action_space.sample())
jax.block_until_ready(env.unwrapped.data)
# JIT masked reset (used in autoreset)
mask = env.unwrapped.data.marked_for_reset
mask = mask.at[0].set(True)
env.unwrapped._reset(mask=mask) # enforce masked reset compile
jax.block_until_ready(env.unwrapped.data)
env.action_space.seed(2)
"""

load_multi_drone_config_code = f"""
Expand All @@ -76,9 +93,11 @@
import jax
import lsy_drone_racing
from lsy_drone_racing.envs.multi_drone_race import MultiDroneRaceEnv
from lsy_drone_racing.envs.multi_drone_race import VecMultiDroneRaceEnv
env = gymnasium.make('MultiDroneRacing-v0',
env = gymnasium.make_vec('MultiDroneRacing-v0',
num_envs={num_envs},
n_drones=config.env.n_drones,
freq=config.env.freq,
sim_config=config.sim,
Expand All @@ -88,58 +107,72 @@
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
device='cpu',
device='{device}',
)
# JIT compile the reset and step functions
env.reset()
# JIT step
env.step(env.action_space.sample())
jax.block_until_ready(env.unwrapped.data)
# JIT masked reset (used in autoreset)
mask = env.unwrapped.data.marked_for_reset
mask = mask.at[0].set(True)
super(MultiDroneRaceEnv, env.unwrapped).reset(mask=mask) # enforce masked reset compile
env.unwrapped._reset(mask=mask) # enforce masked reset compile
jax.block_until_ready(env.unwrapped.data)
env.action_space.seed(2)
"""


def time_sim_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
setup = load_config_code + env_setup_code
def time_sim_reset(
n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu"
) -> NDArray[np.floating]:
setup = load_config_code + env_setup_code.format(num_envs=n_envs, device=device)
stmt = """env.reset()"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_sim_step(
n_tests: int = 10, number: int = 1, physics_mode: str = "analytical"
n_tests: int = 10,
number: int = 1,
physics_mode: str = "analytical",
n_envs: int = 1,
device: str = "cpu",
) -> NDArray[np.floating]:
modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
setup = load_config_code + modify_config_code + env_setup_code + "\nenv.reset()"
stmt = """env.step(action)"""
_env_setup_code = env_setup_code.format(num_envs=n_envs, device=device)
setup = load_config_code + modify_config_code + _env_setup_code + "\nenv.reset()"
stmt = """env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_sim_attitude_step(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
setup = load_config_code + attitude_env_setup_code + "\nenv.reset()"
stmt = """env.step(action)"""
def time_sim_attitude_step(
n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu"
) -> NDArray[np.floating]:
env_setup_code = attitude_env_setup_code.format(num_envs=n_envs, device=device)
setup = load_config_code + env_setup_code + "\nenv.reset()"
stmt = """env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_multi_drone_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
setup = load_multi_drone_config_code + multi_drone_env_setup_code + "\nenv.reset()"
def time_multi_drone_reset(
n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu"
) -> NDArray[np.floating]:
env_setup_code = multi_drone_env_setup_code.format(num_envs=n_envs, device=device)
setup = load_multi_drone_config_code + env_setup_code + "\nenv.reset()"
stmt = """env.reset()"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_multi_drone_step(
n_tests: int = 10, number: int = 100, physics_mode: str = "analytical"
n_tests: int = 10,
number: int = 100,
physics_mode: str = "analytical",
n_envs: int = 1,
device: str = "cpu",
) -> NDArray[np.floating]:
modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
setup = (
load_multi_drone_config_code
+ modify_config_code
+ multi_drone_env_setup_code
+ "\nenv.reset()"
)
env_setup_code = multi_drone_env_setup_code.format(num_envs=n_envs, device=device)

setup = load_multi_drone_config_code + modify_config_code + env_setup_code + "\nenv.reset()"
stmt = """env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
8 changes: 4 additions & 4 deletions lsy_drone_racing/envs/drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
Returns:
The initial observation and info.
"""
obs, info = super().reset(seed=seed, options=options)
obs, info = self._reset(seed=seed, options=options)
obs = {k: v[0, 0] for k, v in obs.items()}
info = {k: v[0, 0] for k, v in info.items()}
return obs, info
Expand All @@ -92,7 +92,7 @@ def step(self, action: NDArray[np.floating]) -> tuple[dict, float, bool, bool, d
Returns:
Observation, reward, terminated, truncated, and info.
"""
obs, reward, terminated, truncated, info = super().step(action)
obs, reward, terminated, truncated, info = self._step(action)
obs = {k: v[0, 0] for k, v in obs.items()}
info = {k: v[0, 0] for k, v in info.items()}
return obs, float(reward[0, 0]), bool(terminated[0, 0]), bool(truncated[0, 0]), info
Expand Down Expand Up @@ -163,7 +163,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
Returns:
The initial observation and info.
"""
obs, info = super().reset(seed=seed, options=options)
obs, info = self._reset(seed=seed, options=options)
obs = {k: v[:, 0] for k, v in obs.items()}
info = {k: v[:, 0] for k, v in info.items()}
return obs, info
Expand All @@ -179,7 +179,7 @@ def step(
Returns:
Observation, reward, terminated, truncated, and info.
"""
obs, reward, terminated, truncated, info = super().step(action)
obs, reward, terminated, truncated, info = self._step(action)
obs = {k: v[:, 0] for k, v in obs.items()}
info = {k: v[:, 0] for k, v in info.items()}
return obs, reward[:, 0], terminated[:, 0], truncated[:, 0], info
26 changes: 24 additions & 2 deletions lsy_drone_racing/envs/multi_drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
Returns:
Observation and info for all drones.
"""
obs, info = super().reset(seed=seed, options=options)
obs, info = self._reset(seed=seed, options=options)
obs = {k: v[0] for k, v in obs.items()}
info = {k: v[0] for k, v in info.items()}
return obs, info
Expand All @@ -98,7 +98,7 @@ def step(
Returns:
Observation, reward, terminated, truncated, and info for all drones.
"""
obs, reward, terminated, truncated, info = super().step(action)
obs, reward, terminated, truncated, info = self._step(action)
obs = {k: v[0] for k, v in obs.items()}
info = {k: v[0] for k, v in info.items()}
return obs, reward[0], terminated[0], truncated[0], info
Expand Down Expand Up @@ -165,3 +165,25 @@ def __init__(
build_observation_space(n_gates, n_obstacles), n_drones
)
self.observation_space = batch_space(batch_space(self.single_observation_space), num_envs)

def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
"""Reset the environment for all drones.
Args:
seed: Random seed.
options: Additional reset options. Not used.
Returns:
Observation and info for all drones.
"""
return self._reset(seed=seed, options=options)

def step(
self, action: NDArray[np.floating]
) -> tuple[dict, NDArray[np.floating], NDArray[np.bool_], NDArray[np.bool_], dict]:
"""Step the environment for all drones.
Args:
action: Action for all drones, i.e., a batch of (n_drones, action_dim) arrays.
"""
return self._step(action)
16 changes: 9 additions & 7 deletions lsy_drone_racing/envs/race_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(
self.device,
)

def reset(
def _reset(
self, *, seed: int | None = None, options: dict | None = None, mask: Array | None = None
) -> tuple[dict[str, NDArray[np.floating]], dict]:
"""Reset the environment.
Expand All @@ -292,10 +292,10 @@ def reset(
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
# the sim.reset_hook function, so we don't need to explicitly do it here
self.sim.reset(mask=mask)
self.data = self._reset(self.data, self.sim.data.states.pos, mask)
self.data = self._reset_env_data(self.data, self.sim.data.states.pos, mask)
return self.obs(), self.info()

def step(
def _step(
self, action: NDArray[np.floating]
) -> tuple[dict[str, NDArray[np.floating]], float, bool, bool, dict]:
"""Step the firmware_wrapper class and its environment.
Expand All @@ -321,10 +321,12 @@ def step(
# previous flags, not the ones from the current step
marked_for_reset = self.data.marked_for_reset
# Apply the environment logic with updated simulation data.
self.data = self._step(self.data, drone_pos, drone_quat, mocap_pos, mocap_quat, contacts)
self.data = self._step_env(
self.data, drone_pos, drone_quat, mocap_pos, mocap_quat, contacts
)
# Auto-reset envs. Add configuration option to disable for single-world envs
if self.autoreset and marked_for_reset.any():
self.reset(mask=marked_for_reset)
self._reset(mask=marked_for_reset)
return self.obs(), self.reward(), self.terminated(), self.truncated(), self.info()

def apply_action(self, action: NDArray[np.floating]):
Expand Down Expand Up @@ -422,7 +424,7 @@ def symbolic_model(self) -> SymbolicModel:

@staticmethod
@jax.jit
def _reset(data: EnvData, drone_pos: Array, mask: Array | None = None) -> EnvData:
def _reset_env_data(data: EnvData, drone_pos: Array, mask: Array | None = None) -> EnvData:
"""Reset auxiliary variables of the environment data."""
mask = jp.ones(data.steps.shape, dtype=bool) if mask is None else mask
target_gate = jp.where(mask[..., None], 0, data.target_gate)
Expand All @@ -443,7 +445,7 @@ def _reset(data: EnvData, drone_pos: Array, mask: Array | None = None) -> EnvDat

@staticmethod
@jax.jit
def _step(
def _step_env(
data: EnvData,
drone_pos: Array,
drone_quat: Array,
Expand Down

0 comments on commit 86f40de

Please sign in to comment.