Skip to content

Commit

Permalink
Add dynamics disturbances
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 13, 2025
1 parent 9348740 commit cac75f7
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions lsy_drone_racing/envs/drone_racing_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,15 @@ def reset(
Observation and info.
"""
if not self.config.env.random_resets:
self.np_random = np.random.default_rng(seed=self.config.env.seed)
self.sim.seed(self.config.env.seed)
if seed is not None:
self.np_random = np.random.default_rng(seed=self.config.env.seed)
self.sim.seed(seed)
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
# the sim.reset_hook function, so we don't need to explicitly do it here
self.sim.reset()

# TODO: Add disturbances
self.target_gate = 0
self._steps = 0
self._last_drone_pos = self.sim.data.states.pos[0, 0]
Expand All @@ -199,9 +200,13 @@ def step(
action: Full-state command [x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate]
to follow.
"""
# TODO: Add action noise
assert action.shape == self.action_space.shape, f"Invalid action shape: {action.shape}"
self.sim.state_control(action.reshape((1, 1, 13)))
action = action.reshape((1, 1, 13))
if "action" in self.disturbances:
key, subkey = jax.random.split(self.sim.data.core.rng_key)
action += self.disturbances["action"](subkey, (1, 1, 13))
self.sim.data = self.sim.data.replace(core=self.sim.data.core.replace(rng_key=key))
self.sim.state_control(action)
self.sim.step(self.sim.freq // self.config.env.freq)
self.target_gate += self.gate_passed()
if self.target_gate == self.n_gates:
Expand Down Expand Up @@ -308,7 +313,6 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:

def load_disturbances(self, disturbances: dict | None = None) -> dict:
"""Load the disturbances from the config."""
# TODO: Add jax disturbances for the simulator dynamics
if disturbances is None: # Default: no passive disturbances.
return {}
return {mode: self.load_random_fn(spec) for mode, spec in disturbances.items()}
Expand Down Expand Up @@ -344,6 +348,8 @@ def setup_sim(self):
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
self.sim.data = self.sim.data.replace(states=states)
self.sim.reset_hook = build_reset_hook(self.randomization)
if "dynamics" in self.disturbances:
self.sim.disturbance_fn = build_dynamics_disturbance_fn(self.disturbances["dynamics"])
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function

def _load_track_into_sim(self, gates: dict, obstacles: dict):
Expand Down Expand Up @@ -406,6 +412,20 @@ def reset_hook(data: SimData, mask: Array) -> SimData:
return reset_hook


def build_dynamics_disturbance_fn(
fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData], SimData]:
"""Build the dynamics disturbance function for the simulation."""

def dynamics_disturbance(data: SimData) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
states = data.states
states = states.replace(force=states.force + fn(subkey, states.force.shape)) # World frame
return data.replace(states=states, core=data.core.replace(rng_key=key))

return dynamics_disturbance


class DroneRacingThrustEnv(DroneRacingEnv):
"""Drone racing environment with a collective thrust attitude command interface.
Expand Down Expand Up @@ -433,7 +453,10 @@ def step(
action: Thrust command [thrust, roll, pitch, yaw].
"""
assert action.shape == self.action_space.shape, f"Invalid action shape: {action.shape}"
# TODO: Add action noise
if "action" in self.disturbances:
key, subkey = jax.random.split(self.sim.data.core.rng_key)
action += self.disturbances["action"](subkey, (1, 1, 4))
self.sim.data = self.sim.data.replace(core=self.sim.data.core.replace(rng_key=key))
self.sim.attitude_control(action.reshape((1, 1, 4)).astype(np.float32))
self.sim.step(self.sim.freq // self.config.env.freq)
self.target_gate += self.gate_passed()
Expand Down

0 comments on commit cac75f7

Please sign in to comment.