This commit is contained in:
Junhan
2026-05-07 19:31:45 +08:00
parent 7ce2823c56
commit 2514cc943d
6 changed files with 5976 additions and 94 deletions

View File

@@ -12,6 +12,14 @@ import cv2
from flask import Flask, request, Response
from PIL import Image
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
"""Pad an array to the target dimension with zeros along the specified axis."""
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width, constant_values=value)
return x
class StarvlaInferenceServer:
@@ -19,7 +27,7 @@ class StarvlaInferenceServer:
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
policy_server_cfg = cfg["policy_server"]
root_paths = cfg["general"]["root_paths"]
self.ckpt_source = policy_server_cfg["ckpt_source"]
@@ -88,14 +96,15 @@ class StarvlaInferenceServer:
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
state_vec = obs["state"]
# import ipdb;ipdb.set_trace()
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
def inference(self, observation: dict) -> dict:
img_left, img_right, img_wrist, state_vec, prompt = \
self.parse_observation(observation)
print(f"{state_vec.shape}")
vla_input = {
"batch_images": [[img_left, img_right, img_wrist]],
"instructions": [prompt],