|
|
|
import json |
|
import pickle |
|
import re |
|
from pathlib import Path |
|
|
|
|
|
|
|
def load_pickle(fp): |
|
with open(fp, "rb") as f: |
|
try: |
|
while True: |
|
yield pickle.load(f) |
|
except EOFError: |
|
pass |
|
|
|
|
|
|
|
fd = Path("model_outputs") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import pandas as pd |
|
|
|
|
|
os.environ["TG_GAME_ST"] = "7" |
|
os.environ["TG_GAME_ED"] = "8" |
|
|
|
|
|
st, ed = os.getenv("TG_GAME_ST", None), os.getenv("TG_GAME_ED", None) |
|
st, ed = ((None if x is None else int(x)) for x in (st, ed)) |
|
fp_out = f"model_outputs/results_gemma-2-9b-it{'' if st is None else f'.{st}'}.jsonl" |
|
|
|
|
|
|
|
from tqdm import tqdm |
|
from itertools import product |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name |
|
|
|
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs") |
|
|
|
|
|
|
|
with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "r", encoding="utf-8") as f: |
|
df = pd.read_json(f, lines=True) |
|
|
|
|
|
df.columns |
|
|
|
|
|
|
|
from agents import run_with_agent |
|
from agents.gemma_2_9b_it import gemma_postproc |
|
|
|
|
|
|
|
def get_buffered_response(texts, game_name, difficulty_level, turn): |
|
if turn > 1: |
|
return None |
|
cur_df = df.loc[(df.game == f"{game_filename(game_name)}_{difficulty_level}")].set_index(["session", "turn"]) |
|
with open(f"problemsets/{game_filename(game_name)}_{difficulty_level}.json", "r", encoding="utf8") as f: |
|
_sid_prompt_dict = json.load(f) |
|
prompt_sid_dict = {v: k for k, v in _sid_prompt_dict.items()} |
|
sid = prompt_sid_dict[texts[0]] |
|
try: |
|
return cur_df.loc[(sid, turn)].response |
|
except KeyError: |
|
return None |
|
|
|
|
|
|
|
run_with_agent(fp_out, get_buffered_response, get_postprocess=gemma_postproc, game_names_list=GAME_NAMES[st:ed], n_turns=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|