update replay

This commit is contained in:
2025-09-08 15:09:20 +08:00
parent ee05f1339c
commit 3cdeb975b7
6 changed files with 535 additions and 11 deletions

View File

@@ -0,0 +1,187 @@
import os
import json
import numpy as np
from dataclasses import dataclass
from pyboot.utils.log import Log
from replay_init_dependencies.gen_traj import generate_circular_trajectory, generate_camera_data_from_poses, \
generate_trajectory_from_description, TrajDescription
@dataclass
class ReplayConfig:
replay_dir: str = ""
recording_data_dir: str = ""
task_template_file: str = ""
use_origin_camera: bool = False
use_trajectory_camera: bool = True
space_interpolation: bool = False
class ReplayInitializer:
TRAJECTORY_FILE_NAME: str = "camera_trajectory.json"
def __init__(self, replay_config: ReplayConfig):
self.replay_config = replay_config
def init_replay(self):
if not os.path.exists(self.replay_config.replay_dir):
os.makedirs(self.replay_config.replay_dir)
recording_data_dict = self.load_recording_data()
data_length = len(recording_data_dict)
task_template = self.load_task_template(data_length)
self.workspace = self.load_workspace(task_template)
self.generate_trajectory_file()
self.generate_replay_input(task_template, recording_data_dict)
task_template_file = os.path.join(self.replay_config.replay_dir,
os.path.basename(self.replay_config.task_template_file))
command = f"omni_python main.py --task_template {task_template_file} --use_recording"
Log.success(f"Replay initialized successfully. You can now run the replay by running following command:")
Log.success(f"{command}")
def load_workspace(self, task_template):
workspace_name = task_template["scene"]["scene_id"].split("/")[-1]
workspace_info = task_template["scene"]["function_space_objects"][workspace_name]
return workspace_info
def load_recording_data(self):
recording_data_dict = {}
for data_dir in os.listdir(self.replay_config.recording_data_dir):
if os.path.isdir(os.path.join(self.replay_config.recording_data_dir, data_dir)):
Log.info(f"Loading recording data from {data_dir}")
state = None
task_result = None
for file in os.listdir(os.path.join(self.replay_config.recording_data_dir, data_dir)):
if file == "state.json":
state = json.load(open(os.path.join(self.replay_config.recording_data_dir, data_dir, file)))
elif file == "task_result.json":
task_result = json.load(
open(os.path.join(self.replay_config.recording_data_dir, data_dir, file)))
if state is None or task_result is None:
Log.warning(f"state or task result not found in {data_dir}")
else:
recording_data_dict[data_dir] = {
"idx": int(data_dir.split('_')[-1][:-1]) + 1,
"state": state,
"task_result": task_result
}
Log.success(f"Recording data loaded: {len(recording_data_dict)} data in total.")
return recording_data_dict
def load_task_template(self, data_length):
task_template = json.load(open(self.replay_config.task_template_file))
replay_config = {
"use_origin_camera": self.replay_config.use_origin_camera,
"use_trajectory_camera": self.replay_config.use_trajectory_camera,
"replay_dir": self.replay_config.replay_dir,
"episodes": data_length,
"cam_trajectory_file": os.path.join(self.replay_config.replay_dir, self.TRAJECTORY_FILE_NAME),
"space_interpolation": self.replay_config.space_interpolation
}
task_template["replay"] = replay_config
Log.success(f"Task template loaded successfully. replay config: {replay_config}")
return task_template
def generate_trajectory_file(self):
workspace_position = self.workspace["position"]
workspace_size = self.workspace["size"]
trajectory_data = []
# poses_mid = generate_circular_trajectory(0.5,0.3,workspace_position)
# poses_low = generate_circular_trajectory(0.05,0.5,workspace_position)
# poses_high = generate_circular_trajectory(0.7,0.15,workspace_position)
desc_1 = [
{"type": TrajDescription.LRot, "value": 20},
{"type": TrajDescription.RRot, "value": 20},
{"type": TrajDescription.DownRot, "value": 20},
{"type": TrajDescription.UpRot, "value": 20},
{"type": TrajDescription.Left, "value": workspace_size[0] / 2},
{"type": TrajDescription.Front, "value": workspace_size[1] / 2},
{"type": TrajDescription.Back, "value": workspace_size[1] / 2},
{"type": TrajDescription.Right, "value": workspace_size[0] / 2}
]
desc_2 = [
{"type": TrajDescription.LRot, "value": 20},
{"type": TrajDescription.RRot, "value": 20},
{"type": TrajDescription.DownRot, "value": 20},
{"type": TrajDescription.UpRot, "value": 20},
{"type": TrajDescription.Right, "value": workspace_size[0] / 2},
{"type": TrajDescription.Front, "value": workspace_size[1] / 2},
{"type": TrajDescription.Back, "value": workspace_size[1] / 2},
{"type": TrajDescription.Left, "value": workspace_size[0] / 2}
]
desc_3 = [
{"type": TrajDescription.DownRot, "value": 20},
{"type": TrajDescription.UpRot, "value": 20},
{"type": TrajDescription.Left, "value": workspace_size[0] / 2},
{"type": TrajDescription.Front, "value": workspace_size[1] / 2},
{"type": TrajDescription.RRot, "value": 75},
{"type": TrajDescription.LRot, "value": 75},
{"type": TrajDescription.Back, "value": workspace_size[1] / 2},
{"type": TrajDescription.Right, "value": workspace_size[0] / 2}
]
desc_4 = [
{"type": TrajDescription.DownRot, "value": 20},
{"type": TrajDescription.UpRot, "value": 20},
{"type": TrajDescription.Right, "value": workspace_size[0] / 2},
{"type": TrajDescription.Front, "value": workspace_size[1] / 2},
{"type": TrajDescription.LRot, "value": 75},
{"type": TrajDescription.RRot, "value": 75},
{"type": TrajDescription.Back, "value": workspace_size[1] / 2},
{"type": TrajDescription.Left, "value": workspace_size[0] / 2}
]
start_pt = workspace_position - np.array([0, workspace_size[1], -0.3])
workspace_bottom_center = workspace_position - np.array([0, 0, workspace_size[2] / 2])
poses_1 = generate_trajectory_from_description(start_pt, workspace_bottom_center, desc_1)
poses_2 = generate_trajectory_from_description(start_pt, workspace_bottom_center, desc_2)
poses_3 = generate_trajectory_from_description(start_pt, workspace_bottom_center, desc_3)
poses_4 = generate_trajectory_from_description(start_pt, workspace_bottom_center, desc_4)
camera_data_1 = generate_camera_data_from_poses(poses_1, "cam_1")
camera_data_2 = generate_camera_data_from_poses(poses_2, "cam_2")
camera_data_3 = generate_camera_data_from_poses(poses_3, "cam_3")
camera_data_4 = generate_camera_data_from_poses(poses_4, "cam_4")
trajectory_data = [camera_data_1, camera_data_2, camera_data_3, camera_data_4]
trajectory_path = os.path.join(self.replay_config.replay_dir, self.TRAJECTORY_FILE_NAME)
with open(trajectory_path, "w") as f:
json.dump(trajectory_data, f)
Log.success(f"Trajectory file generated successfully. Trajectory file saved to {trajectory_path}")
def generate_trajectory_file_from_json(self):
workspace_position = self.workspace["position"]
trajectory_data = []
poses_mid = generate_circular_trajectory(0.5, 0.3, workspace_position)
poses_low = generate_circular_trajectory(0.05, 0.5, workspace_position)
poses_high = generate_circular_trajectory(0.7, 0.15, workspace_position)
camera_data_mid = generate_camera_data_from_poses(poses_mid, "cam_mid")
camera_data_low = generate_camera_data_from_poses(poses_low, "cam_low")
camera_data_high = generate_camera_data_from_poses(poses_high, "cam_high")
trajectory_data = [camera_data_mid, camera_data_low, camera_data_high]
trajectory_path = os.path.join(self.replay_config.replay_dir, self.TRAJECTORY_FILE_NAME)
with open(trajectory_path, "w") as f:
json.dump(trajectory_data, f)
Log.success(f"Trajectory file generated successfully. Trajectory file saved to {trajectory_path}")
def generate_replay_input(self, task_template, recording_data_dict):
states_dir = os.path.join(self.replay_config.replay_dir, "states")
results_dir = os.path.join(self.replay_config.replay_dir, "results")
if not os.path.exists(states_dir):
os.makedirs(states_dir)
if not os.path.exists(results_dir):
os.makedirs(results_dir)
for data_dir in recording_data_dict:
state = recording_data_dict[data_dir]["state"]
task_result = recording_data_dict[data_dir]["task_result"]
state_file = os.path.join(states_dir, f"{recording_data_dict[data_dir]['idx']}.state.json")
result_file = os.path.join(results_dir, f"{recording_data_dict[data_dir]['idx']}.task_result.json")
with open(state_file, "w") as f:
json.dump(state, f)
with open(result_file, "w") as f:
json.dump(task_result, f)
with open(os.path.join(self.replay_config.replay_dir, os.path.basename(self.replay_config.task_template_file)),
"w") as f:
json.dump(task_template, f)
Log.success(f"Replay input generated: {len(recording_data_dict)} data in total.")
Log.success(
f"Replay input saved to '{self.replay_config.replay_dir}', new task file is saved as '{os.path.basename(self.replay_config.task_template_file)}' in '{self.replay_config.replay_dir}'")