Skip to content

Commit

Permalink
[wip, broken] Add race randomization
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 10, 2025
1 parent e877fcb commit 4c25828
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 8 deletions.
3 changes: 3 additions & 0 deletions lsy_drone_racing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""LSY drone racing package for the Autonomous Drone Racing class @ TUM."""
from crazyflow.utils import enable_cache

import lsy_drone_racing.envs # noqa: F401, register environments with gymnasium

enable_cache() # Enable persistent caching of jax functions
78 changes: 70 additions & 8 deletions lsy_drone_racing/envs/drone_racing_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,28 @@
import copy as copy
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

import gymnasium
import mujoco
import numpy as np
from crazyflow import Sim
from crazyflow.sim.sim import identity
from gymnasium import spaces
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.envs.utils import (
randomize_drone_inertia_fn,
randomize_drone_mass_fn,
randomize_drone_pos_fn,
randomize_drone_quat_fn,
)
from lsy_drone_racing.sim.noise import NoiseList
from lsy_drone_racing.utils import check_gate_pass

if TYPE_CHECKING:
from crazyflow.sim.structs import SimData
from jax import Array
from numpy.typing import NDArray

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,6 +115,7 @@ def __init__(self, config: dict):
)
if config.sim.sim_freq % config.env.freq != 0:
raise ValueError(f"({config.sim.sim_freq=}) is no multiple of ({config.env.freq=})")

self.action_space = spaces.Box(low=-1, high=1, shape=(13,))
n_gates, n_obstacles = len(config.env.track.gates), len(config.env.track.obstacles)
self.observation_space = spaces.Dict(
Expand Down Expand Up @@ -134,16 +144,20 @@ def __init__(self, config: dict):
"ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
}
)

self.target_gate = 0
self.symbolic = self.sim.symbolic() if config.env.symbolic else None
self._steps = 0
self._last_drone_pos = np.zeros(3)
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
self.n_gates = len(config.env.track.gates)
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
self.randomization = self.load_randomizations(config.env.get("randomization", None))
self.contact_mask = np.ones((self.sim.n_worlds, 29), dtype=bool)
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor

self.setup_sim()

self.gates_visited = np.array([False] * len(config.env.track.gates))
self.obstacles_visited = np.array([False] * len(config.env.track.obstacles))

Expand All @@ -167,13 +181,6 @@ def reset(
# the sim.reset_hook function, so we don't need to explicitly do it here
self.sim.reset()
# TODO: Add randomization of gates, obstacles, drone, and disturbances
states = self.sim.data.states.replace(
pos=self.drone["pos"].reshape((1, 1, 3)),
quat=self.drone["quat"].reshape((1, 1, 4)),
vel=self.drone["vel"].reshape((1, 1, 3)),
rpy_rates=self.drone["rpy_rates"].reshape((1, 1, 3)),
)
self.sim.data = self.sim.data.replace(states=states)
self.target_gate = 0
self._steps = 0
self._last_drone_pos[:] = self.sim.data.states.pos[0, 0]
Expand Down Expand Up @@ -335,6 +342,24 @@ def load_disturbances(self, disturbances: dict | None = None) -> dict:
dist[mode] = NoiseList.from_specs([spec])
return dist

def load_randomizations(self, randomizations: dict | None = None) -> dict:
"""Load the randomization from the config."""
if randomizations is None:
return {}
return {}

def setup_sim(self):
"""Setup the simulation data and build the reset and step functions with custom hooks."""
pos = self.drone["pos"].reshape(self.sim.data.states.pos.shape)
quat = self.drone["quat"].reshape(self.sim.data.states.quat.shape)
vel = self.drone["vel"].reshape(self.sim.data.states.vel.shape)
rpy_rates = self.drone["rpy_rates"].reshape(self.sim.data.states.rpy_rates.shape)
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
self.sim.data = self.sim.data.replace(states=states)
reset_hook = build_reset_hook(self.randomization)
self.sim.reset_hook = reset_hook
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function

def gate_passed(self) -> bool:
"""Check if the drone has passed a gate.
Expand All @@ -355,6 +380,43 @@ def close(self):
self.sim.close()


