adapt for dexterous hands

This commit is contained in:
QiyangYan
2026-05-22 18:38:16 +08:00
parent 53796c1e63
commit 3d3da4e17f
3 changed files with 340 additions and 236 deletions

View File

@@ -21,6 +21,39 @@ def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.
return np.pad(x, pad_width, constant_values=value)
return x
def normalize_states(states, statistics):
stats = statistics["new_embodiment"]["state"]
q01 = np.array(stats["q01"]).astype(states.dtype)
q99 = np.array(stats["q99"]).astype(states.dtype)
# In the case of q01 == q99, the normalization will be undefined
# So we set the normalized values to the original values
mask = q01 != q99
normalized = np.zeros_like(states)
# Normalize the values where q01 != q99
# Formula: 2 * (x - q01) / (q99 - q01) - 1
normalized[..., mask] = (states[..., mask] - q01[..., mask]) / (
q99[..., mask] - q01[..., mask]
)
normalized[..., mask] = 2 * normalized[..., mask] - 1
# Set the normalized values to the original values where q01 == q99
normalized[..., ~mask] = states[..., ~mask]
# Clip the normalized values to be between -1 and 1
normalized = np.clip(normalized, -1, 1)
return normalized
def unnormalize_actions(normalized_actions, statistics):
stats = statistics["new_embodiment"]["action"]
q01 = np.array(stats["q01"]).astype(normalized_actions.dtype)
q99 = np.array(stats["q99"]).astype(normalized_actions.dtype)
return (normalized_actions + 1) / 2 * (q99 - q01) + q01
class StarvlaInferenceServer:
def __init__(self, config_path: str):
@@ -38,8 +71,6 @@ class StarvlaInferenceServer:
self.host = policy_server_cfg.get("host", "0.0.0.0")
self.port = policy_server_cfg.get("port", 5000)
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
print("Loading StarVLA model...")
self.model = self.load_model()
@@ -74,42 +105,36 @@ class StarvlaInferenceServer:
model = build_framework(cfg=cfg)
model.norm_stats = norm_stats
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if self.use_bf16:
model = model.to(torch.bfloat16)
model = model.eval()
state_dict = torch.load(self.ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = model.to("cuda")
model = model.to("cuda").eval()
self.norm_stats = norm_stats
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
return model
def parse_observation(self, obs, target_size=(320, 180)):
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
head_rgb = obs["rgb"]["head_camera"]
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
state_vec = obs["state"]
# import ipdb;ipdb.set_trace()
img_head = Image.fromarray(cv2.resize(head_rgb, target_size))
state_vec = normalize_states(obs["state"], self.norm_stats)
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
return img_head, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \
img_head, state_vec, prompt = \
self.parse_observation(observation)
print(f"{state_vec.shape}")
vla_input = {
# "batch_images": [[img_left, img_right, img_wrist]],
"image": [img_left],
"image": [img_head],
"lang": prompt,
"state": state_vec
"state": state_vec[None, :], # (1, 62)
}
with torch.no_grad():
@@ -122,9 +147,16 @@ class StarvlaInferenceServer:
if actions.ndim == 3:
actions = actions[0] # (16, 10)
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"gripper_width_chunks": actions[:, 9:10].tolist()}
actions = unnormalize_actions(actions, self.norm_stats)
return {"left_arm": {
"ee_delta_position_chunks": actions[:, :3].tolist(),
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
"finger_chunks": actions[:, 9:31].tolist()},
"right_arm": {
"ee_delta_position_chunks": actions[:, 31:34].tolist(),
"ee_delta_rot6d_chunks": actions[:, 34:40].tolist(),
"finger_chunks": actions[:, 40:62].tolist()}
}
def register_routes(self):