Skip to content

Commit

Permalink
Optimize performance. Fix gate detection
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 13, 2025
1 parent 423979b commit a73f377
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
25 changes: 12 additions & 13 deletions lsy_drone_racing/envs/drone_racing_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
"vel": np.array(self.sim.data.states.vel[0, 0], dtype=np.float32),
"ang_vel": np.array(self.sim.data.states.rpy_rates[0, 0], dtype=np.float32),
}
obs["ang_vel"][:] = R.from_euler("xyz", obs["rpy"]).apply(obs["ang_vel"], inverse=True)

obs["target_gate"] = self.target_gate if self.target_gate < len(self.gates) else -1
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
# use the actual pose, otherwise use the nominal pose.
Expand All @@ -247,9 +245,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
obstacles_pos[self.obstacles_visited] = self.obstacles["pos"][self.obstacles_visited]
obs["obstacles_pos"] = obstacles_pos.astype(np.float32)
obs["obstacles_visited"] = self.obstacles_visited

if "observation" in self.disturbances:
obs = self.disturbances["observation"].apply(obs)
# TODO: Observation disturbances?
return obs

def reward(self) -> float:
Expand Down Expand Up @@ -370,10 +366,13 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
assert not hasattr(self.sim.data, "gate_pos")
assert not hasattr(self.sim.data, "obstacle_pos")

gate_ids = [self.sim.mj_model.body(f"gate:{i}").id for i in range(n_gates)]
gates["ids"] = gate_ids
obstacle_ids = [self.sim.mj_model.body(f"obstacle:{i}").id for i in range(n_obstacles)]
obstacles["ids"] = obstacle_ids
mj_model = self.sim.mj_model
gates["ids"] = [mj_model.body(f"gate:{i}").id for i in range(n_gates)]
gates["mocap_ids"] = [int(mj_model.body(f"gate:{i}").mocapid) for i in range(n_gates)]
obstacles["ids"] = [mj_model.body(f"obstacle:{i}").id for i in range(n_obstacles)]
obstacles["mocap_ids"] = [
int(mj_model.body(f"obstacle:{i}").mocapid) for i in range(n_obstacles)
]

def gate_passed(self) -> bool:
"""Check if the drone has passed a gate.
Expand All @@ -383,12 +382,12 @@ def gate_passed(self) -> bool:
"""
if self.n_gates <= 0 or self.target_gate >= self.n_gates or self.target_gate == -1:
return False
gate_id = self.gates["ids"][self.target_gate]
gate_pos = self.sim.data.mjx_data.mocap_pos[0, gate_id]
gate_quat = self.sim.data.mjx_data.mocap_quat[0, gate_id][..., [3, 0, 1, 2]]
gate_mj_id = self.gates["mocap_ids"][self.target_gate]
gate_pos = self.sim.data.mjx_data.mocap_pos[0, gate_mj_id].squeeze()
gate_rot = R.from_quat(self.sim.data.mjx_data.mocap_quat[0, gate_mj_id], scalar_first=True)
drone_pos = self.sim.data.states.pos[0, 0]
gate_size = (0.45, 0.45)
return check_gate_pass(gate_pos, gate_quat, gate_size, drone_pos, self._last_drone_pos)
return check_gate_pass(gate_pos, gate_rot, gate_size, drone_pos, self._last_drone_pos)

def close(self):
"""Close the environment by stopping the drone and landing back at the starting position."""
Expand Down
6 changes: 3 additions & 3 deletions lsy_drone_racing/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def load_config(path: Path) -> Munch:

def check_gate_pass(
gate_pos: np.ndarray,
gate_quat: np.ndarray,
gate_rot: R,
gate_size: np.ndarray,
drone_pos: np.ndarray,
last_drone_pos: np.ndarray,
Expand All @@ -111,13 +111,13 @@ def check_gate_pass(
Args:
gate_pos: The position of the gate in the world frame.
gate_quat: The quaternion of the gate in the world frame.
gate_rot: The rotation of the gate.
gate_size: The size of the gate box in meters.
drone_pos: The position of the drone in the world frame.
last_drone_pos: The position of the drone in the world frame at the last time step.
"""
# Transform last and current drone position into current gate frame.
gate_rot = R(gate_quat, normalize=False, copy=False)
assert isinstance(gate_rot, R), "gate_rot has to be a Rotation object."
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
# Check the plane intersection. If passed, calculate the point of the intersection and check if
Expand Down

0 comments on commit a73f377

Please sign in to comment.