Skip to content

Commit

Permalink
Remove old code. Improve tests. Improve gate_passed function
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 22, 2025
1 parent 6eb3cf5 commit 0df29e2
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 165 deletions.
75 changes: 27 additions & 48 deletions lsy_drone_racing/envs/race_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
randomize_gate_rpy_fn,
randomize_obstacle_pos_fn,
)
from lsy_drone_racing.utils.utils import gate_passed

if TYPE_CHECKING:
from crazyflow.sim.structs import SimData
Expand Down Expand Up @@ -457,15 +458,15 @@ def _step_env(
n_gates = len(data.gate_mj_ids)
disabled_drones = RaceCoreEnv._disabled_drones(drone_pos, drone_quat, contacts, data)
gates_pos = mocap_pos[:, data.gate_mj_ids]
obstacles_pos = mocap_pos[:, data.obstacle_mj_ids]
# We need to convert the mocap quat from MuJoCo order to scipy order
gates_quat = mocap_quat[:, data.gate_mj_ids][..., [3, 0, 1, 2]]
obstacles_pos = mocap_pos[:, data.obstacle_mj_ids]
# Extract the gate poses of the current target gates and check if the drones have passed
# them between the last and current position
gate_ids = data.gate_mj_ids[data.target_gate % n_gates]
gate_pos = gates_pos[jp.arange(gates_pos.shape[0])[:, None], gate_ids]
gate_quat = gates_quat[jp.arange(gates_quat.shape[0])[:, None], gate_ids]
passed = RaceCoreEnv._gate_passed(drone_pos, data.last_drone_pos, gate_pos, gate_quat)
passed = gate_passed(drone_pos, data.last_drone_pos, gate_pos, gate_quat, (0.45, 0.45))
# Update the target gate index. Increment by one if drones have passed a gate
target_gate = data.target_gate + passed * ~disabled_drones
target_gate = jp.where(target_gate >= n_gates, -1, target_gate)
Expand Down Expand Up @@ -511,6 +512,30 @@ def _obs(
obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None])
return gates_pos, gates_rpy, obstacles_pos

@staticmethod
def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array:
rpy = JaxR.from_quat(quat).as_euler("xyz")
disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1))
disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1))
disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1))
disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1))
disabled = jp.logical_or(disabled, data.target_gate == -1)
contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1)
disabled = jp.logical_or(disabled, contacts)
return disabled

@staticmethod
def _visited(drone_pos: Array, target_pos: Array, sensor_range: float, visited: Array) -> Array:
dpos = drone_pos[..., None, :2] - target_pos[:, None, :, :2]
return jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)

@staticmethod
@jax.jit
def _warp_disabled_drones(data: SimData, mask: Array) -> SimData:
"""Warp the disabled drones below the ground."""
pos = jax.numpy.where(mask[..., None], -1, data.states.pos)
return data.replace(states=data.states.replace(pos=pos))

def _load_track(self, track: dict) -> tuple[dict, dict, dict]:
"""Load the track from the config file."""
gate_pos = np.array([g["pos"] for g in track.gates])
Expand Down Expand Up @@ -593,52 +618,6 @@ def _load_contact_masks(self, sim: Sim) -> Array:
masks = np.tile(masks[None, ...], (sim.n_worlds, 1, 1))
return masks

@staticmethod
def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array:
rpy = JaxR.from_quat(quat).as_euler("xyz")
disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1))
disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1))
disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1))
disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1))
disabled = jp.logical_or(disabled, data.target_gate == -1)
contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1)
disabled = jp.logical_or(disabled, contacts)
return disabled

@staticmethod
def _gate_passed(
drone_pos: Array, last_drone_pos: Array, gate_pos: Array, gate_quat: Array
) -> bool:
"""Check if the drone has passed a gate.
Returns:
True if the drone has passed a gate, else False.
"""
gate_rot = JaxR.from_quat(gate_quat)
gate_size = (0.45, 0.45)
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
# Check if the line between the last position and the current position intersects the plane.
# If so, calculate the point of the intersection and check if it is within the gate box.
passed_plane = (last_pos_local[..., 1] < 0) & (pos_local[..., 1] > 0)
alpha = -last_pos_local[..., 1] / (pos_local[..., 1] - last_pos_local[..., 1])
x_intersect = alpha * (pos_local[..., 0]) + (1 - alpha) * last_pos_local[..., 0]
z_intersect = alpha * (pos_local[..., 2]) + (1 - alpha) * last_pos_local[..., 2]
in_box = (abs(x_intersect) < gate_size[0] / 2) & (abs(z_intersect) < gate_size[1] / 2)
return passed_plane & in_box

