File size: 4,771 Bytes
5953ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import re
import statistics as sts
from collections import defaultdict
from pathlib import Path
from rex.utils.dict import get_dict_content
from rex.utils.io import load_json
from rich.console import Console
from rich.table import Table
inputs_dir = Path("mirror_fewshot_outputs")
# regex = re.compile(r"Mirror_SingleTask_(.*?)_seed(\d+)_(\d+)shot")
regex = re.compile(r"Mirror_wPT_woInst_(.*?)_seed(\d+)_(\d+)shot")
# task -> shot -> seeds
results = defaultdict(lambda: defaultdict(list))
for dirname in os.listdir(inputs_dir):
dpath = inputs_dir / dirname
re_matched = regex.match(dirname)
if dpath.is_dir() and re_matched:
task, seed, shot = re_matched.groups()
results_json_p = dpath / "measures" / "test.final.json"
metrics = load_json(results_json_p)
if "Ent_" in task:
results[task][shot].append(
get_dict_content(metrics, "metrics.ent.micro.f1")
)
elif "Rel_" in task or "ABSA_" in task:
results[task][shot].append(
get_dict_content(metrics, "metrics.rel.rel.micro.f1")
)
elif "Event_" in task:
results[task + "_Trigger"][shot].append(
get_dict_content(metrics, "metrics.event.trigger_cls.f1")
)
results[task + "_Arg"][shot].append(
get_dict_content(metrics, "metrics.event.arg_cls.f1")
)
else:
raise RuntimeError
table = Table(title="Few-shot results")
table.add_column("Task", justify="center")
table.add_column("1-shot", justify="right")
table.add_column("5-shot", justify="right")
table.add_column("10-shot", justify="right")
table.add_column("Avg.", justify="right")
for task in results:
shots = sorted(results[task].keys(), key=lambda x: int(x))
all_seeds = []
shot_results = []
for shot in shots:
seeds = results[task][shot]
all_seeds.extend(seeds)
avg = sum(seeds) / len(seeds)
sts.stdev(seeds)
shot_results.append(f"{100*avg:.2f}Β±{100*sts.stdev(seeds):.2f}")
shot_results.append(f"{100*sts.mean(all_seeds):.2f}")
table.add_row(task, *shot_results)
console = Console()
console.print(table)
"""
Few-shot results wPT wInst
βββββββββββββββββββββββ³ββββββββββββββ³ββββββββββββββ³βββββββββββββ³ββββββββ
β Task β 1-shot β 5-shot β 10-shot β Avg. β
β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©
β Ent_CoNLL03 β 77.50Β±1.64 β 82.73Β±2.29 β 84.48Β±1.62 β 81.57 β
β Rel_CoNLL04 β 34.66Β±10.52 β 52.23Β±3.16 β 58.68Β±1.77 β 48.52 β
β Event_ACE05_Trigger β 49.50Β±3.59 β 65.61Β±19.29 β 60.68Β±2.45 β 58.60 β
β Event_ACE05_Arg β 23.46Β±1.66 β 48.32Β±28.91 β 41.90Β±1.95 β 37.89 β
β ABSA_16res β 67.06Β±0.56 β 73.51Β±14.75 β 68.70Β±1.46 β 69.76 β
βββββββββββββββββββββββ΄ββββββββββββββ΄ββββββββββββββ΄βββββββββββββ΄ββββββββ
Few-shot results wPT woInst
βββββββββββββββββββββββ³ββββββββββββββ³βββββββββββββ³βββββββββββββ³ββββββββ
β Task β 1-shot β 5-shot β 10-shot β Avg. β
β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©
β Ent_CoNLL03 β 76.33Β±1.74 β 82.50Β±1.87 β 84.47Β±1.18 β 81.10 β
β woInst_Rel_CoNLL04 β 34.86Β±6.20 β 48.00Β±4.44 β 55.65Β±2.53 β 46.17 β
β Rel_CoNLL04 β 26.83Β±15.22 β 47.39Β±3.60 β 55.38Β±2.41 β 43.20 β
β Event_ACE05_Trigger β 46.60Β±1.09 β 57.21Β±3.51 β 59.67Β±3.20 β 54.49 β
β Event_ACE05_Arg β 21.60Β±3.61 β 34.43Β±3.63 β 39.62Β±2.60 β 31.88 β
β ABSA_16res β 8.10Β±18.11 β 52.73Β±5.52 β 57.32Β±1.73 β 39.38 β
βββββββββββββββββββββββ΄ββββββββββββββ΄βββββββββββββ΄βββββββββββββ΄ββββββββ
"""
|