Files
starvla_benchmark/starvla_policy.py

126 lines
4.9 KiB
Python

import pickle
from joysim.annotations.config_class import configclass, field
from joysim.annotations.stereotype import stereotype
from joysim.controllers.motion_plan_controller import MotionPlanController
from joysim.extensions.benchmark.action import RobotAction
from joysim.extensions.benchmark.benchmark import (
BenchmarkAction,
BenchmarkObservation,
ControlMode,
)
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
from joysim.utils.log import Log
from joysim.utils.pose import Pose
import numpy as np
import requests
@configclass
@stereotype.register_config("starvla")
class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot")
sensor_names: list[str] = field(
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
required=True,
comment="The names of the sensors"
)
server_url: str = field(
default="http://127.0.0.1:5000/policy",
required=True,
comment="StarVLA policy server url"
)
prompt: str = field(
default="pick the object",
required=True,
comment="task instruction"
)
@stereotype.register_model("starvla")
class StarvlaPolicy(Policy):
def __init__(self, config: StarvlaPolicyConfig):
super().__init__(config)
self.robot_name = config.robot_name
self.sensor_names = config.sensor_names
self.server_url = config.server_url
self.prompt = config.prompt
def reset(self) -> None:
self.current_ee_position_state = None
self.current_ee_euler_xyz_state = None
self.current_gripper_state = None
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
pass
def needs_observation(self) -> bool:
return True
def _handle_server_error(self, response: requests.Response) -> None:
if response.status_code == 500:
err_obj = pickle.loads(response.content)
Log.error(f"StarVLA server error: {err_obj['error']}")
Log.error(f"Traceback: {err_obj['traceback']}")
exit(0)
elif response.status_code != 200:
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}")
exit(0)
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
ee_pose_base = robot_obs["ee_pose_base"]
ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"]
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])])
self.current_ee_position_state = np.array(ee_position).astype(np.float64)
self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64)
self.current_gripper_state = np.array([gripper])
rgb_data = {}
for sensor_name in self.sensor_names:
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
return obs
def compute_action(self, observation: dict) -> dict:
payload = pickle.dumps(observation)
response = requests.post(
self.server_url,
data=payload,
headers={"Content-Type": "application/octet-stream"}
)
self._handle_server_error(response)
result = pickle.loads(response.content)
return result
def postprocess_action(self, action: dict) -> BenchmarkAction:
benchmark_action = BenchmarkAction()
# get base frame end-effector pose
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][0], euler_xyz=action["ee_delta_euler_xyz_chunks"][0])
curr_state_ee_pose = Pose(position=self.current_ee_position_state, euler_xyz=self.current_ee_euler_xyz_state)
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
ik_result = MotionPlanController.solve_ik(
robot_name=self.robot_name,
base_frame_ee_pose=curr_action_ee_pose,
).unwrap()
if not ik_result["success"]:
Log.error(f"IK failed. Ignore this action.")
return benchmark_action
joint_names = ik_result["result"]["plannable_joint_names"]
joint_positions = ik_result["result"]["plannable_joint_positions"][0]
benchmark_action.add_robot_action(
RobotAction(
control_mode=ControlMode.POSITION,
robot_name=self.robot_name,
joint_names=joint_names,
joint_positions=joint_positions
)
)
return benchmark_action