Skip to content

Commit

Permalink
[wip,broken] Add vectorized multi-drone env
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 16, 2025
1 parent 26141b0 commit ee8efea
Show file tree
Hide file tree
Showing 5 changed files with 683 additions and 23 deletions.
1 change: 1 addition & 0 deletions benchmarks/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import lsy_drone_racing
env = gymnasium.make('MultiDroneRacing-v0',
n_envs=1, # TODO: Remove this for single-world envs
n_drones=config.env.n_drones,
freq=config.env.freq,
sim_config=config.sim,
Expand Down
3 changes: 2 additions & 1 deletion lsy_drone_racing/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@

# region MultiEnvs

# TODO: Register specialized, non-vectorized envs for single worlds
register(
id="MultiDroneRacing-v0",
entry_point="lsy_drone_racing.envs.multi_drone_race:MultiDroneRacingEnv",
entry_point="lsy_drone_racing.envs.vec_drone_race:VectorMultiDroneRaceEnv",
max_episode_steps=1800,
disable_env_checker=True,
)
37 changes: 19 additions & 18 deletions lsy_drone_racing/envs/multi_drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ class MultiDroneRacingEnv(gymnasium.Env):
"""A Gymnasium environment for drone racing simulations.
This environment simulates a drone racing scenario where a single drone navigates through a
series of gates in a predefined track. It uses the Sim class for physics simulation and supports
various configuration options for randomization, disturbances, and physics models.
series of gates in a predefined track. It supports various configuration options for
randomization, disturbances, and physics models.
The environment provides:
- A customizable track with gates and obstacles
- Configurable simulation and control frequencies
- Support for different physics models (e.g., PyBullet, mathematical dynamics)
- Support for different physics models (e.g., identified dynamics, analytical dynamics)
- Randomization of drone properties and initial conditions
- Disturbance modeling for realistic flight conditions
- Symbolic expressions for advanced control techniques (optional)
Expand Down Expand Up @@ -86,6 +86,7 @@ class MultiDroneRacingEnv(gymnasium.Env):

def __init__(

Check failure on line 87 in lsy_drone_racing/envs/multi_drone_race.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D417)

lsy_drone_racing/envs/multi_drone_race.py:87:9: D417 Missing argument description in the docstring for `__init__`: `n_envs`
self,
n_envs: int,
n_drones: int,
freq: int,
sim_config: ConfigDict,
Expand All @@ -111,6 +112,7 @@ def __init__(
"""
super().__init__()
self.sim = Sim(
n_worlds=n_envs,
n_drones=n_drones,
physics=sim_config.physics,
control=sim_config.get("control", "state"),
Expand All @@ -130,25 +132,24 @@ def __init__(
self.random_resets = random_resets
self.sensor_range = sensor_range
self.gates, self.obstacles, self.drone = self.load_track(track)
self.n_gates = len(track.gates)
specs = {} if disturbances is None else disturbances
self.disturbances = {mode: rng_spec2fn(spec) for mode, spec in specs.items()}
specs = {} if randomizations is None else randomizations
self.randomizations = {mode: rng_spec2fn(spec) for mode, spec in specs.items()}

# Spaces
self.action_space = spaces.Box(low=-1, high=1, shape=(n_drones, 13))
n_obstacles = len(track.obstacles)
n_gates, n_obstacles = len(track.gates), len(track.obstacles)
self.observation_space = spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
"rpy": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
"ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
"target_gate": spaces.Discrete(self.n_gates, start=-1),
"gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(self.n_gates, 3)),
"gates_rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(self.n_gates, 3)),
"gates_visited": spaces.Box(low=0, high=1, shape=(self.n_gates,), dtype=bool),
"target_gate": spaces.Discrete(n_gates, start=-1),
"gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_gates, 3)),
"gates_rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(n_gates, 3)),
"gates_visited": spaces.Box(low=0, high=1, shape=(n_gates,), dtype=bool),
"obstacles_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_obstacles, 3)),
"obstacles_visited": spaces.Box(low=0, high=1, shape=(n_obstacles,), dtype=bool),
}
Expand All @@ -162,7 +163,7 @@ def __init__(
self.target_gate = np.zeros(self.sim.n_drones, dtype=int)
self._steps = 0
self._last_drone_pos = np.zeros((self.sim.n_drones, 3))
self.gates_visited = np.zeros((self.sim.n_drones, self.n_gates), dtype=bool)
self.gates_visited = np.zeros((self.sim.n_drones, n_gates), dtype=bool)
self.obstacles_visited = np.zeros((self.sim.n_drones, n_obstacles), dtype=bool)

# Compile the reset and step functions with custom hooks
Expand Down Expand Up @@ -242,17 +243,18 @@ def step(
)
self.sim.data = self.warp_disabled_drones(self.sim.data, self.disabled_drones)
# TODO: Clean up the accelerated functions
n_gates = len(self.gates["pos"])
gate_id = self.target_gate % n_gates
passed = self._gate_passed(
self.target_gate,
gate_id,
self.gates["mocap_ids"],
self.sim.data.mjx_data.mocap_pos[0],
self.sim.data.mjx_data.mocap_quat[0],
self.sim.data.states.pos[0],
self._last_drone_pos,
self.n_gates,
)
self.target_gate += np.array(passed) * ~self.disabled_drones
self.target_gate[self.target_gate >= self.n_gates] = -1
self.target_gate[self.target_gate >= n_gates] = -1
self._last_drone_pos = self.sim.data.states.pos[0]
return self.obs(), self.reward(), self.terminated(), False, self.info()

Expand Down Expand Up @@ -397,8 +399,8 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:

def load_contact_masks(self) -> NDArray[np.bool_]:
"""Load contact masks for the simulation that zero out irrelevant contacts per drone."""
n_obstacles = len(self.obstacles["pos"])
object_contacts = n_obstacles + self.n_gates * 5 + 1 # 5 geoms per gate, 1 for the floor
n_gates, n_obstacles = len(self.gates["pos"]), len(self.obstacles["pos"])
object_contacts = n_obstacles + n_gates * 5 + 1 # 5 geoms per gate, 1 for the floor
drone_contacts = (self.sim.n_drones - 1) * self.sim.n_drones // 2
n_contacts = self.sim.n_drones * object_contacts + drone_contacts
masks = np.zeros((self.sim.n_drones, n_contacts), dtype=bool)
Expand Down Expand Up @@ -462,21 +464,20 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
@staticmethod
@jax.jit
def _gate_passed(
target_gate: NDArray,
gate_id: int,
mocap_ids: NDArray,
mocap_pos: Array,
mocap_quat: Array,
drone_pos: Array,
last_drone_pos: NDArray,
n_gates: int,
) -> bool:
"""Check if the drone has passed a gate.
Returns:
True if the drone has passed a gate, else False.
"""
# TODO: Test. Cover cases with no gates.
ids = mocap_ids[target_gate % n_gates]
ids = mocap_ids[gate_id]
gate_pos = mocap_pos[ids]
gate_rot = JaxR.from_quat(mocap_quat[ids][..., [1, 2, 3, 0]])
gate_size = (0.45, 0.45)
Expand Down
Loading

0 comments on commit ee8efea

Please sign in to comment.