Skip to content

Commit

Permalink
Pass mocap IDs to randomize factories. Fix wrong config flags
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 13, 2025
1 parent cac75f7 commit 82a70b6
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 42 deletions.
2 changes: 1 addition & 1 deletion benchmarks/config/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ attitude_freq = 500 # Controller frequency, in Hz.
gui = false # Enable/disable PyBullet's GUI

[env]
random_resets = false # Whether to re-seed the random number generator between episodes
random_resets = true # Whether to re-seed the random number generator between episodes
seed = 1337 # Random seed
freq = 50 # Frequency of the environment's step function, in Hz
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.
Expand Down
2 changes: 1 addition & 1 deletion config/level0.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ high = [0.1, 0.1, 0.1]

[env]
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
random_resets = true # Whether to re-seed the random number generator between episodes
random_resets = false # Whether to re-seed the random number generator between episodes
seed = 1337 # Random seed
freq = 50 # Frequency of the environment's step function, in Hz
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.
Expand Down
2 changes: 1 addition & 1 deletion config/level1.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ gui = false # Enable/disable PyBullet's GU

[env]
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
random_resets = true # Whether to re-seed the random number generator between episodes
random_resets = false # Whether to re-seed the random number generator between episodes
seed = 1337 # Random seed
freq = 50 # Frequency of the environment's step function, in Hz
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.
Expand Down
2 changes: 1 addition & 1 deletion config/level2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ gui = false # Enable/disable PyBullet's GU

[env]
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
random_resets = true # Whether to re-seed the random number generator between episodes
random_resets = false # Whether to re-seed the random number generator between episodes
seed = 1337 # Random seed
freq = 50 # Frequency of the environment's step function, in Hz
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.
Expand Down
2 changes: 1 addition & 1 deletion config/level3.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ gui = false # Enable/disable PyBullet's GU

[env]
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
random_resets = false # Whether to re-seed the random number generator between episodes
random_resets = true # Whether to re-seed the random number generator between episodes
seed = 1337 # Random seed
freq = 50 # Frequency of the environment's step function, in Hz
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.
Expand Down
46 changes: 38 additions & 8 deletions lsy_drone_racing/envs/drone_racing_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@
from gymnasium import spaces
from scipy.spatial.transform import Rotation as R

from lsy_drone_racing.envs.randomize import randomize_sim_fn
from lsy_drone_racing.envs.randomize import (
randomize_drone_inertia_fn,
randomize_drone_mass_fn,
randomize_drone_pos_fn,
randomize_drone_quat_fn,
randomize_gate_pos_fn,
randomize_gate_rpy_fn,
randomize_obstacle_pos_fn,
)
from lsy_drone_racing.utils import check_gate_pass

if TYPE_CHECKING:
Expand Down Expand Up @@ -147,7 +155,7 @@ def __init__(self, config: dict):
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
self.n_gates = len(config.env.track.gates)
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
self.randomization = self.load_randomizations(config.env.get("randomization", None))
self.randomizations = self.load_randomizations(config.env.get("randomization", None))
self.contact_mask = np.ones((self.sim.n_worlds, 25), dtype=bool)
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor

Expand Down Expand Up @@ -250,7 +258,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
obstacles_pos[self.obstacles_visited] = self.obstacles["pos"][self.obstacles_visited]
obs["obstacles_pos"] = obstacles_pos.astype(np.float32)
obs["obstacles_visited"] = self.obstacles_visited
# TODO: Observation disturbances?
# TODO: Decide on observation disturbances
return obs

def reward(self) -> float:
Expand Down Expand Up @@ -347,7 +355,9 @@ def setup_sim(self):
rpy_rates = self.drone["rpy_rates"].reshape(self.sim.data.states.rpy_rates.shape)
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
self.sim.data = self.sim.data.replace(states=states)
self.sim.reset_hook = build_reset_hook(self.randomization)
self.sim.reset_hook = build_reset_hook(
self.randomizations, self.gates["mocap_ids"], self.obstacles["mocap_ids"]
)
if "dynamics" in self.disturbances:
self.sim.disturbance_fn = build_dynamics_disturbance_fn(self.disturbances["dynamics"])
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function
Expand Down Expand Up @@ -400,13 +410,33 @@ def close(self):
self.sim.close()


