finish benchmark debug
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
.vscode
|
.vscode
|
||||||
|
__pycache__/
|
||||||
28
Readme.md
Normal file
28
Readme.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# StarVLA Benchmark
|
||||||
|
|
||||||
|
## 1 Installation
|
||||||
|
|
||||||
|
### 1.1 Install Fastsim
|
||||||
|
|
||||||
|
Install Fastsim from: [https://git.hofee.cn/hofee/fastsim.git](https://git.hofee.cn/hofee/fastsim.git)
|
||||||
|
|
||||||
|
### 1.2 Install StarVLA
|
||||||
|
|
||||||
|
Install edited version of StarVLA from `XiongHao Wu`
|
||||||
|
|
||||||
|
## 2 Run Benchmark
|
||||||
|
|
||||||
|
### 2.1 Run FastSim
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate fastsim
|
||||||
|
joysim launch_simulation --config ./benchmark.yaml # policy will be auto-scanned.
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 Run StarVLA inference server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate starvla
|
||||||
|
python starvla_inference_server.py --config ./benchmark.yaml
|
||||||
|
```
|
||||||
|
|
||||||
234
benchmark.yaml
234
benchmark.yaml
@@ -2,7 +2,6 @@ general:
|
|||||||
scan_project: true
|
scan_project: true
|
||||||
root_paths:
|
root_paths:
|
||||||
asset: /home/ubuntu/projects/gen_data/data
|
asset: /home/ubuntu/projects/gen_data/data
|
||||||
output: /home/ubuntu/output
|
|
||||||
checkpoints: /home/ubuntu/data/models
|
checkpoints: /home/ubuntu/data/models
|
||||||
|
|
||||||
simulation:
|
simulation:
|
||||||
@@ -13,160 +12,197 @@ simulation:
|
|||||||
livestream: 0
|
livestream: 0
|
||||||
|
|
||||||
scene:
|
scene:
|
||||||
name: default_scene_name
|
name: kujiale_multispace
|
||||||
position: [0, 0, 0]
|
|
||||||
rotation: [1, 0, 0, 0]
|
|
||||||
base_config:
|
base_config:
|
||||||
name: default_base
|
stereotype: usd
|
||||||
source: primitive
|
name: _827313_home_workspace_00
|
||||||
stereotype: ground_plane
|
source: local
|
||||||
ground_size: [100,100]
|
asset_path: asset://scenes/kujiale_multispace/827313_home/workspace_00.usd
|
||||||
|
|
||||||
object_cfg_dict:
|
object_cfg_dict:
|
||||||
table:
|
omni6DPose_can_016:
|
||||||
name: simple_table
|
name: omni6DPose_can_016
|
||||||
position: [0.5, 0, 0.25]
|
|
||||||
source: primitive
|
|
||||||
stereotype: rigid
|
|
||||||
primitive_type: cuboid
|
|
||||||
primitive_size: [0.5, 1, 0.5]
|
|
||||||
mass: 1e4
|
|
||||||
|
|
||||||
target:
|
|
||||||
name: target
|
|
||||||
position: [0.4, 0.0, 0.5]
|
|
||||||
scale: [0.001, 0.001, 0.001]
|
|
||||||
axis_y_up: true
|
|
||||||
asset_path: asset://objects/omni6DPose/ball/omni6DPose_ball_020/Aligned.usd
|
|
||||||
stereotype: rigid
|
stereotype: rigid
|
||||||
source: local
|
source: local
|
||||||
|
asset_path: asset://objects/omni6DPose/can/omni6DPose_can_016/Aligned.usd
|
||||||
|
scale:
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
position:
|
||||||
|
- 0.0
|
||||||
|
- -3.79243
|
||||||
|
- 0.5
|
||||||
|
quaternion:
|
||||||
|
- -0.304408012043137
|
||||||
|
- -0.304408012043137
|
||||||
|
- 0.638228612805745
|
||||||
|
- 0.6382286128057448
|
||||||
|
axis_y_up: true
|
||||||
|
omni6DPose_plug_001:
|
||||||
|
name: omni6DPose_plug_001
|
||||||
|
stereotype: rigid
|
||||||
|
source: local
|
||||||
|
asset_path: asset://objects/omni6DPose/plug/omni6DPose_plug_001/Aligned.usd
|
||||||
|
scale:
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
- 0.001
|
||||||
|
position:
|
||||||
|
- 0.419859
|
||||||
|
- -4.152430000000001
|
||||||
|
- 0.510259093
|
||||||
|
quaternion:
|
||||||
|
- 0.7070997233636068
|
||||||
|
- 0.707099723363607
|
||||||
|
- 0.003159306745230588
|
||||||
|
- 0.003159306745230588
|
||||||
|
axis_y_up: true
|
||||||
robot_cfg_dict:
|
robot_cfg_dict:
|
||||||
robot:
|
Franka:
|
||||||
name: my_robot
|
name: Franka
|
||||||
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
||||||
position: [0, 0, 0]
|
position:
|
||||||
|
- -0.3779152929316859
|
||||||
|
- -3.943187336951741
|
||||||
|
- 0.3130090550954169
|
||||||
|
rotation:
|
||||||
|
- 0.9987576855319703
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.04983056883903615
|
||||||
stereotype: single_gripper_arm_robot
|
stereotype: single_gripper_arm_robot
|
||||||
source: local
|
source: local
|
||||||
init_joint_position:
|
init_joint_position:
|
||||||
panda_joint2: -0.1633
|
panda_joint2: -0.1633
|
||||||
panda_joint4: -1.070
|
panda_joint4: -1.07
|
||||||
panda_joint6: 0.8933
|
panda_joint6: 0.8933
|
||||||
panda_joint7: 0.785
|
panda_joint7: 0.785
|
||||||
|
|
||||||
arm_actuator_name: franka_arm
|
arm_actuator_name: franka_arm
|
||||||
gripper_actuator_name: robotiq_2f_85
|
gripper_actuator_name: robotiq_2f_85
|
||||||
|
|
||||||
use_planner: true
|
use_planner: true
|
||||||
planner_cfg:
|
planner_cfg:
|
||||||
stereotype: curobo
|
stereotype: curobo
|
||||||
lazy_init: true
|
lazy_init: true
|
||||||
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
||||||
world_config_source: stage
|
world_config_source: stage
|
||||||
world_stage_ignore_substrings: [my_robot]
|
|
||||||
world_stage_only_paths: [/World]
|
|
||||||
world_stage_reference_prim_path: /World/Robot/SingleGripperArmRobot/my_robot
|
|
||||||
|
|
||||||
sensor_cfg_dict:
|
sensor_cfg_dict:
|
||||||
front_camera:
|
Hand_Camera:
|
||||||
name: front_camera
|
name: Hand_Camera
|
||||||
stereotype: camera
|
stereotype: camera
|
||||||
position: [0.8, 0.0, 0.8]
|
data_types:
|
||||||
data_types: [rgb, depth, normals]
|
- rgb
|
||||||
|
- depth
|
||||||
|
- normals
|
||||||
width: 1280
|
width: 1280
|
||||||
height: 720
|
height: 720
|
||||||
camera_model: pinhole
|
camera_model: pinhole
|
||||||
fix_camera: true
|
fix_camera: true
|
||||||
|
focal_length: 2.8
|
||||||
left_camera:
|
horizontal_aperture: 4.890881131191918
|
||||||
name: left_camera
|
vertical_aperture: 2.7608816125932627
|
||||||
|
convention: opengl
|
||||||
|
attach_to:
|
||||||
|
target_name: Franka
|
||||||
|
is_articulation_part: true
|
||||||
|
articulation_part_name: panda_link8
|
||||||
|
create_fixed_joint: true
|
||||||
|
local_position:
|
||||||
|
- -0.07176474936469446
|
||||||
|
- 0.02890958100382394
|
||||||
|
- 0.01978286477078585
|
||||||
|
local_rotation:
|
||||||
|
- 0.12352531576657892
|
||||||
|
- 0.7000519980574638
|
||||||
|
- -0.6933396337721066
|
||||||
|
- -0.11810524383481495
|
||||||
|
Left_Camera:
|
||||||
|
name: Left_Camera
|
||||||
stereotype: camera
|
stereotype: camera
|
||||||
position: [0.6, 0.7, 0.8]
|
data_types:
|
||||||
data_types: [rgb, depth, normals]
|
- rgb
|
||||||
|
- depth
|
||||||
|
- normals
|
||||||
width: 1280
|
width: 1280
|
||||||
height: 720
|
height: 720
|
||||||
camera_model: pinhole
|
camera_model: pinhole
|
||||||
fix_camera: true
|
fix_camera: false
|
||||||
|
focal_length: 2.1
|
||||||
right_camera:
|
horizontal_aperture: 5.019302546405283
|
||||||
name: right_camera
|
vertical_aperture: 2.833796298140747
|
||||||
|
convention: opengl
|
||||||
|
attach_to:
|
||||||
|
target_name: Franka
|
||||||
|
local_position:
|
||||||
|
- -0.07383269512283744
|
||||||
|
- -0.4566116797983716
|
||||||
|
- 0.5664518136443555
|
||||||
|
local_rotation:
|
||||||
|
- 0.7669289562134721
|
||||||
|
- 0.4437472984144337
|
||||||
|
- -0.232167881174614
|
||||||
|
- -0.4012560108236337
|
||||||
|
Right_Camera:
|
||||||
|
name: Right_Camera
|
||||||
stereotype: camera
|
stereotype: camera
|
||||||
position: [0.6, -0.7, 0.8]
|
data_types:
|
||||||
data_types: [rgb, depth, normals]
|
- rgb
|
||||||
|
- depth
|
||||||
|
- normals
|
||||||
width: 1280
|
width: 1280
|
||||||
height: 720
|
height: 720
|
||||||
camera_model: pinhole
|
camera_model: pinhole
|
||||||
fix_camera: true
|
fix_camera: false
|
||||||
|
focal_length: 2.1
|
||||||
|
horizontal_aperture: 5.053278483887542
|
||||||
|
vertical_aperture: 2.833796298140747
|
||||||
|
convention: opengl
|
||||||
|
attach_to:
|
||||||
|
target_name: Franka
|
||||||
|
local_position:
|
||||||
|
- 0.4501524009531093
|
||||||
|
- 0.7206899873248545
|
||||||
|
- 0.27525652672526973
|
||||||
|
local_rotation:
|
||||||
|
- 0.007056820167439314
|
||||||
|
- 0.007107689842383092
|
||||||
|
- 0.709606075042274
|
||||||
|
- 0.7045274304789896
|
||||||
|
|
||||||
extension:
|
extension:
|
||||||
extension_cfg_dict:
|
extension_cfg_dict:
|
||||||
my_data_collect:
|
benchmark_data_collect:
|
||||||
enable: true
|
enable: true
|
||||||
stereotype: data_collect
|
stereotype: data_collect
|
||||||
observer_cfgs:
|
observer_cfgs:
|
||||||
- stereotype: robot_observer
|
- stereotype: robot_observer
|
||||||
name: my_robot
|
name: Franka
|
||||||
observe_joint_positions: true
|
|
||||||
observe_joint_velocities: true
|
|
||||||
observe_joint_accelerations: true
|
|
||||||
observe_joint_position_targets: true
|
|
||||||
observe_joint_velocity_targets: true
|
|
||||||
observe_position: true
|
|
||||||
observe_rotation: true
|
|
||||||
observe_ee_pose: true
|
observe_ee_pose: true
|
||||||
observe_gripper_state: true
|
observe_gripper_state: true
|
||||||
observe_gripper_drive_state: true
|
observe_gripper_drive_state: true
|
||||||
- stereotype: sensor_observer
|
- stereotype: sensor_observer
|
||||||
name: front_camera
|
name: Hand_Camera
|
||||||
observe_intrinsic_matrix: true
|
|
||||||
observe_extrinsic_matrix: true
|
|
||||||
observe_rgb: true
|
observe_rgb: true
|
||||||
observe_depth: true
|
|
||||||
observe_normals: true
|
|
||||||
- stereotype: sensor_observer
|
- stereotype: sensor_observer
|
||||||
name: left_camera
|
name: Left_Camera
|
||||||
observe_intrinsic_matrix: true
|
|
||||||
observe_extrinsic_matrix: true
|
|
||||||
observe_rgb: true
|
observe_rgb: true
|
||||||
observe_depth: true
|
|
||||||
observe_normals: true
|
|
||||||
- stereotype: sensor_observer
|
- stereotype: sensor_observer
|
||||||
name: right_camera
|
name: Right_Camera
|
||||||
observe_intrinsic_matrix: true
|
|
||||||
observe_extrinsic_matrix: true
|
|
||||||
observe_rgb: true
|
observe_rgb: true
|
||||||
observe_depth: true
|
|
||||||
observe_normals: true
|
|
||||||
|
|
||||||
- stereotype: task_observer
|
starvla_benchmark:
|
||||||
name: task
|
|
||||||
|
|
||||||
- stereotype: object_observer
|
|
||||||
name: target
|
|
||||||
observe_position: true
|
|
||||||
observe_rotation: true
|
|
||||||
observe_scale: true
|
|
||||||
|
|
||||||
my_benchmark:
|
|
||||||
enable: true
|
enable: true
|
||||||
stereotype: benchmark
|
stereotype: benchmark
|
||||||
data_collector_name: my_data_collect
|
data_collector_name: benchmark_data_collect
|
||||||
goals:
|
goals:
|
||||||
- name: reach_target
|
- name: cola on top of book
|
||||||
description: Reach the target
|
description: check if the cola bottle is on the book
|
||||||
stereotype: pose
|
stereotype: on_top
|
||||||
pose_A_source: ee
|
object_A_name: omni6DPose_plug_001
|
||||||
pose_A_params:
|
object_B_name: omni6DPose_can_016
|
||||||
robot_name: my_robot
|
|
||||||
pose_B_source: spawnable
|
|
||||||
pose_B_params:
|
|
||||||
spawnable_name: target
|
|
||||||
position_tolerance: 0.005
|
|
||||||
policy:
|
policy:
|
||||||
stereotype: starvla
|
stereotype: starvla
|
||||||
robot_name: my_robot
|
robot_name: Franka
|
||||||
object_name: target
|
sensor_names: [Hand_Camera, Left_Camera, Right_Camera]
|
||||||
prompt: pick the cola bottle and place it on the book
|
prompt: pick the cola bottle and place it on the book
|
||||||
|
|
||||||
policy_server:
|
policy_server:
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import pickle
|
import pickle
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -77,49 +79,17 @@ class StarvlaInferenceServer:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def parse_observation(self, obs):
|
def parse_observation(self, obs, target_size=(320, 180)):
|
||||||
|
|
||||||
rgb = obs["rgb"][-1]
|
left_rgb, right_rgb, wrist_rgb = obs["rgb"]["Left_Camera"], obs["rgb"]["Right_Camera"], obs["rgb"]["Hand_Camera"]
|
||||||
state = obs["state"][-1]
|
|
||||||
joint = obs.get("joint", None)
|
|
||||||
prompt = obs["prompt"]
|
|
||||||
|
|
||||||
left = rgb[:, :, :3]
|
img_left = Image.fromarray(cv2.resize(left_rgb, target_size))
|
||||||
right = rgb[:, :, 3:6]
|
img_right = Image.fromarray(cv2.resize(right_rgb, target_size))
|
||||||
wrist = rgb[:, :, 6:9]
|
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||||
|
|
||||||
target_size = (320, 180)
|
state_vec = obs["state"]
|
||||||
|
|
||||||
left = cv2.resize(left, target_size)
|
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||||
right = cv2.resize(right, target_size)
|
|
||||||
wrist = cv2.resize(wrist, target_size)
|
|
||||||
|
|
||||||
img_left = Image.fromarray(left)
|
|
||||||
img_right = Image.fromarray(right)
|
|
||||||
img_wrist = Image.fromarray(wrist)
|
|
||||||
|
|
||||||
if self.state_mode == "joint8":
|
|
||||||
|
|
||||||
joint_last = joint[-1]
|
|
||||||
gripper = state[9]
|
|
||||||
|
|
||||||
state_vec = np.concatenate(
|
|
||||||
[joint_last, np.array([gripper])],
|
|
||||||
axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
xyz = state[0:3]
|
|
||||||
rot6d = state[3:9]
|
|
||||||
gripper = state[9]
|
|
||||||
|
|
||||||
state_vec = np.concatenate(
|
|
||||||
[xyz, rot6d[:3], np.array([gripper])],
|
|
||||||
axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
return img_left, img_right, img_wrist, state_vec, prompt
|
|
||||||
|
|
||||||
def inference(self, observation: dict) -> dict:
|
def inference(self, observation: dict) -> dict:
|
||||||
|
|
||||||
@@ -141,18 +111,31 @@ class StarvlaInferenceServer:
|
|||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
if actions.ndim == 3:
|
if actions.ndim == 3:
|
||||||
actions = actions[0]
|
actions = actions[0] # (8, 7)
|
||||||
|
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||||
return {"action": actions.astype(np.float32)}
|
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(),
|
||||||
|
"gripper_chunks": actions[:, 6:7].tolist()}
|
||||||
|
|
||||||
def register_routes(self):
|
def register_routes(self):
|
||||||
|
|
||||||
@self.app.route("/policy", methods=["POST"])
|
@self.app.route("/policy", methods=["POST"])
|
||||||
def policy():
|
def policy():
|
||||||
data = pickle.loads(request.data)
|
try:
|
||||||
result = self.inference(data)
|
data = pickle.loads(request.data)
|
||||||
body = pickle.dumps(result, protocol=4)
|
result = self.inference(data)
|
||||||
return Response(body, mimetype="application/octet-stream")
|
body = pickle.dumps(result, protocol=4)
|
||||||
|
return Response(body, mimetype="application/octet-stream")
|
||||||
|
except Exception as e:
|
||||||
|
err_obj = {
|
||||||
|
"error": str(e),
|
||||||
|
"traceback": traceback.format_exc(),
|
||||||
|
}
|
||||||
|
body = pickle.dumps(err_obj, protocol=4)
|
||||||
|
return Response(
|
||||||
|
body,
|
||||||
|
mimetype="application/octet-stream",
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
@@ -168,6 +151,15 @@ class StarvlaInferenceServer:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config_path = "./benchmark.yaml"
|
parser = argparse.ArgumentParser(description="StarVLA inference server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="./benchmark.yaml",
|
||||||
|
help="Path to benchmark.yaml",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_path = args.config
|
||||||
server = StarvlaInferenceServer(config_path)
|
server = StarvlaInferenceServer(config_path)
|
||||||
server.run()
|
server.run()
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
from joysim.annotations.config_class import configclass, field
|
from joysim.annotations.config_class import configclass, field
|
||||||
from joysim.annotations.stereotype import stereotype
|
from joysim.annotations.stereotype import stereotype
|
||||||
from joysim.app import JoySim
|
from joysim.app import JoySim
|
||||||
|
from joysim.controllers.motion_plan_controller import MotionPlanController
|
||||||
from joysim.core.scene_manager import SceneManager
|
from joysim.core.scene_manager import SceneManager
|
||||||
from joysim.extensions.benchmark.action import RobotAction
|
from joysim.extensions.benchmark.action import RobotAction
|
||||||
from joysim.extensions.benchmark.benchmark import (
|
from joysim.extensions.benchmark.benchmark import (
|
||||||
@@ -9,19 +12,21 @@ from joysim.extensions.benchmark.benchmark import (
|
|||||||
ControlMode,
|
ControlMode,
|
||||||
)
|
)
|
||||||
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
||||||
|
from joysim.utils.log import Log
|
||||||
|
from joysim.utils.pose import Pose
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
@configclass
|
@configclass
|
||||||
@stereotype.register_config("starvla")
|
@stereotype.register_config("starvla")
|
||||||
class StarvlaPolicyConfig(PolicyConfig):
|
class StarvlaPolicyConfig(PolicyConfig):
|
||||||
|
|
||||||
robot_name: str = field(default="my_robot", required=True, comment="The name of the robot")
|
robot_name: str = field(default="my_robot", required=True, comment="The name of the robot")
|
||||||
object_name: str = field(default="target", required=True, comment="The name of the object")
|
sensor_names: list[str] = field(
|
||||||
|
default=["Hand_Camera", "Left_Camera", "Right_Camera"],
|
||||||
|
required=True,
|
||||||
|
comment="The names of the sensors"
|
||||||
|
)
|
||||||
server_url: str = field(
|
server_url: str = field(
|
||||||
default="http://127.0.0.1:5000/policy",
|
default="http://127.0.0.1:5000/policy",
|
||||||
required=True,
|
required=True,
|
||||||
@@ -42,70 +47,76 @@ class StarvlaPolicy(Policy):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.robot_name = config.robot_name
|
self.robot_name = config.robot_name
|
||||||
self.object_name = config.object_name
|
self.sensor_names = config.sensor_names
|
||||||
self.server_url = config.server_url
|
self.server_url = config.server_url
|
||||||
self.prompt = config.prompt
|
self.prompt = config.prompt
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
pass
|
self.current_ee_position_state = None
|
||||||
|
self.current_ee_euler_xyz_state = None
|
||||||
|
self.current_gripper_state = None
|
||||||
|
|
||||||
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def needs_observation(self) -> bool:
|
def needs_observation(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _handle_server_error(self, response: requests.Response) -> None:
|
||||||
|
if response.status_code == 500:
|
||||||
|
err_obj = pickle.loads(response.content)
|
||||||
|
Log.error(f"StarVLA server error: {err_obj['error']}")
|
||||||
|
Log.error(f"Traceback: {err_obj['traceback']}")
|
||||||
|
exit(0)
|
||||||
|
elif response.status_code != 200:
|
||||||
|
Log.error(f"StarVLA server error with status code <{response.status_code}> : {response.text}")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||||
|
|
||||||
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||||
joint_positions = robot_obs["joint_positions"]
|
ee_pose_base = robot_obs["ee_pose_base"]
|
||||||
robot_position = robot_obs["position"]
|
ee_position, ee_euler_xyz = ee_pose_base["position"],ee_pose_base["euler_xyz"]
|
||||||
robot_quaternion = robot_obs["rotation"]
|
gripper = 1.0 if robot_obs["gripper_state"]["opened"] else 0.0
|
||||||
|
state = np.concatenate([ee_position,ee_euler_xyz,np.array([gripper])])
|
||||||
state = np.concatenate([
|
self.current_ee_position_state = np.array(ee_position).astype(np.float64)
|
||||||
robot_position,
|
self.current_ee_euler_xyz_state = np.array(ee_euler_xyz).astype(np.float64)
|
||||||
robot_quaternion,
|
self.current_gripper_state = np.array([gripper])
|
||||||
np.array([0.0])
|
rgb_data = {}
|
||||||
])
|
for sensor_name in self.sensor_names:
|
||||||
|
sensor_obs = benchmark_observation.get_sensor_observations(sensor_name)
|
||||||
camera_obs = benchmark_observation.get_sensor_observations()
|
rgb_data[sensor_name] = sensor_obs["rgb"].data.cpu().numpy().astype(np.uint8)
|
||||||
rgb = camera_obs["rgb"]
|
obs = {"state": state,"rgb": rgb_data,"prompt": self.prompt}
|
||||||
|
|
||||||
obs = {
|
|
||||||
"state": np.expand_dims(state, axis=0),
|
|
||||||
"joint": np.expand_dims(joint_positions, axis=0),
|
|
||||||
"rgb": np.expand_dims(rgb, axis=0),
|
|
||||||
"prompt": self.prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def compute_action(self, observation: dict) -> dict:
|
def compute_action(self, observation: dict) -> dict:
|
||||||
|
|
||||||
payload = pickle.dumps(observation)
|
payload = pickle.dumps(observation)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.server_url,
|
self.server_url,
|
||||||
data=payload,
|
data=payload,
|
||||||
headers={"Content-Type": "application/octet-stream"}
|
headers={"Content-Type": "application/octet-stream"}
|
||||||
)
|
)
|
||||||
|
self._handle_server_error(response)
|
||||||
if response.status_code != 200:
|
|
||||||
raise RuntimeError(f"StarVLA server error: {response.text}")
|
|
||||||
|
|
||||||
result = pickle.loads(response.content)
|
result = pickle.loads(response.content)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||||
|
|
||||||
benchmark_action = BenchmarkAction()
|
benchmark_action = BenchmarkAction()
|
||||||
|
|
||||||
robot = SceneManager.get_robot(self.robot_name)
|
# get base frame end-effector pose # TODO: Make sure add or multiply the current state
|
||||||
joint_names = robot.get_planner().get_plannable_joint_names()
|
ee_position = action["ee_delta_position_chunks"][0] + self.current_ee_position_state
|
||||||
|
ee_euler_xyz = action["ee_delta_euler_xyz_chunks"][0] + self.current_ee_euler_xyz_state
|
||||||
|
|
||||||
joint_positions = action["action"][0]
|
ee_pose = Pose(position=ee_position, euler_xyz=ee_euler_xyz)
|
||||||
|
ik_result = MotionPlanController.solve_ik(
|
||||||
|
robot_name=self.robot_name,
|
||||||
|
base_frame_ee_pose=ee_pose,
|
||||||
|
).unwrap()
|
||||||
|
if not ik_result["success"]:
|
||||||
|
Log.error(f"IK failed: {ik_result['status']}. Ignore this action.")
|
||||||
|
return benchmark_action
|
||||||
|
|
||||||
|
joint_names = ik_result["result"]["plannable_joint_names"]
|
||||||
|
joint_positions = ik_result["result"]["plannable_joint_positions"][0]
|
||||||
benchmark_action.add_robot_action(
|
benchmark_action.add_robot_action(
|
||||||
RobotAction(
|
RobotAction(
|
||||||
control_mode=ControlMode.POSITION,
|
control_mode=ControlMode.POSITION,
|
||||||
@@ -119,5 +130,5 @@ class StarvlaPolicy(Policy):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
js = JoySim("./benchmark.yaml")
|
js = JoySim("/home/ubuntu/projects/benchmark/benchmark.yaml")
|
||||||
js.start()
|
js.start()
|
||||||
Reference in New Issue
Block a user