diff --git a/gia/eval/callback.py b/gia/eval/callback.py index 5c3a080..4b6198f 100644 --- a/gia/eval/callback.py +++ b/gia/eval/callback.py @@ -2,10 +2,10 @@ import glob import json import subprocess -import wandb from accelerate import Accelerator from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +import wandb from gia.config import Arguments from gia.eval.utils import is_slurm_available diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py index 91b645c..3e2cae7 100644 --- a/gia/eval/evaluator.py +++ b/gia/eval/evaluator.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from gia.config.arguments import Arguments @@ -5,11 +7,12 @@ from gia.model import GiaModel class Evaluator: - def __init__(self, args: Arguments, task: str) -> None: + def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None: self.args = args self.task = task + self.mean_random = mean_random - @torch.no_grad() + @torch.inference_mode() def evaluate(self, model: GiaModel) -> float: return self._evaluate(model) diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py index ec5e5b2..eeaf7cb 100644 --- a/gia/eval/rl/envs/core.py +++ b/gia/eval/rl/envs/core.py @@ -177,7 +177,6 @@ def make(task_name: str, num_envs: int = 1): elif task_name.startswith("metaworld"): import gymnasium as gym - import metaworld env_id = TASK_TO_ENV_MAPPING[task_name] env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py index f0d0b9b..39dc0d2 100644 --- a/gia/eval/rl/gia_agent.py +++ b/gia/eval/rl/gia_agent.py @@ -54,7 +54,7 @@ class GiaAgent: self.action_space = action_space self.deterministic = deterministic self.device = next(model.parameters()).device - self._max_length = self.model.config.max_position_embeddings - 10 + self._max_length = self.model.config.max_position_embeddings - 100 # TODO: fix this if isinstance(observation_space, spaces.Box): self._observation_key = "continuous_observations" @@ -75,6 +75,11 @@ class GiaAgent: ) -> Tuple[Tuple[Tensor, Tensor], ...]: return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) + def set_model(self, model: GiaModel) -> None: + self.model = model + self.device = next(model.parameters()).device + self._max_length = self.model.config.max_position_embeddings + def reset(self, num_envs: int = 1) -> None: if self.prompter is not None: prompts = self.prompter.generate_prompts(num_envs) diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py index f8531ee..754c05d 100644 --- a/gia/eval/rl/gym_evaluator.py +++ b/gia/eval/rl/gym_evaluator.py @@ -1,7 +1,7 @@ import gym from gym.vector.vector_env import VectorEnv -from gia.eval.mappings import TASK_TO_ENV_MAPPING +# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING from gia.eval.rl.rl_evaluator import RLEvaluator diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py index c5cc423..91189f3 100644 --- a/gia/eval/rl/rl_evaluator.py +++ b/gia/eval/rl/rl_evaluator.py @@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent class RLEvaluator(Evaluator): + def __init__(self, args, task): + super().__init__(args, task) + self.agent = GiaAgent() + def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ? raise NotImplementedError diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json index 1b8ebee..ff7d030 100644 --- a/gia/eval/rl/scores_dict.json +++ b/gia/eval/rl/scores_dict.json @@ -929,8 +929,8 @@ }, "metaworld-assembly": { "expert": { - "mean": 311.29314618777823, - "std": 75.04282151450695 + "mean": 3523.81468486244, + "std": 63.22745220327798 }, "random": { "mean": 220.65601680730813,