update
This commit is contained in:
@@ -12,6 +12,14 @@ import cv2
|
||||
from flask import Flask, request, Response
|
||||
from PIL import Image
|
||||
|
||||
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||
current_dim = x.shape[axis]
|
||||
if current_dim < target_dim:
|
||||
pad_width = [(0, 0)] * len(x.shape)
|
||||
pad_width[axis] = (0, target_dim - current_dim)
|
||||
return np.pad(x, pad_width, constant_values=value)
|
||||
return x
|
||||
|
||||
class StarvlaInferenceServer:
|
||||
|
||||
@@ -19,7 +27,7 @@ class StarvlaInferenceServer:
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
|
||||
policy_server_cfg = cfg["policy_server"]
|
||||
root_paths = cfg["general"]["root_paths"]
|
||||
self.ckpt_source = policy_server_cfg["ckpt_source"]
|
||||
@@ -88,14 +96,15 @@ class StarvlaInferenceServer:
|
||||
img_wrist = Image.fromarray(cv2.resize(wrist_rgb, target_size))
|
||||
|
||||
state_vec = obs["state"]
|
||||
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# state_vec = pad_to_dim(np.array(state_vec), 100, axis=-1)
|
||||
return img_left, img_right, img_wrist, state_vec, obs["prompt"]
|
||||
|
||||
def inference(self, observation: dict) -> dict:
|
||||
|
||||
|
||||
img_left, img_right, img_wrist, state_vec, prompt = \
|
||||
self.parse_observation(observation)
|
||||
|
||||
print(f"{state_vec.shape}")
|
||||
vla_input = {
|
||||
"batch_images": [[img_left, img_right, img_wrist]],
|
||||
"instructions": [prompt],
|
||||
|
||||
Reference in New Issue
Block a user