-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a ControllerManager for multi-drone deployments. Fix tests and li…
…nting
- Loading branch information
Showing
7 changed files
with
202 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
"""Asynchronous controller manager for multi-process control of multiple drones. | ||
This module provides a controller manager that allows multiple controllers to run in separate | ||
processes without blocking other controllers or the main process. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import multiprocessing as mp | ||
from queue import Empty | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
from lsy_drone_racing.control.controller import Controller | ||
|
||
if TYPE_CHECKING: | ||
from multiprocessing.synchronize import Event | ||
|
||
from numpy.typing import NDArray | ||
|
||
|
||
class ControllerManager: | ||
"""Multi-process safe manager class for asynchronous/non-blocking controller execution. | ||
Note: | ||
The controller manager currently does not support step and episode callbacks. | ||
Todo: | ||
Implement an automated return mechanism for the controllers. | ||
""" | ||
|
||
def __init__(self, controllers: list[Controller], default_action: NDArray): | ||
"""Initialize the controller manager.""" | ||
assert all(isinstance(c, Controller) for c in controllers), "Invalid controller type(s)!" | ||
self._controllers_cls = controllers | ||
self._obs_queues = [mp.Queue(1) for _ in controllers] | ||
self._action_queues = [mp.Queue(1) for _ in controllers] | ||
self._ready = [mp.Event() for _ in controllers] | ||
self._shutdown = [mp.Event() for _ in controllers] | ||
self._actions = np.tile(default_action, (len(controllers), 1)) | ||
|
||
def start(self, init_args: tuple | None = None, init_kwargs: dict | None = None): | ||
"""Start the controller manager.""" | ||
for i, c in enumerate(self._controllers_cls): | ||
args = ( | ||
c, | ||
tuple() if init_args is None else init_args, | ||
dict() if init_kwargs is None else init_kwargs, | ||
self._obs_queues[i], | ||
self._action_queues[i], | ||
self._ready[i], | ||
self._shutdown[i], | ||
) | ||
self._controller_procs.append(mp.Process(target=self._control_loop, args=args)) | ||
self._controller_procs[-1].start() | ||
for ready in self._ready: # Wait for all controllers to be ready | ||
ready.wait() | ||
|
||
def update_obs(self, obs: dict, info: dict): | ||
"""Pass the observation and info updates to all controller processes. | ||
Args: | ||
obs: The observation dictionary. | ||
info: The info dictionary. | ||
""" | ||
for obs_queue in self._obs_queues: | ||
_clear_producing_queue(obs_queue) | ||
obs_queue.put((obs, info)) | ||
|
||
def latest_actions(self) -> NDArray: | ||
"""Get the latest actions from all controllers.""" | ||
for i, action_queue in enumerate(self._action_queues): | ||
if not action_queue.empty(): # Length of queue is 1 -> action is ready | ||
# The action queue could be cleared in between the check and the get() call. Since | ||
# the controller processes immediately put the next action into the queue, this | ||
# minimum block time is acceptable. | ||
self._actions[i] = action_queue.get() | ||
return np.array(self._actions) | ||
|
||
@staticmethod | ||
def _control_loop( | ||
cls: type[Controller], | ||
init_args: tuple, | ||
init_kwargs: dict, | ||
obs_queue: mp.Queue, | ||
action_queue: mp.Queue, | ||
ready: Event, | ||
shutdown: Event, | ||
): | ||
controller = cls(*init_args, **init_kwargs) | ||
ready.set() | ||
while not shutdown.is_set(): | ||
obs, info = obs_queue.get() # Blocks until new observation is available | ||
action = controller.compute_control(obs, info) | ||
_clear_producing_queue(action_queue) | ||
action_queue.put_nowait(action) | ||
|
||
|
||
def _clear_producing_queue(queue: mp.Queue): | ||
"""Clear the queue if it is not empty and this process is the ONLY producer. | ||
Warning: | ||
Only works for queues with a length of 1. | ||
""" | ||
if not queue.empty(): # There are remaining items in the queue | ||
try: | ||
queue.get_nowait() | ||
except Empty: # Another process could have consumed the last item in between | ||
pass # This is fine, the queue is empty |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#!/usr/bin/env python | ||
"""Launch script for the real race with multiple drones. | ||
Usage: | ||
python deploy.py <path/to/controller.py> <path/to/config.toml> | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
import time | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
import fire | ||
import gymnasium | ||
import rospy | ||
|
||
from lsy_drone_racing.controller_manager import ControllerManager | ||
from lsy_drone_racing.utils import load_config, load_controller | ||
|
||
if TYPE_CHECKING: | ||
from lsy_drone_racing.envs.drone_racing_deploy_env import ( | ||
DroneRacingAttitudeDeployEnv, | ||
DroneRacingDeployEnv, | ||
) | ||
|
||
# rospy.init_node changes the default logging configuration of Python, which is bad practice at | ||
# best. As a workaround, we can create loggers under the ROS root logger `rosout`. | ||
# Also see https://github.com/ros/ros_comm/issues/1384 | ||
logger = logging.getLogger("rosout." + __name__) | ||
|
||
|
||
def main(config: str = "multi_level3.toml"): | ||
"""Deployment script to run the controller on the real drone. | ||
Args: | ||
config: Path to the competition configuration. Assumes the file is in `config/`. | ||
controller: The name of the controller file in `lsy_drone_racing/control/` or None. If None, | ||
the controller specified in the config file is used. | ||
""" | ||
config = load_config(Path(__file__).parents[1] / "config" / config) | ||
env_id = "DroneRacingAttitudeDeploy-v0" if "Thrust" in config.env.id else "DroneRacingDeploy-v0" | ||
env: DroneRacingDeployEnv | DroneRacingAttitudeDeployEnv = gymnasium.make(env_id, config=config) | ||
obs, info = env.reset() | ||
|
||
module_path = Path(__file__).parents[1] / "lsy_drone_racing/control" | ||
controller_paths = [module_path / p if p.is_relative() else p for p in config.controller.files] | ||
controller_manager = ControllerManager([load_controller(p) for p in controller_paths]) | ||
controller_manager.start(init_args=(obs, info)) | ||
|
||
try: | ||
start_time = time.perf_counter() | ||
while not rospy.is_shutdown(): | ||
t_loop = time.perf_counter() | ||
obs, info = env.unwrapped.obs, env.unwrapped.info | ||
# Compute the control action asynchronously. This limits delays and prevents slow | ||
# controllers from blocking the controllers for other drones. | ||
controller_manager.update_obs(obs, info) | ||
actions = controller_manager.latest_actions() | ||
next_obs, reward, terminated, truncated, info = env.step(actions) | ||
controller_manager.step_callback(actions, next_obs, reward, terminated, truncated, info) | ||
obs = next_obs | ||
if terminated or truncated: | ||
break | ||
if (dt := (time.perf_counter() - t_loop)) < (1 / config.env.freq): | ||
time.sleep(1 / config.env.freq - dt) | ||
else: | ||
exc = dt - 1 / config.env.freq | ||
logger.warning(f"Controller execution time exceeded loop frequency by {exc:.3f}s.") | ||
ep_time = time.perf_counter() - start_time | ||
controller_manager.episode_callback() | ||
logger.info( | ||
f"Track time: {ep_time:.3f}s" if obs["target_gate"] == -1 else "Task not completed" | ||
) | ||
finally: | ||
env.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.INFO) | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters