Skip to content

Commit

Permalink
Fix multi-drone envs with JaxToNumpy wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Feb 11, 2025
1 parent 2a347cf commit 37a7b96
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
4 changes: 3 additions & 1 deletion lsy_drone_racing/envs/multi_drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def step(
obs, reward, terminated, truncated, info = self._step(action)
obs = {k: v[0] for k, v in obs.items()}
info = {k: v[0] for k, v in info.items()}
return obs, reward[0], terminated[0], truncated[0], info
# TODO: Fix by moving towards pettingzoo API
# https://pettingzoo.farama.org/api/parallel/
return obs, reward[0, 0], terminated[0].all(), truncated[0].all(), info


class VecMultiDroneRaceEnv(RaceCoreEnv, VectorEnv):
Expand Down
37 changes: 21 additions & 16 deletions lsy_drone_racing/envs/race_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,20 +358,20 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
self.obstacles["nominal_pos"],
)
obs = {
"pos": np.array(self.sim.data.states.pos, dtype=np.float32),
"quat": np.array(self.sim.data.states.quat, dtype=np.float32),
"vel": np.array(self.sim.data.states.vel, dtype=np.float32),
"ang_vel": np.array(self.sim.data.states.ang_vel, dtype=np.float32),
"target_gate": np.array(self.data.target_gate, dtype=int),
"gates_pos": np.asarray(gates_pos, dtype=np.float32),
"gates_quat": np.asarray(gates_quat, dtype=np.float32),
"gates_visited": np.asarray(self.data.gates_visited, dtype=bool),
"obstacles_pos": np.asarray(obstacles_pos, dtype=np.float32),
"obstacles_visited": np.asarray(self.data.obstacles_visited, dtype=bool),
"pos": self.sim.data.states.pos,
"quat": self.sim.data.states.quat,
"vel": self.sim.data.states.vel,
"ang_vel": self.sim.data.states.ang_vel,
"target_gate": self.data.target_gate,
"gates_pos": gates_pos,
"gates_quat": gates_quat,
"gates_visited": self.data.gates_visited,
"obstacles_pos": obstacles_pos,
"obstacles_visited": self.data.obstacles_visited,
}
return obs

def reward(self) -> NDArray[np.float32]:
def reward(self) -> Array:
"""Compute the reward for the current state.
Note:
Expand All @@ -382,19 +382,19 @@ def reward(self) -> NDArray[np.float32]:
Returns:
Reward for the current state.
"""
return np.array(-1.0 * (self.data.target_gate == -1), dtype=np.float32)
return -1.0 * (self.data.target_gate == -1) # Implicit float conversion

def terminated(self) -> NDArray[np.bool_]:
def terminated(self) -> Array:
"""Check if the episode is terminated.
Returns:
True if all drones have been disabled, else False.
"""
return np.array(self.data.disabled_drones, dtype=bool)
return self.data.disabled_drones

def truncated(self) -> NDArray[np.bool_]:
def truncated(self) -> Array:
"""Array of booleans indicating if the episode is truncated."""
return np.tile(self.data.steps >= self.data.max_episode_steps, (self.sim.n_drones, 1))
return self._truncated(self.data.steps, self.data.max_episode_steps, self.sim.n_drones)

def info(self) -> dict:
"""Return an info dictionary containing additional information about the environment."""
Expand Down Expand Up @@ -494,6 +494,11 @@ def _obs(
obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None])
return gates_pos, gates_quat, obstacles_pos

@staticmethod
@partial(jax.jit, static_argnames="n_drones")
def _truncated(steps: Array, max_episode_steps: Array, n_drones: int) -> Array:
return jp.tile(steps >= max_episode_steps, (n_drones, 1))

@staticmethod
def _disabled_drones(pos: Array, contacts: Array, data: EnvData) -> Array:
disabled = jp.logical_or(data.disabled_drones, jp.any(pos < data.pos_limit_low, axis=-1))
Expand Down
17 changes: 14 additions & 3 deletions scripts/multi_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import fire
import gymnasium
import numpy as np
from gymnasium.wrappers.jax_to_numpy import JaxToNumpy

from lsy_drone_racing.utils import load_config, load_controller

Expand Down Expand Up @@ -69,7 +70,9 @@ def simulate(
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
action_space=config.env.action_space,
)
env = JaxToNumpy(env)

for _ in range(n_runs): # Run n_runs episodes with the controller
obs, info = env.reset()
Expand All @@ -81,7 +84,9 @@ def simulate(
curr_time = i / config.env.freq

action = controller.compute_control(obs, info)
action = np.array([action] * config.env.n_drones * env.unwrapped.sim.n_worlds)
action = np.array(
[action] * config.env.n_drones * env.unwrapped.sim.n_worlds, dtype=np.float32
)
action[1, 0] += 0.2
obs, reward, terminated, truncated, info = env.step(action)
done = terminated | truncated
Expand All @@ -92,9 +97,15 @@ def simulate(
# Synchronize the GUI.
if config.sim.gui:
if ((i * fps) % config.env.freq) < fps:
env.render()
try:
env.render()
# TODO: JaxToNumpy not working with None (returned by env.render()). Open issue
# in gymnasium and fix this.
except Exception as e:
if not e.args[0].startswith("No known conversion for Jax type"):
raise e
i += 1
if done.all():
if done:
break

controller.episode_callback() # Update the controller internal state and models.
Expand Down

0 comments on commit 37a7b96

Please sign in to comment.