diff --git a/README.md b/README.md index e6da75a7..927c3b9f 100644 --- a/README.md +++ b/README.md @@ -418,6 +418,6 @@ In a second terminal: python scripts/deploy.py --controller --config level3.toml ``` -where `` implements a controller that inherits from `lsy_drone_racing.control.BaseController` +where `` implements a controller that inherits from `lsy_drone_racing.control.Controller` diff --git a/docs/getting_started/general.rst b/docs/getting_started/general.rst index 1fc76dd8..ae1d70ae 100644 --- a/docs/getting_started/general.rst +++ b/docs/getting_started/general.rst @@ -5,13 +5,13 @@ Welcome to the LSY Drone Racing Project! This is a platform developed by the LSY Implementing Your Own Algorithms -------------------------------- -To implement your own controller, you need to implement a ``Controller`` class in the :mod:`lsy_drone_racing.control` module. The only restriction we place on controllers is that they have to implement the interface defined by the :class:`BaseController ` class. Apart from that, you are encouraged to use the full spectrum of control algorithms, e.g., MPC, trajectory optimization, reinforcement learning, etc., to compete in the challenge. Please make sure to put your controller implementation in the :mod:`lsy_drone_racing.control` module to make sure that it is correctly recognized by our scripts. +To implement your own controller, you need to implement a ``Controller`` class in the :mod:`lsy_drone_racing.control` module. The only restriction we place on controllers is that they have to implement the interface defined by the :class:`Controller ` class. Apart from that, you are encouraged to use the full spectrum of control algorithms, e.g., MPC, trajectory optimization, reinforcement learning, etc., to compete in the challenge. Please make sure to put your controller implementation in the :mod:`lsy_drone_racing.control` module to make sure that it is correctly recognized by our scripts. .. note:: Make sure to inherit from the base class for your controller implementation. This ensures that your controller is compatible with our scripts. Also make sure to only create one controller class per file. Otherwise, we do not know which controller to load from the file. .. warning:: - You are not allowed to modify the interface of the :class:`BaseController ` class. Doing so will make your controller incompatible with the deployment environment and we won't be able to run your controller on our setup. + You are not allowed to modify the interface of the :class:`Controller ` class. Doing so will make your controller incompatible with the deployment environment and we won't be able to run your controller on our setup. .. warning:: Many students are enthusiastic about deep reinforcement learning and try to use it to solve the challenge. While you are completely free in choosing your control algorithm, we know from experience that training good agents is non-trivial, requires significant compute, and can be difficult to transfer into the real world setup. Students taking this approach should make sure they already have some experience with RL, and take their policies to the real world setup early to address potential sim2real issues. diff --git a/lsy_drone_racing/control/__init__.py b/lsy_drone_racing/control/__init__.py index 9ea06e69..b4793749 100644 --- a/lsy_drone_racing/control/__init__.py +++ b/lsy_drone_racing/control/__init__.py @@ -6,13 +6,13 @@ To give you an idea of what you need to do, we also include some example implementations: -* :class:`~.BaseController`: The abstract base class defining the interface for all controllers. +* :class:`~.Controller`: The abstract base class defining the interface for all controllers. * :class:`PPOController `: An example implementation using a pre-trained Proximal Policy Optimization (PPO) model. * :class:`PPOController `: A controller that follows a pre-defined trajectory using cubic spline interpolation. """ -from lsy_drone_racing.control.controller import BaseController +from lsy_drone_racing.control.controller import Controller -__all__ = ["BaseController"] +__all__ = ["Controller"] diff --git a/lsy_drone_racing/control/attitude_controller.py b/lsy_drone_racing/control/attitude_controller.py index dab7d758..c91c43f1 100644 --- a/lsy_drone_racing/control/attitude_controller.py +++ b/lsy_drone_racing/control/attitude_controller.py @@ -18,13 +18,13 @@ from scipy.interpolate import CubicSpline from scipy.spatial.transform import Rotation as R -from lsy_drone_racing.control import BaseController +from lsy_drone_racing.control import Controller if TYPE_CHECKING: from numpy.typing import NDArray -class AttitudeController(BaseController): +class AttitudeController(Controller): """Example of a controller using the collective thrust and attitude interface. Modified from https://github.com/utiasDSL/crazyswarm-import/blob/ad2f7ea987f458a504248a1754b124ba39fc2f21/ros_ws/src/crazyswarm/scripts/position_ctl_m.py diff --git a/lsy_drone_racing/control/controller.py b/lsy_drone_racing/control/controller.py index 1937a3ba..ad84fb7f 100644 --- a/lsy_drone_racing/control/controller.py +++ b/lsy_drone_racing/control/controller.py @@ -5,8 +5,8 @@ from that, you are free to add any additional methods, attributes, or classes to your controller. As an example, you could load the weights of a neural network in the constructor and use it to -compute the control commands in the :meth:`compute_control <.BaseController.compute_control>` -method. You could also use the :meth:`step_callback <.BaseController.step_callback>` method to +compute the control commands in the :meth:`compute_control <.Controller.compute_control>` +method. You could also use the :meth:`step_callback <.Controller.step_callback>` method to update the controller state at runtime. Note: @@ -24,7 +24,7 @@ from numpy.typing import NDArray -class BaseController(ABC): +class Controller(ABC): """Base class for controller implementations.""" def __init__(self, obs: dict[str, NDArray[np.floating]], info: dict, config: dict): diff --git a/lsy_drone_racing/control/ppo_controller.py b/lsy_drone_racing/control/ppo_controller.py index 4193e871..01d32299 100644 --- a/lsy_drone_racing/control/ppo_controller.py +++ b/lsy_drone_racing/control/ppo_controller.py @@ -17,13 +17,13 @@ import numpy as np from stable_baselines3 import PPO -from lsy_drone_racing.control import BaseController +from lsy_drone_racing.control import Controller if TYPE_CHECKING: from numpy.typing import NDArray -class PPOController(BaseController): +class PPOController(Controller): """Controller using a pre-trained PPO model.""" def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: dict): diff --git a/lsy_drone_racing/control/trajectory_controller.py b/lsy_drone_racing/control/trajectory_controller.py index 4a9a7b87..949fbd05 100644 --- a/lsy_drone_racing/control/trajectory_controller.py +++ b/lsy_drone_racing/control/trajectory_controller.py @@ -16,13 +16,13 @@ import numpy as np from scipy.interpolate import CubicSpline -from lsy_drone_racing.control import BaseController +from lsy_drone_racing.control import Controller if TYPE_CHECKING: from numpy.typing import NDArray -class TrajectoryController(BaseController): +class TrajectoryController(Controller): """Controller that follows a pre-defined trajectory.""" def __init__(self, obs: dict[str, NDArray[np.floating]], info: dict, config: dict): diff --git a/lsy_drone_racing/utils/utils.py b/lsy_drone_racing/utils/utils.py index 0d717042..cbab2e5b 100644 --- a/lsy_drone_racing/utils/utils.py +++ b/lsy_drone_racing/utils/utils.py @@ -15,7 +15,7 @@ from jax.scipy.spatial.transform import Rotation as R from ml_collections import ConfigDict -from lsy_drone_racing.control.controller import BaseController +from lsy_drone_racing.control.controller import Controller if TYPE_CHECKING: from pathlib import Path @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def load_controller(path: Path) -> Type[BaseController]: +def load_controller(path: Path) -> Type[Controller]: """Load the controller module from the given path and return the Controller class. Args: @@ -45,17 +45,15 @@ def filter(mod: Any) -> bool: Args: mod: Any attribute of the controller module to be checked. """ - subcls = inspect.isclass(mod) and issubclass(mod, BaseController) + subcls = inspect.isclass(mod) and issubclass(mod, Controller) return subcls and mod.__module__ == controller_module.__name__ controllers = inspect.getmembers(controller_module, filter) - controllers = [c for _, c in controllers if issubclass(c, BaseController)] - assert len(controllers) > 0, ( - f"No controller found in {path}. Have you subclassed BaseController?" - ) + controllers = [c for _, c in controllers if issubclass(c, Controller)] + assert len(controllers) > 0, f"No controller found in {path}. Have you subclassed Controller?" assert len(controllers) == 1, f"Multiple controllers found in {path}. Only one is allowed." controller_module.Controller = controllers[0] - assert issubclass(controller_module.Controller, BaseController) + assert issubclass(controller_module.Controller, Controller) try: return controller_module.Controller diff --git a/scripts/multi_sim.py b/scripts/multi_sim.py index fd6ee186..71be6982 100644 --- a/scripts/multi_sim.py +++ b/scripts/multi_sim.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from ml_collections import ConfigDict - from lsy_drone_racing.control.controller import BaseController + from lsy_drone_racing.control.controller import Controller from lsy_drone_racing.envs.multi_drone_race import MultiDroneRacingEnv @@ -76,7 +76,7 @@ def simulate( for _ in range(n_runs): # Run n_runs episodes with the controller obs, info = env.reset() - controller: BaseController = controller_cls(obs, info, config) + controller: Controller = controller_cls(obs, info, config) i = 0 fps = 60 diff --git a/scripts/sim.py b/scripts/sim.py index 945a0df7..9919a9f1 100644 --- a/scripts/sim.py +++ b/scripts/sim.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ml_collections import ConfigDict - from lsy_drone_racing.control.controller import BaseController + from lsy_drone_racing.control.controller import Controller from lsy_drone_racing.envs.drone_race import DroneRaceEnv @@ -73,7 +73,7 @@ def simulate( for _ in range(n_runs): # Run n_runs episodes with the controller done = False obs, info = env.reset() - controller: BaseController = controller_cls(obs, info, config) + controller: Controller = controller_cls(obs, info, config) i = 0 fps = 60 diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 1a47fc81..f19347df 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -5,7 +5,7 @@ from ml_collections import ConfigDict from scipy.spatial.transform import Rotation as R -from lsy_drone_racing.control.controller import BaseController +from lsy_drone_racing.control.controller import Controller from lsy_drone_racing.utils import gate_passed, load_config, load_controller @@ -20,7 +20,7 @@ def test_load_controller(): c = load_controller( Path(__file__).parents[3] / "lsy_drone_racing/control/trajectory_controller.py" ) - assert issubclass(c, BaseController), f"Controller {c} is not a subclass of BaseController" + assert issubclass(c, Controller), f"Controller {c} is not a subclass of `Controller`" @pytest.mark.unit