finish task_gen

This commit is contained in:
2025-09-05 11:10:42 +08:00
parent da022d4f83
commit 4e51158215
17 changed files with 82 additions and 50 deletions

View File

@@ -1,12 +1,33 @@
import json
import os
from pyboot.runner import Runner
from others.task_generate import OldTaskGenerator
from pyboot.utils.log import Log
from task_gen_dependencies.task_generate import OldTaskGenerator
class TaskGenerator(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.generate_config = self.config["generate"]
self.input_target_task_templates_path = self.generate_config["input_target_task_templates_path"]
self.input_data_root = self.generate_config["input_data_root"]
self.output_task_root_dir = self.generate_config.get("output_task_root_dir", None)
if self.output_task_root_dir is None:
self.output_task_root_dir = os.path.join(self.workspace_path, "task_root_dir")
self.target_task_templates = json.load(open(self.input_target_task_templates_path, "r"))
def run(self):
pass
for task_template_name, task_template_path in self.target_task_templates.items():
task_template = json.load(open(task_template_path, "r"))
self.generate_from_template(task_template_name, task_template)
Log.success(f"Generated {task_template['recording_setting']['num_of_episode']} tasks from <{task_template_name}>")
def generate_from_template(self, template: dict):
pass
def generate_from_template(self, template_name: str, template: dict):
task_dir = os.path.join(self.output_task_root_dir, template_name)
task_num = template["recording_setting"]["num_of_episode"]
task_name = template["task"]
old_task_generator = OldTaskGenerator(template, self.input_data_root)
if not os.path.exists(task_dir):
os.makedirs(task_dir)
old_task_generator.generate_tasks(save_path=task_dir, task_num=task_num, task_name=task_name)