diff --git a/lsy_drone_racing/controller_manager.py b/lsy_drone_racing/controller_manager.py new file mode 100644 index 00000000..f4e2cffc --- /dev/null +++ b/lsy_drone_racing/controller_manager.py @@ -0,0 +1,110 @@ +"""Asynchronous controller manager for multi-process control of multiple drones. + +This module provides a controller manager that allows multiple controllers to run in separate +processes without blocking other controllers or the main process. +""" + +from __future__ import annotations + +import multiprocessing as mp +from queue import Empty +from typing import TYPE_CHECKING + +import numpy as np + +from lsy_drone_racing.control.controller import Controller + +if TYPE_CHECKING: + from multiprocessing.synchronize import Event + + from numpy.typing import NDArray + + +class ControllerManager: + """Multi-process safe manager class for asynchronous/non-blocking controller execution. + + Note: + The controller manager currently does not support step and episode callbacks. + + Todo: + Implement an automated return mechanism for the controllers. + """ + + def __init__(self, controllers: list[Controller], default_action: NDArray): + """Initialize the controller manager.""" + assert all(isinstance(c, Controller) for c in controllers), "Invalid controller type(s)!" + self._controllers_cls = controllers + self._obs_queues = [mp.Queue(1) for _ in controllers] + self._action_queues = [mp.Queue(1) for _ in controllers] + self._ready = [mp.Event() for _ in controllers] + self._shutdown = [mp.Event() for _ in controllers] + self._actions = np.tile(default_action, (len(controllers), 1)) + + def start(self, init_args: tuple | None = None, init_kwargs: dict | None = None): + """Start the controller manager.""" + for i, c in enumerate(self._controllers_cls): + args = ( + c, + tuple() if init_args is None else init_args, + dict() if init_kwargs is None else init_kwargs, + self._obs_queues[i], + self._action_queues[i], + self._ready[i], + self._shutdown[i], + ) + self._controller_procs.append(mp.Process(target=self._control_loop, args=args)) + self._controller_procs[-1].start() + for ready in self._ready: # Wait for all controllers to be ready + ready.wait() + + def update_obs(self, obs: dict, info: dict): + """Pass the observation and info updates to all controller processes. + + Args: + obs: The observation dictionary. + info: The info dictionary. + """ + for obs_queue in self._obs_queues: + _clear_producing_queue(obs_queue) + obs_queue.put((obs, info)) + + def latest_actions(self) -> NDArray: + """Get the latest actions from all controllers.""" + for i, action_queue in enumerate(self._action_queues): + if not action_queue.empty(): # Length of queue is 1 -> action is ready + # The action queue could be cleared in between the check and the get() call. Since + # the controller processes immediately put the next action into the queue, this + # minimum block time is acceptable. + self._actions[i] = action_queue.get() + return np.array(self._actions) + + @staticmethod + def _control_loop( + cls: type[Controller], + init_args: tuple, + init_kwargs: dict, + obs_queue: mp.Queue, + action_queue: mp.Queue, + ready: Event, + shutdown: Event, + ): + controller = cls(*init_args, **init_kwargs) + ready.set() + while not shutdown.is_set(): + obs, info = obs_queue.get() # Blocks until new observation is available + action = controller.compute_control(obs, info) + _clear_producing_queue(action_queue) + action_queue.put_nowait(action) + + +def _clear_producing_queue(queue: mp.Queue): + """Clear the queue if it is not empty and this process is the ONLY producer. + + Warning: + Only works for queues with a length of 1. + """ + if not queue.empty(): # There are remaining items in the queue + try: + queue.get_nowait() + except Empty: # Another process could have consumed the last item in between + pass # This is fine, the queue is empty diff --git a/lsy_drone_racing/envs/race_core.py b/lsy_drone_racing/envs/race_core.py index 70a723bf..b3696ac5 100644 --- a/lsy_drone_racing/envs/race_core.py +++ b/lsy_drone_racing/envs/race_core.py @@ -19,7 +19,6 @@ from crazyflow.sim.symbolic import symbolic_attitude from flax.struct import dataclass from gymnasium import spaces -from jax.scipy.spatial.transform import Rotation as JaxR from scipy.spatial.transform import Rotation as R from lsy_drone_racing.envs.randomize import ( diff --git a/lsy_drone_racing/vicon.py b/lsy_drone_racing/vicon.py index 507e1fe9..59873758 100644 --- a/lsy_drone_racing/vicon.py +++ b/lsy_drone_racing/vicon.py @@ -20,7 +20,6 @@ import yaml from crazyswarm.msg import StateVector from rosgraph import Master -from scipy.spatial.transform import Rotation as R from tf2_msgs.msg import TFMessage from lsy_drone_racing.utils.import_utils import get_ros_package_path diff --git a/pyproject.toml b/pyproject.toml index e3b46539..aa609f5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,12 +20,12 @@ classifiers = [ dependencies = [ "fire >= 0.6.0", "numpy >= 1.24.1, < 2.0.0", - "PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency - "rospkg >= 1.5.1", # TODO: Remove after moving to cflib + "PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency + "rospkg >= 1.5.1", # TODO: Remove after moving to cflib "scipy >= 1.10.1", "gymnasium >= 1.0.0", "toml >= 0.10.2", - "ml_collections >= 1.0", + "ml-collections >= 1.0", ] [project.optional-dependencies] diff --git a/scripts/multi_deploy.py b/scripts/multi_deploy.py new file mode 100644 index 00000000..862819df --- /dev/null +++ b/scripts/multi_deploy.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +"""Launch script for the real race with multiple drones. + +Usage: + +python deploy.py + +""" + +from __future__ import annotations + +import logging +import time +from pathlib import Path +from typing import TYPE_CHECKING + +import fire +import gymnasium +import rospy + +from lsy_drone_racing.controller_manager import ControllerManager +from lsy_drone_racing.utils import load_config, load_controller + +if TYPE_CHECKING: + from lsy_drone_racing.envs.drone_racing_deploy_env import ( + DroneRacingAttitudeDeployEnv, + DroneRacingDeployEnv, + ) + +# rospy.init_node changes the default logging configuration of Python, which is bad practice at +# best. As a workaround, we can create loggers under the ROS root logger `rosout`. +# Also see https://github.com/ros/ros_comm/issues/1384 +logger = logging.getLogger("rosout." + __name__) + + +def main(config: str = "multi_level3.toml"): + """Deployment script to run the controller on the real drone. + + Args: + config: Path to the competition configuration. Assumes the file is in `config/`. + controller: The name of the controller file in `lsy_drone_racing/control/` or None. If None, + the controller specified in the config file is used. + """ + config = load_config(Path(__file__).parents[1] / "config" / config) + env_id = "DroneRacingAttitudeDeploy-v0" if "Thrust" in config.env.id else "DroneRacingDeploy-v0" + env: DroneRacingDeployEnv | DroneRacingAttitudeDeployEnv = gymnasium.make(env_id, config=config) + obs, info = env.reset() + + module_path = Path(__file__).parents[1] / "lsy_drone_racing/control" + controller_paths = [module_path / p if p.is_relative() else p for p in config.controller.files] + controller_manager = ControllerManager([load_controller(p) for p in controller_paths]) + controller_manager.start(init_args=(obs, info)) + + try: + start_time = time.perf_counter() + while not rospy.is_shutdown(): + t_loop = time.perf_counter() + obs, info = env.unwrapped.obs, env.unwrapped.info + # Compute the control action asynchronously. This limits delays and prevents slow + # controllers from blocking the controllers for other drones. + controller_manager.update_obs(obs, info) + actions = controller_manager.latest_actions() + next_obs, reward, terminated, truncated, info = env.step(actions) + controller_manager.step_callback(actions, next_obs, reward, terminated, truncated, info) + obs = next_obs + if terminated or truncated: + break + if (dt := (time.perf_counter() - t_loop)) < (1 / config.env.freq): + time.sleep(1 / config.env.freq - dt) + else: + exc = dt - 1 / config.env.freq + logger.warning(f"Controller execution time exceeded loop frequency by {exc:.3f}s.") + ep_time = time.perf_counter() - start_time + controller_manager.episode_callback() + logger.info( + f"Track time: {ep_time:.3f}s" if obs["target_gate"] == -1 else "Task not completed" + ) + finally: + env.close() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + fire.Fire(main) diff --git a/tests/integration/test_controllers.py b/tests/integration/test_controllers.py index 4881d96d..134f66b6 100644 --- a/tests/integration/test_controllers.py +++ b/tests/integration/test_controllers.py @@ -3,6 +3,7 @@ import gymnasium import numpy as np import pytest +from gymnasium.wrappers import JaxToNumpy from lsy_drone_racing.utils import load_config, load_controller @@ -59,10 +60,11 @@ def test_attitude_controller(physics: str): random_resets=config.env.random_resets, seed=config.env.seed, ) + env = JaxToNumpy(env) obs, info = env.reset() ctrl = ctrl_cls(obs, info, config) while True: - action = ctrl.compute_control(obs, info) + action = ctrl.compute_control(obs, info).astype(np.float32) obs, reward, terminated, truncated, info = env.step(action) ctrl.step_callback(action, obs, reward, terminated, truncated, info) if terminated or truncated: diff --git a/tests/unit/envs/test_envs.py b/tests/unit/envs/test_envs.py index 0b4e80a9..01cb4cdc 100644 --- a/tests/unit/envs/test_envs.py +++ b/tests/unit/envs/test_envs.py @@ -4,6 +4,7 @@ import gymnasium import pytest from gymnasium.utils.env_checker import check_env +from gymnasium.wrappers import JaxToNumpy from lsy_drone_racing.utils import load_config @@ -31,4 +32,4 @@ def test_passive_checker_wrapper_warnings(action_space: str): seed=config.env.seed, disable_env_checker=False, ) - check_env(env.unwrapped) + check_env(JaxToNumpy(env))