xjlulu commited on
Commit
530d98b
1 Parent(s): 4ecf8d9

"second commit"

Browse files
__pycache__/dataset.cpython-39.pyc ADDED
Binary file (2.76 kB). View file
 
__pycache__/model.cpython-39.pyc ADDED
Binary file (2.21 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.72 kB). View file
 
app.py CHANGED
@@ -1,7 +1,127 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from typing import Dict, List
3
 
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader
8
 
9
+ import json
10
+ import pickle
11
+ from pathlib import Path
12
+
13
+ from dataset import SeqClsDataset
14
+ from utils import Vocab
15
+ from model import SeqClassifier
16
+
17
+ import ipdb
18
+
19
+ max_len = 128
20
+ hidden_size = 256
21
+ num_layers = 2
22
+ dropout = 0.1
23
+ bidirectional = True
24
+ lr = 1e-3
25
+ batch_size = 64
26
+ num_epoch = 5
27
+
28
+
29
+ TRAIN = "train"
30
+ DEV = "eval"
31
+ TEST = "test"
32
+ SPLITS = [TRAIN, DEV, TEST]
33
+
34
+ device = "cpu"
35
+ data_dir = Path("./data/intent/")
36
+ ckpt_dir = Path("./ckpt/intent/")
37
+ cache_dir = Path("./cache/intent/")
38
+ # Before executing, place intent2idx.json, embeddings.pt, vocab.pkl, and utils.py in /content
39
+ with open(cache_dir / "vocab.pkl", "rb") as f:
40
+ vocab: Vocab = pickle.load(f)
41
+ intent_idx_path = cache_dir / "intent2idx.json"
42
+ intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
43
+ data_paths = {split: data_dir / f"{split}.json" for split in SPLITS}
44
+ data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
45
+ datasets: Dict[str, SeqClsDataset] = {
46
+ split: SeqClsDataset(split_data, vocab, intent2idx, max_len)
47
+ for split, split_data in data.items()
48
+ }
49
+ #ipdb.set_trace()
50
+ test_loader = DataLoader(datasets['test'], batch_size=batch_size, shuffle=False)
51
+ embeddings = torch.load(cache_dir / "embeddings.pt")
52
+ embeddings.to(device)
53
+
54
+ # Load the best model after training
55
+ # Initialize a new model with the same architecture
56
+ best_model = SeqClassifier(
57
+ embeddings=embeddings,
58
+ hidden_size=hidden_size,
59
+ num_layers=num_layers,
60
+ dropout=dropout,
61
+ bidirectional=bidirectional,
62
+ num_class=len(intent2idx)
63
+ ).to(device)
64
+
65
+ # Define the path to the checkpoint file
66
+ ckpt_path = ckpt_dir / "model_checkpoint.pth"
67
+
68
+ # Load the model's state_dict and optimizer's state_dict from the checkpoint
69
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
70
+
71
+ # Load the model's weights
72
+ best_model.load_state_dict(checkpoint['model_state_dict']).to(device)
73
+
74
+ # Reinitialize the optimizer with the model's parameters and load its state
75
+ '''weight_decay = 1e-5
76
+ optimizer = optim.Adam(best_model.parameters(), lr=lr, weight_decay=weight_decay)
77
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])'''
78
+
79
+ # Retrieve the epoch number from the checkpoint
80
+ epoch = checkpoint['epoch']
81
+
82
+ # Set the best model to evaluation mode
83
+ best_model.eval()
84
+
85
+
86
+ dic_intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
87
+ dic_idx2label = {idx: intent for intent, idx in dic_intent2idx.items()}
88
+
89
+ def Tidx2label(idx: int):
90
+ return dic_idx2label[idx]
91
+
92
+ with open(cache_dir / "vocab.pkl", "rb") as f:
93
+ vocab: Vocab = pickle.load(f)
94
+
95
+ # 把句子做成embeddings的索引
96
+ def collate_fn(texts: str) -> torch.tensor:
97
+ # 提取所有樣本的文本數據和標籤數據
98
+ texts = texts.split()
99
+
100
+ # 使用 vocab 將文本數據轉換為整數索引序列,並指定最大長度
101
+ encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)
102
+
103
+ # 將整數索引序列轉換為 PyTorch 張量
104
+ encoded_text = torch.tensor(encoded_texts)
105
+ return encoded_text
106
+
107
+
108
+ def classify(text):
109
+ encoded_text = collate_fn(text).to(device)
110
+ output = best_model(encoded_text[0])
111
+ Predicted_class = torch.argmax(output).item()
112
+ prediction = Tidx2label(Predicted_class)
113
+ return prediction
114
+
115
+ demo = gr.Interface(
116
+ fn=classify,
117
+ inputs=gr.Textbox(placeholder="請輸入一段文字..."),
118
+ outputs="label",
119
+ interpretation="default",
120
+ examples=[
121
+ ["Take me to church"],
122
+ ["tell me what to call you"],
123
+ ["could you be a person"]
124
+ ]
125
+ )
126
+
127
+ demo.launch()
cache/intent/embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f48c2a4bb711ddd28a95f849b676ab6c76a4aeff3ba01976ccea97a4808ce790
3
+ size 7789931
cache/intent/intent2idx.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "last_maintenance": 0,
3
+ "car_rental": 1,
4
+ "transactions": 2,
5
+ "user_name": 3,
6
+ "credit_limit": 4,
7
+ "date": 5,
8
+ "greeting": 6,
9
+ "international_fees": 7,
10
+ "gas": 8,
11
+ "calculator": 9,
12
+ "redeem_rewards": 10,
13
+ "change_ai_name": 11,
14
+ "alarm": 12,
15
+ "pin_change": 13,
16
+ "update_playlist": 14,
17
+ "what_can_i_ask_you": 15,
18
+ "translate": 16,
19
+ "change_accent": 17,
20
+ "text": 18,
21
+ "thank_you": 19,
22
+ "where_are_you_from": 20,
23
+ "goodbye": 21,
24
+ "recipe": 22,
25
+ "interest_rate": 23,
26
+ "ingredients_list": 24,
27
+ "tire_pressure": 25,
28
+ "definition": 26,
29
+ "who_do_you_work_for": 27,
30
+ "todo_list": 28,
31
+ "improve_credit_score": 29,
32
+ "meaning_of_life": 30,
33
+ "change_speed": 31,
34
+ "exchange_rate": 32,
35
+ "next_holiday": 33,
36
+ "make_call": 34,
37
+ "insurance_change": 35,
38
+ "spending_history": 36,
39
+ "meal_suggestion": 37,
40
+ "fun_fact": 38,
41
+ "restaurant_suggestion": 39,
42
+ "tire_change": 40,
43
+ "calendar_update": 41,
44
+ "confirm_reservation": 42,
45
+ "next_song": 43,
46
+ "are_you_a_bot": 44,
47
+ "yes": 45,
48
+ "find_phone": 46,
49
+ "cancel_reservation": 47,
50
+ "what_is_your_name": 48,
51
+ "bill_balance": 49,
52
+ "direct_deposit": 50,
53
+ "flight_status": 51,
54
+ "order_status": 52,
55
+ "maybe": 53,
56
+ "transfer": 54,
57
+ "freeze_account": 55,
58
+ "cancel": 56,
59
+ "shopping_list": 57,
60
+ "measurement_conversion": 58,
61
+ "jump_start": 59,
62
+ "international_visa": 60,
63
+ "travel_alert": 61,
64
+ "oil_change_when": 62,
65
+ "accept_reservations": 63,
66
+ "report_lost_card": 64,
67
+ "pto_request_status": 65,
68
+ "repeat": 66,
69
+ "directions": 67,
70
+ "payday": 68,
71
+ "smart_home": 69,
72
+ "damaged_card": 70,
73
+ "lost_luggage": 71,
74
+ "carry_on": 72,
75
+ "insurance": 73,
76
+ "what_song": 74,
77
+ "current_location": 75,
78
+ "ingredient_substitution": 76,
79
+ "order": 77,
80
+ "todo_list_update": 78,
81
+ "reset_settings": 79,
82
+ "replacement_card_duration": 80,
83
+ "order_checks": 81,
84
+ "roll_dice": 82,
85
+ "new_card": 83,
86
+ "vaccines": 84,
87
+ "pto_used": 85,
88
+ "time": 86,
89
+ "how_old_are_you": 87,
90
+ "account_blocked": 88,
91
+ "card_declined": 89,
92
+ "who_made_you": 90,
93
+ "shopping_list_update": 91,
94
+ "rewards_balance": 92,
95
+ "restaurant_reviews": 93,
96
+ "change_user_name": 94,
97
+ "spelling": 95,
98
+ "nutrition_info": 96,
99
+ "restaurant_reservation": 97,
100
+ "timer": 98,
101
+ "cook_time": 99,
102
+ "whisper_mode": 100,
103
+ "travel_notification": 101,
104
+ "routing": 102,
105
+ "book_hotel": 103,
106
+ "apr": 104,
107
+ "w2": 105,
108
+ "gas_type": 106,
109
+ "schedule_meeting": 107,
110
+ "meeting_schedule": 108,
111
+ "reminder": 109,
112
+ "income": 110,
113
+ "plug_type": 111,
114
+ "what_are_your_hobbies": 112,
115
+ "schedule_maintenance": 113,
116
+ "report_fraud": 114,
117
+ "food_last": 115,
118
+ "traffic": 116,
119
+ "no": 117,
120
+ "reminder_update": 118,
121
+ "book_flight": 119,
122
+ "mpg": 120,
123
+ "pto_balance": 121,
124
+ "tell_joke": 122,
125
+ "calories": 123,
126
+ "balance": 124,
127
+ "rollover_401k": 125,
128
+ "weather": 126,
129
+ "change_language": 127,
130
+ "distance": 128,
131
+ "play_music": 129,
132
+ "min_payment": 130,
133
+ "sync_device": 131,
134
+ "pay_bill": 132,
135
+ "taxes": 133,
136
+ "share_location": 134,
137
+ "bill_due": 135,
138
+ "pto_request": 136,
139
+ "calendar": 137,
140
+ "uber": 138,
141
+ "do_you_have_pets": 139,
142
+ "change_volume": 140,
143
+ "timezone": 141,
144
+ "application_status": 142,
145
+ "flip_coin": 143,
146
+ "credit_score": 144,
147
+ "oil_change_how": 145,
148
+ "expiration_date": 146,
149
+ "credit_limit_change": 147,
150
+ "how_busy": 148,
151
+ "travel_suggestion": 149
152
+ }
cache/intent/vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d4fa520420cf60655dd67114826cef0f8be23bc7ca07cdb3c072f2a400e242b
3
+ size 78973
cache/slot/embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faba49b73dfdd2a98dbbfe7b53eed50b8edd9df716169e8f837558c5e24c42bf
3
+ size 4941099
cache/slot/tag2idx.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "O": 0,
3
+ "B-date": 1,
4
+ "I-time": 2,
5
+ "B-time": 3,
6
+ "B-last_name": 4,
7
+ "I-people": 5,
8
+ "B-people": 6,
9
+ "I-date": 7,
10
+ "B-first_name": 8
11
+ }
cache/slot/vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c711af8ba9cba928df00a20913b2bcdd0738ab3b9210b4b9f10d0ff9dcf27f16
3
+ size 49861
ckpt/intent/model_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65fdb8e191b37fc6866acd1699f8978736bfb975b176d5ee0464f43301d928e8
3
+ size 56947301
data/intent/eval.json ADDED
The diff for this file is too large to render. See raw diff
 
