Skip to content

Commit

Permalink
Improve benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 17, 2025
1 parent ee8efea commit 9205213
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 32 deletions.
28 changes: 17 additions & 11 deletions benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,26 @@ def print_benchmark_results(name: str, timings: list[float]):
print(f"FPS: {1 / np.mean(timings):.2f}")


def main(n_tests: int = 10, sim_steps: int = 10, multi_drone: bool = False):
def main(
n_tests: int = 2,
number: int = 100,
multi_drone: bool = False,
reset: bool = True,
step: bool = True,
):
reset_fn, step_fn = time_sim_reset, time_sim_step
if multi_drone:
reset_fn, step_fn = time_multi_drone_reset, time_multi_drone_step
timings = reset_fn(n_tests=n_tests)
print_benchmark_results(name="Sim reset", timings=timings)
timings = step_fn(n_tests=n_tests, sim_steps=sim_steps)
print_benchmark_results(name="Sim steps", timings=timings / sim_steps)
timings = step_fn(n_tests=n_tests, sim_steps=sim_steps, physics_mode="sys_id")
print_benchmark_results(name="Sim steps (sys_id backend)", timings=timings / sim_steps)
timings = step_fn(n_tests=n_tests, sim_steps=sim_steps, physics_mode="mujoco")
print_benchmark_results(name="Sim steps (mujoco backend)", timings=timings / sim_steps)
timings = step_fn(n_tests=n_tests, sim_steps=sim_steps, physics_mode="sys_id")
print_benchmark_results(name="Sim steps (sys_id backend)", timings=timings / sim_steps)
if reset:
timings = reset_fn(n_tests=n_tests, number=number)
print_benchmark_results(name="Sim reset", timings=timings / number)
if step:
timings = step_fn(n_tests=n_tests, number=number)
print_benchmark_results(name="Sim steps", timings=timings / number)
timings = step_fn(n_tests=n_tests, number=number, physics_mode="sys_id")
print_benchmark_results(name="Sim steps (sys_id backend)", timings=timings / number)
# timings = step_fn(n_tests=n_tests, number=number, physics_mode="mujoco")
# print_benchmark_results(name="Sim steps (mujoco backend)", timings=timings / number)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def main():
config = load_config(Path(__file__).parents[1] / "config/level0.toml")
config = load_config(Path(__file__).parents[1] / "config/level3.toml")
env = gymnasium.make("DroneRacing-v0", config=config)
env.reset()
for _ in range(1_000):
Expand Down
39 changes: 19 additions & 20 deletions benchmarks/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
env.reset()
env.step(env.action_space.sample()) # JIT compile
env.reset()
env.action_space.seed(42)
action = env.action_space.sample()
"""

attitude_env_setup_code = """
Expand All @@ -57,14 +59,16 @@
env.reset()
env.step(env.action_space.sample()) # JIT compile
env.reset()
env.action_space.seed(42)
action = env.action_space.sample()
"""

load_multi_drone_config_code = f"""
from pathlib import Path
from lsy_drone_racing.utils import load_config
config = load_config(Path('{Path(__file__).parents[1] / "config/multi_level0.toml"}'))
config = load_config(Path('{Path(__file__).parents[1] / "config/multi_level3.toml"}'))
"""

multi_drone_env_setup_code = """
Expand All @@ -87,42 +91,39 @@
env.reset()
env.step(env.action_space.sample()) # JIT compile
env.reset()
env.action_space.seed(2)
"""


def time_sim_reset(n_tests: int = 10) -> NDArray[np.floating]:
def time_sim_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
setup = load_config_code + env_setup_code
stmt = """env.reset()"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=1, repeat=n_tests))
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_sim_step(
n_tests: int = 10, sim_steps: int = 100, physics_mode: str = "analytical"
n_tests: int = 10, number: int = 1, physics_mode: str = "analytical"
) -> NDArray[np.floating]:
modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
setup = load_config_code + modify_config_code + env_setup_code + "\nenv.reset()"
stmt = f"""
for _ in range({sim_steps}):
env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=1, repeat=n_tests))
stmt = """env.step(action)"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_sim_attitude_step(n_tests: int = 10, sim_steps: int = 100) -> NDArray[np.floating]:
def time_sim_attitude_step(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
setup = load_config_code + attitude_env_setup_code + "\nenv.reset()"
stmt = f"""
for _ in range({sim_steps}):
env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=1, repeat=n_tests))
stmt = """env.step(action)"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_multi_drone_reset(n_tests: int = 10) -> NDArray[np.floating]:
def time_multi_drone_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
setup = load_multi_drone_config_code + multi_drone_env_setup_code + "\nenv.reset()"
stmt = """env.reset()"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=1, repeat=n_tests))
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))


def time_multi_drone_step(
n_tests: int = 10, sim_steps: int = 100, physics_mode: str = "analytical"
n_tests: int = 10, number: int = 100, physics_mode: str = "analytical"
) -> NDArray[np.floating]:
modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
setup = (
Expand All @@ -131,7 +132,5 @@ def time_multi_drone_step(
+ multi_drone_env_setup_code
+ "\nenv.reset()"
)
stmt = f"""
for _ in range({sim_steps}):
env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=1, repeat=n_tests))
stmt = """env.step(env.action_space.sample())"""
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))

0 comments on commit 9205213

Please sign in to comment.