finish task_gen
This commit is contained in:
@@ -1,12 +1,33 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from pyboot.runner import Runner
|
||||
from others.task_generate import OldTaskGenerator
|
||||
from pyboot.utils.log import Log
|
||||
from task_gen_dependencies.task_generate import OldTaskGenerator
|
||||
|
||||
class TaskGenerator(Runner):
|
||||
def __init__(self, config_path: str):
|
||||
super().__init__(config_path)
|
||||
self.generate_config = self.config["generate"]
|
||||
self.input_target_task_templates_path = self.generate_config["input_target_task_templates_path"]
|
||||
self.input_data_root = self.generate_config["input_data_root"]
|
||||
self.output_task_root_dir = self.generate_config.get("output_task_root_dir", None)
|
||||
if self.output_task_root_dir is None:
|
||||
self.output_task_root_dir = os.path.join(self.workspace_path, "task_root_dir")
|
||||
self.target_task_templates = json.load(open(self.input_target_task_templates_path, "r"))
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
for task_template_name, task_template_path in self.target_task_templates.items():
|
||||
task_template = json.load(open(task_template_path, "r"))
|
||||
self.generate_from_template(task_template_name, task_template)
|
||||
Log.success(f"Generated {task_template['recording_setting']['num_of_episode']} tasks from <{task_template_name}>")
|
||||
|
||||
def generate_from_template(self, template: dict):
|
||||
pass
|
||||
def generate_from_template(self, template_name: str, template: dict):
|
||||
task_dir = os.path.join(self.output_task_root_dir, template_name)
|
||||
task_num = template["recording_setting"]["num_of_episode"]
|
||||
task_name = template["task"]
|
||||
|
||||
old_task_generator = OldTaskGenerator(template, self.input_data_root)
|
||||
if not os.path.exists(task_dir):
|
||||
os.makedirs(task_dir)
|
||||
old_task_generator.generate_tasks(save_path=task_dir, task_num=task_num, task_name=task_name)
|
||||
|
||||
@@ -10,9 +10,13 @@ class TaskTemplatesDivider(Runner):
|
||||
super().__init__(config_path)
|
||||
self.divide_config = self.config["divide"]
|
||||
|
||||
self.task_templates_root_dir = self.divide_config["task_templates_root_dir"]
|
||||
self.output_task_templates_dir = self.divide_config["output_task_templates_dir"]
|
||||
self.output_template_targets_dir = self.divide_config["output_template_targets_dir"]
|
||||
self.input_task_templates_root_dir = self.divide_config["input_task_templates_root_dir"]
|
||||
self.output_task_templates_dir = self.divide_config.get("output_task_templates_dir", None)
|
||||
if self.output_task_templates_dir is None:
|
||||
self.output_task_templates_dir = os.path.join(self.workspace_path, "task_templates")
|
||||
self.output_template_targets_dir = self.divide_config.get("output_template_targets_dir", None)
|
||||
if self.output_template_targets_dir is None:
|
||||
self.output_template_targets_dir = os.path.join(self.workspace_path, "template_targets")
|
||||
self.divide_num = self.divide_config["divide_num"]
|
||||
self.total_nums = self.divide_config["total_nums"]
|
||||
|
||||
@@ -26,11 +30,11 @@ class TaskTemplatesDivider(Runner):
|
||||
|
||||
def load_all_task_templates(self):
|
||||
task_list = []
|
||||
for task_template_dir in os.listdir(self.task_templates_root_dir):
|
||||
if os.path.isdir(os.path.join(self.task_templates_root_dir, task_template_dir)):
|
||||
for file in os.listdir(os.path.join(self.task_templates_root_dir, task_template_dir)):
|
||||
for task_template_dir in os.listdir(self.input_task_templates_root_dir):
|
||||
if os.path.isdir(os.path.join(self.input_task_templates_root_dir, task_template_dir)):
|
||||
for file in os.listdir(os.path.join(self.input_task_templates_root_dir, task_template_dir)):
|
||||
if file.endswith('.json'):
|
||||
task_list.append(os.path.join(self.task_templates_root_dir, task_template_dir, file))
|
||||
task_list.append(os.path.join(self.input_task_templates_root_dir, task_template_dir, file))
|
||||
Log.success(f"Loaded {len(task_list)} tasks")
|
||||
return task_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user