File size: 1,819 Bytes
8fc2b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import openai
import argparse
import os
from cliport import tasks
from cliport.dataset import RavensDataset
from cliport.environments.environment import Environment

from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import TerminalFormatter

import time
import random
import json
import traceback
import pybullet as p
import IPython
from gensim.topdown_sim_runner import TopDownSimulationRunner
import hydra
from datetime import datetime

from gensim.memory import Memory
from gensim.utils import set_gpt_model, clear_messages, format_finetune_prompt

@hydra.main(config_path='../cliport/cfg', config_name='data', version_base="1.2")
def main(cfg):
    # parser.add_argument("--task", type=str, default='build-car')
    # parser.add_argument("--model", type=str, default='davinci:ft-wang-lab:gensim-2023-08-04-18-28-34')

    task = cfg.target_task
    model = cfg.target_model
    prompt = format_finetune_prompt(task)

    openai.api_key = cfg['openai_key']
    model_time = datetime.now().strftime("%d_%m_%Y_%H:%M:%S")
    cfg['model_output_dir'] = os.path.join(cfg['output_folder'], cfg['prompt_folder'] + "_" + model_time)
    if 'seed' in cfg:
       cfg['model_output_dir'] = cfg['model_output_dir'] + f"_{cfg['seed']}"

    set_gpt_model(cfg['gpt_model'])
    memory = Memory(cfg)
    simulation_runner = TopDownSimulationRunner(cfg, memory)

    for trial_i in range(cfg['trials']):
        response = openai.Completion.create(
            model=model,
            prompt=prompt)
        res = response["choices"][0]["text"]
        simulation_runner.task_creation(res)
        simulation_runner.simulate_task()
        simulation_runner.print_current_stats()

    simulation_runner.save_stats()




# load few shot prompts


if __name__ == "__main__":
    main()