Update benchmark.yaml and inference server: Adjusted benchmark parameters, modified inference server routes, and enhanced policy handling for chunked actions.
This commit is contained in:
@@ -30,8 +30,8 @@ scene:
|
||||
- 0.001
|
||||
- 0.001
|
||||
position:
|
||||
- 0.0
|
||||
- -3.79243
|
||||
- 0.2
|
||||
- -4.15243
|
||||
- 0.5
|
||||
quaternion:
|
||||
- -0.304408012043137
|
||||
@@ -204,7 +204,8 @@ extension:
|
||||
stereotype: starvla
|
||||
robot_name: Franka
|
||||
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||
prompt: pick the cola bottle and place it on the book
|
||||
prompt: pick up the white plug
|
||||
run_trunk_size: 16
|
||||
|
||||
recorder:
|
||||
enable: false # set to true to record the data
|
||||
@@ -214,7 +215,7 @@ extension:
|
||||
backend_root_path: output://benchmark_record
|
||||
|
||||
policy_server:
|
||||
ckpt_path: checkpoints://0309_qwenpi_droid_cola_post/final_model/pytorch_model.pt
|
||||
ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||
ckpt_source: local
|
||||
host: 0.0.0.0
|
||||
port: 5000
|
||||
|
||||
@@ -111,14 +111,14 @@ class StarvlaInferenceServer:
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0] # (8, 7)
|
||||
actions = actions[0] # (16, 10)
|
||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(),
|
||||
"gripper_chunks": actions[:, 6:7].tolist()}
|
||||
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||
"gripper_chunks": actions[:, 9:10].tolist()}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
@self.app.route("/policy", methods=["POST"])
|
||||
@self.app.route("/policy/inference", methods=["POST"])
|
||||
def policy():
|
||||
try:
|
||||
data = pickle.loads(request.data)
|
||||
@@ -136,6 +136,11 @@ class StarvlaInferenceServer:
|
||||
mimetype="application/octet-stream",
|
||||
status=500,
|
||||
)
|
||||
@self.app.route("/policy/health", methods=["GET"])
|
||||
def health():
|
||||
if self.model is None:
|
||||
return Response("Failed to load model", mimetype="application/json", status=500)
|
||||
return Response("OK", mimetype="application/json")
|
||||
|
||||
def run(self):
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from joysim.utils.log import Log
|
||||
from joysim.utils.pose import Pose
|
||||
import numpy as np
|
||||
import requests
|
||||
import time
|
||||
|
||||
@configclass
|
||||
@stereotype.register_config("starvla")
|
||||
@@ -37,6 +38,12 @@ class StarvlaPolicyConfig(PolicyConfig):
|
||||
comment="task instruction"
|
||||
)
|
||||
|
||||
run_trunk_size: int = field(
|
||||
default=16,
|
||||
required=True,
|
||||
comment="The number of chunks to run in one inference step"
|
||||
)
|
||||
|
||||
|
||||
@stereotype.register_model("starvla")
|
||||
class StarvlaPolicy(Policy):
|
||||
@@ -51,33 +58,39 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_ee_position_state = None
|
||||
self.current_ee_euler_xyz_state = None
|
||||
self.current_ee_rot6d_state = None
|
||||
self.current_gripper_state = None
|
||||
self.current_chunk_id = 0
|
||||
self.current_chunk_result = None
|
||||
self.run_trunk_size = self.config.run_trunk_size
|
||||
|
||||
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||
pass
|
||||
def needs_observation(self) -> bool:
|
||||
return True
|
||||
Log.info(f"Waiting for StarVLA inference server to be ready...")
|
||||
while True:
|
||||
try:
|
||||
if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
Log.success(f"StarVLA inference server is ready.")
|
||||
|
||||
def needs_observation(self) -> bool:
|
||||
return self.current_chunk_id == 0
|
||||
|
||||
def _handle_server_error(self, response: requests.Response) -> None:
|
||||
if response.status_code == 500:
|
||||
err_obj = pickle.loads(response.content)
|
||||
Log.error(f"StarVLA server error: {err_obj['error']}")
|
||||
Log.error(f"Traceback: {err_obj['traceback']}")
|
||||
exit(0)
|
||||
Log.error(f"Traceback: {err_obj['traceback']}", exit=True)
|
||||
elif response.status_code != 200:
|
||||
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}")
|
||||
exit(0)
|
||||
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
|
||||
|
||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||
ee_pose_base = robot_obs["ee_pose_base"]
|
||||
ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"]
|
||||
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
|
||||
state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])])
|
||||
self.current_ee_position_state = np.array(ee_position).astype(np.float64)
|
||||
self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64)
|
||||
self.current_gripper_state = np.array([gripper])
|
||||
state = np.concatenate([ee_position,ee_rot6d,np.array([gripper])])
|
||||
rgb_data = {}
|
||||
for sensor_name in self.sensor_names:
|
||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||
@@ -87,22 +100,36 @@ class StarvlaPolicy(Policy):
|
||||
return obs
|
||||
|
||||
def compute_action(self, observation: dict) -> dict:
|
||||
payload = pickle.dumps(observation)
|
||||
response = requests.post(
|
||||
self.server_url,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/octet-stream"}
|
||||
)
|
||||
self._handle_server_error(response)
|
||||
result = pickle.loads(response.content)
|
||||
if self.current_chunk_result is None:
|
||||
self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64)
|
||||
self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64)
|
||||
self.current_gripper_state = np.array([observation["state"][9]])
|
||||
payload = pickle.dumps(observation)
|
||||
response = requests.post(
|
||||
f"{self.server_url}/inference",
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/octet-stream"}
|
||||
)
|
||||
self._handle_server_error(response)
|
||||
result = pickle.loads(response.content)
|
||||
max_trunk_size = len(result["ee_delta_position_chunks"])
|
||||
if self.run_trunk_size > max_trunk_size:
|
||||
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
|
||||
self.run_trunk_size = max_trunk_size
|
||||
self.run_trunk_size = max_trunk_size
|
||||
self.current_chunk_result = result
|
||||
else:
|
||||
result = self.current_chunk_result
|
||||
|
||||
return result
|
||||
|
||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||
benchmark_action = BenchmarkAction()
|
||||
|
||||
# get base frame end-effector pose
|
||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][0], euler_xyz=action["ee_delta_euler_xyz_chunks"][0])
|
||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, euler_xyz=self.current_ee_euler_xyz_state)
|
||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
||||
Log.debug(f"trunck_id: {self.current_chunk_id}, curr_state_ee_pose: {curr_state_ee_pose}")
|
||||
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
||||
ik_result = MotionPlanController.solve_ik(
|
||||
robot_name=self.robot_name,
|
||||
@@ -122,5 +149,8 @@ class StarvlaPolicy(Policy):
|
||||
joint_positions=joint_positions
|
||||
)
|
||||
)
|
||||
|
||||
self.current_chunk_id += 1
|
||||
if self.current_chunk_id == self.run_trunk_size:
|
||||
self.current_chunk_id = 0
|
||||
self.current_chunk_result = None
|
||||
return benchmark_action
|
||||
Reference in New Issue
Block a user