272 lines
12 KiB
Python
272 lines
12 KiB
Python
import pickle
|
|
import time
|
|
import json
|
|
import numpy as np
|
|
from scipy.spatial.transform import Rotation as R
|
|
import requests
|
|
|
|
from fastsim.annotations.config_class import configclass, field
|
|
from fastsim.annotations.stereotype import stereotype
|
|
from fastsim.controllers.spawnable_controller import SpawnableController
|
|
from fastsim.controllers.visualize_controller import VisualizeController
|
|
from fastsim.unisim.robots.models.modular_robot import ModularRobot
|
|
from fastsim.utils.namespace import PoseVisualType, SimulatorType
|
|
from fastsim.unisim.robots.actuator_configs.grippers import GripperDriveJointConfig
|
|
from fastsim.extensions.benchmark.action import RobotAction
|
|
from fastsim.extensions.benchmark.benchmark import (
|
|
BenchmarkAction,
|
|
BenchmarkObservation,
|
|
ControlMode,
|
|
)
|
|
from fastsim.extensions.benchmark.policy import Policy, PolicyConfig
|
|
from fastsim.utils.log import Log
|
|
from fastsim.utils.pose import Pose
|
|
|
|
@configclass
|
|
@stereotype.register_config("starvla")
|
|
class StarvlaPolicyConfig(PolicyConfig):
|
|
|
|
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
|
visualize_action_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the action end effector pose")
|
|
visualize_state_ee_pose: bool = field(default=False, required=True, comment="Whether to visualize the state end effector pose")
|
|
visualize_bounding_box_targets: list[str] = field(
|
|
default_factory=list,
|
|
required=False,
|
|
comment="Spawnable object names to draw bounding boxes for (empty = off)",
|
|
)
|
|
sensor_names: list[str] = field(
|
|
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
|
|
required=True,
|
|
comment="The names of the sensors"
|
|
)
|
|
server_url: str = field(
|
|
default="http://127.0.0.1:5000/policy",
|
|
required=True,
|
|
comment="StarVLA policy server url"
|
|
)
|
|
|
|
prompt: str = field(
|
|
default="pick the object",
|
|
required=True,
|
|
comment="task instruction"
|
|
)
|
|
|
|
run_trunk_size: int = field(
|
|
default=16,
|
|
required=True,
|
|
comment="The number of chunks to run in one inference step"
|
|
)
|
|
|
|
|
|
@stereotype.register_model("starvla")
|
|
class StarvlaPolicy(Policy):
|
|
|
|
def __init__(self, config: StarvlaPolicyConfig):
|
|
super().__init__(config)
|
|
|
|
self.robot_name = config.robot_name
|
|
self.sensor_names = config.sensor_names
|
|
self.server_url = config.server_url
|
|
self.prompt = config.prompt
|
|
self.visualize_action_ee_pose = config.visualize_action_ee_pose
|
|
self.visualize_state_ee_pose = config.visualize_state_ee_pose
|
|
self.visualize_bounding_box_targets = list(config.visualize_bounding_box_targets or [])
|
|
# prevent circular import
|
|
import pandas as pd
|
|
df_data = pd.read_parquet("/home/zhiyuan/zhujuan/datasets/add_remove_lid_15fps_10epi/data/chunk-000/file-000.parquet")
|
|
self.dummy_data = np.array(df_data.groupby('episode_index')['observation.state'].apply(list).to_dict()[0])
|
|
self.dummy_data_idx = 0
|
|
|
|
def reset(self) -> None:
|
|
self.current_state = {}
|
|
self.current_chunk_id = 0
|
|
self.current_chunk_result = None
|
|
self.run_trunk_size = self.config.run_trunk_size
|
|
self.robot: ModularRobot = SpawnableController.get_spawnable_data(self.robot_name).unwrap()
|
|
self.left_hand_joints = SpawnableController.control_robot(
|
|
self.robot_name,
|
|
"get_actuator_joint_names",
|
|
parameters={"actuator_name": "left_hand"},
|
|
).unwrap()
|
|
self.right_hand_joints = SpawnableController.control_robot(
|
|
self.robot_name,
|
|
"get_actuator_joint_names",
|
|
parameters={"actuator_name": "right_hand"},
|
|
).unwrap()
|
|
|
|
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
|
Log.info(f"Waiting for StarVLA inference server to be ready...")
|
|
while True:
|
|
try:
|
|
if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200:
|
|
break
|
|
except Exception:
|
|
time.sleep(1)
|
|
Log.success(f"StarVLA inference server is ready.")
|
|
|
|
def needs_observation(self) -> bool:
|
|
return self.current_chunk_id == 0
|
|
|
|
def _handle_server_error(self, response: requests.Response) -> None:
|
|
if response.status_code == 500:
|
|
err_obj = pickle.loads(response.content)
|
|
Log.error(f"StarVLA server error: {err_obj['error']}")
|
|
Log.error(f"Traceback: {err_obj['traceback']}", exit=True)
|
|
elif response.status_code != 200:
|
|
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
|
|
|
|
def split_joints(self, state_or_action, keys=None) -> list[dict]:
|
|
if keys is None:
|
|
keys = ["left_arm", "right_arm"]
|
|
total_dim = 31 * len(keys)
|
|
assert state_or_action.shape[-1] == total_dim, f"Expected last dimension to be {total_dim}, got {state_or_action.shape[-1]}"
|
|
joints_all = np.split(state_or_action, [31], axis=-1)
|
|
return_dict = {}
|
|
for key, joints in zip(keys, joints_all):
|
|
ee_pos, ee_rot6d, finger_qpos = np.split(joints, [3, 9], axis=-1)
|
|
return_dict[key] = {
|
|
"ee_pos": ee_pos,
|
|
"ee_rot6d": ee_rot6d,
|
|
"finger_qpos": finger_qpos
|
|
}
|
|
return return_dict
|
|
|
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
|
left_ee_pose_base = robot_obs["ee_pose"]["left_arm"]["base_frame"]
|
|
left_ee_position, left_ee_rot6d = left_ee_pose_base["position"], left_ee_pose_base["rot6d"]
|
|
right_ee_pose_base = robot_obs["ee_pose"]["right_arm"]["base_frame"]
|
|
right_ee_position, right_ee_rot6d = right_ee_pose_base["position"], right_ee_pose_base["rot6d"]
|
|
finger_positions = robot_obs["joint_positions"] # use finger joints(44) only
|
|
state = np.concatenate([left_ee_position, left_ee_rot6d, finger_positions[:22],
|
|
right_ee_position, right_ee_rot6d, finger_positions[22:]], axis=-1) # (62,)
|
|
rgb_data = {}
|
|
for sensor_name in self.sensor_names:
|
|
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
|
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
|
|
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
|
|
return obs
|
|
|
|
def compute_action(self, observation: dict) -> dict:
|
|
if self.current_chunk_result is None:
|
|
self.current_state.update(self.split_joints(observation["state"]))
|
|
payload = pickle.dumps(observation)
|
|
response = requests.post(
|
|
f"{self.server_url}/inference",
|
|
data=payload,
|
|
headers={"Content-Type": "application/octet-stream"}
|
|
)
|
|
self.test_obs = observation["state"] #TODO
|
|
self._handle_server_error(response)
|
|
result = pickle.loads(response.content)
|
|
max_trunk_size = len(result["right_arm"]["ee_delta_position_chunks"])
|
|
if self.run_trunk_size > max_trunk_size:
|
|
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
|
|
self.run_trunk_size = max_trunk_size
|
|
self.run_trunk_size = max_trunk_size
|
|
self.current_chunk_result = result
|
|
else:
|
|
result = self.current_chunk_result
|
|
|
|
return result
|
|
|
|
|
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
|
benchmark_action = BenchmarkAction()
|
|
read_chunk_size = 1
|
|
dummy_action = self.dummy_data[self.dummy_data_idx:(self.dummy_data_idx + read_chunk_size)]
|
|
if self.dummy_data_idx + read_chunk_size >= self.dummy_data.shape[0]:
|
|
self.dummy_data_idx = 0
|
|
exit(0)
|
|
else:
|
|
self.dummy_data_idx += read_chunk_size
|
|
read_chunk_id = 0
|
|
print(f'{self.current_chunk_id=}, {self.dummy_data_idx = }, {read_chunk_id=}')
|
|
time.sleep(1.0)
|
|
|
|
left_rpy_state = dummy_action[:, 3:6] # (3,)
|
|
right_rpy_state = dummy_action[:, 31:34] # (3,)
|
|
|
|
left_rot_state = R.from_euler('xyz', left_rpy_state).as_matrix()
|
|
right_rot_state = R.from_euler('xyz', right_rpy_state).as_matrix()
|
|
|
|
left_state_rot6d = np.concatenate([left_rot_state[:, 0], left_rot_state[:, 1]], axis=-1) # (6,)
|
|
right_state_rot6d = np.concatenate([right_rot_state[:, 0], right_rot_state[:, 1]], axis=-1) # (6,)
|
|
|
|
read_state = {"left_arm": {
|
|
"ee_position_chunks": dummy_action[:, :3].tolist(),
|
|
"ee_rot6d_chunks": left_state_rot6d.tolist(),
|
|
"finger_chunks": dummy_action[:, 6:28].tolist()},
|
|
"right_arm": {
|
|
"ee_position_chunks": dummy_action[:, 28:31].tolist(),
|
|
"ee_rot6d_chunks": right_state_rot6d.tolist(),
|
|
"finger_chunks": dummy_action[:, 34:56].tolist()}
|
|
}
|
|
for arm_key in self.robot['arms'].keys():
|
|
action_arm = action[arm_key]
|
|
delta_ee_pose = Pose(position=action_arm["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action_arm["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
|
curr_state_ee_pose = Pose(position=self.current_state[arm_key]["ee_pos"], rot6d=self.current_state[arm_key]["ee_rot6d"])
|
|
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
|
finger_joint_qpos = action_arm["finger_chunks"][self.current_chunk_id] + self.current_state[arm_key]["finger_qpos"]
|
|
joint_names = self.left_hand_joints if arm_key == "left_arm" else self.right_hand_joints
|
|
state_arm = read_state[arm_key]
|
|
benchmark_action.add_robot_action(
|
|
RobotAction(
|
|
control_mode=ControlMode.POSITION,
|
|
robot_name=self.robot_name,
|
|
joint_names=joint_names,
|
|
# joint_positions=finger_joint_qpos
|
|
joint_positions=state_arm["finger_chunks"][read_chunk_id]
|
|
)
|
|
)
|
|
benchmark_action.add_robot_action(
|
|
RobotAction(
|
|
control_mode=ControlMode.EE_POSE,
|
|
robot_name=self.robot_name,
|
|
# ee_pose=curr_action_ee_pose,
|
|
ee_pose=Pose(position=state_arm["ee_position_chunks"][read_chunk_id], rot6d=state_arm["ee_rot6d_chunks"][read_chunk_id]),
|
|
arm_name=arm_key
|
|
)
|
|
)
|
|
self._visualize_base_frame_ee_poses(curr_state_ee_pose, curr_action_ee_pose)
|
|
self._visualize_bounding_boxes()
|
|
self.current_chunk_id += 1
|
|
if self.current_chunk_id == self.run_trunk_size:
|
|
self.current_chunk_id = 0
|
|
self.current_chunk_result = None
|
|
return benchmark_action
|
|
|
|
# ------------------- Visualization -------------------
|
|
def _visualize_base_frame_ee_poses(
|
|
self, pose_state_base: Pose, pose_action_base: Pose
|
|
) -> None:
|
|
if not self.visualize_action_ee_pose and not self.visualize_state_ee_pose:
|
|
return
|
|
robot_base_world = SpawnableController.control_robot(
|
|
self.robot_name, "get_pose"
|
|
).unwrap()
|
|
if self.visualize_state_ee_pose:
|
|
VisualizeController.create_pose_visualization(
|
|
robot_base_world * pose_state_base,
|
|
name=f"{self.robot_name}/starvla_state_ee",
|
|
simulator=SimulatorType.ISAACLAB,
|
|
pose_type=PoseVisualType.COORDINATE,
|
|
extra_params={"axis_length": 0.08, "thickness": 0.006},
|
|
).unwrap()
|
|
if self.visualize_action_ee_pose:
|
|
VisualizeController.create_pose_visualization(
|
|
robot_base_world * pose_action_base,
|
|
name=f"{self.robot_name}/starvla_action_ee",
|
|
simulator=SimulatorType.ISAACLAB,
|
|
pose_type=PoseVisualType.COORDINATE,
|
|
extra_params={"axis_length": 0.1, "thickness": 0.006},
|
|
).unwrap()
|
|
|
|
def _visualize_bounding_boxes(self) -> None:
|
|
if not self.visualize_bounding_box_targets:
|
|
return
|
|
for target_name in self.visualize_bounding_box_targets:
|
|
VisualizeController.visualize_target_bounding_box(
|
|
target_name, simulator=SimulatorType.ISAACLAB
|
|
).unwrap()
|