Skip to content

Commit

Permalink
Implement vectorized, multi-drone core race env.
Browse files Browse the repository at this point in the history
Add specialized environments based off of RaceCoreEnv. Switch from rpy_rates to ang_vel
  • Loading branch information
amacati committed Jan 20, 2025
1 parent 20cbbce commit 493e8f0
Show file tree
Hide file tree
Showing 15 changed files with 368 additions and 1,466 deletions.
4 changes: 2 additions & 2 deletions benchmarks/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@
import jax
import lsy_drone_racing
from lsy_drone_racing.envs.multi_drone_race import MultiDroneRaceEnv
env = gymnasium.make('MultiDroneRacing-v0',
n_envs=1, # TODO: Remove this for single-world envs
n_drones=config.env.n_drones,
freq=config.env.freq,
sim_config=config.sim,
Expand All @@ -98,7 +98,7 @@
# JIT masked reset (used in autoreset)
mask = env.unwrapped.data.marked_for_reset
mask = mask.at[0].set(True)
env.unwrapped.reset(mask=mask)
super(MultiDroneRaceEnv, env.unwrapped).reset(mask=mask) # enforce masked reset compile
jax.block_until_ready(env.unwrapped.data)
env.action_space.seed(2)
"""
Expand Down
2 changes: 1 addition & 1 deletion config/level0.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ pos = [0.0, 1.0, 1.4]
pos = [1.0, 1.0, 0.05]
rpy = [0, 0, 0]
vel = [0, 0, 0]
rpy_rates = [0, 0, 0]
ang_vel = [0, 0, 0]
2 changes: 1 addition & 1 deletion config/level1.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pos = [0.0, 1.0, 1.4]
pos = [1.0, 1.0, 0.05]
rpy = [0, 0, 0]
vel = [0, 0, 0]
rpy_rates = [0, 0, 0]
ang_vel = [0, 0, 0]

[env.disturbances.action]
fn = "normal"
Expand Down
2 changes: 1 addition & 1 deletion config/level2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pos = [0.0, 1.0, 1.4]
pos = [1.0, 1.0, 0.05]
rpy = [0, 0, 0]
vel = [0, 0, 0]
rpy_rates = [0, 0, 0]
ang_vel = [0, 0, 0]

[env.disturbances.action]
fn = "normal"
Expand Down
2 changes: 1 addition & 1 deletion config/level3.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pos = [-0.5, 0.0, 1.4]
pos = [1.0, 1.0, 0.07]
rpy = [0, 0, 0]
vel = [0, 0, 0]
rpy_rates = [0, 0, 0]
ang_vel = [0, 0, 0]

[env.disturbances.action]
fn = "normal"
Expand Down
2 changes: 1 addition & 1 deletion config/multi_level0.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ pos = [0.0, 1.0, 1.4]
pos = [[1.0, 1.0, 0.05], [1.2, 1.0, 0.05]]
rpy = [[0, 0, 0], [0, 0, 0]]
vel = [[0, 0, 0], [0, 0, 0]]
rpy_rates = [[0, 0, 0], [0, 0, 0]]
ang_vel = [[0, 0, 0], [0, 0, 0]]
2 changes: 1 addition & 1 deletion config/multi_level3.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pos = [0.0, 1.0, 1.4]
pos = [[1.0, 1.0, 0.05], [1.2, 1.0, 0.05]]
rpy = [[0, 0, 0], [0, 0, 0]]
vel = [[0, 0, 0], [0, 0, 0]]
rpy_rates = [[0, 0, 0], [0, 0, 0]]
ang_vel = [[0, 0, 0], [0, 0, 0]]

[env.disturbances.action]
fn = "normal"
Expand Down
30 changes: 16 additions & 14 deletions lsy_drone_racing/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,46 @@

from gymnasium import register

from lsy_drone_racing.envs.drone_racing_env import DroneRacingEnv

__all__ = ["DroneRacingEnv"]
# region SingleDroneEnvs

register(
id="DroneRacing-v0",
entry_point="lsy_drone_racing.envs.drone_racing_env:DroneRacingEnv",
entry_point="lsy_drone_racing.envs.drone_race:DroneRaceEnv",
vector_entry_point="lsy_drone_racing.envs.drone_race:VecDroneRaceEnv",
max_episode_steps=1800, # 30 seconds * 60 Hz,
disable_env_checker=True, # Remove warnings about 2D observations
)

register(
id="DroneRacingAttitude-v0",
entry_point="lsy_drone_racing.envs.drone_racing_env:DroneRacingAttitudeEnv",
entry_point="lsy_drone_racing.envs.drone_race:DroneRaceAttitudeEnv",
vector_entry_point="lsy_drone_racing.envs.drone_race:VecDroneRaceAttitudeEnv",
max_episode_steps=1800,
disable_env_checker=True,
)

# region MultiDroneEnvs

register(
id="DroneRacingDeploy-v0",
entry_point="lsy_drone_racing.envs.drone_racing_deploy_env:DroneRacingDeployEnv",
id="MultiDroneRacing-v0",
entry_point="lsy_drone_racing.envs.multi_drone_race:MultiDroneRaceEnv",
vector_entry_point="lsy_drone_racing.envs.multi_drone_race:VecMultiDroneRaceEnv",
max_episode_steps=1800,
disable_env_checker=True,
)

# region DeployEnvs

register(
id="DroneRacingAttitudeDeploy-v0",
entry_point="lsy_drone_racing.envs.drone_racing_deploy_env:DroneRacingAttitudeDeployEnv",
id="DroneRacingDeploy-v0",
entry_point="lsy_drone_racing.envs.drone_racing_deploy_env:DroneRacingDeployEnv",
max_episode_steps=1800,
disable_env_checker=True,
)

# region MultiEnvs

# TODO: Register specialized, non-vectorized envs for single worlds
register(
id="MultiDroneRacing-v0",
entry_point="lsy_drone_racing.envs.vec_drone_race:VectorMultiDroneRaceEnv",
id="DroneRacingAttitudeDeploy-v0",
entry_point="lsy_drone_racing.envs.drone_racing_deploy_env:DroneRacingAttitudeDeployEnv",
max_episode_steps=1800,
disable_env_checker=True,
)
110 changes: 110 additions & 0 deletions lsy_drone_racing/envs/drone_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Single drone racing environments."""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from gymnasium import Env
from gymnasium.vector import VectorEnv
from gymnasium.vector.utils import batch_space

from lsy_drone_racing.envs.race_core import RaceCoreEnv, action_space, observation_space

if TYPE_CHECKING:
import numpy as np
from ml_collections import ConfigDict
from numpy.typing import NDArray


class DroneRaceEnv(RaceCoreEnv, Env):

Check failure on line 19 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D101)

