Enhance benchmark configuration and gripper handling: Added 'ee_link_name' and 'action_frequency' to benchmark.yaml, introduced gripper width mapping in policy, and updated inference server to reflect gripper width in actions.

This commit is contained in:
hufei.hofee
2026-03-20 23:05:18 +08:00
parent 457c26b868
commit 833ade73fd
4 changed files with 467 additions and 25 deletions

View File

@@ -1,8 +1,13 @@
import pickle
import time
import json
import numpy as np
import requests
from joysim.annotations.config_class import configclass, field
from joysim.annotations.stereotype import stereotype
from joysim.controllers.motion_plan_controller import MotionPlanController
from joysim.controllers.spawnable_controller import SpawnableController
from joysim.core.robots.configs.actuator_configs.grippers import GripperDriveJointConfig
from joysim.extensions.benchmark.action import RobotAction
from joysim.extensions.benchmark.benchmark import (
BenchmarkAction,
@@ -12,15 +17,13 @@ from joysim.extensions.benchmark.benchmark import (
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
from joysim.utils.log import Log
from joysim.utils.pose import Pose
import numpy as np
import requests
import time
@configclass
@stereotype.register_config("starvla")
class StarvlaPolicyConfig(PolicyConfig):
robot_name: str = field(default="None", required=True, comment="The name of the robot")
gripper_width_mapper_file: str = field(default="", required=True, comment="The file path to the gripper width mapper")
sensor_names: list[str] = field(
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
required=True,
@@ -55,14 +58,24 @@ class StarvlaPolicy(Policy):
self.sensor_names = config.sensor_names
self.server_url = config.server_url
self.prompt = config.prompt
self.gripper_width_mapper = json.load(open(config.gripper_width_mapper_file, "r"))
def reset(self) -> None:
self.current_ee_position_state = None
self.current_ee_rot6d_state = None
self.current_gripper_state = None
self.current_gripper_width = None
self.current_chunk_id = 0
self.current_chunk_result = None
self.run_trunk_size = self.config.run_trunk_size
self.drive_joints: dict[str, GripperDriveJointConfig] = SpawnableController.control_robot(self.robot_name, "get_gripper_drive_joints").unwrap()
for joint_name, joint_config in self.drive_joints.items():
SpawnableController.control_robot(self.robot_name, "set_joint_stiffness", parameters={"joint_names": [joint_name], "stiffness": joint_config.position_control_stiffness}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_damping", parameters={"joint_names": [joint_name], "damping": joint_config.position_control_damping}).unwrap()
SpawnableController.control_robot(self.robot_name, "set_joint_effort_limit", parameters={"joint_names": [joint_name], "effort_limit": joint_config.position_control_effort_limit}).unwrap()
self.max_width = float("-inf")
self.min_width = float("inf")
for entry in self.gripper_width_mapper:
self.max_width = max(self.max_width, entry["width"])
self.min_width = min(self.min_width, entry["width"])
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
Log.info(f"Waiting for StarVLA inference server to be ready...")
@@ -87,10 +100,10 @@ class StarvlaPolicy(Policy):
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
ee_pose_base = robot_obs["ee_pose_base"]
ee_pose_base = robot_obs["ee_pose"]["base_frame"]
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
state = np.concatenate([ee_position,ee_rot6d,np.array([gripper])])
normalized_gripper_width = self.__map_joint_position_to_normalized_width(robot_obs["gripper_drive_state"]["position"][0])
state = np.concatenate([ee_position,ee_rot6d,np.array([normalized_gripper_width])])
rgb_data = {}
for sensor_name in self.sensor_names:
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
@@ -103,7 +116,7 @@ class StarvlaPolicy(Policy):
if self.current_chunk_result is None:
self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64)
self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64)
self.current_gripper_state = np.array([observation["state"][9]])
self.current_gripper_width = np.array([observation["state"][9]])
payload = pickle.dumps(observation)
response = requests.post(
f"{self.server_url}/inference",
@@ -122,6 +135,26 @@ class StarvlaPolicy(Policy):
result = self.current_chunk_result
return result
def __map_joint_position_to_normalized_width(self, joint_position: float) -> float:
for entry in self.gripper_width_mapper:
if round(entry["angel"], 2) == round(joint_position, 2):
return 1-(entry["width"] - self.min_width) / (self.max_width - self.min_width)
raise ValueError(f"Joint position {joint_position} not found in gripper width mapper")
def __map_gripper_joint_position(self, normalized_gripper_width: float) -> float:
joint_positions = []
joint_names = []
if normalized_gripper_width > 0.5:
for joint_name, joint_config in self.drive_joints.items():
joint_positions.append(joint_config.close_position)
joint_names.append(joint_name)
else:
for joint_name, joint_config in self.drive_joints.items():
joint_positions.append(joint_config.open_position)
joint_names.append(joint_name)
return joint_positions, joint_names
def postprocess_action(self, action: dict) -> BenchmarkAction:
benchmark_action = BenchmarkAction()
@@ -129,24 +162,24 @@ class StarvlaPolicy(Policy):
# get base frame end-effector pose
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
Log.debug(f"trunck_id: {self.current_chunk_id}, curr_state_ee_pose: {curr_state_ee_pose}")
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
ik_result = MotionPlanController.solve_ik(
robot_name=self.robot_name,
base_frame_ee_pose=curr_action_ee_pose,
).unwrap()
if not ik_result["success"]:
Log.error(f"IK failed. Ignore this action.")
return benchmark_action
curr_action_gripper_width = action["gripper_width_chunks"][self.current_chunk_id]
joint_names = ik_result["result"]["plannable_joint_names"]
joint_positions = ik_result["result"]["plannable_joint_positions"][0]
gripper_joint_positions, gripper_joint_names = self.__map_gripper_joint_position(curr_action_gripper_width[0])
Log.debug(f"action_gripper_joint_positions: {gripper_joint_positions}, action_normalized_gripper_width: {round(curr_action_gripper_width[0], 2)}")
benchmark_action.add_robot_action(
RobotAction(
control_mode=ControlMode.POSITION,
robot_name=self.robot_name,
joint_names=joint_names,
joint_positions=joint_positions
joint_names=gripper_joint_names,
joint_positions=gripper_joint_positions
)
)
benchmark_action.add_robot_action(
RobotAction(
control_mode=ControlMode.EE_POSE,
robot_name=self.robot_name,
ee_pose=curr_action_ee_pose
)
)
self.current_chunk_id += 1