finish benchmark debug
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import pickle
|
||||
|
||||
from joysim.annotations.config_class import configclass, field
|
||||
from joysim.annotations.stereotype import stereotype
|
||||
from joysim.app import JoySim
|
||||
from joysim.controllers.motion_plan_controller import MotionPlanController
|
||||
from joysim.core.scene_manager import SceneManager
|
||||
from joysim.extensions.benchmark.action import RobotAction
|
||||
from joysim.extensions.benchmark.benchmark import (
|
||||
@@ -9,19 +12,21 @@ from joysim.extensions.benchmark.benchmark import (
|
||||
ControlMode,
|
||||
)
|
||||
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
||||
|
||||
from joysim.utils.log import Log
|
||||
from joysim.utils.pose import Pose
|
||||
import numpy as np
|
||||
import pickle
|
||||
import requests
|
||||
|
||||
|
||||
@configclass
|
||||
@stereotype.register_config("starvla")
|
||||
class StarvlaPolicyConfig(PolicyConfig):
|
||||
|
||||
robot_name: str = field(default="my_robot", required=True, comment="The name of the robot")
|
||||
object_name: str = field(default="target", required=True, comment="The name of the object")
|
||||
|
||||
sensor_names: list[str] = field(
|
||||
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
|
||||
required=True,
|
||||
comment="The names of the sensors"
|
||||
)
|
||||
server_url: str = field(
|
||||
default="http://127.0.0.1:5000/policy",
|
||||
required=True,
|
||||
@@ -34,7 +39,7 @@ class StarvlaPolicyConfig(PolicyConfig):
|
||||
comment="task instruction"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@stereotype.register_model("starvla")
|
||||
class StarvlaPolicy(Policy):
|
||||
|
||||
@@ -42,70 +47,76 @@ class StarvlaPolicy(Policy):
|
||||
super().__init__(config)
|
||||
|
||||
self.robot_name = config.robot_name
|
||||
self.object_name = config.object_name
|
||||
self.sensor_names = config.sensor_names
|
||||
self.server_url = config.server_url
|
||||
self.prompt = config.prompt
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
self.current_ee_position_state = None
|
||||
self.current_ee_euler_xyz_state = None
|
||||
self.current_gripper_state = None
|
||||
|
||||
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||
pass
|
||||
|
||||
def needs_observation(self) -> bool:
|
||||
return True
|
||||
|
||||
def _handle_server_error(self, response: requests.Response) -> None:
|
||||
if response.status_code == 500:
|
||||
err_obj = pickle.loads(response.content)
|
||||
Log.error(f"StarVLA server error: {err_obj['error']}")
|
||||
Log.error(f"Traceback: {err_obj['traceback']}")
|
||||
exit(0)
|
||||
elif response.status_code != 200:
|
||||
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}")
|
||||
exit(0)
|
||||
|
||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||
|
||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||
joint_positions = robot_obs["joint_positions"]
|
||||
robot_position = robot_obs["position"]
|
||||
robot_quaternion = robot_obs["rotation"]
|
||||
|
||||
state = np.concatenate([
|
||||
robot_position,
|
||||
robot_quaternion,
|
||||
np.array([0.0])
|
||||
])
|
||||
|
||||
camera_obs = benchmark_observation.get_sensor_observations()
|
||||
rgb = camera_obs["rgb"]
|
||||
|
||||
obs = {
|
||||
"state": np.expand_dims(state, axis=0),
|
||||
"joint": np.expand_dims(joint_positions, axis=0),
|
||||
"rgb": np.expand_dims(rgb, axis=0),
|
||||
"prompt": self.prompt
|
||||
}
|
||||
ee_pose_base = robot_obs["ee_pose_base"]
|
||||
ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"]
|
||||
gripper = 1.0 if robot_obs["gripper_state"]["opened"] else 0.0
|
||||
state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])])
|
||||
self.current_ee_position_state = np.array(ee_position).astype(np.float64)
|
||||
self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64)
|
||||
self.current_gripper_state = np.array([gripper])
|
||||
rgb_data = {}
|
||||
for sensor_name in self.sensor_names:
|
||||
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
|
||||
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
|
||||
|
||||
return obs
|
||||
|
||||
def compute_action(self, observation: dict) -> dict:
|
||||
|
||||
payload = pickle.dumps(observation)
|
||||
|
||||
response = requests.post(
|
||||
self.server_url,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/octet-stream"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"StarVLA server error: {response.text}")
|
||||
|
||||
self._handle_server_error(response)
|
||||
result = pickle.loads(response.content)
|
||||
|
||||
return result
|
||||
|
||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||
|
||||
benchmark_action = BenchmarkAction()
|
||||
|
||||
robot = SceneManager.get_robot(self.robot_name)
|
||||
joint_names = robot.get_planner().get_plannable_joint_names()
|
||||
|
||||
joint_positions = action["action"][0]
|
||||
# get base frame end-effector pose # TODO: Make sure add or multiply the current state
|
||||
ee_position = action["ee_delta_position_chunks"][0] + self.current_ee_position_state
|
||||
ee_euler_xyz = action["ee_delta_euler_xyz_chunks"][0] + self.current_ee_euler_xyz_state
|
||||
|
||||
ee_pose = Pose(position=ee_position, euler_xyz=ee_euler_xyz)
|
||||
ik_result = MotionPlanController.solve_ik(
|
||||
robot_name=self.robot_name,
|
||||
base_frame_ee_pose=ee_pose,
|
||||
).unwrap()
|
||||
if not ik_result["success"]:
|
||||
Log.error(f"IK failed: {ik_result['status']}. Ignore this action.")
|
||||
return benchmark_action
|
||||
|
||||
joint_names = ik_result["result"]["plannable_joint_names"]
|
||||
joint_positions = ik_result["result"]["plannable_joint_positions"][0]
|
||||
benchmark_action.add_robot_action(
|
||||
RobotAction(
|
||||
control_mode=ControlMode.POSITION,
|
||||
@@ -119,5 +130,5 @@ class StarvlaPolicy(Policy):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
js = JoySim("./benchmark.yaml")
|
||||
js = JoySim("/home/ubuntu/projects/benchmark/benchmark.yaml")
|
||||
js.start()
|
||||
Reference in New Issue
Block a user