Skip to content

Commit

Permalink
[wip,broken] Improve performance of vectorized core env
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 19, 2025
1 parent 4ee7f35 commit 1c845e8
Show file tree
Hide file tree
Showing 6 changed files with 765 additions and 15 deletions.
8 changes: 5 additions & 3 deletions lsy_drone_racing/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class BaseController(ABC):
"""Base class for controller implementations."""

def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: dict):
def __init__(self, obs: dict[str, NDArray[np.floating]], info: dict, config: dict):
"""Initialization of the controller.
Instructions:
Expand All @@ -36,9 +36,11 @@ def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: d
constants, counters, pre-plan trajectories, etc.
Args:
initial_obs: The initial observation of the environment's state. See the environment's
obs: The initial observation of the environment's state. See the environment's
observation space for details.
initial_info: Additional environment information from the reset.
info: The initial environment information from the reset.
config: The race configuration. See the config files for details. Contains additional
information such as disturbance configurations, randomizations, etc.
"""

@abstractmethod
Expand Down
12 changes: 7 additions & 5 deletions lsy_drone_racing/control/trajectory_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@
class TrajectoryController(BaseController):
"""Controller that follows a pre-defined trajectory."""

def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: dict):
def __init__(self, obs: dict[str, NDArray[np.floating]], info: dict, config: dict):
"""Initialization of the controller.
Args:
initial_obs: The initial observation of the environment's state. See the environment's
obs: The initial observation of the environment's state. See the environment's
observation space for details.
initial_info: Additional environment information from the reset.
info: The initial environment information from the reset.
config: The race configuration. See the config files for details. Contains additional
information such as disturbance configurations, randomizations, etc.
"""
super().__init__(initial_obs, initial_info)
super().__init__(obs, info, config)
waypoints = np.array(
[
[1.0, 1.0, 0.05],
Expand All @@ -52,7 +54,7 @@ def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: d
t = np.linspace(0, self.t_total, len(waypoints))
self.trajectory = CubicSpline(t, waypoints)
self._tick = 0
self._freq = initial_info["env_freq"]
self._freq = config.env.freq

def compute_control(
self, obs: dict[str, NDArray[np.floating]], info: dict | None = None
Expand Down
2 changes: 1 addition & 1 deletion lsy_drone_racing/envs/vec_drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
# Env settings
self.freq = freq
self.seed = seed
self.autoreset = True # Can be overridden by subclasses
self.autoreset = False # Can be overridden by subclasses
self.device = jax.devices(device)[0]
self.symbolic = symbolic_attitude(1 / self.freq)
self.random_resets = random_resets
Expand Down
Loading

0 comments on commit 1c845e8

Please sign in to comment.