Skip to content

Commit

Permalink
Add a ControllerManager for multi-drone deployments. Fix tests and li…
Browse files Browse the repository at this point in the history
…nting
  • Loading branch information
amacati committed Feb 13, 2025
1 parent 5a4beb6 commit 5d278b5
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 7 deletions.
110 changes: 110 additions & 0 deletions lsy_drone_racing/controller_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Asynchronous controller manager for multi-process control of multiple drones.
This module provides a controller manager that allows multiple controllers to run in separate
processes without blocking other controllers or the main process.
"""

from __future__ import annotations

import multiprocessing as mp
from queue import Empty
from typing import TYPE_CHECKING

import numpy as np

from lsy_drone_racing.control.controller import Controller

if TYPE_CHECKING:
from multiprocessing.synchronize import Event

from numpy.typing import NDArray


class ControllerManager:
"""Multi-process safe manager class for asynchronous/non-blocking controller execution.
Note:
The controller manager currently does not support step and episode callbacks.
Todo:
Implement an automated return mechanism for the controllers.
"""

def __init__(self, controllers: list[Controller], default_action: NDArray):
"""Initialize the controller manager."""
assert all(isinstance(c, Controller) for c in controllers), "Invalid controller type(s)!"
self._controllers_cls = controllers
self._obs_queues = [mp.Queue(1) for _ in controllers]
self._action_queues = [mp.Queue(1) for _ in controllers]
self._ready = [mp.Event() for _ in controllers]
self._shutdown = [mp.Event() for _ in controllers]
self._actions = np.tile(default_action, (len(controllers), 1))

def start(self, init_args: tuple | None = None, init_kwargs: dict | None = None):
"""Start the controller manager."""
for i, c in enumerate(self._controllers_cls):
args = (
c,
tuple() if init_args is None else init_args,
dict() if init_kwargs is None else init_kwargs,
self._obs_queues[i],
self._action_queues[i],
self._ready[i],
self._shutdown[i],
)
self._controller_procs.append(mp.Process(target=self._control_loop, args=args))
self._controller_procs[-1].start()
for ready in self._ready: # Wait for all controllers to be ready
ready.wait()

def update_obs(self, obs: dict, info: dict):
"""Pass the observation and info updates to all controller processes.
Args:
obs: The observation dictionary.
info: The info dictionary.
"""
for obs_queue in self._obs_queues:
_clear_producing_queue(obs_queue)
obs_queue.put((obs, info))

def latest_actions(self) -> NDArray:
"""Get the latest actions from all controllers."""
for i, action_queue in enumerate(self._action_queues):
if not action_queue.empty(): # Length of queue is 1 -> action is ready
# The action queue could be cleared in between the check and the get() call. Since
# the controller processes immediately put the next action into the queue, this
# minimum block time is acceptable.
self._actions[i] = action_queue.get()
return np.array(self._actions)

@staticmethod
def _control_loop(
cls: type[Controller],
init_args: tuple,
init_kwargs: dict,
obs_queue: mp.Queue,
action_queue: mp.Queue,
ready: Event,
shutdown: Event,
):
controller = cls(*init_args, **init_kwargs)
ready.set()
while not shutdown.is_set():
obs, info = obs_queue.get() # Blocks until new observation is available
action = controller.compute_control(obs, info)
_clear_producing_queue(action_queue)
action_queue.put_nowait(action)


def _clear_producing_queue(queue: mp.Queue):
"""Clear the queue if it is not empty and this process is the ONLY producer.
Warning:
Only works for queues with a length of 1.
"""
if not queue.empty(): # There are remaining items in the queue
try:
queue.get_nowait()
except Empty: # Another process could have consumed the last item in between
pass # This is fine, the queue is empty
1 change: 0 additions & 1 deletion lsy_drone_racing/envs/race_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from crazyflow.sim.symbolic import symbolic_attitude
from flax.struct import dataclass
from gymnasium import spaces
from jax.scipy.spatial.transform import Rotation as JaxR
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.envs.randomize import (
Expand Down
1 change: 0 additions & 1 deletion lsy_drone_racing/vicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import yaml
from crazyswarm.msg import StateVector
from rosgraph import Master
from scipy.spatial.transform import Rotation as R
from tf2_msgs.msg import TFMessage

from lsy_drone_racing.utils.import_utils import get_ros_package_path
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ classifiers = [
dependencies = [
"fire >= 0.6.0",
"numpy >= 1.24.1, < 2.0.0",
"PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency
"rospkg >= 1.5.1", # TODO: Remove after moving to cflib
"PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency
"rospkg >= 1.5.1", # TODO: Remove after moving to cflib
"scipy >= 1.10.1",
"gymnasium >= 1.0.0",
"toml >= 0.10.2",
"ml_collections >= 1.0",
"ml-collections >= 1.0",
]

[project.optional-dependencies]
Expand Down
84 changes: 84 additions & 0 deletions scripts/multi_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python
"""Launch script for the real race with multiple drones.
Usage:
python deploy.py <path/to/controller.py> <path/to/config.toml>
"""

from __future__ import annotations

import logging
import time
from pathlib import Path
from typing import TYPE_CHECKING

import fire
import gymnasium
import rospy

from lsy_drone_racing.controller_manager import ControllerManager
from lsy_drone_racing.utils import load_config, load_controller

if TYPE_CHECKING:
from lsy_drone_racing.envs.drone_racing_deploy_env import (
DroneRacingAttitudeDeployEnv,
DroneRacingDeployEnv,
)

# rospy.init_node changes the default logging configuration of Python, which is bad practice at
# best. As a workaround, we can create loggers under the ROS root logger `rosout`.
# Also see https://github.com/ros/ros_comm/issues/1384
logger = logging.getLogger("rosout." + __name__)


def main(config: str = "multi_level3.toml"):
"""Deployment script to run the controller on the real drone.
Args:
config: Path to the competition configuration. Assumes the file is in `config/`.
controller: The name of the controller file in `lsy_drone_racing/control/` or None. If None,
the controller specified in the config file is used.
"""
config = load_config(Path(__file__).parents[1] / "config" / config)
env_id = "DroneRacingAttitudeDeploy-v0" if "Thrust" in config.env.id else "DroneRacingDeploy-v0"
env: DroneRacingDeployEnv | DroneRacingAttitudeDeployEnv = gymnasium.make(env_id, config=config)
obs, info = env.reset()

module_path = Path(__file__).parents[1] / "lsy_drone_racing/control"
controller_paths = [module_path / p if p.is_relative() else p for p in config.controller.files]
controller_manager = ControllerManager([load_controller(p) for p in controller_paths])
controller_manager.start(init_args=(obs, info))

try:
start_time = time.perf_counter()
while not rospy.is_shutdown():
t_loop = time.perf_counter()
obs, info = env.unwrapped.obs, env.unwrapped.info
# Compute the control action asynchronously. This limits delays and prevents slow
# controllers from blocking the controllers for other drones.
controller_manager.update_obs(obs, info)
actions = controller_manager.latest_actions()
next_obs, reward, terminated, truncated, info = env.step(actions)
controller_manager.step_callback(actions, next_obs, reward, terminated, truncated, info)
obs = next_obs
if terminated or truncated:
break
if (dt := (time.perf_counter() - t_loop)) < (1 / config.env.freq):
time.sleep(1 / config.env.freq - dt)
else:
exc = dt - 1 / config.env.freq
logger.warning(f"Controller execution time exceeded loop frequency by {exc:.3f}s.")
ep_time = time.perf_counter() - start_time
controller_manager.episode_callback()
logger.info(
f"Track time: {ep_time:.3f}s" if obs["target_gate"] == -1 else "Task not completed"
)
finally:
env.close()


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
fire.Fire(main)
4 changes: 3 additions & 1 deletion tests/integration/test_controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gymnasium
import numpy as np
import pytest
from gymnasium.wrappers import JaxToNumpy

from lsy_drone_racing.utils import load_config, load_controller

Expand Down Expand Up @@ -59,10 +60,11 @@ def test_attitude_controller(physics: str):
random_resets=config.env.random_resets,
seed=config.env.seed,
)
env = JaxToNumpy(env)
obs, info = env.reset()
ctrl = ctrl_cls(obs, info, config)
while True:
action = ctrl.compute_control(obs, info)
action = ctrl.compute_control(obs, info).astype(np.float32)
obs, reward, terminated, truncated, info = env.step(action)
ctrl.step_callback(action, obs, reward, terminated, truncated, info)
if terminated or truncated:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/envs/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gymnasium
import pytest
from gymnasium.utils.env_checker import check_env
from gymnasium.wrappers import JaxToNumpy

from lsy_drone_racing.utils import load_config

Expand Down Expand Up @@ -31,4 +32,4 @@ def test_passive_checker_wrapper_warnings(action_space: str):
seed=config.env.seed,
disable_env_checker=False,
)
check_env(env.unwrapped)
check_env(JaxToNumpy(env))

0 comments on commit 5d278b5

Please sign in to comment.