lsy_drone_racing/envs/drone_race.py:19:7: D101 Missing docstring in public class
def __init__(

Check failure on line 20 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D107)

lsy_drone_racing/envs/drone_race.py:20:9: D107 Missing docstring in `__init__`
self,
freq: int,
sim_config: ConfigDict,
sensor_range: float,
track: ConfigDict | None = None,
disturbances: ConfigDict | None = None,
randomizations: ConfigDict | None = None,
random_resets: bool = False,
seed: int = 1337,
max_episode_steps: int = 1500,
device: Literal["cpu", "gpu"] = "cpu",
):
super().__init__(
n_envs=1,
n_drones=1,
freq=freq,
sim_config=sim_config,
sensor_range=sensor_range,
track=track,
disturbances=disturbances,
randomizations=randomizations,
random_resets=random_resets,
seed=seed,
max_episode_steps=max_episode_steps,
device=device,
)
self.action_space = action_space("state")
n_gates, n_obstacles = len(track.gates), len(track.obstacles)
self.observation_space = observation_space(n_gates, n_obstacles)
self.autoreset = False

def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:

Check failure on line 52 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D102)

lsy_drone_racing/envs/drone_race.py:52:9: D102 Missing docstring in public method
obs, info = super().reset(seed=seed, options=options)
obs = {k: v[0, 0] for k, v in obs.items()}
info = {k: v[0, 0] for k, v in info.items()}
return obs, info

def step(self, action: NDArray[np.floating]) -> tuple[dict, float, bool, bool, dict]:

Check failure on line 58 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D102)

lsy_drone_racing/envs/drone_race.py:58:9: D102 Missing docstring in public method
obs, reward, terminated, truncated, info = super().step(action)
obs = {k: v[0, 0] for k, v in obs.items()}
info = {k: v[0, 0] for k, v in info.items()}
return obs, reward[0, 0], terminated[0, 0], truncated[0, 0], info


class VecDroneRaceEnv(RaceCoreEnv, VectorEnv):

Check failure on line 65 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D101)

lsy_drone_racing/envs/drone_race.py:65:7: D101 Missing docstring in public class
def __init__(

Check failure on line 66 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D107)

lsy_drone_racing/envs/drone_race.py:66:9: D107 Missing docstring in `__init__`
self,
num_envs: int,
freq: int,
sim_config: ConfigDict,
sensor_range: float,
track: ConfigDict | None = None,
disturbances: ConfigDict | None = None,
randomizations: ConfigDict | None = None,
random_resets: bool = False,
seed: int = 1337,
max_episode_steps: int = 1500,
device: Literal["cpu", "gpu"] = "cpu",
):
super().__init__(
n_envs=num_envs,
n_drones=1,
freq=freq,
sim_config=sim_config,
sensor_range=sensor_range,
track=track,
disturbances=disturbances,
randomizations=randomizations,
random_resets=random_resets,
seed=seed,
max_episode_steps=max_episode_steps,
device=device,
)
self.single_action_space = action_space("state")
self.action_space = batch_space(self.single_action_space, num_envs)
n_gates, n_obstacles = len(track.gates), len(track.obstacles)
self.single_observation_space = observation_space(n_gates, n_obstacles)
self.observation_space = batch_space(self.single_observation_space, num_envs)

def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:

Check failure on line 100 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D102)

lsy_drone_racing/envs/drone_race.py:100:9: D102 Missing docstring in public method
obs, info = super().reset(seed=seed, options=options)
obs = {k: v[:, 0] for k, v in obs.items()}
info = {k: v[:, 0] for k, v in info.items()}
return obs, info

def step(self, action: NDArray[np.floating]) -> tuple[dict, float, bool, bool, dict]:

Check failure on line 106 in lsy_drone_racing/envs/drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D102)

lsy_drone_racing/envs/drone_race.py:106:9: D102 Missing docstring in public method
obs, reward, terminated, truncated, info = super().step(action)
obs = {k: v[:, 0] for k, v in obs.items()}
info = {k: v[:, 0] for k, v in info.items()}
return obs, reward[:, 0], terminated[:, 0], truncated[:, 0], info
Loading

0 comments on commit 493e8f0

Please sign in to comment.