Skip to content

Commit

Permalink
[wip,broken] Add contact masking. Prepare reset logic
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 10, 2025
1 parent 28fd1cb commit e877fcb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 6 additions & 8 deletions lsy_drone_racing/envs/drone_racing_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __init__(self, config: dict):
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
self.n_gates = len(config.env.track.gates)
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
self.contact_mask = np.ones((self.sim.n_worlds, 29), dtype=bool)
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor

self.gates_visited = np.array([False] * len(config.env.track.gates))
self.obstacles_visited = np.array([False] * len(config.env.track.obstacles))
Expand All @@ -161,6 +163,8 @@ def reset(
self.sim.seed(self.config.env.seed)
if seed is not None:
self.sim.seed(seed)
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
# the sim.reset_hook function, so we don't need to explicitly do it here
self.sim.reset()
# TODO: Add randomization of gates, obstacles, drone, and disturbances
states = self.sim.data.states.replace(
Expand Down Expand Up @@ -278,7 +282,7 @@ def terminated(self) -> bool:
}
if state not in self.state_space:
return True # Drone is out of bounds
if self.sim.contacts("drone:0").any():
if np.logical_and(self.sim.contacts("drone:0"), self.contact_mask).any():
return True
if self.sim.data.states.pos[0, 0, 2] < 0.0:
return True
Expand Down Expand Up @@ -320,17 +324,11 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
for i in range(len(obstacles["pos"])):
obstacle = frame.attach_body(obstacle_spec.find_body("world"), "", f":o{i}")
obstacle.pos = obstacles["pos"][i]
# TODO: Simplify rebuilding the simulation after changing the mujoco model
self.sim.mj_model, self.sim.mj_data, self.sim.mjx_model, mjx_data = self.sim.compile_mj(
spec
)
self.sim.data = self.sim.data.replace(mjx_data=mjx_data)
self.sim.default_data = self.sim.data.replace()
self.sim.build()

def load_disturbances(self, disturbances: dict | None = None) -> dict:
"""Load the disturbances from the config."""
dist = {}
dist = {} # TODO: Add jax disturbances for the simulator dynamics
if disturbances is None: # Default: no passive disturbances.
return dist
for mode, spec in disturbances.items():
Expand Down
4 changes: 3 additions & 1 deletion scripts/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,7 @@ def log_episode_stats(obs: dict, info: dict, config: Munch, curr_time: float):


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.basicConfig()
logging.getLogger("lsy_drone_racing").setLevel(logging.INFO)
logger.setLevel(logging.INFO)
fire.Fire(simulate)

0 comments on commit e877fcb

Please sign in to comment.