adapt for dexterous hands
This commit is contained in:
@@ -21,6 +21,39 @@ def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.
|
||||
return np.pad(x, pad_width, constant_values=value)
|
||||
return x
|
||||
|
||||
|
||||
def normalize_states(states, statistics):
|
||||
stats = statistics["new_embodiment"]["state"]
|
||||
q01 = np.array(stats["q01"]).astype(states.dtype)
|
||||
q99 = np.array(stats["q99"]).astype(states.dtype)
|
||||
|
||||
# In the case of q01 == q99, the normalization will be undefined
|
||||
# So we set the normalized values to the original values
|
||||
mask = q01 != q99
|
||||
normalized = np.zeros_like(states)
|
||||
|
||||
# Normalize the values where q01 != q99
|
||||
# Formula: 2 * (x - q01) / (q99 - q01) - 1
|
||||
normalized[..., mask] = (states[..., mask] - q01[..., mask]) / (
|
||||
q99[..., mask] - q01[..., mask]
|
||||
)
|
||||
normalized[..., mask] = 2 * normalized[..., mask] - 1
|
||||
|
||||
# Set the normalized values to the original values where q01 == q99
|
||||
normalized[..., ~mask] = states[..., ~mask]
|
||||
|
||||
# Clip the normalized values to be between -1 and 1
|
||||
normalized = np.clip(normalized, -1, 1)
|
||||
return normalized
|
||||
|
||||
|
||||
def unnormalize_actions(normalized_actions, statistics):
|
||||
stats = statistics["new_embodiment"]["action"]
|
||||
q01 = np.array(stats["q01"]).astype(normalized_actions.dtype)
|
||||
q99 = np.array(stats["q99"]).astype(normalized_actions.dtype)
|
||||
|
||||
return (normalized_actions + 1) / 2 * (q99 - q01) + q01
|
||||
|
||||
class StarvlaInferenceServer:
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
@@ -38,8 +71,6 @@ class StarvlaInferenceServer:
|
||||
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
||||
self.port = policy_server_cfg.get("port", 5000)
|
||||
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
||||
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
||||
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
||||
|
||||
print("Loading StarVLA model...")
|
||||
self.model = self.load_model()
|
||||
@@ -74,42 +105,36 @@ class StarvlaInferenceServer:
|
||||
model = build_framework(cfg=cfg)
|
||||
model.norm_stats = norm_stats
|
||||
|
||||
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if self.use_bf16:
|
||||
model = model.to(torch.bfloat16)
|
||||
model = model.eval()
|
||||
|
||||
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model = model.to("cuda")
|
||||
|
||||
model = model.to("cuda").eval()
|
||||
|
||||
self.norm_stats = norm_stats
|
||||
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
||||
|
||||
return model
|
||||
|
||||
def parse_observation(self, obs, target_size=(320, 180)):
|
||||
|
||||
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
||||
head_rgb = obs["rgb"]["head_camera"]
|
||||
|
||||
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
||||
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||
|
||||
state_vec = obs["state"]
|
||||
# import ipdb;ipdb.set_trace()
|
||||
img_head = Image.fromarray(cv2.resize(head_rgb, target_size))
|
||||
state_vec = normalize_states(obs["state"], self.norm_stats)
|
||||
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||
return img_head, state_vec, obs["prompt"]
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||
img_head, state_vec, prompt = \
|
||||
self.parse_observation(observation)
|
||||
print(f"{state_vec.shape}")
|
||||
vla_input = {
|
||||
# "batch_images": [[img_left, img_right, img_wrist]],
|
||||
"image": [img_left],
|
||||
"image": [img_head],
|
||||
"lang": prompt,
|
||||
"state": state_vec
|
||||
"state": state_vec[None, :], # (1, 62)
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -122,9 +147,16 @@ class StarvlaInferenceServer:
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0] # (16, 10)
|
||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||
"gripper_width_chunks": actions[:, 9:10].tolist()}
|
||||
actions = unnormalize_actions(actions, self.norm_stats)
|
||||
return {"left_arm": {
|
||||
"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||
"finger_chunks": actions[:, 9:31].tolist()},
|
||||
"right_arm": {
|
||||
"ee_delta_position_chunks": actions[:, 31:34].tolist(),
|
||||
"ee_delta_rot6d_chunks": actions[:, 34:40].tolist(),
|
||||
"finger_chunks": actions[:, 40:62].tolist()}
|
||||
}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user