File size: 3,297 Bytes
c9d7b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# %%
import json
import pickle
import re
from pathlib import Path
# %%
def load_pickle(fp):
with open(fp, "rb") as f:
try:
while True:
yield pickle.load(f)
except EOFError:
pass
# %%
fd = Path("model_outputs")
# %%
# %%
# %%
# # %%
# # concat pickle results (1/22)
# list(fd.glob("results_gemma_*"))[0]
#
# # %%
# fps = sorted(fd.glob("results_gemma_*"))
# all_responses = dict()
# errors = set()
# for fp in fps:
# responses = list(load_pickle(str(fp)))
# print(fp.name, len(responses), responses[0][0], responses[-1][0])
# for r in responses:
# if r[-1]:
# errors.add((r[0], str(r[-1])))
# all_responses.setdefault(r[:2], set())
# all_responses[r[:2]].add(r)
# errors = sorted(errors)
#
# # %%
# assert all(len(v) == 1 for v in all_responses.values()), f"Duplicated response(s) found"
#
# # %%
# duplicated = {k: v for k, v in all_responses.items() if len(v) > 1}
#
# # %%
# concatenated = [list(v)[0] for v in all_responses.values()]
#
# # %%
# with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "w", encoding="utf8") as o:
# for i in concatenated:
# json.dump({
# "game": i[0],
# "session": i[1],
# "turn": 1,
# "response": i[2],
# "solved": i[3][0],
# "val_msg": i[3][1],
# "error": repr(i[4]) if i[4] else i[4],
# }, o, ensure_ascii=False)
# o.write("\n")
# %%
# %%
# %%
# %%
# %%
# %%
# Rerun gemma, resolving errors
# %%
import os
import json
import pandas as pd
# %%
os.environ["TG_GAME_ST"] = "7"
os.environ["TG_GAME_ED"] = "8"
# %%
st, ed = os.getenv("TG_GAME_ST", None), os.getenv("TG_GAME_ED", None)
st, ed = ((None if x is None else int(x)) for x in (st, ed))
fp_out = f"model_outputs/results_gemma-2-9b-it{'' if st is None else f'.{st}'}.jsonl"
# %%
from tqdm import tqdm
from itertools import product
from transformers import AutoTokenizer, AutoModelForCausalLM
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")
# %%
with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "r", encoding="utf-8") as f:
df = pd.read_json(f, lines=True)
# %%
df.columns
# %%
from agents import run_with_agent
from agents.gemma_2_9b_it import gemma_postproc
# %%
def get_buffered_response(texts, game_name, difficulty_level, turn):
if turn > 1:
return None
cur_df = df.loc[(df.game == f"{game_filename(game_name)}_{difficulty_level}")].set_index(["session", "turn"])
with open(f"problemsets/{game_filename(game_name)}_{difficulty_level}.json", "r", encoding="utf8") as f:
_sid_prompt_dict = json.load(f)
prompt_sid_dict = {v: k for k, v in _sid_prompt_dict.items()}
sid = prompt_sid_dict[texts[0]]
try:
return cur_df.loc[(sid, turn)].response
except KeyError:
return None
# %%
run_with_agent(fp_out, get_buffered_response, get_postprocess=gemma_postproc, game_names_list=GAME_NAMES[st:ed], n_turns=1)
# %%
# %%
# type(cur_df.loc[(sid, 1)].response)
# %%
# %%
# %%
# %%
# %%
# %%
# %%
|