Skip to content

Commit

Permalink
Update Vicon to track gates.
Browse files Browse the repository at this point in the history
Use exact gate position during deployment when sufficiently close.
[TODO: Test on real setup]
  • Loading branch information
amacati committed Jun 4, 2024
1 parent bfd654b commit ab59746
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 36 deletions.
12 changes: 12 additions & 0 deletions lsy_drone_racing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def euler_from_quaternion(x: float, y: float, z: float, w: float) -> tuple[float
return roll_x, pitch_y, yaw_z # in radians


def map2pi(angle: np.ndarray) -> np.ndarray:
"""Map an angle or array of angles to the interval of [-pi, pi].
Args:
angle: Number or array of numbers.
Returns:
The remapped angles.
"""
return ((angle + np.pi) % (2 * np.pi)) - np.pi


def load_controller(path: Path) -> Type[BaseController]:
"""Load the controller module from the given path and return the Controller class.
Expand Down
72 changes: 54 additions & 18 deletions lsy_drone_racing/vicon.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,66 @@
from __future__ import annotations

import time

import numpy as np
import rospy
import yaml
from rosgraph import Master
from tf2_msgs.msg import TFMessage

from lsy_drone_racing.import_utils import get_ros_package_path
from lsy_drone_racing.utils import euler_from_quaternion
from lsy_drone_racing.utils import euler_from_quaternion, map2pi


class ViconWatcher:
class Vicon:
"""Vicon interface for the pose estimation data for the drone and any other tracked objects.
Vicon sends a stream of ROS messages containing the current pose data. We subscribe to these
messages and save the pose data for each object in dictionaries. Users can then retrieve the
latest pose data directly from these dictionaries.
"""

def __init__(self, track_names: list[str] = []):
def __init__(
self, track_names: list[str] = [], auto_track_drone: bool = True, timeout: float = 0.0
):
"""Load the crazyflies.yaml file and register the subscribers for the Vicon pose data.
Args:
track_names: The names of any additional objects besides the drone to track.
auto_track_drone: Infer the drone name and add it to the positions if True.
timeout: If greater than 0, Vicon waits for position updates of all tracked objects
before returning.
"""
assert Master("/rosnode").is_online(), "ROS is not running. Please run hover.launch first!"
try:
rospy.init_node("playback_node")
except rospy.exceptions.ROSException:
... # ROS node is already running which is fine for us
config_path = get_ros_package_path("crazyswarm") / "launch/crazyflies.yaml"
assert config_path.exists(), "Crazyfly config file missing!"
with open(config_path, "r") as f:
config = yaml.load(f, yaml.SafeLoader)
assert len(config["crazyflies"]) == 1, "Only one crazyfly allowed at a time!"
self.drone_name = f"cf{config['crazyflies'][0]['id']}"

self.drone_name = None
if auto_track_drone:
with open(get_ros_package_path("crazyswarm") / "launch/crazyflies.yaml", "r") as f:
config = yaml.load(f, yaml.SafeLoader)
assert len(config["crazyflies"]) == 1, "Only one crazyfly allowed at a time!"
self.drone_name = f"cf{config['crazyflies'][0]['id']}"
track_names.insert(0, self.drone_name)
self.track_names = track_names
# Register the Vicon subscribers for the drone and any other tracked object
self.pos: dict[str, np.ndarray] = {"cf": np.array([])}
self.rpy: dict[str, np.ndarray] = {"cf": np.array([])}
for track_name in track_names: # Initialize the objects' pose
self.pos[track_name], self.rpy[track_name] = np.array([]), np.array([])
self.pos: dict[str, np.ndarray] = {}
self.rpy: dict[str, np.ndarray] = {}
self.vel: dict[str, np.ndarray] = {}
self.ang_vel: dict[str, np.ndarray] = {}
self.time: dict[str, float] = {}

self.sub = rospy.Subscriber("/tf", TFMessage, self.save_pose)
if timeout:
tstart = time.time()
while not self.active and time.time() - tstart < timeout:
time.sleep(0.01)
if not self.active:
raise TimeoutError(
"Timeout while fetching initial position updates for all tracked objects."
f"Missing objects: {[k for k in self.track_names if k not in self.ang_vel]}"
)

def save_pose(self, data: TFMessage):
"""Save the position and orientation of all transforms.
Expand All @@ -51,10 +69,18 @@ def save_pose(self, data: TFMessage):
data: The TF message containing the objects' pose.
"""
for tf in data.transforms:
name = "cf" if tf.child_frame_id == self.drone_name else tf.child_frame_id
name = tf.child_frame_id.split("/")[-1]
if name not in self.pos:
continue
T, R = tf.transform.translation, tf.transform.rotation
self.pos[name] = np.array([T.x, T.y, T.z])
self.rpy[name] = np.array(euler_from_quaternion(R.x, R.y, R.z, R.w))
pos = np.array([T.x, T.y, T.z])
rpy = np.array(euler_from_quaternion(R.x, R.y, R.z, R.w))
if self.pos[name]:
self.vel[name] = (pos - self.pos[name]) / (time.time() - self.time[name])
self.ang_vel[name] = map2pi(rpy - self.rpy[name]) / (time.time() - self.time[name])
self.time[name] = time.time()
self.pos[name] = pos
self.rpy[name] = rpy

def pose(self, name: str) -> tuple[np.ndarray, np.ndarray]:
"""Get the latest pose of a tracked object.
Expand All @@ -67,7 +93,17 @@ def pose(self, name: str) -> tuple[np.ndarray, np.ndarray]:
"""
return self.pos[name], self.rpy[name]

@property
def poses(self) -> tuple[np.ndarray, np.ndarray]:
"""Get the latest poses of all objects."""
return np.stack(self.pos.values()), np.stack(self.rpy.values())

@property
def names(self) -> list[str]:
"""Get a list of actively tracked names."""
return list(self.pos.keys())

@property
def active(self) -> bool:
"""Check if Vicon has sent data for each object."""
return all(p.size > 0 for p in self.pos.values())
return all([name in self.ang_vel for name in self.track_names])
37 changes: 19 additions & 18 deletions scripts/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lsy_drone_racing.command import Command, apply_command
from lsy_drone_racing.import_utils import get_ros_package_path, pycrazyswarm
from lsy_drone_racing.utils import check_gate_pass, load_controller
from lsy_drone_racing.vicon import ViconWatcher
from lsy_drone_racing.vicon import Vicon

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,15 +93,9 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
time_helper = swarm.timeHelper
cf = swarm.allcfs.crazyflies[0]

vicon = ViconWatcher() # TODO: Integrate autodetection of gate and obstacle positions

timeout = 5.0
tstart = time.time()
while not vicon.active:
logger.info("Waiting for vicon...")
time.sleep(1)
if time.time() - tstart > timeout:
raise TimeoutError("Vicon unavailable.")
gate_names = [f"gate{i}" for i in range(1, len(config.quadrotor_config.gates) + 1)]
obstacle_names = [f"obstacle{i}" for i in range(1, len(config.quadrotor_config.obstacles) + 1)]
vicon = Vicon(track_names=gate_names + obstacle_names, timeout=1.0)

config_path = Path(config).resolve()
assert config_path.is_file(), "Config file does not exist!"
Expand Down Expand Up @@ -137,8 +131,9 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
_, env_info = env.reset()

# Override environment state and evaluate constraints
drone_pos_and_vel = [vicon.pos["cf"][0], 0, vicon.pos["cf"][1], 0, vicon.pos["cf"][2], 0]
drone_rot_and_agl_vel = [vicon.rpy["cf"][0], vicon.rpy["cf"][1], vicon.rpy["cf"][2], 0, 0, 0]
drone_pos, drone_rot = vicon.pos[vicon.drone_name], vicon.rpy[vicon.drone_name]
drone_pos_and_vel = [drone_pos[0], 0, drone_pos[1], 0, drone_pos[2], 0]
drone_rot_and_agl_vel = [drone_rot[0], drone_rot[1], drone_rot[2], 0, 0, 0]
env.state = drone_pos_and_vel + drone_rot_and_agl_vel
constraint_values = env.constraints.get_values(env, only_state=True)
x_reference = config.quadrotor_config.task_info.stabilization_goal
Expand Down Expand Up @@ -179,13 +174,17 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
p = vicon.pos["cf"]
# This only looks at the x-y plane, could be improved
# TODO: Replace with 3D distance once gate poses are given with height
gate_dist = np.sqrt(np.sum((p[0:2] - gate_poses[target_gate_id][0:2]) ** 2))
gate_dist = np.sqrt(np.sum((p[0:2] - vicon.pos[gate_names[target_gate_id]][0:2]) ** 2))
if gate_dist < 0.45:
current_target_gate_pos = vicon.pos[gate_names[target_gate_id]]
else:
current_target_gate_pos = gate_poses[target_gate_id][0:6]
info = {
"mse": np.sum(state_error**2),
"collision": (None, False), # Leave always false in sim2real
"current_target_gate_id": target_gate_id,
"current_target_gate_in_range": gate_dist < 0.45,
"current_target_gate_pos": gate_poses[target_gate_id][0:6], # Always "exact"
"current_target_gate_pos": current_target_gate_pos,
"current_target_gate_type": gate_poses[target_gate_id][6],
"at_goal_position": False, # Leave always false in sim2real
"task_completed": False, # Leave always false in sim2real
Expand All @@ -194,19 +193,21 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
}

# Check if the drone has passed the current gate
if check_gate_pass(gate_poses[target_gate_id], vicon.pos["cf"], last_drone_pos):
if check_gate_pass(
gate_poses[target_gate_id], vicon.pos[vicon.drone_name], last_drone_pos
):
target_gate_id += 1
print(f"Gate {target_gate_id} passed in {curr_time:.4}s")
last_drone_pos = vicon.pos["cf"].copy()
last_drone_pos = vicon.pos[vicon.drone_name].copy()

if target_gate_id == len(gate_poses): # Reached the end
target_gate_id = -1
total_time = time.time() - start_time

# Get the latest vicon observation and call the controller
p = vicon.pos["cf"]
p = vicon.pos[vicon.drone_name]
drone_pos_and_vel = [p[0], 0, p[1], 0, p[2], 0]
r = vicon.rpy["cf"]
r = vicon.rpy[vicon.drone_name]
drone_rot_and_agl_vel = [r[0], r[1], r[2], 0, 0, 0]
vicon_obs = drone_pos_and_vel + drone_rot_and_agl_vel
# In sim2real: Reward always 0, done always false
Expand Down

0 comments on commit ab59746

Please sign in to comment.