Update benchmark.yaml and inference server: Adjusted benchmark parameters, modified inference server routes, and enhanced policy handling for chunked actions.
This commit is contained in:
@@ -30,8 +30,8 @@ scene:
|
|||||||
- 0.001
|
- 0.001
|
||||||
- 0.001
|
- 0.001
|
||||||
position:
|
position:
|
||||||
- 0.0
|
- 0.2
|
||||||
- -3.79243
|
- -4.15243
|
||||||
- 0.5
|
- 0.5
|
||||||
quaternion:
|
quaternion:
|
||||||
- -0.304408012043137
|
- -0.304408012043137
|
||||||
@@ -204,7 +204,8 @@ extension:
|
|||||||
stereotype: starvla
|
stereotype: starvla
|
||||||
robot_name: Franka
|
robot_name: Franka
|
||||||
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||||
prompt: pick the cola bottle and place it on the book
|
prompt: pick up the white plug
|
||||||
|
run_trunk_size: 16
|
||||||
|
|
||||||
recorder:
|
recorder:
|
||||||
enable: false # set to true to record the data
|
enable: false # set to true to record the data
|
||||||
@@ -214,7 +215,7 @@ extension:
|
|||||||
backend_root_path: output://benchmark_record
|
backend_root_path: output://benchmark_record
|
||||||
|
|
||||||
policy_server:
|
policy_server:
|
||||||
ckpt_path: checkpoints://0309_qwenpi_droid_cola_post/final_model/pytorch_model.pt
|
ckpt_path: checkpoints://0318_qwenpi_droid_pretrain_8node/checkpoints/steps_30000_pytorch_model.pt
|
||||||
ckpt_source: local
|
ckpt_source: local
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
port: 5000
|
port: 5000
|
||||||
|
|||||||
@@ -111,14 +111,14 @@ class StarvlaInferenceServer:
|
|||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
if actions.ndim == 3:
|
if actions.ndim == 3:
|
||||||
actions = actions[0] # (8, 7)
|
actions = actions[0] # (16, 10)
|
||||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||||
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(),
|
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||||
"gripper_chunks": actions[:, 6:7].tolist()}
|
"gripper_chunks": actions[:, 9:10].tolist()}
|
||||||
|
|
||||||
def register_routes(self):
|
def register_routes(self):
|
||||||
|
|
||||||
@self.app.route("/policy", methods=["POST"])
|
@self.app.route("/policy/inference", methods=["POST"])
|
||||||
def policy():
|
def policy():
|
||||||
try:
|
try:
|
||||||
data = pickle.loads(request.data)
|
data = pickle.loads(request.data)
|
||||||
@@ -136,6 +136,11 @@ class StarvlaInferenceServer:
|
|||||||
mimetype="application/octet-stream",
|
mimetype="application/octet-stream",
|
||||||
status=500,
|
status=500,
|
||||||
)
|
)
|
||||||
|
@self.app.route("/policy/health", methods=["GET"])
|
||||||
|
def health():
|
||||||
|
if self.model is None:
|
||||||
|
return Response("Failed to load model", mimetype="application/json", status=500)
|
||||||
|
return Response("OK", mimetype="application/json")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from joysim.utils.log import Log
|
|||||||
from joysim.utils.pose import Pose
|
from joysim.utils.pose import Pose
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
@configclass
|
@configclass
|
||||||
@stereotype.register_config("starvla")
|
@stereotype.register_config("starvla")
|
||||||
@@ -37,6 +38,12 @@ class StarvlaPolicyConfig(PolicyConfig):
|
|||||||
comment="task instruction"
|
comment="task instruction"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
run_trunk_size: int = field(
|
||||||
|
default=16,
|
||||||
|
required=True,
|
||||||
|
comment="The number of chunks to run in one inference step"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@stereotype.register_model("starvla")
|
@stereotype.register_model("starvla")
|
||||||
class StarvlaPolicy(Policy):
|
class StarvlaPolicy(Policy):
|
||||||
@@ -51,33 +58,39 @@ class StarvlaPolicy(Policy):
|
|||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.current_ee_position_state = None
|
self.current_ee_position_state = None
|
||||||
self.current_ee_euler_xyz_state = None
|
self.current_ee_rot6d_state = None
|
||||||
self.current_gripper_state = None
|
self.current_gripper_state = None
|
||||||
|
self.current_chunk_id = 0
|
||||||
|
self.current_chunk_result = None
|
||||||
|
self.run_trunk_size = self.config.run_trunk_size
|
||||||
|
|
||||||
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||||
pass
|
Log.info(f"Waiting for StarVLA inference server to be ready...")
|
||||||
def needs_observation(self) -> bool:
|
while True:
|
||||||
return True
|
try:
|
||||||
|
if requests.get(f"{self.server_url}/health", timeout=1.0).status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
time.sleep(1)
|
||||||
|
Log.success(f"StarVLA inference server is ready.")
|
||||||
|
|
||||||
|
def needs_observation(self) -> bool:
|
||||||
|
return self.current_chunk_id == 0
|
||||||
|
|
||||||
def _handle_server_error(self, response: requests.Response) -> None:
|
def _handle_server_error(self, response: requests.Response) -> None:
|
||||||
if response.status_code == 500:
|
if response.status_code == 500:
|
||||||
err_obj = pickle.loads(response.content)
|
err_obj = pickle.loads(response.content)
|
||||||
Log.error(f"StarVLA server error: {err_obj['error']}")
|
Log.error(f"StarVLA server error: {err_obj['error']}")
|
||||||
Log.error(f"Traceback: {err_obj['traceback']}")
|
Log.error(f"Traceback: {err_obj['traceback']}", exit=True)
|
||||||
exit(0)
|
|
||||||
elif response.status_code != 200:
|
elif response.status_code != 200:
|
||||||
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}")
|
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}", exit=True)
|
||||||
exit(0)
|
|
||||||
|
|
||||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||||
ee_pose_base = robot_obs["ee_pose_base"]
|
ee_pose_base = robot_obs["ee_pose_base"]
|
||||||
ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"]
|
ee_position, ee_rot6d = ee_pose_base["position"],ee_pose_base["rot6d"]
|
||||||
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
|
gripper = 0.0 if robot_obs["gripper_state"]["opened"] else 1.0
|
||||||
state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])])
|
state = np.concatenate([ee_position,ee_rot6d,np.array([gripper])])
|
||||||
self.current_ee_position_state = np.array(ee_position).astype(np.float64)
|
|
||||||
self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64)
|
|
||||||
self.current_gripper_state = np.array([gripper])
|
|
||||||
rgb_data = {}
|
rgb_data = {}
|
||||||
for sensor_name in self.sensor_names:
|
for sensor_name in self.sensor_names:
|
||||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||||
@@ -87,22 +100,36 @@ class StarvlaPolicy(Policy):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def compute_action(self, observation: dict) -> dict:
|
def compute_action(self, observation: dict) -> dict:
|
||||||
payload = pickle.dumps(observation)
|
if self.current_chunk_result is None:
|
||||||
response = requests.post(
|
self.current_ee_position_state = np.array(observation["state"][:3]).astype(np.float64)
|
||||||
self.server_url,
|
self.current_ee_rot6d_state = np.array(observation["state"][3:9]).astype(np.float64)
|
||||||
data=payload,
|
self.current_gripper_state = np.array([observation["state"][9]])
|
||||||
headers={"Content-Type": "application/octet-stream"}
|
payload = pickle.dumps(observation)
|
||||||
)
|
response = requests.post(
|
||||||
self._handle_server_error(response)
|
f"{self.server_url}/inference",
|
||||||
result = pickle.loads(response.content)
|
data=payload,
|
||||||
|
headers={"Content-Type": "application/octet-stream"}
|
||||||
|
)
|
||||||
|
self._handle_server_error(response)
|
||||||
|
result = pickle.loads(response.content)
|
||||||
|
max_trunk_size = len(result["ee_delta_position_chunks"])
|
||||||
|
if self.run_trunk_size > max_trunk_size:
|
||||||
|
Log.warning(f"Run trunk size {self.run_trunk_size} is greater than the number of chunks {max_trunk_size}. Set run trunk size to {max_trunk_size}.")
|
||||||
|
self.run_trunk_size = max_trunk_size
|
||||||
|
self.run_trunk_size = max_trunk_size
|
||||||
|
self.current_chunk_result = result
|
||||||
|
else:
|
||||||
|
result = self.current_chunk_result
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||||
benchmark_action = BenchmarkAction()
|
benchmark_action = BenchmarkAction()
|
||||||
|
|
||||||
# get base frame end-effector pose
|
# get base frame end-effector pose
|
||||||
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][0], euler_xyz=action["ee_delta_euler_xyz_chunks"][0])
|
delta_ee_pose = Pose(position=action["ee_delta_position_chunks"][self.current_chunk_id], rot6d=action["ee_delta_rot6d_chunks"][self.current_chunk_id])
|
||||||
curr_state_ee_pose = Pose(position=self.current_ee_position_state, euler_xyz=self.current_ee_euler_xyz_state)
|
curr_state_ee_pose = Pose(position=self.current_ee_position_state, rot6d=self.current_ee_rot6d_state)
|
||||||
|
Log.debug(f"trunck_id: {self.current_chunk_id}, curr_state_ee_pose: {curr_state_ee_pose}")
|
||||||
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
curr_action_ee_pose = curr_state_ee_pose * delta_ee_pose # action2base = state2base * action2state
|
||||||
ik_result = MotionPlanController.solve_ik(
|
ik_result = MotionPlanController.solve_ik(
|
||||||
robot_name=self.robot_name,
|
robot_name=self.robot_name,
|
||||||
@@ -122,5 +149,8 @@ class StarvlaPolicy(Policy):
|
|||||||
joint_positions=joint_positions
|
joint_positions=joint_positions
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.current_chunk_id += 1
|
||||||
|
if self.current_chunk_id == self.run_trunk_size:
|
||||||
|
self.current_chunk_id = 0
|
||||||
|
self.current_chunk_result = None
|
||||||
return benchmark_action
|
return benchmark_action
|
||||||
Reference in New Issue
Block a user