@staticmethod
def _visited(drone_pos: Array, target_pos: Array, sensor_range: float, visited: Array) -> Array:
dpos = drone_pos[..., None, :2] - target_pos[:, None, :, :2]
return jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)

@staticmethod
@jax.jit
def _warp_disabled_drones(data: SimData, mask: Array) -> SimData:
"""Warp the disabled drones below the ground."""
pos = jax.numpy.where(mask[..., None], -1, data.states.pos)
return data.replace(states=data.states.replace(pos=pos))


# region Factories

Expand Down
4 changes: 2 additions & 2 deletions lsy_drone_racing/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
dependency for sim-only scripts.
"""

from lsy_drone_racing.utils.utils import check_gate_pass, load_config, load_controller, map2pi
from lsy_drone_racing.utils.utils import gate_passed, load_config, load_controller

__all__ = ["load_config", "load_controller", "check_gate_pass", "map2pi"]
__all__ = ["load_config", "load_controller", "gate_passed"]
57 changes: 24 additions & 33 deletions lsy_drone_racing/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,26 @@
import inspect
import logging
import sys
from functools import partial
from typing import TYPE_CHECKING, Type

import numpy as np
import jax
import toml
from jax.numpy import vectorize
from jax.scipy.spatial.transform import Rotation as R
from ml_collections import ConfigDict
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.control.controller import BaseController

if TYPE_CHECKING:
from pathlib import Path
from typing import Any

from numpy.typing import NDArray
from jax import Array

logger = logging.getLogger(__name__)


def map2pi(angle: NDArray[np.floating]) -> NDArray[np.floating]:
"""Map an angle or array of angles to the interval of [-pi, pi].
Args:
angle: Number or array of numbers.
Returns:
The remapped angles.
"""
return ((angle + np.pi) % (2 * np.pi)) - np.pi


def load_controller(path: Path) -> Type[BaseController]:
"""Load the controller module from the given path and return the Controller class.
Expand Down Expand Up @@ -89,12 +79,14 @@ def load_config(path: Path) -> ConfigDict:
return ConfigDict(toml.load(f))


