Spaces:
Runtime error
Runtime error
import numpy as np | |
import os | |
import IPython | |
import random | |
import json | |
from gensim.utils import save_text | |
class Memory: | |
""" | |
class that maintains a buffer of generated tasks and codes | |
""" | |
def __init__(self, cfg): | |
self.prompt_folder = f"prompts/{cfg['prompt_folder']}" | |
self.data_path = cfg["prompt_data_path"] | |
self.cfg = cfg | |
# a chat history is a list of strings | |
self.chat_log = [] | |
self.online_task_buffer = {} | |
self.online_code_buffer = {} | |
self.online_asset_buffer = {} | |
# directly load current offline memory into online memory | |
base_tasks, base_assets, base_task_codes = self.load_offline_memory() | |
self.online_task_buffer.update(base_tasks) | |
self.online_asset_buffer.update(base_assets) | |
# load each code file | |
for task_file in base_task_codes: | |
# the original cliport task path | |
if os.path.exists("cliport/tasks/" + task_file): | |
self.online_code_buffer[task_file] = open("cliport/tasks/" + task_file).read() | |
# the generated cliport task path | |
elif os.path.exists("cliport/generated_tasks/" + task_file): | |
self.online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read() | |
print(f"load {len(self.online_code_buffer)} tasks for memory from offline to online:") | |
cache_embedding_path = "outputs/task_cache_embedding.npz" | |
if os.path.exists(cache_embedding_path): | |
print("task code embeding:", cache_embedding_path) | |
self.task_code_embedding = np.load(cache_embedding_path) | |
def save_run(self, new_task): | |
"""save chat history and potentially save base memory""" | |
print("save all interaction to :", f'{new_task["task-name"]}_full_output') | |
unroll_chatlog = '' | |
for chat in self.chat_log: | |
unroll_chatlog += chat | |
save_text( | |
self.cfg['model_output_dir'], f'{new_task["task-name"]}_full_output', unroll_chatlog | |
) | |
def save_task_to_online(self, new_task, code): | |
"""(not dumping the task offline). save the task information for online bootstrapping.""" | |
self.online_task_buffer[new_task['task-name']] = new_task | |
code_file_name = new_task["task-name"].replace("-", "_") + ".py" | |
# code file name: actual code in contrast to offline code files format. | |
self.online_code_buffer[code_file_name] = code | |
def save_task_to_offline(self, new_task, code): | |
"""save the current task descriptions, assets, and code, if it passes reflection and environment test""" | |
generated_task_code_path = os.path.join( | |
self.cfg["prompt_data_path"], "generated_task_codes.json" | |
) | |
generated_task_codes = json.load(open(generated_task_code_path)) | |
new_file_path = new_task["task-name"].replace("-", "_") + ".py" | |
if new_file_path not in generated_task_codes: | |
generated_task_codes.append(new_file_path) | |
python_file_path = "cliport/generated_tasks/" + new_file_path | |
print(f"save {new_task['task-name']} to ", python_file_path) | |
with open(python_file_path, "w", | |
) as fhandle: | |
fhandle.write(code) | |
with open(generated_task_code_path, "w") as outfile: | |
json.dump(generated_task_codes, outfile, indent=4) | |
else: | |
print(f"{new_file_path}.py already exists.") | |
# save task descriptions | |
generated_task_path = os.path.join( | |
self.cfg["prompt_data_path"], "generated_tasks.json" | |
) | |
generated_tasks = json.load(open(generated_task_path)) | |
generated_tasks[new_task["task-name"]] = new_task | |
with open(generated_task_path, "w") as outfile: | |
json.dump(generated_tasks, outfile, indent=4) | |
def load_offline_memory(self): | |
"""get the current task descriptions, assets, and code""" | |
base_task_path = os.path.join(self.data_path, "base_tasks.json") | |
base_asset_path = os.path.join(self.data_path, "base_assets.json") | |
base_task_code_path = os.path.join(self.data_path, "base_task_codes.json") | |
base_tasks = json.load(open(base_task_path)) | |
base_assets = json.load(open(base_asset_path)) | |
base_task_codes = json.load(open(base_task_code_path)) | |
if self.cfg["load_memory"]: | |
generated_task_path = os.path.join(self.data_path, "generated_tasks.json") | |
generated_asset_path = os.path.join(self.data_path, "generated_assets.json") | |
generated_task_code_path = os.path.join(self.data_path, "generated_task_codes.json") | |
print("original base task num:", len(base_tasks)) | |
base_tasks.update(json.load(open(generated_task_path))) | |
# base_assets.update(json.load(open(generated_asset_path))) | |
for task in json.load(open(generated_task_code_path)): | |
if task not in base_task_codes: | |
base_task_codes.append(task) | |
print("current base task num:", len(base_tasks)) | |
return base_tasks, base_assets, base_task_codes | |