adapt for starvla-dex
This commit is contained in:
@@ -106,13 +106,14 @@ class StarvlaInferenceServer:
|
||||
self.parse_observation(observation)
|
||||
print(f"{state_vec.shape}")
|
||||
vla_input = {
|
||||
"batch_images": [[img_left, img_right, img_wrist]],
|
||||
"instructions": [prompt],
|
||||
"state": [state_vec]
|
||||
# "batch_images": [[img_left, img_right, img_wrist]],
|
||||
"image": [img_left],
|
||||
"lang": prompt,
|
||||
"state": state_vec
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model.predict_action(**vla_input)
|
||||
output = self.model.predict_action(examples=vla_input)
|
||||
|
||||
actions = output.get("normalized_actions")
|
||||
|
||||
@@ -176,4 +177,4 @@ if __name__ == "__main__":
|
||||
|
||||
config_path = args.config
|
||||
server = StarvlaInferenceServer(config_path)
|
||||
server.run()
|
||||
server.run()
|
||||
|
||||
Reference in New Issue
Block a user