def check_gate_pass(
gate_pos: np.ndarray,
gate_rot: R,
gate_size: np.ndarray,
drone_pos: np.ndarray,
last_drone_pos: np.ndarray,
@jax.jit
@partial(vectorize, signature="(3),(3),(3),(4)->()", excluded=[4])
def gate_passed(
drone_pos: Array,
last_drone_pos: Array,
gate_pos: Array,
gate_quat: Array,
gate_size: tuple[float, float],
) -> bool:
"""Check if the drone has passed the current gate.
Expand All @@ -110,23 +102,22 @@ def check_gate_pass(
goal changes.
Args:
gate_pos: The position of the gate in the world frame.
gate_rot: The rotation of the gate.
gate_size: The size of the gate box in meters.
drone_pos: The position of the drone in the world frame.
last_drone_pos: The position of the drone in the world frame at the last time step.
gate_pos: The position of the gate in the world frame.
gate_quat: The rotation of the gate as a wxyz quaternion.
gate_size: The size of the gate box in meters.
"""
# Transform last and current drone position into current gate frame.
assert isinstance(gate_rot, R), "gate_rot has to be a Rotation object."
gate_rot = R.from_quat(gate_quat)
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
# Check the plane intersection. If passed, calculate the point of the intersection and check if
# it is within the gate box.
if last_pos_local[1] < 0 and pos_local[1] > 0: # Drone has passed the goal plane
alpha = -last_pos_local[1] / (pos_local[1] - last_pos_local[1])
x_intersect = alpha * (pos_local[0]) + (1 - alpha) * last_pos_local[0]
z_intersect = alpha * (pos_local[2]) + (1 - alpha) * last_pos_local[2]
# Divide gate size by 2 to get the distance from the center to the edges
if abs(x_intersect) < gate_size[0] / 2 and abs(z_intersect) < gate_size[1] / 2:
return True
return False
passed_plane = (last_pos_local[1] < 0) & (pos_local[1] > 0)
alpha = -last_pos_local[1] / (pos_local[1] - last_pos_local[1])
x_intersect = alpha * (pos_local[0]) + (1 - alpha) * last_pos_local[0]
z_intersect = alpha * (pos_local[2]) + (1 - alpha) * last_pos_local[2]
# Divide gate size by 2 to get the distance from the center to the edges
in_box = (abs(x_intersect) < gate_size[0] / 2) & (abs(z_intersect) < gate_size[1] / 2)
return passed_plane & in_box
77 changes: 30 additions & 47 deletions tests/integration/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,48 @@
import lsy_drone_racing # noqa: F401, environment registrations
from lsy_drone_racing.utils import load_config

CONFIG_FILES = ["level0.toml", "level1.toml", "level2.toml", "level3.toml"]
MULTI_CONFIG_FILES = ["multi_level0.toml", "multi_level3.toml"]
CONFIG_FILES = {
"DroneRacing-v0": ["level0.toml", "level1.toml", "level2.toml", "level3.toml"],
"MultiDroneRacing-v0": ["multi_level0.toml", "multi_level3.toml"],
}
ENV_IDS = ["DroneRacing-v0", "MultiDroneRacing-v0"]


@pytest.mark.parametrize("physics", ["analytical", "sys_id"])
@pytest.mark.parametrize("config_file", CONFIG_FILES)
@pytest.mark.parametrize(
("env_id", "config_file"),
[(env_id, config_file) for env_id in ENV_IDS for config_file in CONFIG_FILES[env_id]],
)
@pytest.mark.integration
def test_envs(physics: str, config_file: str):
def test_single_drone_envs(env_id: str, config_file: str, physics: str):
"""Test the simulation environments with different physics modes and config files."""
config = load_config(Path(__file__).parents[2] / "config" / config_file)
assert hasattr(config.sim, "physics"), "Physics mode is not set"
config.sim.physics = physics # override physics mode
assert hasattr(config.env, "id"), "Environment ID is not set"
config.env.id = "DroneRacing-v0" # override environment ID

env = gymnasium.make(
"DroneRacing-v0",
freq=config.env.freq,
sim_config=config.sim,
sensor_range=config.env.sensor_range,
track=config.env.track,
disturbances=config.env.get("disturbances"),
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
)
kwargs = {
"freq": config.env.freq,
"sim_config": config.sim,
"sensor_range": config.env.sensor_range,
"track": config.env.track,
"disturbances": config.env.get("disturbances"),
"randomizations": config.env.get("randomizations"),
"random_resets": config.env.random_resets,
"seed": config.env.seed,
}
if "n_drones" in config.env:
kwargs["n_drones"] = config.env.n_drones

env = gymnasium.make(env_id, **kwargs)
env.reset()
for _ in range(10):
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
if terminated or truncated:
break
for _ in range(100):
_, _, _, _, _ = env.step(env.action_space.sample())
env.close()


@pytest.mark.parametrize("physics", ["analytical", "sys_id"])
@pytest.mark.parametrize("config_file", MULTI_CONFIG_FILES)
@pytest.mark.integration
def test_vec_envs(physics: str, config_file: str):
"""Test the simulation environments with different physics modes and config files."""
config = load_config(Path(__file__).parents[2] / "config" / config_file)
assert hasattr(config.sim, "physics"), "Physics mode is not set"
config.sim.physics = physics # override physics mode
assert hasattr(config.env, "id"), "Environment ID is not set"
config.env.id = "MultiDroneRacing-v0" # override environment ID

env = gymnasium.make_vec(
"MultiDroneRacing-v0",
num_envs=2,
n_drones=config.env.n_drones,
freq=config.env.freq,
sim_config=config.sim,
sensor_range=config.env.sensor_range,
track=config.env.track,
disturbances=config.env.get("disturbances"),
randomizations=config.env.get("randomizations"),
random_resets=config.env.random_resets,
seed=config.env.seed,
)
kwargs["num_envs"] = 2
env = gymnasium.make_vec(env_id, **kwargs)
env.reset()
for _ in range(10):
env.step(env.action_space.sample())
for _ in range(100):
_, _, _, _, _ = env.step(env.action_space.sample())
env.close()
52 changes: 17 additions & 35 deletions tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.control.controller import BaseController
from lsy_drone_racing.utils import check_gate_pass, load_config, load_controller, map2pi
from lsy_drone_racing.utils import gate_passed, load_config, load_controller


@pytest.mark.unit
Expand All @@ -24,48 +24,30 @@ def test_load_controller():


@pytest.mark.unit
def test_map2pi():
assert map2pi(0) == 0
assert map2pi(np.pi) == -np.pi
assert map2pi(-np.pi) == -np.pi
assert map2pi(2 * np.pi) == 0
assert map2pi(-2 * np.pi) == 0
assert np.allclose(map2pi(np.arange(10) * 2 * np.pi), np.zeros(10))
assert np.max(map2pi(np.linspace(-100, 100, num=1000))) <= np.pi
assert np.min(map2pi(np.linspace(-100, 100, num=1000))) >= -np.pi


@pytest.mark.unit
def test_check_gate_pass():
def test_gate_pass():
# TODO: Check accelerated function in RaceCore instead
gate_pos = np.array([0, 0, 0])
gate_rot = R.from_euler("xyz", [0, 0, 0])
gate_quat = R.identity().as_quat()
gate_size = np.array([1, 1])
# Test passing through the gate
assert check_gate_pass(gate_pos, gate_rot, gate_size, np.array([0, 1, 0]), np.array([0, -1, 0]))
drone_pos, last_drone_pos = np.array([0, 1, 0]), np.array([0, -1, 0])
assert gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
# Test passing outside the gate boundaries
assert not check_gate_pass(
gate_pos, gate_rot, gate_size, np.array([2, 1, 0]), np.array([2, -1, 0])
)
drone_pos, last_drone_pos = np.array([2, 1, 0]), np.array([2, -1, 0])
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
# Test passing close to the gate
assert not check_gate_pass(
gate_pos, gate_rot, gate_size, np.array([0.51, 1, 0]), np.array([0.51, -1, 0])
)
drone_pos, last_drone_pos = np.array([0.51, 1, 0]), np.array([0.51, -1, 0])
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
# Test passing opposite direction
assert not check_gate_pass(
gate_pos, gate_rot, gate_size, np.array([0, -1, 0]), np.array([0, 1, 0])
)
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
# Test with rotated gate
rotated_gate = R.from_euler("xyz", [0, np.pi / 4, 0])
assert check_gate_pass(
gate_pos, rotated_gate, gate_size, np.array([0.5, 0.5, 0]), np.array([-0.5, -0.5, 0])
)
rotated_gate_quat = R.from_euler("xyz", [0, np.pi / 4, 0]).as_quat()
drone_pos, last_drone_pos = np.array([0.5, 0.5, 0]), np.array([-0.5, -0.5, 0])
assert gate_passed(drone_pos, last_drone_pos, gate_pos, rotated_gate_quat, gate_size)
# Test with moved gate
moved_gate_pos = np.array([1, 1, 1])
assert check_gate_pass(
moved_gate_pos, gate_rot, gate_size, np.array([1, 2, 1]), np.array([1, 0, 1])
)
drone_pos, last_drone_pos = np.array([1, 2, 1]), np.array([1, 0, 1])
assert gate_passed(drone_pos, last_drone_pos, moved_gate_pos, gate_quat, gate_size)
# Test not crossing the plane
assert not check_gate_pass(
gate_pos, gate_rot, gate_size, np.array([0, -0.5, 0]), np.array([0, -1, 0])
)
drone_pos, last_drone_pos = np.array([0, -0.5, 0]), np.array([0, -1, 0])
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)

0 comments on commit 0df29e2

Please sign in to comment.