Spaces:
Runtime error
Runtime error
"second commit"
Browse files- __pycache__/dataset.cpython-39.pyc +0 -0
- __pycache__/model.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +124 -4
- cache/intent/embeddings.pt +3 -0
- cache/intent/intent2idx.json +152 -0
- cache/intent/vocab.pkl +3 -0
- cache/slot/embeddings.pt +3 -0
- cache/slot/tag2idx.json +11 -0
- cache/slot/vocab.pkl +3 -0
- ckpt/intent/model_checkpoint.pth +3 -0
- data/intent/eval.json +0 -0
- data/intent/test.json +0 -0
- data/intent/train.json +0 -0
- data/slot/eval.json +0 -0
- data/slot/test.json +0 -0
- data/slot/train.json +0 -0
- dataset.py +74 -0
- environment.yml +10 -0
- model.py +76 -0
- utils.py +42 -0
__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (2.76 kB). View file
|
|
__pycache__/model.cpython-39.pyc
ADDED
Binary file (2.21 kB). View file
|
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (2.72 kB). View file
|
|
app.py
CHANGED
@@ -1,7 +1,127 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from typing import Dict, List
|
3 |
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
|
9 |
+
import json
|
10 |
+
import pickle
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
from dataset import SeqClsDataset
|
14 |
+
from utils import Vocab
|
15 |
+
from model import SeqClassifier
|
16 |
+
|
17 |
+
import ipdb
|
18 |
+
|
19 |
+
max_len = 128
|
20 |
+
hidden_size = 256
|
21 |
+
num_layers = 2
|
22 |
+
dropout = 0.1
|
23 |
+
bidirectional = True
|
24 |
+
lr = 1e-3
|
25 |
+
batch_size = 64
|
26 |
+
num_epoch = 5
|
27 |
+
|
28 |
+
|
29 |
+
TRAIN = "train"
|
30 |
+
DEV = "eval"
|
31 |
+
TEST = "test"
|
32 |
+
SPLITS = [TRAIN, DEV, TEST]
|
33 |
+
|
34 |
+
device = "cpu"
|
35 |
+
data_dir = Path("./data/intent/")
|
36 |
+
ckpt_dir = Path("./ckpt/intent/")
|
37 |
+
cache_dir = Path("./cache/intent/")
|
38 |
+
# Before executing, place intent2idx.json, embeddings.pt, vocab.pkl, and utils.py in /content
|
39 |
+
with open(cache_dir / "vocab.pkl", "rb") as f:
|
40 |
+
vocab: Vocab = pickle.load(f)
|
41 |
+
intent_idx_path = cache_dir / "intent2idx.json"
|
42 |
+
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
|
43 |
+
data_paths = {split: data_dir / f"{split}.json" for split in SPLITS}
|
44 |
+
data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
|
45 |
+
datasets: Dict[str, SeqClsDataset] = {
|
46 |
+
split: SeqClsDataset(split_data, vocab, intent2idx, max_len)
|
47 |
+
for split, split_data in data.items()
|
48 |
+
}
|
49 |
+
#ipdb.set_trace()
|
50 |
+
test_loader = DataLoader(datasets['test'], batch_size=batch_size, shuffle=False)
|
51 |
+
embeddings = torch.load(cache_dir / "embeddings.pt")
|
52 |
+
embeddings.to(device)
|
53 |
+
|
54 |
+
# Load the best model after training
|
55 |
+
# Initialize a new model with the same architecture
|
56 |
+
best_model = SeqClassifier(
|
57 |
+
embeddings=embeddings,
|
58 |
+
hidden_size=hidden_size,
|
59 |
+
num_layers=num_layers,
|
60 |
+
dropout=dropout,
|
61 |
+
bidirectional=bidirectional,
|
62 |
+
num_class=len(intent2idx)
|
63 |
+
).to(device)
|
64 |
+
|
65 |
+
# Define the path to the checkpoint file
|
66 |
+
ckpt_path = ckpt_dir / "model_checkpoint.pth"
|
67 |
+
|
68 |
+
# Load the model's state_dict and optimizer's state_dict from the checkpoint
|
69 |
+
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
70 |
+
|
71 |
+
# Load the model's weights
|
72 |
+
best_model.load_state_dict(checkpoint['model_state_dict']).to(device)
|
73 |
+
|
74 |
+
# Reinitialize the optimizer with the model's parameters and load its state
|
75 |
+
'''weight_decay = 1e-5
|
76 |
+
optimizer = optim.Adam(best_model.parameters(), lr=lr, weight_decay=weight_decay)
|
77 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])'''
|
78 |
+
|
79 |
+
# Retrieve the epoch number from the checkpoint
|
80 |
+
epoch = checkpoint['epoch']
|
81 |
+
|
82 |
+
# Set the best model to evaluation mode
|
83 |
+
best_model.eval()
|
84 |
+
|
85 |
+
|
86 |
+
dic_intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
|
87 |
+
dic_idx2label = {idx: intent for intent, idx in dic_intent2idx.items()}
|
88 |
+
|
89 |
+
def Tidx2label(idx: int):
|
90 |
+
return dic_idx2label[idx]
|
91 |
+
|
92 |
+
with open(cache_dir / "vocab.pkl", "rb") as f:
|
93 |
+
vocab: Vocab = pickle.load(f)
|
94 |
+
|
95 |
+
# 把句子做成embeddings的索引
|
96 |
+
def collate_fn(texts: str) -> torch.tensor:
|
97 |
+
# 提取所有樣本的文本數據和標籤數據
|
98 |
+
texts = texts.split()
|
99 |
+
|
100 |
+
# 使用 vocab 將文本數據轉換為整數索引序列,並指定最大長度
|
101 |
+
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)
|
102 |
+
|
103 |
+
# 將整數索引序列轉換為 PyTorch 張量
|
104 |
+
encoded_text = torch.tensor(encoded_texts)
|
105 |
+
return encoded_text
|
106 |
+
|
107 |
+
|
108 |
+
def classify(text):
|
109 |
+
encoded_text = collate_fn(text).to(device)
|
110 |
+
output = best_model(encoded_text[0])
|
111 |
+
Predicted_class = torch.argmax(output).item()
|
112 |
+
prediction = Tidx2label(Predicted_class)
|
113 |
+
return prediction
|
114 |
+
|
115 |
+
demo = gr.Interface(
|
116 |
+
fn=classify,
|
117 |
+
inputs=gr.Textbox(placeholder="請輸入一段文字..."),
|
118 |
+
outputs="label",
|
119 |
+
interpretation="default",
|
120 |
+
examples=[
|
121 |
+
["Take me to church"],
|
122 |
+
["tell me what to call you"],
|
123 |
+
["could you be a person"]
|
124 |
+
]
|
125 |
+
)
|
126 |
+
|
127 |
+
demo.launch()
|
cache/intent/embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f48c2a4bb711ddd28a95f849b676ab6c76a4aeff3ba01976ccea97a4808ce790
|
3 |
+
size 7789931
|
cache/intent/intent2idx.json
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"last_maintenance": 0,
|
3 |
+
"car_rental": 1,
|
4 |
+
"transactions": 2,
|
5 |
+
"user_name": 3,
|
6 |
+
"credit_limit": 4,
|
7 |
+
"date": 5,
|
8 |
+
"greeting": 6,
|
9 |
+
"international_fees": 7,
|
10 |
+
"gas": 8,
|
11 |
+
"calculator": 9,
|
12 |
+
"redeem_rewards": 10,
|
13 |
+
"change_ai_name": 11,
|
14 |
+
"alarm": 12,
|
15 |
+
"pin_change": 13,
|
16 |
+
"update_playlist": 14,
|
17 |
+
"what_can_i_ask_you": 15,
|
18 |
+
"translate": 16,
|
19 |
+
"change_accent": 17,
|
20 |
+
"text": 18,
|
21 |
+
"thank_you": 19,
|
22 |
+
"where_are_you_from": 20,
|
23 |
+
"goodbye": 21,
|
24 |
+
"recipe": 22,
|
25 |
+
"interest_rate": 23,
|
26 |
+
"ingredients_list": 24,
|
27 |
+
"tire_pressure": 25,
|
28 |
+
"definition": 26,
|
29 |
+
"who_do_you_work_for": 27,
|
30 |
+
"todo_list": 28,
|
31 |
+
"improve_credit_score": 29,
|
32 |
+
"meaning_of_life": 30,
|
33 |
+
"change_speed": 31,
|
34 |
+
"exchange_rate": 32,
|
35 |
+
"next_holiday": 33,
|
36 |
+
"make_call": 34,
|
37 |
+
"insurance_change": 35,
|
38 |
+
"spending_history": 36,
|
39 |
+
"meal_suggestion": 37,
|
40 |
+
"fun_fact": 38,
|
41 |
+
"restaurant_suggestion": 39,
|
42 |
+
"tire_change": 40,
|
43 |
+
"calendar_update": 41,
|
44 |
+
"confirm_reservation": 42,
|
45 |
+
"next_song": 43,
|
46 |
+
"are_you_a_bot": 44,
|
47 |
+
"yes": 45,
|
48 |
+
"find_phone": 46,
|
49 |
+
"cancel_reservation": 47,
|
50 |
+
"what_is_your_name": 48,
|
51 |
+
"bill_balance": 49,
|
52 |
+
"direct_deposit": 50,
|
53 |
+
"flight_status": 51,
|
54 |
+
"order_status": 52,
|
55 |
+
"maybe": 53,
|
56 |
+
"transfer": 54,
|
57 |
+
"freeze_account": 55,
|
58 |
+
"cancel": 56,
|
59 |
+
"shopping_list": 57,
|
60 |
+
"measurement_conversion": 58,
|
61 |
+
"jump_start": 59,
|
62 |
+
"international_visa": 60,
|
63 |
+
"travel_alert": 61,
|
64 |
+
"oil_change_when": 62,
|
65 |
+
"accept_reservations": 63,
|
66 |
+
"report_lost_card": 64,
|
67 |
+
"pto_request_status": 65,
|
68 |
+
"repeat": 66,
|
69 |
+
"directions": 67,
|
70 |
+
"payday": 68,
|
71 |
+
"smart_home": 69,
|
72 |
+
"damaged_card": 70,
|
73 |
+
"lost_luggage": 71,
|
74 |
+
"carry_on": 72,
|
75 |
+
"insurance": 73,
|
76 |
+
"what_song": 74,
|
77 |
+
"current_location": 75,
|
78 |
+
"ingredient_substitution": 76,
|
79 |
+
"order": 77,
|
80 |
+
"todo_list_update": 78,
|
81 |
+
"reset_settings": 79,
|
82 |
+
"replacement_card_duration": 80,
|
83 |
+
"order_checks": 81,
|
84 |
+
"roll_dice": 82,
|
85 |
+
"new_card": 83,
|
86 |
+
"vaccines": 84,
|
87 |
+
"pto_used": 85,
|
88 |
+
"time": 86,
|
89 |
+
"how_old_are_you": 87,
|
90 |
+
"account_blocked": 88,
|
91 |
+
"card_declined": 89,
|
92 |
+
"who_made_you": 90,
|
93 |
+
"shopping_list_update": 91,
|
94 |
+
"rewards_balance": 92,
|
95 |
+
"restaurant_reviews": 93,
|
96 |
+
"change_user_name": 94,
|
97 |
+
"spelling": 95,
|
98 |
+
"nutrition_info": 96,
|
99 |
+
"restaurant_reservation": 97,
|
100 |
+
"timer": 98,
|
101 |
+
"cook_time": 99,
|
102 |
+
"whisper_mode": 100,
|
103 |
+
"travel_notification": 101,
|
104 |
+
"routing": 102,
|
105 |
+
"book_hotel": 103,
|
106 |
+
"apr": 104,
|
107 |
+
"w2": 105,
|
108 |
+
"gas_type": 106,
|
109 |
+
"schedule_meeting": 107,
|
110 |
+
"meeting_schedule": 108,
|
111 |
+
"reminder": 109,
|
112 |
+
"income": 110,
|
113 |
+
"plug_type": 111,
|
114 |
+
"what_are_your_hobbies": 112,
|
115 |
+
"schedule_maintenance": 113,
|
116 |
+
"report_fraud": 114,
|
117 |
+
"food_last": 115,
|
118 |
+
"traffic": 116,
|
119 |
+
"no": 117,
|
120 |
+
"reminder_update": 118,
|
121 |
+
"book_flight": 119,
|
122 |
+
"mpg": 120,
|
123 |
+
"pto_balance": 121,
|
124 |
+
"tell_joke": 122,
|
125 |
+
"calories": 123,
|
126 |
+
"balance": 124,
|
127 |
+
"rollover_401k": 125,
|
128 |
+
"weather": 126,
|
129 |
+
"change_language": 127,
|
130 |
+
"distance": 128,
|
131 |
+
"play_music": 129,
|
132 |
+
"min_payment": 130,
|
133 |
+
"sync_device": 131,
|
134 |
+
"pay_bill": 132,
|
135 |
+
"taxes": 133,
|
136 |
+
"share_location": 134,
|
137 |
+
"bill_due": 135,
|
138 |
+
"pto_request": 136,
|
139 |
+
"calendar": 137,
|
140 |
+
"uber": 138,
|
141 |
+
"do_you_have_pets": 139,
|
142 |
+
"change_volume": 140,
|
143 |
+
"timezone": 141,
|
144 |
+
"application_status": 142,
|
145 |
+
"flip_coin": 143,
|
146 |
+
"credit_score": 144,
|
147 |
+
"oil_change_how": 145,
|
148 |
+
"expiration_date": 146,
|
149 |
+
"credit_limit_change": 147,
|
150 |
+
"how_busy": 148,
|
151 |
+
"travel_suggestion": 149
|
152 |
+
}
|
cache/intent/vocab.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d4fa520420cf60655dd67114826cef0f8be23bc7ca07cdb3c072f2a400e242b
|
3 |
+
size 78973
|
cache/slot/embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:faba49b73dfdd2a98dbbfe7b53eed50b8edd9df716169e8f837558c5e24c42bf
|
3 |
+
size 4941099
|
cache/slot/tag2idx.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"O": 0,
|
3 |
+
"B-date": 1,
|
4 |
+
"I-time": 2,
|
5 |
+
"B-time": 3,
|
6 |
+
"B-last_name": 4,
|
7 |
+
"I-people": 5,
|
8 |
+
"B-people": 6,
|
9 |
+
"I-date": 7,
|
10 |
+
"B-first_name": 8
|
11 |
+
}
|
cache/slot/vocab.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c711af8ba9cba928df00a20913b2bcdd0738ab3b9210b4b9f10d0ff9dcf27f16
|
3 |
+
size 49861
|
ckpt/intent/model_checkpoint.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65fdb8e191b37fc6866acd1699f8978736bfb975b176d5ee0464f43301d928e8
|
3 |
+
size 56947301
|
data/intent/eval.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/intent/test.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/intent/train.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/slot/eval.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/slot/test.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/slot/train.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
dataset.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from utils import Vocab
|
8 |
+
|
9 |
+
|
10 |
+
class SeqClsDataset(Dataset):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
data: List[Dict],
|
14 |
+
vocab: Vocab,
|
15 |
+
label_mapping: Dict[str, int],
|
16 |
+
max_len: int,
|
17 |
+
):
|
18 |
+
self.data = data
|
19 |
+
self.vocab = vocab
|
20 |
+
self.label_mapping = label_mapping
|
21 |
+
self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()}
|
22 |
+
self.max_len = max_len
|
23 |
+
|
24 |
+
def __len__(self) -> int:
|
25 |
+
return len(self.data)
|
26 |
+
|
27 |
+
def __getitem__(self, index) -> Dict:
|
28 |
+
instance = self.data[index]
|
29 |
+
return instance
|
30 |
+
|
31 |
+
@property
|
32 |
+
def num_classes(self) -> int:
|
33 |
+
return len(self.label_mapping)
|
34 |
+
|
35 |
+
def collate_fn(self, samples: List[Dict]) -> Dict:
|
36 |
+
# sample就是batch data
|
37 |
+
# collate_fn幫你把batch data編碼成詞彙的索引
|
38 |
+
# batch[0] = {'text': '~', 'intent': '~', 'id': 'train-0'}
|
39 |
+
|
40 |
+
# 提取所有樣本的文本數據和標籤數據
|
41 |
+
texts = samples["text"]
|
42 |
+
labels = samples["intent"]
|
43 |
+
|
44 |
+
# 使用 vocab 將文本數據轉換為整數索引序列,並指定最大長度
|
45 |
+
encoded_texts = self.vocab.encode_batch([text.split() for text in texts], to_len=self.max_len)
|
46 |
+
|
47 |
+
# 將標籤數據轉換為整數索引序列
|
48 |
+
encoded_labels = [self.label_mapping[label] for label in labels]
|
49 |
+
|
50 |
+
# 將整數索引序列轉換為 PyTorch 張量
|
51 |
+
encoded_text = torch.tensor(encoded_texts)
|
52 |
+
encoded_label = torch.tensor(encoded_labels)
|
53 |
+
|
54 |
+
# 創建批次數據字典
|
55 |
+
batch_data = {
|
56 |
+
"encoded_text": encoded_text,
|
57 |
+
"encoded_label": encoded_label
|
58 |
+
}
|
59 |
+
|
60 |
+
return batch_data
|
61 |
+
|
62 |
+
def label2idx(self, label: str):
|
63 |
+
return self.label_mapping[label]
|
64 |
+
|
65 |
+
def idx2label(self, idx: int):
|
66 |
+
return self._idx2label[idx]
|
67 |
+
|
68 |
+
|
69 |
+
class SeqTaggingClsDataset(SeqClsDataset):
|
70 |
+
ignore_idx = -100
|
71 |
+
|
72 |
+
def collate_fn(self, samples):
|
73 |
+
# TODO: implement collate_fn
|
74 |
+
raise NotImplementedError
|
environment.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: adl-hw1
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- python=3.9
|
6 |
+
- cudatoolkit=10.2
|
7 |
+
- cudnn=7.6
|
8 |
+
- pip
|
9 |
+
- pip:
|
10 |
+
- pip-tools
|
model.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
# Set device
|
7 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
+
|
9 |
+
class SeqClassifier(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
embeddings: torch.tensor,
|
13 |
+
hidden_size: int,
|
14 |
+
num_layers: int,
|
15 |
+
dropout: float,
|
16 |
+
bidirectional: bool,
|
17 |
+
num_class: int,
|
18 |
+
) -> None:
|
19 |
+
super(SeqClassifier, self).__init__()
|
20 |
+
self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False)
|
21 |
+
self.hidden_size=hidden_size
|
22 |
+
self.num_layers=num_layers
|
23 |
+
self.dropout=dropout
|
24 |
+
self.bidirectional=bidirectional
|
25 |
+
self.num_class=num_class
|
26 |
+
|
27 |
+
# model architecture
|
28 |
+
self.rnn = nn.LSTM(
|
29 |
+
input_size=embeddings.size(1),
|
30 |
+
hidden_size=hidden_size,
|
31 |
+
num_layers=num_layers,
|
32 |
+
dropout=dropout,
|
33 |
+
bidirectional=bidirectional,
|
34 |
+
batch_first=True
|
35 |
+
)
|
36 |
+
self.dropout_layer = nn.Dropout(p=self.dropout)
|
37 |
+
self.fc = nn.Linear(self.encoder_output_size, num_class)
|
38 |
+
|
39 |
+
@property
|
40 |
+
def encoder_output_size(self) -> int:
|
41 |
+
# calculate the output dimension of rnn
|
42 |
+
if self.bidirectional:
|
43 |
+
return self.hidden_size * 2
|
44 |
+
else:
|
45 |
+
return self.hidden_size
|
46 |
+
|
47 |
+
def forward(self, batch) -> torch.Tensor:
|
48 |
+
# 將輸入嵌入到詞嵌入空間,就是把詞索引換成詞向量
|
49 |
+
embedded = self.embed(batch)
|
50 |
+
|
51 |
+
# 過 LSTM 層
|
52 |
+
rnn_output, _ = self.rnn(embedded)
|
53 |
+
rnn_output = self.dropout_layer(rnn_output)
|
54 |
+
|
55 |
+
if not self.training:
|
56 |
+
last_hidden_state_forward = rnn_output[ -1, :self.hidden_size] # 正向方向的隐藏状态
|
57 |
+
last_hidden_state_backward = rnn_output[ 0, self.hidden_size:] # 反向方向的隐藏状态
|
58 |
+
combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=0)
|
59 |
+
|
60 |
+
# 通過全連接層
|
61 |
+
logits = self.fc(combined_hidden_state)
|
62 |
+
return logits # 返回預測結果
|
63 |
+
|
64 |
+
last_hidden_state_forward = rnn_output[:, -1, :self.hidden_size] # 正向方向的隐藏状态
|
65 |
+
last_hidden_state_backward = rnn_output[:, 0, self.hidden_size:] # 反向方向的隐藏状态
|
66 |
+
combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=1)
|
67 |
+
|
68 |
+
# 通過全連接層
|
69 |
+
logits = self.fc(combined_hidden_state)
|
70 |
+
return logits # 返回預測結果
|
71 |
+
|
72 |
+
|
73 |
+
class SeqTagger(SeqClassifier):
|
74 |
+
def forward(self, batch) -> Dict[str, torch.Tensor]:
|
75 |
+
# TODO: implement model forward
|
76 |
+
raise NotImplementedError
|
utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable, List
|
2 |
+
|
3 |
+
class Vocab:
|
4 |
+
PAD = "[PAD]"
|
5 |
+
UNK = "[UNK]"
|
6 |
+
|
7 |
+
def __init__(self, vocab: Iterable[str]) -> None:
|
8 |
+
self.token2idx = {
|
9 |
+
Vocab.PAD: 0,
|
10 |
+
Vocab.UNK: 1,
|
11 |
+
**{token: i for i, token in enumerate(vocab, 2)},
|
12 |
+
}
|
13 |
+
|
14 |
+
@property
|
15 |
+
def pad_id(self) -> int:
|
16 |
+
return self.token2idx[Vocab.PAD]
|
17 |
+
|
18 |
+
@property
|
19 |
+
def unk_id(self) -> int:
|
20 |
+
return self.token2idx[Vocab.UNK]
|
21 |
+
|
22 |
+
@property
|
23 |
+
def tokens(self) -> List[str]:
|
24 |
+
return list(self.token2idx.keys())
|
25 |
+
|
26 |
+
def token_to_id(self, token: str) -> int:
|
27 |
+
return self.token2idx.get(token, self.unk_id)
|
28 |
+
|
29 |
+
def encode(self, tokens: List[str]) -> List[int]:
|
30 |
+
return [self.token_to_id(token) for token in tokens]
|
31 |
+
|
32 |
+
def encode_batch(
|
33 |
+
self, batch_tokens: List[List[str]], to_len: int = None
|
34 |
+
) -> List[List[int]]:
|
35 |
+
batch_ids = [self.encode(tokens) for tokens in batch_tokens]
|
36 |
+
to_len = max(len(ids) for ids in batch_ids) if to_len is None else to_len
|
37 |
+
padded_ids = pad_to_len(batch_ids, to_len, self.pad_id)
|
38 |
+
return padded_ids
|
39 |
+
|
40 |
+
def pad_to_len(seqs: List[List[int]], to_len: int, padding: int) -> List[List[int]]:
|
41 |
+
paddeds = [seq[:to_len] + [padding] * max(0, to_len - len(seq)) for seq in seqs]
|
42 |
+
return paddeds
|