finish load inference server model
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
.vscode
|
||||||
179
benchmark.yaml
Normal file
179
benchmark.yaml
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
general:
|
||||||
|
scan_project: true
|
||||||
|
root_paths:
|
||||||
|
asset: /home/ubuntu/projects/gen_data/data
|
||||||
|
output: /home/ubuntu/output
|
||||||
|
checkpoints: /home/ubuntu/data/models
|
||||||
|
|
||||||
|
simulation:
|
||||||
|
launch_config:
|
||||||
|
device: cuda
|
||||||
|
enable_cameras: true
|
||||||
|
headless: false
|
||||||
|
livestream: 0
|
||||||
|
|
||||||
|
scene:
|
||||||
|
name: default_scene_name
|
||||||
|
position: [0, 0, 0]
|
||||||
|
rotation: [1, 0, 0, 0]
|
||||||
|
base_config:
|
||||||
|
name: default_base
|
||||||
|
source: primitive
|
||||||
|
stereotype: ground_plane
|
||||||
|
ground_size: [100,100]
|
||||||
|
|
||||||
|
object_cfg_dict:
|
||||||
|
table:
|
||||||
|
name: simple_table
|
||||||
|
position: [0.5, 0, 0.25]
|
||||||
|
source: primitive
|
||||||
|
stereotype: rigid
|
||||||
|
primitive_type: cuboid
|
||||||
|
primitive_size: [0.5, 1, 0.5]
|
||||||
|
mass: 1e4
|
||||||
|
|
||||||
|
target:
|
||||||
|
name: target
|
||||||
|
position: [0.4, 0.0, 0.5]
|
||||||
|
scale: [0.001, 0.001, 0.001]
|
||||||
|
axis_y_up: true
|
||||||
|
asset_path: asset://objects/omni6DPose/ball/omni6DPose_ball_020/Aligned.usd
|
||||||
|
stereotype: rigid
|
||||||
|
source: local
|
||||||
|
|
||||||
|
|
||||||
|
robot_cfg_dict:
|
||||||
|
robot:
|
||||||
|
name: my_robot
|
||||||
|
asset_path: asset://Franka/franka_robotiq_2f85_zedmini.usd
|
||||||
|
position: [0, 0, 0]
|
||||||
|
stereotype: single_gripper_arm_robot
|
||||||
|
source: local
|
||||||
|
init_joint_position:
|
||||||
|
panda_joint2: -0.1633
|
||||||
|
panda_joint4: -1.070
|
||||||
|
panda_joint6: 0.8933
|
||||||
|
panda_joint7: 0.785
|
||||||
|
|
||||||
|
arm_actuator_name: franka_arm
|
||||||
|
gripper_actuator_name: robotiq_2f_85
|
||||||
|
|
||||||
|
use_planner: true
|
||||||
|
planner_cfg:
|
||||||
|
stereotype: curobo
|
||||||
|
lazy_init: true
|
||||||
|
robot_config_file: asset://curobo/franka_robotiq_2f85/franka_robotiq_2f85.yml
|
||||||
|
world_config_source: stage
|
||||||
|
world_stage_ignore_substrings: [my_robot]
|
||||||
|
world_stage_only_paths: [/World]
|
||||||
|
world_stage_reference_prim_path: /World/Robot/SingleGripperArmRobot/my_robot
|
||||||
|
|
||||||
|
sensor_cfg_dict:
|
||||||
|
front_camera:
|
||||||
|
name: front_camera
|
||||||
|
stereotype: camera
|
||||||
|
position: [0.8, 0.0, 0.8]
|
||||||
|
data_types: [rgb, depth, normals]
|
||||||
|
width: 1280
|
||||||
|
height: 720
|
||||||
|
camera_model: pinhole
|
||||||
|
fix_camera: true
|
||||||
|
|
||||||
|
left_camera:
|
||||||
|
name: left_camera
|
||||||
|
stereotype: camera
|
||||||
|
position: [0.6, 0.7, 0.8]
|
||||||
|
data_types: [rgb, depth, normals]
|
||||||
|
width: 1280
|
||||||
|
height: 720
|
||||||
|
camera_model: pinhole
|
||||||
|
fix_camera: true
|
||||||
|
|
||||||
|
right_camera:
|
||||||
|
name: right_camera
|
||||||
|
stereotype: camera
|
||||||
|
position: [0.6, -0.7, 0.8]
|
||||||
|
data_types: [rgb, depth, normals]
|
||||||
|
width: 1280
|
||||||
|
height: 720
|
||||||
|
camera_model: pinhole
|
||||||
|
fix_camera: true
|
||||||
|
|
||||||
|
extension:
|
||||||
|
extension_cfg_dict:
|
||||||
|
my_data_collect:
|
||||||
|
enable: true
|
||||||
|
stereotype: data_collect
|
||||||
|
observer_cfgs:
|
||||||
|
- stereotype: robot_observer
|
||||||
|
name: my_robot
|
||||||
|
observe_joint_positions: true
|
||||||
|
observe_joint_velocities: true
|
||||||
|
observe_joint_accelerations: true
|
||||||
|
observe_joint_position_targets: true
|
||||||
|
observe_joint_velocity_targets: true
|
||||||
|
observe_position: true
|
||||||
|
observe_rotation: true
|
||||||
|
observe_ee_pose: true
|
||||||
|
observe_gripper_state: true
|
||||||
|
observe_gripper_drive_state: true
|
||||||
|
- stereotype: sensor_observer
|
||||||
|
name: front_camera
|
||||||
|
observe_intrinsic_matrix: true
|
||||||
|
observe_extrinsic_matrix: true
|
||||||
|
observe_rgb: true
|
||||||
|
observe_depth: true
|
||||||
|
observe_normals: true
|
||||||
|
- stereotype: sensor_observer
|
||||||
|
name: left_camera
|
||||||
|
observe_intrinsic_matrix: true
|
||||||
|
observe_extrinsic_matrix: true
|
||||||
|
observe_rgb: true
|
||||||
|
observe_depth: true
|
||||||
|
observe_normals: true
|
||||||
|
- stereotype: sensor_observer
|
||||||
|
name: right_camera
|
||||||
|
observe_intrinsic_matrix: true
|
||||||
|
observe_extrinsic_matrix: true
|
||||||
|
observe_rgb: true
|
||||||
|
observe_depth: true
|
||||||
|
observe_normals: true
|
||||||
|
|
||||||
|
- stereotype: task_observer
|
||||||
|
name: task
|
||||||
|
|
||||||
|
- stereotype: object_observer
|
||||||
|
name: target
|
||||||
|
observe_position: true
|
||||||
|
observe_rotation: true
|
||||||
|
observe_scale: true
|
||||||
|
|
||||||
|
my_benchmark:
|
||||||
|
enable: true
|
||||||
|
stereotype: benchmark
|
||||||
|
data_collector_name: my_data_collect
|
||||||
|
goals:
|
||||||
|
- name: reach_target
|
||||||
|
description: Reach the target
|
||||||
|
stereotype: pose
|
||||||
|
pose_A_source: ee
|
||||||
|
pose_A_params:
|
||||||
|
robot_name: my_robot
|
||||||
|
pose_B_source: spawnable
|
||||||
|
pose_B_params:
|
||||||
|
spawnable_name: target
|
||||||
|
position_tolerance: 0.005
|
||||||
|
policy:
|
||||||
|
stereotype: starvla
|
||||||
|
robot_name: my_robot
|
||||||
|
object_name: target
|
||||||
|
prompt: pick the cola bottle and place it on the book
|
||||||
|
|
||||||
|
policy_server:
|
||||||
|
ckpt_path: checkpoints://0309_qwenpi_droid_cola_post/final_model/pytorch_model.pt
|
||||||
|
ckpt_source: local
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 5000
|
||||||
|
use_bf16: true
|
||||||
|
unnorm_key: oxe_bridge
|
||||||
|
state_mode: ee_pose7
|
||||||
173
starvla_inference_server.py
Normal file
173
starvla_inference_server.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
|
||||||
|
import yaml
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from flask import Flask, request, Response
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class StarvlaInferenceServer:
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
|
policy_server_cfg = cfg["policy_server"]
|
||||||
|
root_paths = cfg["general"]["root_paths"]
|
||||||
|
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||||
|
self.ckpt_path = self._resolve_ckpt_path(
|
||||||
|
ckpt_url=policy_server_cfg["ckpt_path"],
|
||||||
|
root_paths=root_paths,
|
||||||
|
)
|
||||||
|
self.host = policy_server_cfg.get("host", "0.0.0.0")
|
||||||
|
self.port = policy_server_cfg.get("port", 5000)
|
||||||
|
self.use_bf16 = policy_server_cfg.get("use_bf16", True)
|
||||||
|
self.unnorm_key = policy_server_cfg.get("unnorm_key", "oxe_bridge")
|
||||||
|
self.state_mode = policy_server_cfg.get("state_mode", "ee_pose7")
|
||||||
|
|
||||||
|
print("Loading StarVLA model...")
|
||||||
|
self.model = self.load_model()
|
||||||
|
print("Model loaded.")
|
||||||
|
|
||||||
|
self.app = Flask(__name__)
|
||||||
|
self.register_routes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_ckpt_path(ckpt_url: str, root_paths: dict) -> str:
|
||||||
|
parsed = urlparse(ckpt_url)
|
||||||
|
if not parsed.scheme:
|
||||||
|
return ckpt_url
|
||||||
|
root = root_paths.get(parsed.scheme)
|
||||||
|
if not root:
|
||||||
|
raise KeyError(
|
||||||
|
f"cannot find the checkpoint root path in root_paths: {root_paths}"
|
||||||
|
)
|
||||||
|
rel = (parsed.netloc + parsed.path).lstrip("/")
|
||||||
|
return os.path.join(root, rel)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
|
||||||
|
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
|
||||||
|
from starVLA.model.framework.__init__ import build_framework
|
||||||
|
|
||||||
|
model_config, norm_stats = read_mode_config(self.ckpt_path)
|
||||||
|
|
||||||
|
cfg = dict_to_namespace(model_config)
|
||||||
|
cfg.trainer.pretrained_checkpoint = None
|
||||||
|
|
||||||
|
model = build_framework(cfg=cfg)
|
||||||
|
model.norm_stats = norm_stats
|
||||||
|
|
||||||
|
state_dict = torch.load(self.ckpt_path, map_location="cpu")
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
if self.use_bf16:
|
||||||
|
model = model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
model = model.to("cuda").eval()
|
||||||
|
|
||||||
|
self.norm_stats = norm_stats
|
||||||
|
self.action_norm_stats = norm_stats.get(self.unnorm_key, {}).get("action", None)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def parse_observation(self, obs):
|
||||||
|
|
||||||
|
rgb = obs["rgb"][-1]
|
||||||
|
state = obs["state"][-1]
|
||||||
|
joint = obs.get("joint", None)
|
||||||
|
prompt = obs["prompt"]
|
||||||
|
|
||||||
|
left = rgb[:, :, :3]
|
||||||
|
right = rgb[:, :, 3:6]
|
||||||
|
wrist = rgb[:, :, 6:9]
|
||||||
|
|
||||||
|
target_size = (320, 180)
|
||||||
|
|
||||||
|
left = cv2.resize(left, target_size)
|
||||||
|
right = cv2.resize(right, target_size)
|
||||||
|
wrist = cv2.resize(wrist, target_size)
|
||||||
|
|
||||||
|
img_left = Image.fromarray(left)
|
||||||
|
img_right = Image.fromarray(right)
|
||||||
|
img_wrist = Image.fromarray(wrist)
|
||||||
|
|
||||||
|
if self.state_mode == "joint8":
|
||||||
|
|
||||||
|
joint_last = joint[-1]
|
||||||
|
gripper = state[9]
|
||||||
|
|
||||||
|
state_vec = np.concatenate(
|
||||||
|
[joint_last, np.array([gripper])],
|
||||||
|
axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
xyz = state[0:3]
|
||||||
|
rot6d = state[3:9]
|
||||||
|
gripper = state[9]
|
||||||
|
|
||||||
|
state_vec = np.concatenate(
|
||||||
|
[xyz, rot6d[:3], np.array([gripper])],
|
||||||
|
axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
return img_left, img_right, img_wrist, state_vec, prompt
|
||||||
|
|
||||||
|
def inference(self, observation: dict) -> dict:
|
||||||
|
|
||||||
|
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||||
|
self.parse_observation(observation)
|
||||||
|
|
||||||
|
vla_input = {
|
||||||
|
"batch_images": [[img_left, img_right, img_wrist]],
|
||||||
|
"instructions": [prompt],
|
||||||
|
"state": [state_vec]
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model.predict_action(**vla_input)
|
||||||
|
|
||||||
|
actions = output.get("normalized_actions")
|
||||||
|
|
||||||
|
if isinstance(actions, torch.Tensor):
|
||||||
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
|
if actions.ndim == 3:
|
||||||
|
actions = actions[0]
|
||||||
|
|
||||||
|
return {"action": actions.astype(np.float32)}
|
||||||
|
|
||||||
|
def register_routes(self):
|
||||||
|
|
||||||
|
@self.app.route("/policy", methods=["POST"])
|
||||||
|
def policy():
|
||||||
|
data = pickle.loads(request.data)
|
||||||
|
result = self.inference(data)
|
||||||
|
body = pickle.dumps(result, protocol=4)
|
||||||
|
return Response(body, mimetype="application/octet-stream")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
|
||||||
|
print("StarVLA policy server running")
|
||||||
|
print(f"Host: {self.host}")
|
||||||
|
print(f"Port: {self.port}")
|
||||||
|
|
||||||
|
self.app.run(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
threaded=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = "./benchmark.yaml"
|
||||||
|
server = StarvlaInferenceServer(config_path)
|
||||||
|
server.run()
|
||||||
123
starvla_policy.py
Normal file
123
starvla_policy.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from joysim.annotations.config_class import configclass, field
|
||||||
|
from joysim.annotations.stereotype import stereotype
|
||||||
|
from joysim.app import JoySim
|
||||||
|
from joysim.core.scene_manager import SceneManager
|
||||||
|
from joysim.extensions.benchmark.action import RobotAction
|
||||||
|
from joysim.extensions.benchmark.benchmark import (
|
||||||
|
BenchmarkAction,
|
||||||
|
BenchmarkObservation,
|
||||||
|
ControlMode,
|
||||||
|
)
|
||||||
|
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@configclass
|
||||||
|
@stereotype.register_config("starvla")
|
||||||
|
class StarvlaPolicyConfig(PolicyConfig):
|
||||||
|
|
||||||
|
robot_name: str = field(default="my_robot", required=True, comment="The name of the robot")
|
||||||
|
object_name: str = field(default="target", required=True, comment="The name of the object")
|
||||||
|
|
||||||
|
server_url: str = field(
|
||||||
|
default="http://127.0.0.1:5000/policy",
|
||||||
|
required=True,
|
||||||
|
comment="StarVLA policy server url"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: str = field(
|
||||||
|
default="pick the object",
|
||||||
|
required=True,
|
||||||
|
comment="task instruction"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@stereotype.register_model("starvla")
|
||||||
|
class StarvlaPolicy(Policy):
|
||||||
|
|
||||||
|
def __init__(self, config: StarvlaPolicyConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.robot_name = config.robot_name
|
||||||
|
self.object_name = config.object_name
|
||||||
|
self.server_url = config.server_url
|
||||||
|
self.prompt = config.prompt
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def needs_observation(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
||||||
|
|
||||||
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
||||||
|
joint_positions = robot_obs["joint_positions"]
|
||||||
|
robot_position = robot_obs["position"]
|
||||||
|
robot_quaternion = robot_obs["rotation"]
|
||||||
|
|
||||||
|
state = np.concatenate([
|
||||||
|
robot_position,
|
||||||
|
robot_quaternion,
|
||||||
|
np.array([0.0])
|
||||||
|
])
|
||||||
|
|
||||||
|
camera_obs = benchmark_observation.get_sensor_observations()
|
||||||
|
rgb = camera_obs["rgb"]
|
||||||
|
|
||||||
|
obs = {
|
||||||
|
"state": np.expand_dims(state, axis=0),
|
||||||
|
"joint": np.expand_dims(joint_positions, axis=0),
|
||||||
|
"rgb": np.expand_dims(rgb, axis=0),
|
||||||
|
"prompt": self.prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def compute_action(self, observation: dict) -> dict:
|
||||||
|
|
||||||
|
payload = pickle.dumps(observation)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.server_url,
|
||||||
|
data=payload,
|
||||||
|
headers={"Content-Type": "application/octet-stream"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(f"StarVLA server error: {response.text}")
|
||||||
|
|
||||||
|
result = pickle.loads(response.content)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
||||||
|
|
||||||
|
benchmark_action = BenchmarkAction()
|
||||||
|
|
||||||
|
robot = SceneManager.get_robot(self.robot_name)
|
||||||
|
joint_names = robot.get_planner().get_plannable_joint_names()
|
||||||
|
|
||||||
|
joint_positions = action["action"][0]
|
||||||
|
|
||||||
|
benchmark_action.add_robot_action(
|
||||||
|
RobotAction(
|
||||||
|
control_mode=ControlMode.POSITION,
|
||||||
|
robot_name=self.robot_name,
|
||||||
|
joint_names=joint_names,
|
||||||
|
joint_positions=joint_positions
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return benchmark_action
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
js = JoySim("./benchmark.yaml")
|
||||||
|
js.start()
|
||||||
Reference in New Issue
Block a user