Skip to content

Commit

Permalink
Refactor jit'ed functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 13, 2025
1 parent b451fd1 commit f1b2976
Showing 1 changed file with 66 additions and 108 deletions.
174 changes: 66 additions & 108 deletions lsy_drone_racing/envs/multi_drone_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
randomize_gate_rpy_fn,
randomize_obstacle_pos_fn,
)
from lsy_drone_racing.utils import check_gate_pass

if TYPE_CHECKING:
from crazyflow.sim.structs import SimData
Expand Down Expand Up @@ -227,7 +226,7 @@ def step(
self.sim.step(self.sim.freq // self.freq)
# TODO: Clean up the accelerated functions
self.disabled_drones = np.array(
self.update_active_drones_acc(
self._disabled_drones(
self.sim.data.states.pos[0],
self.sim.data.states.quat[0],
self.pos_bounds.low,
Expand All @@ -242,7 +241,7 @@ def step(
)
self.sim.data = self.warp_disabled_drones(self.sim.data, self.disabled_drones)
# TODO: Clean up the accelerated functions
passed = self.gate_passed_accelerated(
passed = self._gate_passed(
self.target_gate,
self.gates["mocap_ids"],
self.sim.data.mjx_data.mocap_pos[0],
Expand All @@ -263,21 +262,11 @@ def render(self):
def obs(self) -> dict[str, NDArray[np.floating]]:
"""Return the observation of the environment."""
# TODO: Accelerate this function
obs = {
"pos": np.array(self.sim.data.states.pos[0], dtype=np.float32),
"rpy": R.from_quat(self.sim.data.states.quat[0]).as_euler("xyz").astype(np.float32),
"vel": np.array(self.sim.data.states.vel[0], dtype=np.float32),
"ang_vel": np.array(self.sim.data.states.rpy_rates[0], dtype=np.float32),
}
obs["target_gate"] = self.target_gate
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
# use the actual pose, otherwise use the nominal pose.
drone_pos = self.sim.data.states.pos[0]
# Performance optimization: Get a continuous slice instead of using a list of indices which
# copies the data. Assumes that the mocap ids are consecutive.
gates_visited, gates_pos, gates_rpy = self.obs_acc_gates(
gates_visited, gates_pos, gates_rpy = self._obs_gates(
self.gates_visited,
drone_pos,
self.sim.data.states.pos[0],
self.sim.data.mjx_data.mocap_pos[0],
self.sim.data.mjx_data.mocap_quat[0],
self.gates["mocap_ids"],
Expand All @@ -286,62 +275,66 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
self.gates["nominal_rpy"],
)
self.gates_visited = np.asarray(gates_visited, dtype=bool)
obs["gates_pos"] = np.asarray(gates_pos, dtype=np.float32)
obs["gates_rpy"] = np.asarray(gates_rpy, dtype=np.float32)
obs["gates_visited"] = self.gates_visited

obstacles_visited, obstacles_pos = self.obs_acc_obstacles(
obstacles_visited, obstacles_pos = self._obs_obstacles(
self.obstacles_visited,
drone_pos,
self.sim.data.states.pos[0],
self.sim.data.mjx_data.mocap_pos[0],
self.obstacles["mocap_ids"],
self.sensor_range,
self.obstacles["nominal_pos"],
)
self.obstacles_visited = np.asarray(obstacles_visited, dtype=bool)
obs["obstacles_pos"] = np.asarray(obstacles_pos, dtype=np.float32)
obs["obstacles_visited"] = self.obstacles_visited
# TODO: Decide on observation disturbances
obs = {
"pos": np.array(self.sim.data.states.pos[0], dtype=np.float32),
"rpy": R.from_quat(self.sim.data.states.quat[0]).as_euler("xyz").astype(np.float32),
"vel": np.array(self.sim.data.states.vel[0], dtype=np.float32),
"ang_vel": np.array(self.sim.data.states.rpy_rates[0], dtype=np.float32),
"target_gate": self.target_gate,
"gates_pos": np.asarray(gates_pos, dtype=np.float32),
"gates_rpy": np.asarray(gates_rpy, dtype=np.float32),
"gates_visited": self.gates_visited,
"obstacles_pos": np.asarray(obstacles_pos, dtype=np.float32),
"obstacles_visited": self.obstacles_visited,
}
return obs

@staticmethod
@jax.jit
def obs_acc_gates(
gates_visited,
drone_pos,
mocap_pos,
mocap_quat,
mocap_ids,
sensor_range,
nominal_pos,
nominal_rpy,
):
# TODO: Clean up the accelerated functions
gates_pos = mocap_pos[mocap_ids]
gates_quat = mocap_quat[mocap_ids][..., [1, 2, 3, 0]]
gates_rpy = jax.scipy.spatial.transform.Rotation.from_quat(gates_quat).as_euler("xyz")
dpos = drone_pos[..., None, :2] - gates_pos[:, :2]
def _obs_gates(
gates_visited: NDArray,
drone_pos: Array,
mocap_pos: Array,
mocap_quat: Array,
mocap_ids: NDArray,
sensor_range: float,
nominal_pos: NDArray,
nominal_rpy: NDArray,
) -> tuple[Array, Array, Array]:
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
real_quat = mocap_quat[mocap_ids][..., [1, 2, 3, 0]]
real_rpy = jax.scipy.spatial.transform.Rotation.from_quat(real_quat).as_euler("xyz")
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
gates_visited = jp.logical_or(gates_visited, in_range)

mask = gates_visited[..., None]
gates_pos = jp.where(mask, gates_pos, nominal_pos)
gates_rpy = jp.where(mask, gates_rpy, nominal_rpy)
gates_pos = jp.where(gates_visited[..., None], mocap_pos[mocap_ids], nominal_pos)
gates_rpy = jp.where(gates_visited[..., None], real_rpy, nominal_rpy)
return gates_visited, gates_pos, gates_rpy

@staticmethod
@jax.jit
def obs_acc_obstacles(
obstacles_visited, drone_pos, mocap_pos, mocap_ids, sensor_range, nominal_pos
):
# TODO: Clean up the accelerated functions
obstacles_pos = mocap_pos[mocap_ids]
dpos = drone_pos[..., None, :2] - obstacles_pos[:, :2]
def _obs_obstacles(
visited: NDArray,
drone_pos: Array,
mocap_pos: Array,
mocap_ids: NDArray,
sensor_range: float,
nominal_pos: NDArray,
) -> tuple[Array, Array]:
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
obstacles_visited = jp.logical_or(obstacles_visited, in_range)
mask = obstacles_visited[..., None]
obstacles_pos = jp.where(mask, obstacles_pos, nominal_pos)
return obstacles_visited, obstacles_pos
visited = jp.logical_or(visited, in_range)
return visited, jp.where(visited[..., None], mocap_pos[mocap_ids], nominal_pos)

def reward(self) -> float:
"""Compute the reward for the current state.
Expand All @@ -368,33 +361,20 @@ def info(self) -> dict:
"""Return an info dictionary containing additional information about the environment."""
return {"collisions": np.any(self.sim.contacts(), axis=-1), "symbolic_model": self.symbolic}

def update_active_drones(self):
# TODO: Accelerate
pos = self.sim.data.states.pos[0, ...]
rpy = R.from_quat(self.sim.data.states.quat[0, ...]).as_euler("xyz")
disabled = np.logical_or(self.disabled_drones, np.all(pos < self.pos_bounds.low, axis=-1))
disabled = np.logical_or(disabled, np.all(pos > self.pos_bounds.high, axis=-1))
disabled = np.logical_or(disabled, np.all(rpy < self.rpy_bounds.low, axis=-1))
disabled = np.logical_or(disabled, np.all(rpy > self.rpy_bounds.high, axis=-1))
disabled = np.logical_or(disabled, self.target_gate == -1)
contacts = np.any(np.logical_and(self.sim.contacts(), self.contact_masks), axis=-1)
disabled = np.logical_or(disabled, contacts)
self.disabled_drones = disabled

@staticmethod
@jax.jit
def update_active_drones_acc(
pos,
quat,
pos_low,
pos_high,
rpy_low,
rpy_high,
target_gate,
disabled_drones,
contacts,
contact_masks,
):
def _disabled_drones(
pos: Array,
quat: Array,
pos_low: NDArray,
pos_high: NDArray,
rpy_low: NDArray,
rpy_high: NDArray,
target_gate: NDArray,
disabled_drones: NDArray,
contacts: Array,
contact_masks: NDArray,
) -> Array:
rpy = jax.scipy.spatial.transform.Rotation.from_quat(quat).as_euler("xyz")
disabled = jp.logical_or(disabled_drones, jp.all(pos < pos_low, axis=-1))
disabled = jp.logical_or(disabled, jp.all(pos > pos_high, axis=-1))
Expand Down Expand Up @@ -481,30 +461,9 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
mocap_ids = [int(mj_model.body(f"obstacle:{i}").mocapid) for i in range(n_obstacles)]
obstacles["mocap_ids"] = np.array(mocap_ids, dtype=np.int32)

def gate_passed(self) -> bool:
"""Check if the drone has passed a gate.
Returns:
True if the drone has passed a gate, else False.
"""
passed = np.zeros(self.sim.n_drones, dtype=bool)
if self.n_gates <= 0:
return passed
gate_ids = self.target_gate % self.n_gates
gate_mj_id = self.gates["mocap_ids"][gate_ids]
gate_pos = self.sim.data.mjx_data.mocap_pos[0, gate_mj_id].squeeze()
gate_rot = R.from_quat(self.sim.data.mjx_data.mocap_quat[0, gate_mj_id], scalar_first=True)
drone_pos = self.sim.data.states.pos[0]
gate_size = (0.45, 0.45)
for i in range(self.sim.n_drones):
passed[i] = check_gate_pass(
gate_pos[i], gate_rot[i], gate_size, drone_pos[i], self._last_drone_pos[i]
)
return passed

@staticmethod
@jax.jit
def gate_passed_accelerated(
def _gate_passed(
target_gate: NDArray,
mocap_ids: NDArray,
mocap_pos: Array,
Expand All @@ -518,18 +477,16 @@ def gate_passed_accelerated(
Returns:
True if the drone has passed a gate, else False.
"""
# TODO: Test, refactor, optimize. Cover cases with no gates.
gate_ids = target_gate % n_gates
gate_mj_id = mocap_ids[gate_ids]
gate_pos = mocap_pos[gate_mj_id]
gate_rot = jax.scipy.spatial.transform.Rotation.from_quat(
mocap_quat[gate_mj_id][..., [1, 2, 3, 0]]
)
# TODO: Test. Cover cases with no gates.
ids = mocap_ids[target_gate % n_gates]
gate_pos = mocap_pos[ids]
gate_quat = mocap_quat[ids][..., [1, 2, 3, 0]]
gate_rot = jax.scipy.spatial.transform.Rotation.from_quat(gate_quat)
gate_size = (0.45, 0.45)
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
# Check the plane intersection. If passed, calculate the point of the intersection and check if
# it is within the gate box.
# Check if the line between the last position and the current position intersects the plane.
# If so, calculate the point of the intersection and check if it is within the gate box.
passed_plane = (last_pos_local[..., 1] < 0) & (pos_local[..., 1] > 0)
alpha = -last_pos_local[..., 1] / (pos_local[..., 1] - last_pos_local[..., 1])
x_intersect = alpha * (pos_local[..., 0]) + (1 - alpha) * last_pos_local[..., 0]
Expand All @@ -540,6 +497,7 @@ def gate_passed_accelerated(
@staticmethod
@jax.jit
def warp_disabled_drones(data: SimData, mask: NDArray) -> SimData:
"""Warp the disabled drones below the ground."""
mask = mask.reshape((1, -1, 1))
pos = jax.numpy.where(mask, -1, data.states.pos)
return data.replace(states=data.states.replace(pos=pos))
Expand Down

0 comments on commit f1b2976

Please sign in to comment.