def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array[bool]], SimData]:
"""Build the reset hook for the simulation."""
modify_drone_pos = identity
if "drone_pos" in randomizations:
modify_drone_pos = randomize_drone_pos_fn(randomizations["drone_pos"])
modify_drone_quat = identity
if "drone_rpy" in randomizations:
modify_drone_quat = randomize_drone_quat_fn(randomizations["drone_rpy"])
modify_drone_mass = identity
if "drone_mass" in randomizations:
modify_drone_mass = randomize_drone_mass_fn(randomizations["drone_mass"])
modify_drone_inertia = identity
if "drone_inertia" in randomizations:
modify_drone_inertia = randomize_drone_inertia_fn(randomizations["drone_inertia"])
modify_gate_pos = identity
if "gate_pos" in randomizations:
modify_gate_pos = randomize_gate_pos_fn(randomizations["gate_pos"])

Check failure on line 399 in lsy_drone_racing/envs/drone_racing_env.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

lsy_drone_racing/envs/drone_racing_env.py:399:27: F821 Undefined name `randomize_gate_pos_fn`
modify_gate_rpy = identity
if "gate_rpy" in randomizations:
modify_gate_rpy = randomize_gate_rpy_fn(randomizations["gate_rpy"])

Check failure on line 402 in lsy_drone_racing/envs/drone_racing_env.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

lsy_drone_racing/envs/drone_racing_env.py:402:27: F821 Undefined name `randomize_gate_rpy_fn`
modify_obstacle_pos = identity
if "obstacle_pos" in randomizations:
modify_obstacle_pos = randomize_obstacle_pos_fn(randomizations["obstacle_pos"])

Check failure on line 405 in lsy_drone_racing/envs/drone_racing_env.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

lsy_drone_racing/envs/drone_racing_env.py:405:31: F821 Undefined name `randomize_obstacle_pos_fn`

def reset_hook(data: SimData, mask: Array[bool]) -> SimData:
data = modify_drone_pos(data, mask)
data = modify_drone_quat(data, mask)
data = modify_drone_mass(data, mask)
data = modify_drone_inertia(data, mask)
data = modify_gate_pos(data, mask)
data = modify_gate_rpy(data, mask)
data = modify_obstacle_pos(data, mask)
return data

return reset_hook


class DroneRacingThrustEnv(DroneRacingEnv):
"""Drone racing environment with a collective thrust attitude command interface.
Expand Down
66 changes: 66 additions & 0 deletions lsy_drone_racing/envs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Callable

Check failure on line 1 in lsy_drone_racing/envs/utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D100)

lsy_drone_racing/envs/utils.py:1:1: D100 Missing docstring in public module

import jax
import jax.numpy as jp
from crazyflow.sim.structs import SimData
from crazyflow.utils import leaf_replace
from jax import Array
from jax.scipy.spatial.transform import Rotation as R


def randomize_drone_pos_fn(
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the drone position."""

def randomize_drone_pos(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
drone_pos = data.states.pos + rng(subkey, shape=data.states.pos.shape)
states = leaf_replace(data.states, mask, pos=drone_pos)
return data.replace(core=data.core.replace(rng_key=key), states=states)

return randomize_drone_pos


def randomize_drone_quat_fn(
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the drone quaternion."""

def randomize_drone_quat(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
rpy = R.from_quat(data.states.quat).as_euler("xyz")
quat = R.from_euler("xyz", rpy + rng(subkey, shape=rpy.shape)).as_quat()
states = leaf_replace(data.states, mask, quat=quat)
return data.replace(core=data.core.replace(rng_key=key), states=states)

return randomize_drone_quat


def randomize_drone_mass_fn(
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the drone mass."""

def randomize_drone_mass(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
mass = data.states.mass + rng(subkey, shape=data.params.mass.shape)
states = leaf_replace(data.states, mask, mass=mass)
return data.replace(core=data.core.replace(rng_key=key), states=states)

return randomize_drone_mass


def randomize_drone_inertia_fn(
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the drone inertia."""

def randomize_drone_inertia(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
J = data.params.J + rng(subkey, shape=data.params.J.shape)
J_inv = jp.linalg.inv(J)
states = leaf_replace(data.states, mask, J=J, J_inv=J_inv)
return data.replace(core=data.core.replace(rng_key=key), states=states)

return randomize_drone_inertia

0 comments on commit 4c25828

Please sign in to comment.