Skip to content

Commit

Permalink
Rename BaseController to Controller
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Feb 13, 2025
1 parent 37a7b96 commit 5a4beb6
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,6 @@ In a second terminal:
python scripts/deploy.py --controller <your_controller.py> --config level3.toml
```

where `<your_controller.py>` implements a controller that inherits from `lsy_drone_racing.control.BaseController`
where `<your_controller.py>` implements a controller that inherits from `lsy_drone_racing.control.Controller`


4 changes: 2 additions & 2 deletions docs/getting_started/general.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ Welcome to the LSY Drone Racing Project! This is a platform developed by the LSY

Implementing Your Own Algorithms
--------------------------------
To implement your own controller, you need to implement a ``Controller`` class in the :mod:`lsy_drone_racing.control` module. The only restriction we place on controllers is that they have to implement the interface defined by the :class:`BaseController <lsy_drone_racing.control.controller.BaseController>` class. Apart from that, you are encouraged to use the full spectrum of control algorithms, e.g., MPC, trajectory optimization, reinforcement learning, etc., to compete in the challenge. Please make sure to put your controller implementation in the :mod:`lsy_drone_racing.control` module to make sure that it is correctly recognized by our scripts.
To implement your own controller, you need to implement a ``Controller`` class in the :mod:`lsy_drone_racing.control` module. The only restriction we place on controllers is that they have to implement the interface defined by the :class:`Controller <lsy_drone_racing.control.controller.Controller>` class. Apart from that, you are encouraged to use the full spectrum of control algorithms, e.g., MPC, trajectory optimization, reinforcement learning, etc., to compete in the challenge. Please make sure to put your controller implementation in the :mod:`lsy_drone_racing.control` module to make sure that it is correctly recognized by our scripts.

.. note::
Make sure to inherit from the base class for your controller implementation. This ensures that your controller is compatible with our scripts. Also make sure to only create one controller class per file. Otherwise, we do not know which controller to load from the file.

.. warning::
You are not allowed to modify the interface of the :class:`BaseController <lsy_drone_racing.control.controller.BaseController>` class. Doing so will make your controller incompatible with the deployment environment and we won't be able to run your controller on our setup.
You are not allowed to modify the interface of the :class:`Controller <lsy_drone_racing.control.controller.Controller>` class. Doing so will make your controller incompatible with the deployment environment and we won't be able to run your controller on our setup.

.. warning::
Many students are enthusiastic about deep reinforcement learning and try to use it to solve the challenge. While you are completely free in choosing your control algorithm, we know from experience that training good agents is non-trivial, requires significant compute, and can be difficult to transfer into the real world setup. Students taking this approach should make sure they already have some experience with RL, and take their policies to the real world setup early to address potential sim2real issues.
Expand Down
6 changes: 3 additions & 3 deletions lsy_drone_racing/control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
To give you an idea of what you need to do, we also include some example implementations:
* :class:`~.BaseController`: The abstract base class defining the interface for all controllers.
* :class:`~.Controller`: The abstract base class defining the interface for all controllers.
* :class:`PPOController <lsy_drone_racing.control.ppo_controller.PPOController>`: An example
implementation using a pre-trained Proximal Policy Optimization (PPO) model.
* :class:`PPOController <lsy_drone_racing.control.trajectory_controller.TrajectoryController>`: A
controller that follows a pre-defined trajectory using cubic spline interpolation.
"""

from lsy_drone_racing.control.controller import BaseController
from lsy_drone_racing.control.controller import Controller

__all__ = ["BaseController"]
__all__ = ["Controller"]
4 changes: 2 additions & 2 deletions lsy_drone_racing/control/attitude_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from scipy.interpolate import CubicSpline
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.control import BaseController
from lsy_drone_racing.control import Controller

if TYPE_CHECKING:
from numpy.typing import NDArray


class AttitudeController(BaseController):
class AttitudeController(Controller):
"""Example of a controller using the collective thrust and attitude interface.
Modified from https://github.com/utiasDSL/crazyswarm-import/blob/ad2f7ea987f458a504248a1754b124ba39fc2f21/ros_ws/src/crazyswarm/scripts/position_ctl_m.py
Expand Down
6 changes: 3 additions & 3 deletions lsy_drone_racing/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from that, you are free to add any additional methods, attributes, or classes to your controller.
As an example, you could load the weights of a neural network in the constructor and use it to
compute the control commands in the :meth:`compute_control <.BaseController.compute_control>`
method. You could also use the :meth:`step_callback <.BaseController.step_callback>` method to
compute the control commands in the :meth:`compute_control <.Controller.compute_control>`
method. You could also use the :meth:`step_callback <.Controller.step_callback>` method to
update the controller state at runtime.
Note:
Expand All @@ -24,7 +24,7 @@
from numpy.typing import NDArray


class BaseController(ABC):
class Controller(ABC):
"""Base class for controller implementations."""

def __init__(self, obs: dict[str, NDArray[np.floating]], info: dict, config: dict):
Expand Down
4 changes: 2 additions & 2 deletions lsy_drone_racing/control/ppo_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import numpy as np
from stable_baselines3 import PPO

from lsy_drone_racing.control import BaseController
from lsy_drone_racing.control import Controller

if TYPE_CHECKING:
from numpy.typing import NDArray


class PPOController(BaseController):
class PPOController(Controller):
"""Controller using a pre-trained PPO model."""

def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: dict):
Expand Down
4 changes: 2 additions & 2 deletions lsy_drone_racing/control/trajectory_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import numpy as np
from scipy.interpolate import CubicSpline

from lsy_drone_racing.control import BaseController
from lsy_drone_racing.control import Controller

if TYPE_CHECKING:
from numpy.typing import NDArray


class TrajectoryController(BaseController):
class TrajectoryController(Controller):
"""Controller that follows a pre-defined trajectory."""

def __init__(self, obs: dict[str, NDArray[np.floating]], info: dict, config: dict):
Expand Down
14 changes: 6 additions & 8 deletions lsy_drone_racing/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jax.scipy.spatial.transform import Rotation as R
from ml_collections import ConfigDict

from lsy_drone_racing.control.controller import BaseController
from lsy_drone_racing.control.controller import Controller

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -26,7 +26,7 @@
logger = logging.getLogger(__name__)


def load_controller(path: Path) -> Type[BaseController]:
def load_controller(path: Path) -> Type[Controller]:
"""Load the controller module from the given path and return the Controller class.
Args:
Expand All @@ -45,17 +45,15 @@ def filter(mod: Any) -> bool:
Args:
mod: Any attribute of the controller module to be checked.
"""
subcls = inspect.isclass(mod) and issubclass(mod, BaseController)
subcls = inspect.isclass(mod) and issubclass(mod, Controller)
return subcls and mod.__module__ == controller_module.__name__

controllers = inspect.getmembers(controller_module, filter)
controllers = [c for _, c in controllers if issubclass(c, BaseController)]
assert len(controllers) > 0, (
f"No controller found in {path}. Have you subclassed BaseController?"
)
controllers = [c for _, c in controllers if issubclass(c, Controller)]
assert len(controllers) > 0, f"No controller found in {path}. Have you subclassed Controller?"
assert len(controllers) == 1, f"Multiple controllers found in {path}. Only one is allowed."
controller_module.Controller = controllers[0]
assert issubclass(controller_module.Controller, BaseController)
assert issubclass(controller_module.Controller, Controller)

try:
return controller_module.Controller
Expand Down
4 changes: 2 additions & 2 deletions scripts/multi_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if TYPE_CHECKING:
from ml_collections import ConfigDict

from lsy_drone_racing.control.controller import BaseController
from lsy_drone_racing.control.controller import Controller
from lsy_drone_racing.envs.multi_drone_race import MultiDroneRacingEnv


Expand Down Expand Up @@ -76,7 +76,7 @@ def simulate(

for _ in range(n_runs): # Run n_runs episodes with the controller
obs, info = env.reset()
controller: BaseController = controller_cls(obs, info, config)
controller: Controller = controller_cls(obs, info, config)
i = 0
fps = 60

Expand Down
4 changes: 2 additions & 2 deletions scripts/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
if TYPE_CHECKING:
from ml_collections import ConfigDict

from lsy_drone_racing.control.controller import BaseController
from lsy_drone_racing.control.controller import Controller
from lsy_drone_racing.envs.drone_race import DroneRaceEnv


Expand Down Expand Up @@ -73,7 +73,7 @@ def simulate(
for _ in range(n_runs): # Run n_runs episodes with the controller
done = False
obs, info = env.reset()
controller: BaseController = controller_cls(obs, info, config)
controller: Controller = controller_cls(obs, info, config)
i = 0
fps = 60

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ml_collections import ConfigDict
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.control.controller import BaseController
from lsy_drone_racing.control.controller import Controller
from lsy_drone_racing.utils import gate_passed, load_config, load_controller


Expand All @@ -20,7 +20,7 @@ def test_load_controller():
c = load_controller(
Path(__file__).parents[3] / "lsy_drone_racing/control/trajectory_controller.py"
)
assert issubclass(c, BaseController), f"Controller {c} is not a subclass of BaseController"
assert issubclass(c, Controller), f"Controller {c} is not a subclass of `Controller`"


@pytest.mark.unit
Expand Down

0 comments on commit 5a4beb6

Please sign in to comment.