def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array], SimData]:
def build_reset_hook(
randomizations: dict, gate_mocap_ids: list[int], obstacle_mocap_ids: list[int]
) -> Callable[[SimData, Array], SimData]:
"""Build the reset hook for the simulation."""
randomizations = [randomize_sim_fn(target, rng) for target, rng in randomizations.items()]
randomization_fns = []
for target, rng in randomizations.items():
match target:
case "drone_pos":
randomization_fns.append(randomize_drone_pos_fn(rng))
case "drone_rpy":
randomization_fns.append(randomize_drone_quat_fn(rng))
case "drone_mass":
randomization_fns.append(randomize_drone_mass_fn(rng))
case "drone_inertia":
randomization_fns.append(randomize_drone_inertia_fn(rng))
case "gate_pos":
randomization_fns.append(randomize_gate_pos_fn(rng, gate_mocap_ids))
case "gate_rpy":
randomization_fns.append(randomize_gate_rpy_fn(rng, gate_mocap_ids))
case "obstacle_pos":
randomization_fns.append(randomize_obstacle_pos_fn(rng, obstacle_mocap_ids))
case _:
raise ValueError(f"Invalid target: {target}")

def reset_hook(data: SimData, mask: Array) -> SimData:
for randomize in randomizations:
data = randomize(data, mask)
for randomize_fn in randomization_fns:
data = randomize_fn(data, mask)
return data

return reset_hook
Expand Down
32 changes: 3 additions & 29 deletions lsy_drone_racing/envs/randomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,6 @@
from jax.scipy.spatial.transform import Rotation as R


def randomize_sim_fn(
target: str, randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array]
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes aspects of the simulation."""
match target:
case "drone_pos":
return randomize_drone_pos_fn(randomize_fn)
case "drone_rpy":
return randomize_drone_quat_fn(randomize_fn)
case "drone_mass":
return randomize_drone_mass_fn(randomize_fn)
case "drone_inertia":
return randomize_drone_inertia_fn(randomize_fn)
case "gate_pos":
return randomize_gate_pos_fn(randomize_fn)
case "gate_rpy":
return randomize_gate_rpy_fn(randomize_fn)
case "obstacle_pos":
return randomize_obstacle_pos_fn(randomize_fn)
case _:
raise ValueError(f"Invalid target: {target}")


def randomize_drone_pos_fn(
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData, Array], SimData]:
Expand Down Expand Up @@ -97,10 +74,9 @@ def randomize_drone_inertia(data: SimData, mask: Array) -> SimData:


def randomize_gate_pos_fn(
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array], gate_ids: list[int]
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the gate position."""
gate_ids = [0, 1, 2, 3] # TODO: Make this dynamic

def randomize_gate_pos(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
Expand All @@ -114,10 +90,9 @@ def randomize_gate_pos(data: SimData, mask: Array) -> SimData:


def randomize_gate_rpy_fn(
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array], gate_ids: list[int]
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the gate rotation."""
gate_ids = [0, 1, 2, 3] # TODO: Make this dynamic

def randomize_gate_rpy(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
Expand All @@ -133,10 +108,9 @@ def randomize_gate_rpy(data: SimData, mask: Array) -> SimData:


def randomize_obstacle_pos_fn(
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
randomize_fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array], obstacle_ids: list[int]
) -> Callable[[SimData, Array], SimData]:
"""Create a function that randomizes the obstacle position."""
obstacle_ids = [4, 5, 6, 7] # TODO: Make this dynamic

def randomize_obstacle_pos(data: SimData, mask: Array) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
Expand Down

0 comments on commit 82a70b6

Please sign in to comment.