156 lines
6.4 KiB
Python
156 lines
6.4 KiB
Python
import pickle
|
|
|
|
from joysim.annotations.config_class import configclass, field
|
|
from joysim.annotations.stereotype import stereotype
|
|
from joysim.controllers.motion_plan_controller import MotionPlanController
|
|
from joysim.extensions.benchmark.action import RobotAction
|
|
from joysim.extensions.benchmark.benchmark import (
|
|
BenchmarkAction,
|
|
BenchmarkObservation,
|
|
ControlMode,
|
|
)
|
|
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
|
from joysim.utils.log import Log
|
|
from joysim.utils.pose import Pose
|
|
import numpy as np
|
|
import requests
|
|
import time
|
|
|
|
@configclass
|
|
@stereotype.register_config("starvla")
|
|
class StarvlaPolicyConfig(PolicyConfig):
|
|
|
|
robot_name: str = field(default="None", required=True, comment="The name of the robot")
|
|
sensor_names: list[str] = field(
|
|
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
|
|
required=True,
|
|
comment="The names of the sensors"
|
|
)
|
|
server_url: str = field(
|
|
default="http://127.0.0.1:5000/policy",
|
|
required=True,
|
|
comment="StarVLA policy server url"
|
|
)
|
|
|
|
prompt: str = field(
|
|
default="pick the object",
|
|
required=True,
|
|
comment="task instruction"
|
|
)
|
|
|
|
run_trunk_size: int = field(
|
|
default=16,
|
|
required=True,
|
|
comment="The number of chunks to run in one inference step"
|
|
)
|
|
|
|
|
|
@stereotype.register_model("starvla")
|
|
class StarvlaPolicy(Policy):
|
|
|
|
def __init__(self, config: StarvlaPolicyConfig):
|
|
super().__init__(config)
|
|
|
|
self.robot_name = config.robot_name
|
|
self.sensor_names = config.sensor_names
|
|
self.server_url = config.server_url
|
|
self.prompt = config.prompt
|
|
|
|
def reset(self) -> None:
|
|
self.current_ee_position_state = None
|
|
self.current_ee_rot6d_state = None
|
|
self.current_gripper_state = None
|
|
self.current_chunk_id = 0
|
|
self.current_chunk_result = None
|
|
self.run_trunk_size = self.config.run_trunk_size
|
|
|
|
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
|
Log.info(f"Waiting for StarVLA inference server to be ready...")
|
|
while True:
|
|
try:
|
|
if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200:
|
|
break
|
|
except Exception:
|
|
time.sleep(1)
|
|
Log.success(f"StarVLA inference server is ready.")
|
|
|
|
def needs_observation(self) -> bool:
|
|
return self.current_chunk_id == 0
|
|
|
|
def _handle_server_error(self, response: requests.Response) -> None:
|
|
if response.status_code == 500:
|
|
err_obj = pickle.loads(response.content)
|
|
Log.error(f"StarVLA server error: {err_obj['error']}")
|
|
Log.error(f"Traceback: {err_obj['traceback']}", exit=True)
|
|
elif response.status_code != 200:
|
|
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
|
|
|
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
|
ee_pose_base = robot_obs["ee_pose_base"]
|
|
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
|
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
|
|
state = np.concatenate([ee_position,ee_rot6d,np.array([gripper])])
|
|
rgb_data = {}
|
|
for sensor_name in self.sensor_names:
|
|
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
|
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
|
|
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
|
|
|
|
return obs
|
|
|
|
def compute_action(self, observation: dict) -> dict:
|
|
if self.current_chunk_result is None:
|
|
self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64)
|
|
self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64)
|
|
self.current_gripper_state = np.array([observation["state"][9]])
|
|
payload = pickle.dumps(observation)
|
|
response = requests.post(
|
|
f"{self.server_url}/inference",
|
|
data=payload,
|
|
headers={"Content-Type": "application/octet-stream"}
|
|
)
|
|
self._handle_server_error(response)
|
|
result = pickle.loads(response.content)
|
|
max_trunk_size = len(result["ee_delta_position_chunks"])
|
|
if self.run_trunk_size > max_trunk_size:
|
|
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
|
|
self.run_trunk_size = max_trunk_size
|
|
self.run_trunk_size = max_trunk_size
|
|
self.current_chunk_result = result
|
|
else:
|
|
result = self.current_chunk_result
|
|
|
|
return result
|
|
|
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
|
benchmark_action = BenchmarkAction()
|
|
|
|
# get base frame end-effector pose
|
|
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
|
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
|
Log.debug(f"trunck_id: {self.current_chunk_id}, curr_state_ee_pose: {curr_state_ee_pose}")
|
|
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
|
ik_result = MotionPlanController.solve_ik(
|
|
robot_name=self.robot_name,
|
|
base_frame_ee_pose=curr_action_ee_pose,
|
|
).unwrap()
|
|
if not ik_result["success"]:
|
|
Log.error(f"IK failed. Ignore this action.")
|
|
return benchmark_action
|
|
|
|
joint_names = ik_result["result"]["plannable_joint_names"]
|
|
joint_positions = ik_result["result"]["plannable_joint_positions"][0]
|
|
benchmark_action.add_robot_action(
|
|
RobotAction(
|
|
control_mode=ControlMode.POSITION,
|
|
robot_name=self.robot_name,
|
|
joint_names=joint_names,
|
|
joint_positions=joint_positions
|
|
)
|
|
)
|
|
self.current_chunk_id += 1
|
|
if self.current_chunk_id == self.run_trunk_size:
|
|
self.current_chunk_id = 0
|
|
self.current_chunk_result = None
|
|
return benchmark_action |