diff --git a/lsy_drone_racing/envs/multi_drone_race.py b/lsy_drone_racing/envs/multi_drone_race.py index 1e400184..3ec64b3a 100644 --- a/lsy_drone_racing/envs/multi_drone_race.py +++ b/lsy_drone_racing/envs/multi_drone_race.py @@ -101,7 +101,9 @@ def step( obs, reward, terminated, truncated, info = self._step(action) obs = {k: v[0] for k, v in obs.items()} info = {k: v[0] for k, v in info.items()} - return obs, reward[0], terminated[0], truncated[0], info + # TODO: Fix by moving towards pettingzoo API + # https://pettingzoo.farama.org/api/parallel/ + return obs, reward[0, 0], terminated[0].all(), truncated[0].all(), info class VecMultiDroneRaceEnv(RaceCoreEnv, VectorEnv): diff --git a/lsy_drone_racing/envs/race_core.py b/lsy_drone_racing/envs/race_core.py index a2d91b66..70a723bf 100644 --- a/lsy_drone_racing/envs/race_core.py +++ b/lsy_drone_racing/envs/race_core.py @@ -358,20 +358,20 @@ def obs(self) -> dict[str, NDArray[np.floating]]: self.obstacles["nominal_pos"], ) obs = { - "pos": np.array(self.sim.data.states.pos, dtype=np.float32), - "quat": np.array(self.sim.data.states.quat, dtype=np.float32), - "vel": np.array(self.sim.data.states.vel, dtype=np.float32), - "ang_vel": np.array(self.sim.data.states.ang_vel, dtype=np.float32), - "target_gate": np.array(self.data.target_gate, dtype=int), - "gates_pos": np.asarray(gates_pos, dtype=np.float32), - "gates_quat": np.asarray(gates_quat, dtype=np.float32), - "gates_visited": np.asarray(self.data.gates_visited, dtype=bool), - "obstacles_pos": np.asarray(obstacles_pos, dtype=np.float32), - "obstacles_visited": np.asarray(self.data.obstacles_visited, dtype=bool), + "pos": self.sim.data.states.pos, + "quat": self.sim.data.states.quat, + "vel": self.sim.data.states.vel, + "ang_vel": self.sim.data.states.ang_vel, + "target_gate": self.data.target_gate, + "gates_pos": gates_pos, + "gates_quat": gates_quat, + "gates_visited": self.data.gates_visited, + "obstacles_pos": obstacles_pos, + "obstacles_visited": self.data.obstacles_visited, } return obs - def reward(self) -> NDArray[np.float32]: + def reward(self) -> Array: """Compute the reward for the current state. Note: @@ -382,19 +382,19 @@ def reward(self) -> NDArray[np.float32]: Returns: Reward for the current state. """ - return np.array(-1.0 * (self.data.target_gate == -1), dtype=np.float32) + return -1.0 * (self.data.target_gate == -1) # Implicit float conversion - def terminated(self) -> NDArray[np.bool_]: + def terminated(self) -> Array: """Check if the episode is terminated. Returns: True if all drones have been disabled, else False. """ - return np.array(self.data.disabled_drones, dtype=bool) + return self.data.disabled_drones - def truncated(self) -> NDArray[np.bool_]: + def truncated(self) -> Array: """Array of booleans indicating if the episode is truncated.""" - return np.tile(self.data.steps >= self.data.max_episode_steps, (self.sim.n_drones, 1)) + return self._truncated(self.data.steps, self.data.max_episode_steps, self.sim.n_drones) def info(self) -> dict: """Return an info dictionary containing additional information about the environment.""" @@ -494,6 +494,11 @@ def _obs( obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None]) return gates_pos, gates_quat, obstacles_pos + @staticmethod + @partial(jax.jit, static_argnames="n_drones") + def _truncated(steps: Array, max_episode_steps: Array, n_drones: int) -> Array: + return jp.tile(steps >= max_episode_steps, (n_drones, 1)) + @staticmethod def _disabled_drones(pos: Array, contacts: Array, data: EnvData) -> Array: disabled = jp.logical_or(data.disabled_drones, jp.any(pos < data.pos_limit_low, axis=-1)) diff --git a/scripts/multi_sim.py b/scripts/multi_sim.py index 760cf08f..fd6ee186 100644 --- a/scripts/multi_sim.py +++ b/scripts/multi_sim.py @@ -16,6 +16,7 @@ import fire import gymnasium import numpy as np +from gymnasium.wrappers.jax_to_numpy import JaxToNumpy from lsy_drone_racing.utils import load_config, load_controller @@ -69,7 +70,9 @@ def simulate( randomizations=config.env.get("randomizations"), random_resets=config.env.random_resets, seed=config.env.seed, + action_space=config.env.action_space, ) + env = JaxToNumpy(env) for _ in range(n_runs): # Run n_runs episodes with the controller obs, info = env.reset() @@ -81,7 +84,9 @@ def simulate( curr_time = i / config.env.freq action = controller.compute_control(obs, info) - action = np.array([action] * config.env.n_drones * env.unwrapped.sim.n_worlds) + action = np.array( + [action] * config.env.n_drones * env.unwrapped.sim.n_worlds, dtype=np.float32 + ) action[1, 0] += 0.2 obs, reward, terminated, truncated, info = env.step(action) done = terminated | truncated @@ -92,9 +97,15 @@ def simulate( # Synchronize the GUI. if config.sim.gui: if ((i * fps) % config.env.freq) < fps: - env.render() + try: + env.render() + # TODO: JaxToNumpy not working with None (returned by env.render()). Open issue + # in gymnasium and fix this. + except Exception as e: + if not e.args[0].startswith("No known conversion for Jax type"): + raise e i += 1 - if done.all(): + if done: break controller.episode_callback() # Update the controller internal state and models.