data/intent/test.json ADDED
The diff for this file is too large to render. See raw diff
 
data/intent/train.json ADDED
The diff for this file is too large to render. See raw diff
 
data/slot/eval.json ADDED
The diff for this file is too large to render. See raw diff
 
data/slot/test.json ADDED
The diff for this file is too large to render. See raw diff
 
data/slot/train.json ADDED
The diff for this file is too large to render. See raw diff
 
dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+ import torch
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+ from utils import Vocab
8
+
9
+
10
+ class SeqClsDataset(Dataset):
11
+ def __init__(
12
+ self,
13
+ data: List[Dict],
14
+ vocab: Vocab,
15
+ label_mapping: Dict[str, int],
16
+ max_len: int,
17
+ ):
18
+ self.data = data
19
+ self.vocab = vocab
20
+ self.label_mapping = label_mapping
21
+ self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()}
22
+ self.max_len = max_len
23
+
24
+ def __len__(self) -> int:
25
+ return len(self.data)
26
+
27
+ def __getitem__(self, index) -> Dict:
28
+ instance = self.data[index]
29
+ return instance
30
+
31
+ @property
32
+ def num_classes(self) -> int:
33
+ return len(self.label_mapping)
34
+
35
+ def collate_fn(self, samples: List[Dict]) -> Dict:
36
+ # sample就是batch data
37
+ # collate_fn幫你把batch data編碼成詞彙的索引
38
+ # batch[0] = {'text': '~', 'intent': '~', 'id': 'train-0'}
39
+
40
+ # 提取所有樣本的文本數據和標籤數據
41
+ texts = samples["text"]
42
+ labels = samples["intent"]
43
+
44
+ # 使用 vocab 將文本數據轉換為整數索引序列,並指定最大長度
45
+ encoded_texts = self.vocab.encode_batch([text.split() for text in texts], to_len=self.max_len)
46
+
47
+ # 將標籤數據轉換為整數索引序列
48
+ encoded_labels = [self.label_mapping[label] for label in labels]
49
+
50
+ # 將整數索引序列轉換為 PyTorch 張量
51
+ encoded_text = torch.tensor(encoded_texts)
52
+ encoded_label = torch.tensor(encoded_labels)
53
+
54
+ # 創建批次數據字典
55
+ batch_data = {
56
+ "encoded_text": encoded_text,
57
+ "encoded_label": encoded_label
58
+ }
59
+
60
+ return batch_data
61
+
62
+ def label2idx(self, label: str):
63
+ return self.label_mapping[label]
64
+
65
+ def idx2label(self, idx: int):
66
+ return self._idx2label[idx]
67
+
68
+
69
+ class SeqTaggingClsDataset(SeqClsDataset):
70
+ ignore_idx = -100
71
+
72
+ def collate_fn(self, samples):
73
+ # TODO: implement collate_fn
74
+ raise NotImplementedError
environment.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: adl-hw1
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - python=3.9
6
+ - cudatoolkit=10.2
7
+ - cudnn=7.6
8
+ - pip
9
+ - pip:
10
+ - pip-tools
model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # Set device
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ class SeqClassifier(nn.Module):
10
+ def __init__(
11
+ self,
12
+ embeddings: torch.tensor,
13
+ hidden_size: int,
14
+ num_layers: int,
15
+ dropout: float,
16
+ bidirectional: bool,
17
+ num_class: int,
18
+ ) -> None:
19
+ super(SeqClassifier, self).__init__()
20
+ self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False)
21
+ self.hidden_size=hidden_size
22
+ self.num_layers=num_layers
23
+ self.dropout=dropout
24
+ self.bidirectional=bidirectional
25
+ self.num_class=num_class
26
+
27
+ # model architecture
28
+ self.rnn = nn.LSTM(
29
+ input_size=embeddings.size(1),
30
+ hidden_size=hidden_size,
31
+ num_layers=num_layers,
32
+ dropout=dropout,
33
+ bidirectional=bidirectional,
34
+ batch_first=True
35
+ )
36
+ self.dropout_layer = nn.Dropout(p=self.dropout)
37
+ self.fc = nn.Linear(self.encoder_output_size, num_class)
38
+
39
+ @property
40
+ def encoder_output_size(self) -> int:
41
+ # calculate the output dimension of rnn
42
+ if self.bidirectional:
43
+ return self.hidden_size * 2
44
+ else:
45
+ return self.hidden_size
46
+
47
+ def forward(self, batch) -> torch.Tensor:
48
+ # 將輸入嵌入到詞嵌入空間,就是把詞索引換成詞向量
49
+ embedded = self.embed(batch)
50
+
51
+ # 過 LSTM 層
52
+ rnn_output, _ = self.rnn(embedded)
53
+ rnn_output = self.dropout_layer(rnn_output)
54
+
55
+ if not self.training:
56
+ last_hidden_state_forward = rnn_output[ -1, :self.hidden_size] # 正向方向的隐藏状态
57
+ last_hidden_state_backward = rnn_output[ 0, self.hidden_size:] # 反向方向的隐藏状态
58
+ combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=0)
59
+
60
+ # 通過全連接層
61
+ logits = self.fc(combined_hidden_state)
62
+ return logits # 返回預測結果
63
+
64
+ last_hidden_state_forward = rnn_output[:, -1, :self.hidden_size] # 正向方向的隐藏状态
65
+ last_hidden_state_backward = rnn_output[:, 0, self.hidden_size:] # 反向方向的隐藏状态
66
+ combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=1)
67
+
68
+ # 通過全連接層
69
+ logits = self.fc(combined_hidden_state)
70
+ return logits # 返回預測結果
71
+
72
+
73
+ class SeqTagger(SeqClassifier):
74
+ def forward(self, batch) -> Dict[str, torch.Tensor]:
75
+ # TODO: implement model forward
76
+ raise NotImplementedError
utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, List
2
+
3
+ class Vocab:
4
+ PAD = "[PAD]"
5
+ UNK = "[UNK]"
6
+
7
+ def __init__(self, vocab: Iterable[str]) -> None:
8
+ self.token2idx = {
9
+ Vocab.PAD: 0,
10
+ Vocab.UNK: 1,
11
+ **{token: i for i, token in enumerate(vocab, 2)},
12
+ }
13
+
14
+ @property
15
+ def pad_id(self) -> int:
16
+ return self.token2idx[Vocab.PAD]
17
+
18
+ @property
19
+ def unk_id(self) -> int:
20
+ return self.token2idx[Vocab.UNK]
21
+
22
+ @property
23
+ def tokens(self) -> List[str]:
24
+ return list(self.token2idx.keys())
25
+
26
+ def token_to_id(self, token: str) -> int:
27
+ return self.token2idx.get(token, self.unk_id)
28
+
29
+ def encode(self, tokens: List[str]) -> List[int]:
30
+ return [self.token_to_id(token) for token in tokens]
31
+
32
+ def encode_batch(
33
+ self, batch_tokens: List[List[str]], to_len: int = None
34
+ ) -> List[List[int]]:
35
+ batch_ids = [self.encode(tokens) for tokens in batch_tokens]
36
+ to_len = max(len(ids) for ids in batch_ids) if to_len is None else to_len
37
+ padded_ids = pad_to_len(batch_ids, to_len, self.pad_id)
38
+ return padded_ids
39
+
40
+ def pad_to_len(seqs: List[List[int]], to_len: int, padding: int) -> List[List[int]]:
41
+ paddeds = [seq[:to_len] + [padding] * max(0, to_len - len(seq)) for seq in seqs]
42
+ return paddeds