Skip to content

Commit

Permalink
Small renamings
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 13, 2025
1 parent 4e81ac7 commit 26141b0
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions lsy_drone_racing/envs/multi_drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from crazyflow import Sim
from crazyflow.sim.symbolic import symbolic_attitude
from gymnasium import spaces
from jax.scipy.spatial.transform import Rotation as JaxR
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.envs.randomize import (
Expand Down Expand Up @@ -302,7 +303,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
@staticmethod
@jax.jit
def _obs_gates(
gates_visited: NDArray,
visited: NDArray,
drone_pos: Array,
mocap_pos: Array,
mocap_quat: Array,
Expand All @@ -312,14 +313,12 @@ def _obs_gates(
nominal_rpy: NDArray,
) -> tuple[Array, Array, Array]:
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
real_quat = mocap_quat[mocap_ids][..., [1, 2, 3, 0]]
real_rpy = jax.scipy.spatial.transform.Rotation.from_quat(real_quat).as_euler("xyz")
real_rpy = JaxR.from_quat(mocap_quat[mocap_ids][..., [1, 2, 3, 0]]).as_euler("xyz")
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
gates_visited = jp.logical_or(gates_visited, in_range)
gates_pos = jp.where(gates_visited[..., None], mocap_pos[mocap_ids], nominal_pos)
gates_rpy = jp.where(gates_visited[..., None], real_rpy, nominal_rpy)
return gates_visited, gates_pos, gates_rpy
visited = jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)
gates_pos = jp.where(visited[..., None], mocap_pos[mocap_ids], nominal_pos)
gates_rpy = jp.where(visited[..., None], real_rpy, nominal_rpy)
return visited, gates_pos, gates_rpy

@staticmethod
@jax.jit
Expand All @@ -332,8 +331,7 @@ def _obs_obstacles(
nominal_pos: NDArray,
) -> tuple[Array, Array]:
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
visited = jp.logical_or(visited, in_range)
visited = jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)
return visited, jp.where(visited[..., None], mocap_pos[mocap_ids], nominal_pos)

def reward(self) -> float:
Expand Down Expand Up @@ -375,7 +373,7 @@ def _disabled_drones(
contacts: Array,
contact_masks: NDArray,
) -> Array:
rpy = jax.scipy.spatial.transform.Rotation.from_quat(quat).as_euler("xyz")
rpy = JaxR.from_quat(quat).as_euler("xyz")
disabled = jp.logical_or(disabled_drones, jp.all(pos < pos_low, axis=-1))
disabled = jp.logical_or(disabled, jp.all(pos > pos_high, axis=-1))
disabled = jp.logical_or(disabled, jp.all(rpy < rpy_low, axis=-1))
Expand Down Expand Up @@ -480,8 +478,7 @@ def _gate_passed(
# TODO: Test. Cover cases with no gates.
ids = mocap_ids[target_gate % n_gates]
gate_pos = mocap_pos[ids]
gate_quat = mocap_quat[ids][..., [1, 2, 3, 0]]
gate_rot = jax.scipy.spatial.transform.Rotation.from_quat(gate_quat)
gate_rot = JaxR.from_quat(mocap_quat[ids][..., [1, 2, 3, 0]])
gate_size = (0.45, 0.45)
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
Expand Down

0 comments on commit 26141b0

Please sign in to comment.