Spaces:
Runtime error
Runtime error
"""Ravens main training script.""" | |
import os | |
import pickle | |
import json | |
import numpy as np | |
import hydra | |
from cliport import agents | |
from cliport import dataset | |
from cliport import tasks | |
from cliport.utils import utils | |
from cliport.environments.environment import Environment | |
from torch.utils.data import DataLoader | |
def main(vcfg): | |
# Load train cfg | |
tcfg = utils.load_hydra_config(vcfg['train_config']) | |
# Initialize environment and task. | |
env = Environment( | |
vcfg['assets_root'], | |
disp=vcfg['disp'], | |
shared_memory=vcfg['shared_memory'], | |
hz=480, | |
record_cfg=vcfg['record'] | |
) | |
# Choose eval mode and task. | |
mode = vcfg['mode'] | |
eval_task = vcfg['eval_task'] | |
print("eval_task!!!", eval_task) | |
if mode not in {'train', 'val', 'test'}: | |
raise Exception("Invalid mode. Valid options: train, val, test") | |
# Load eval dataset. | |
dataset_type = vcfg['type'] | |
if 'multi' in dataset_type: | |
ds = dataset.RavensMultiTaskDataset(vcfg['data_dir'], | |
tcfg, | |
group=eval_task, | |
mode=mode, | |
n_demos=vcfg['n_demos'], | |
augment=False) | |
else: | |
ds = dataset.RavensDataset(os.path.join(vcfg['data_dir'], f"{eval_task}-{mode}"), | |
tcfg, | |
n_demos=vcfg['n_demos'], | |
augment=False) | |
all_results = {} | |
name = '{}-{}-n{}'.format(eval_task, vcfg['agent'], vcfg['n_demos']) | |
# Save path for results. | |
json_name = f"multi-results-{mode}.json" if 'multi' in vcfg['model_path'] else f"results-{mode}.json" | |
save_path = vcfg['save_path'] | |
print(f"Save path for results: {save_path}") | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
save_json = os.path.join(save_path, f'{name}-{json_name}') | |
# Load existing results. | |
existing_results = {} | |
if os.path.exists(save_json): | |
with open(save_json, 'r') as f: | |
existing_results = json.load(f) | |
# Make a list of checkpoints to eval. | |
ckpts_to_eval = list_ckpts_to_eval(vcfg, existing_results) | |
data_loader = DataLoader(ds, shuffle=False, | |
pin_memory=False, | |
num_workers=1 ) | |
# Evaluation loop | |
print(f"Evaluating: {str(ckpts_to_eval)}") | |
for ckpt in ckpts_to_eval: | |
model_file = os.path.join(vcfg['model_path'], ckpt) | |
if not os.path.exists(model_file) or not os.path.isfile(model_file): | |
print(f"Checkpoint not found: {model_file}") | |
continue | |
elif not vcfg['update_results'] and ckpt in existing_results: | |
print(f"Skipping because of existing results for {model_file}.") | |
continue | |
results = [] | |
mean_reward = 0.0 | |
# Run testing for each training run. | |
for train_run in range(vcfg['n_repeats']): | |
# Initialize agent. | |
utils.set_seed(train_run, torch=True) | |
agent = agents.names[vcfg['agent']](name, tcfg, data_loader, data_loader) | |
# Load checkpoint | |
agent.load(model_file) | |
print(f"Loaded: {model_file}") | |
record = vcfg['record']['save_video'] | |
n_demos = vcfg['n_demos'] | |
# Run testing and save total rewards with last transition info. | |
for i in range(0, n_demos): | |
print(f'Test: {i + 1}/{n_demos}') | |
try: | |
episode, seed = ds.load(i) | |
except: | |
print(f"skip bad example {i}") | |
continue | |
goal = episode[-1] | |
total_reward = 0 | |
np.random.seed(seed) | |
# set task | |
if 'multi' in dataset_type: | |
task_name = ds.get_curr_task() | |
task = tasks.names[task_name]() | |
print(f'Evaluating on {task_name}') | |
else: | |
task_name = vcfg['eval_task'] | |
task = tasks.names[task_name]() | |
task.mode = mode | |
env.seed(seed) | |
env.set_task(task) | |
obs = env.reset() | |
info = env.info | |
reward = 0 | |
# Start recording video (NOTE: super slow) | |
if record: | |
video_name = f'{task_name}-{i+1:06d}' | |
if 'multi' in vcfg['model_task']: | |
video_name = f"{vcfg['model_task']}-{video_name}" | |
env.start_rec(video_name) | |
for _ in range(task.max_steps): | |
act = agent.act(obs, info, goal) | |
lang_goal = info['lang_goal'] | |
# print(f'Lang Goal: {lang_goal}') | |
obs, reward, done, info = env.step(act) | |
total_reward += reward | |
# print(f'Total Reward: {total_reward:.3f} | Done: {done}\n') | |
if done: | |
break | |
results.append((total_reward, info)) | |
mean_reward = np.mean([r for r, i in results]) | |
print(f'Mean: {mean_reward} | Task: {task_name} | Ckpt: {ckpt}') | |
# End recording video | |
if record: | |
env.end_rec() | |
all_results[ckpt] = { | |
'episodes': results, | |
'mean_reward': mean_reward, | |
} | |
# Save results in a json file. | |
if vcfg['save_results']: | |
print("save results to:", save_json) | |
# Load existing results | |
if os.path.exists(save_json): | |
with open(save_json, 'r') as f: | |
existing_results = json.load(f) | |
existing_results.update(all_results) | |
all_results = existing_results | |
with open(save_json, 'w') as f: | |
json.dump(all_results, f, indent=4) | |
def list_ckpts_to_eval(vcfg, existing_results): | |
ckpts_to_eval = [] | |
# Just the last.ckpt | |
if vcfg['checkpoint_type'] == 'last': | |
last_ckpt = 'last.ckpt' | |
ckpts_to_eval.append(last_ckpt) | |
# Validation checkpoints that haven't been already evaluated. | |
elif vcfg['checkpoint_type'] == 'val_missing': | |
checkpoints = sorted([c for c in os.listdir(vcfg['model_path']) if "steps=" in c]) | |
ckpts_to_eval = [c for c in checkpoints if c not in existing_results] | |
# Find the best checkpoint from validation and run eval on the test set. | |
elif vcfg['checkpoint_type'] == 'test_best': | |
result_jsons = [c for c in os.listdir(vcfg['results_path']) if "results-val" in c] | |
if 'multi' in vcfg['model_task']: | |
result_jsons = [r for r in result_jsons if "multi" in r] | |
else: | |
result_jsons = [r for r in result_jsons if "multi" not in r] | |
if len(result_jsons) > 0: | |
result_json = result_jsons[0] | |
with open(os.path.join(vcfg['results_path'], result_json), 'r') as f: | |
eval_res = json.load(f) | |
best_checkpoint = 'last.ckpt' | |
best_success = -1.0 | |
for ckpt, res in eval_res.items(): | |
if res['mean_reward'] > best_success: | |
best_checkpoint = ckpt | |
best_success = res['mean_reward'] | |
print(best_checkpoint) | |
ckpt = best_checkpoint | |
ckpts_to_eval.append(ckpt) | |
else: | |
print("No best val ckpt found. Using last.ckpt") | |
ckpt = 'last.ckpt' | |
ckpts_to_eval.append(ckpt) | |
# Load a specific checkpoint with a substring e.g: 'steps=10000' | |
else: | |
print(f"Looking for: {vcfg['checkpoint_type']}") | |
checkpoints = [c for c in os.listdir(vcfg['model_path']) if vcfg['checkpoint_type'] in c] | |
checkpoint = checkpoints[0] if len(checkpoints) > 0 else "" | |
ckpt = checkpoint | |
ckpts_to_eval.append(ckpt) | |
print("ckpts_to_eval:", ckpts_to_eval) | |
return ckpts_to_eval | |
if __name__ == '__main__': | |
main() | |