123 lines
3.4 KiB
Python
123 lines
3.4 KiB
Python
from joysim.annotations.config_class import configclass, field
|
|
from joysim.annotations.stereotype import stereotype
|
|
from joysim.app import JoySim
|
|
from joysim.core.scene_manager import SceneManager
|
|
from joysim.extensions.benchmark.action import RobotAction
|
|
from joysim.extensions.benchmark.benchmark import (
|
|
BenchmarkAction,
|
|
BenchmarkObservation,
|
|
ControlMode,
|
|
)
|
|
from joysim.extensions.benchmark.policy import Policy, PolicyConfig
|
|
|
|
import numpy as np
|
|
import pickle
|
|
import requests
|
|
|
|
|
|
@configclass
|
|
@stereotype.register_config("starvla")
|
|
class StarvlaPolicyConfig(PolicyConfig):
|
|
|
|
robot_name: str = field(default="my_robot", required=True, comment="The name of the robot")
|
|
object_name: str = field(default="target", required=True, comment="The name of the object")
|
|
|
|
server_url: str = field(
|
|
default="http://127.0.0.1:5000/policy",
|
|
required=True,
|
|
comment="StarVLA policy server url"
|
|
)
|
|
|
|
prompt: str = field(
|
|
default="pick the object",
|
|
required=True,
|
|
comment="task instruction"
|
|
)
|
|
|
|
|
|
@stereotype.register_model("starvla")
|
|
class StarvlaPolicy(Policy):
|
|
|
|
def __init__(self, config: StarvlaPolicyConfig):
|
|
super().__init__(config)
|
|
|
|
self.robot_name = config.robot_name
|
|
self.object_name = config.object_name
|
|
self.server_url = config.server_url
|
|
self.prompt = config.prompt
|
|
|
|
def reset(self) -> None:
|
|
pass
|
|
|
|
def warmup(self, benchmark_observation: BenchmarkObservation) -> None:
|
|
pass
|
|
|
|
def needs_observation(self) -> bool:
|
|
return True
|
|
|
|
def preprocess_observation(self, benchmark_observation: BenchmarkObservation) -> dict:
|
|
|
|
robot_obs = benchmark_observation.get_robot_observations(self.robot_name)["robot_data"]
|
|
joint_positions = robot_obs["joint_positions"]
|
|
robot_position = robot_obs["position"]
|
|
robot_quaternion = robot_obs["rotation"]
|
|
|
|
state = np.concatenate([
|
|
robot_position,
|
|
robot_quaternion,
|
|
np.array([0.0])
|
|
])
|
|
|
|
camera_obs = benchmark_observation.get_sensor_observations()
|
|
rgb = camera_obs["rgb"]
|
|
|
|
obs = {
|
|
"state": np.expand_dims(state, axis=0),
|
|
"joint": np.expand_dims(joint_positions, axis=0),
|
|
"rgb": np.expand_dims(rgb, axis=0),
|
|
"prompt": self.prompt
|
|
}
|
|
|
|
return obs
|
|
|
|
def compute_action(self, observation: dict) -> dict:
|
|
|
|
payload = pickle.dumps(observation)
|
|
|
|
response = requests.post(
|
|
self.server_url,
|
|
data=payload,
|
|
headers={"Content-Type": "application/octet-stream"}
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise RuntimeError(f"StarVLA server error: {response.text}")
|
|
|
|
result = pickle.loads(response.content)
|
|
|
|
return result
|
|
|
|
def postprocess_action(self, action: dict) -> BenchmarkAction:
|
|
|
|
benchmark_action = BenchmarkAction()
|
|
|
|
robot = SceneManager.get_robot(self.robot_name)
|
|
joint_names = robot.get_planner().get_plannable_joint_names()
|
|
|
|
joint_positions = action["action"][0]
|
|
|
|
benchmark_action.add_robot_action(
|
|
RobotAction(
|
|
control_mode=ControlMode.POSITION,
|
|
robot_name=self.robot_name,
|
|
joint_names=joint_names,
|
|
joint_positions=joint_positions
|
|
)
|
|
)
|
|
|
|
return benchmark_action
|
|
|
|
|
|
if __name__ == "__main__":
|
|
js = JoySim("./benchmark.yaml")
|
|
js.start() |