File size: 1,326 Bytes
8fc2b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
import os
import hydra
import random

import re
import openai
import IPython
import time
import pybullet as p
import traceback
from datetime import datetime
from pprint import pprint
import cv2
import re
import random
import json

from gensim.agent import Agent
from gensim.critic import Critic
from gensim.sim_runner import SimulationRunner
from gensim.memory import Memory
from gensim.utils import set_gpt_model, clear_messages


@hydra.main(config_path='../cliport/cfg', config_name='data', version_base="1.2")
def main(cfg):
    openai.api_key = cfg['openai_key']

    model_time = datetime.now().strftime("%d_%m_%Y_%H:%M:%S")
    cfg['model_output_dir'] = os.path.join(cfg['output_folder'], cfg['prompt_folder'] + "_" + model_time)
    if 'seed' in cfg:
       cfg['model_output_dir'] = cfg['model_output_dir'] + f"_{cfg['seed']}"

    set_gpt_model(cfg['gpt_model'])
    memory = Memory(cfg)
    agent = Agent(cfg, memory)
    critic = Critic(cfg, memory)
    simulation_runner = SimulationRunner(cfg, agent, critic, memory)

    for trial_i in range(cfg['trials']):
        simulation_runner.task_creation()
        simulation_runner.simulate_task()
        simulation_runner.print_current_stats()
        # clear_messages()

    simulation_runner.save_stats()

if __name__ == '__main__':
    main()