-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathevaluate.py
179 lines (154 loc) · 7.25 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from config.config import POLICY_CONFIG, TASK_CONFIG, TRAIN_CONFIG, ROBOT_PORTS # must import first
import os
import cv2
import torch
import pickle
import argparse
from time import time
from robot import Robot
from training.utils import *
# parse the task name via command line
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='task1')
args = parser.parse_args()
task = args.task
# config
cfg = TASK_CONFIG
policy_config = POLICY_CONFIG
train_cfg = TRAIN_CONFIG
device = os.environ['DEVICE']
def capture_image(cam):
# Capture a single frame
_, frame = cam.read()
# Generate a unique filename with the current date and time
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Define your crop coordinates (top left corner and bottom right corner)
x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
# Crop the image
image = image[y1:y2, x1:x2]
# Resize the image
image = cv2.resize(image, (cfg['cam_width'], cfg['cam_height']), interpolation=cv2.INTER_AREA)
return image
if __name__ == "__main__":
# init camera
cam = cv2.VideoCapture(cfg['camera_port'])
# Check if the camera opened successfully
if not cam.isOpened():
raise IOError("Cannot open camera")
# init follower
follower = Robot(device_name=ROBOT_PORTS['follower'])
# load the policy
ckpt_path = os.path.join(train_cfg['checkpoint_dir'], train_cfg['eval_ckpt_name'])
policy = make_policy(policy_config['policy_class'], policy_config)
loading_status = policy.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device)))
print(loading_status)
policy.to(device)
policy.eval()
print(f'Loaded: {ckpt_path}')
stats_path = os.path.join(train_cfg['checkpoint_dir'], f'dataset_stats.pkl')
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
query_frequency = policy_config['num_queries']
if policy_config['temporal_agg']:
query_frequency = 1
num_queries = policy_config['num_queries']
# bring the follower to the leader
for i in range(90):
follower.read_position()
_ = capture_image(cam)
obs = {
'qpos': pwm2pos(follower.read_position()),
'qvel': vel2pwm(follower.read_velocity()),
'images': {cn: capture_image(cam) for cn in cfg['camera_names']}
}
os.system('say "start"')
n_rollouts = 1
for i in range(n_rollouts):
### evaluation loop
if policy_config['temporal_agg']:
all_time_actions = torch.zeros([cfg['episode_len'], cfg['episode_len']+num_queries, cfg['state_dim']]).to(device)
qpos_history = torch.zeros((1, cfg['episode_len'], cfg['state_dim'])).to(device)
with torch.inference_mode():
# init buffers
obs_replay = []
action_replay = []
for t in range(cfg['episode_len']):
qpos_numpy = np.array(obs['qpos'])
qpos = pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().to(device).unsqueeze(0)
qpos_history[:, t] = qpos
curr_image = get_image(obs['images'], cfg['camera_names'], device)
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image)
if policy_config['temporal_agg']:
all_time_actions[[t], t:t+num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = torch.from_numpy(exp_weights.astype(np.float32)).to(device).unsqueeze(dim=1)
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
else:
raw_action = all_actions[:, t % query_frequency]
### post-process actions
raw_action = raw_action.squeeze(0).cpu().numpy()
action = post_process(raw_action)
action = pos2pwm(action).astype(int)
### take action
follower.set_goal_pos(action)
### update obs
obs = {
'qpos': pwm2pos(follower.read_position()),
'qvel': vel2pwm(follower.read_velocity()),
'images': {cn: capture_image(cam) for cn in cfg['camera_names']}
}
### store data
obs_replay.append(obs)
action_replay.append(action)
os.system('say "stop"')
# create a dictionary to store the data
data_dict = {
'/observations/qpos': [],
'/observations/qvel': [],
'/action': [],
}
# there may be more than one camera
for cam_name in cfg['camera_names']:
data_dict[f'/observations/images/{cam_name}'] = []
# store the observations and actions
for o, a in zip(obs_replay, action_replay):
data_dict['/observations/qpos'].append(o['qpos'])
data_dict['/observations/qvel'].append(o['qvel'])
data_dict['/action'].append(a)
# store the images
for cam_name in cfg['camera_names']:
data_dict[f'/observations/images/{cam_name}'].append(o['images'][cam_name])
t0 = time()
max_timesteps = len(data_dict['/observations/qpos'])
# create data dir if it doesn't exist
data_dir = cfg['dataset_dir']
if not os.path.exists(data_dir): os.makedirs(data_dir)
# count number of files in the directory
idx = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name))])
dataset_path = os.path.join(data_dir, f'episode_{idx}')
# save the data
with h5py.File("data/demo/trained.hdf5", 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
root.attrs['sim'] = True
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in cfg['camera_names']:
_ = image.create_dataset(cam_name, (max_timesteps, cfg['cam_height'], cfg['cam_width'], 3), dtype='uint8',
chunks=(1, cfg['cam_height'], cfg['cam_width'], 3), )
qpos = obs.create_dataset('qpos', (max_timesteps, cfg['state_dim']))
qvel = obs.create_dataset('qvel', (max_timesteps, cfg['state_dim']))
# image = obs.create_dataset("image", (max_timesteps, 240, 320, 3), dtype='uint8', chunks=(1, 240, 320, 3))
action = root.create_dataset('action', (max_timesteps, cfg['action_dim']))
for name, array in data_dict.items():
root[name][...] = array
# disable torque
follower._disable_torque()