Update benchmark.yaml and inference server: Adjusted benchmark parameters, modified inference server routes, and enhanced policy handling for chunked actions.
This commit is contained in:
@@ -111,14 +111,14 @@ class StarvlaInferenceServer:
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if actions.ndim == 3:
|
||||
actions = actions[0] # (8, 7)
|
||||
actions = actions[0] # (16, 10)
|
||||
return {"ee_delta_position_chunks": actions[:, :3].tolist(),
|
||||
"ee_delta_euler_xyz_chunks": actions[:, 3:6].tolist(),
|
||||
"gripper_chunks": actions[:, 6:7].tolist()}
|
||||
"ee_delta_rot6d_chunks": actions[:, 3:9].tolist(),
|
||||
"gripper_chunks": actions[:, 9:10].tolist()}
|
||||
|
||||
def register_routes(self):
|
||||
|
||||
@self.app.route("/policy", methods=["POST"])
|
||||
@self.app.route("/policy/inference", methods=["POST"])
|
||||
def policy():
|
||||
try:
|
||||
data = pickle.loads(request.data)
|
||||
@@ -136,6 +136,11 @@ class StarvlaInferenceServer:
|
||||
mimetype="application/octet-stream",
|
||||
status=500,
|
||||
)
|
||||
@self.app.route("/policy/health", methods=["GET"])
|
||||
def health():
|
||||
if self.model is None:
|
||||
return Response("Failed to load model", mimetype="application/json", status=500)
|
||||
return Response("OK", mimetype="application/json")
|
||||
|
||||
def run(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user