45 lines
2.2 KiB
Python
45 lines
2.2 KiB
Python
import json
|
|
import os
|
|
from typing import List
|
|
|
|
from pyboot import stereotype
|
|
from pyboot.runner import Runner
|
|
from pyboot.utils.log import Log
|
|
from task_gen_dependencies.task_generate import OldTaskGenerator
|
|
|
|
|
|
@stereotype.runner("task_generator")
|
|
class TaskGenerator(Runner):
|
|
def __init__(self, config_path: str):
|
|
super().__init__(config_path)
|
|
self.generate_config = self.config["generate"]
|
|
self.input_target_task_templates_path = self.generate_config["input_target_task_templates_path"]
|
|
self.input_data_root = self.generate_config["input_data_root"]
|
|
self.output_task_root_dir = self.generate_config.get("output_task_root_dir", None)
|
|
self.output_generate_result_path = self.generate_config.get("output_generate_result_path", None)
|
|
if self.output_task_root_dir is None:
|
|
self.output_task_root_dir = os.path.join(self.workspace_path, "task_root_dir")
|
|
if self.output_generate_result_path is None:
|
|
self.output_generate_result_path = os.path.join(self.workspace_path, "generate_result_path.json")
|
|
self.target_task_templates = json.load(open(self.input_target_task_templates_path, "r"))
|
|
|
|
def run(self):
|
|
generate_results = {}
|
|
for task_template_name, task_template_path in self.target_task_templates.items():
|
|
task_template = json.load(open(task_template_path, "r"))
|
|
generated_tasks_path_list = self.generate_from_template(task_template_name, task_template)
|
|
Log.success(f"Generated {len(generated_tasks_path_list)} tasks from <{task_template_name}>")
|
|
generate_results["task_template_name"] = generated_tasks_path_list
|
|
json.dump(generate_results, open(self.output_generate_result_path, "w"))
|
|
|
|
|
|
def generate_from_template(self, template_name: str, template: dict) -> List[str]:
|
|
task_dir = os.path.join(self.output_task_root_dir, template_name)
|
|
task_num = template["recording_setting"]["num_of_episode"]
|
|
task_name = template["task"]
|
|
|
|
old_task_generator = OldTaskGenerator(template, self.input_data_root)
|
|
if not os.path.exists(task_dir):
|
|
os.makedirs(task_dir)
|
|
return old_task_generator.generate_tasks(save_path=task_dir, task_num=task_num, task_name=task_name)
|