minhdang14902 commited on
Commit
9d212d1
·
verified ·
1 Parent(s): 5bb1928

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +432 -0
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
+ import nltk
5
+ from transformers.models.roberta.modeling_roberta import *
6
+ from transformers import RobertaForQuestionAnswering
7
+ from nltk import word_tokenize
8
+ import json
9
+ import pandas as pd
10
+ # import re
11
+ import base64
12
+ # Set the background image
13
+ # background_image = """
14
+ # <style>
15
+ # [data-testid="stAppViewContainer"] > .main {
16
+ # background-image: url("https://images.unsplash.com/photo-1542281286-9e0a16bb7366");
17
+ # background-size: 100vw 100vh; # This sets the size to cover 100% of the viewport width and height
18
+ # background-position: center;
19
+ # background-repeat: no-repeat;
20
+ # }
21
+ # </style>
22
+ # """
23
+ # st.markdown(background_image, unsafe_allow_html=True)
24
+
25
+ # def set_bg_hack(main_bg):
26
+ # '''
27
+ # A function to unpack an image from root folder and set as bg.
28
+
29
+ # Returns
30
+ # -------
31
+ # The background.
32
+ # '''
33
+ # # set bg name
34
+ # main_bg_ext = "png"
35
+
36
+ # st.markdown(
37
+ # f"""
38
+ # <style>
39
+ # .stApp {{
40
+ # background: url(data:image/{main_bg_ext};base64,{base64.b64encode(open(main_bg, "rb").read()).decode()});
41
+ # background-size: cover
42
+ # }}
43
+ # </style>
44
+ # """,
45
+ # unsafe_allow_html=True
46
+ # )
47
+ # set_bg_hack("Background.png")
48
+
49
+ # image_url = "logo1.png"
50
+
51
+ # # Hiển thị hình ảnh mà không có caption và điều chỉnh kích thước nhỏ lại
52
+ # st.image(image_url, width=100)
53
+
54
+ # Download punkt for nltk
55
+ print("===================================================================")
56
+ @st.cache_data
57
+ def download_nltk_punkt():
58
+ nltk.download('punkt_tab')
59
+
60
+ # Cache loading PhoBert model and tokenizer
61
+ @st.cache_data
62
+ def load_phoBert():
63
+ model = AutoModelForSequenceClassification.from_pretrained('minhdang14902/Phobert_Law')
64
+ tokenizer = AutoTokenizer.from_pretrained('minhdang14902/Phobert_Law')
65
+ return model, tokenizer
66
+
67
+
68
+
69
+ # Call the cached functions
70
+ download_nltk_punkt()
71
+ phoBert_model, phoBert_tokenizer = load_phoBert()
72
+
73
+ # Initialize the pipeline with the loaded PhoBert model and tokenizer
74
+ chatbot_pipeline = pipeline("sentiment-analysis", model=phoBert_model, tokenizer=phoBert_tokenizer)
75
+
76
+ # Load spaCy Vietnamese model
77
+ # nlp = spacy.load('vi_core_news_lg')
78
+
79
+ # Load intents from json file
80
+ def load_json_file(filename):
81
+ with open(filename) as f:
82
+ file = json.load(f)
83
+ return file
84
+
85
+ filename = './Law_2907.json'
86
+ intents = load_json_file(filename)
87
+
88
+ def create_df():
89
+ df = pd.DataFrame({
90
+ 'Pattern': [],
91
+ 'Tag': []
92
+ })
93
+ return df
94
+
95
+ df = create_df()
96
+
97
+ def extract_json_info(json_file, df):
98
+ for intent in json_file['intents']:
99
+ for pattern in intent['patterns']:
100
+ sentence_tag = [pattern, intent['tag']]
101
+ df.loc[len(df.index)] = sentence_tag
102
+ return df
103
+
104
+ df = extract_json_info(intents, df)
105
+ df2 = df.copy()
106
+
107
+ labels = df2['Tag'].unique().tolist()
108
+ labels = [s.strip() for s in labels]
109
+ num_labels = len(labels)
110
+ id2label = {id: label for id, label in enumerate(labels)}
111
+ label2id = {label: id for id, label in enumerate(labels)}
112
+
113
+ # def tokenize_with_spacy(text):
114
+ # doc = nlp(text)
115
+ # tokens = [token.text for token in doc]
116
+ # tokenized_text = ' '.join(tokens)
117
+ # tokenized_text = re.sub(r'(?<!\s)([.,?])', r' \1', tokenized_text)
118
+ # tokenized_text = re.sub(r'([.,?])(?!\s)', r'\1 ', tokenized_text)
119
+ # return tokenized_text
120
+
121
+ # Load Roberta model and tokenizer
122
+
123
+ _CHECKPOINT_FOR_DOC = "roberta-base"
124
+ _CONFIG_FOR_DOC = "RobertaConfig"
125
+ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
126
+
127
+
128
+ class MRCQuestionAnswering(RobertaPreTrainedModel):
129
+ config_class = RobertaConfig
130
+
131
+ def _reorder_cache(self, past, beam_idx):
132
+ pass
133
+
134
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
135
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
136
+
137
+ def __init__(self, config):
138
+ super().__init__(config)
139
+ self.num_labels = config.num_labels
140
+
141
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
142
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
143
+
144
+ self.init_weights()
145
+
146
+ def forward(
147
+ self,
148
+ input_ids=None,
149
+ words_lengths=None,
150
+ start_idx=None,
151
+ end_idx=None,
152
+ attention_mask=None,
153
+ token_type_ids=None,
154
+ position_ids=None,
155
+ head_mask=None,
156
+ inputs_embeds=None,
157
+ start_positions=None,
158
+ end_positions=None,
159
+ span_answer_ids=None,
160
+ output_attentions=None,
161
+ output_hidden_states=None,
162
+ return_dict=None,
163
+ ):
164
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
165
+
166
+ outputs = self.roberta(
167
+ input_ids,
168
+ attention_mask=attention_mask,
169
+ token_type_ids=None, # Roberta doesn't use token_type_ids
170
+ position_ids=position_ids,
171
+ head_mask=head_mask,
172
+ inputs_embeds=inputs_embeds,
173
+ output_attentions=output_attentions,
174
+ output_hidden_states=output_hidden_states,
175
+ return_dict=return_dict,
176
+ )
177
+
178
+ sequence_output = outputs[0]
179
+
180
+ context_embedding = sequence_output
181
+
182
+ batch_size = input_ids.shape[0]
183
+ max_sub_word = input_ids.shape[1]
184
+ max_word = words_lengths.shape[1]
185
+ align_matrix = torch.zeros((batch_size, max_word, max_sub_word))
186
+
187
+ for i, sample_length in enumerate(words_lengths):
188
+ for j in range(len(sample_length)):
189
+ start_idx = torch.sum(sample_length[:j])
190
+ align_matrix[i][j][start_idx: start_idx + sample_length[j]] = 1 if sample_length[j] > 0 else 0
191
+
192
+ align_matrix = align_matrix.to(context_embedding.device)
193
+ context_embedding_align = torch.bmm(align_matrix, context_embedding)
194
+
195
+ logits = self.qa_outputs(context_embedding_align)
196
+ start_logits, end_logits = logits.split(1, dim=-1)
197
+ start_logits = start_logits.squeeze(-1).contiguous()
198
+ end_logits = end_logits.squeeze(-1).contiguous()
199
+
200
+ total_loss = None
201
+ if start_positions is not None and end_positions is not None:
202
+ if len(start_positions.size()) > 1:
203
+ start_positions = start_positions.squeeze(-1)
204
+ if len(end_positions.size()) > 1:
205
+ end_positions = end_positions.squeeze(-1)
206
+ ignored_index = start_logits.size(1)
207
+ start_positions = start_positions.clamp(0, ignored_index)
208
+ end_positions = end_positions.clamp(0, ignored_index)
209
+
210
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
211
+ start_loss = loss_fct(start_logits, start_positions)
212
+ end_loss = loss_fct(end_logits, end_positions)
213
+ total_loss = (start_loss + end_loss) / 2
214
+
215
+ if not return_dict:
216
+ output = (start_logits, end_logits) + outputs[2:]
217
+ return ((total_loss,) + output) if total_loss is not None else output
218
+
219
+ return QuestionAnsweringModelOutput(
220
+ loss=total_loss,
221
+ start_logits=start_logits,
222
+ end_logits=end_logits,
223
+ hidden_states=outputs.hidden_states,
224
+ attentions=outputs.attentions,
225
+ )
226
+
227
+ # roberta_model_checkpoint = "minhdang14902/Roberta_edu"
228
+ # roberta_tokenizer = AutoTokenizer.from_pretrained(roberta_model_checkpoint)
229
+ # roberta_model = MRCQuestionAnswering.from_pretrained(roberta_model_checkpoint)
230
+
231
+ # Cache loading Roberta model and tokenizer
232
+ @st.cache_data
233
+ def load_roberta_model():
234
+ model = MRCQuestionAnswering.from_pretrained('minhdang14902/Roberta_Law')
235
+ tokenizer = AutoTokenizer.from_pretrained('minhdang14902/Roberta_Law')
236
+ return model, tokenizer
237
+
238
+ roberta_model, roberta_tokenizer = load_roberta_model()
239
+
240
+
241
+ def chatRoberta(text):
242
+ label = label2id[chatbot_pipeline(text)[0]['label']]
243
+ response = intents['intents'][label]['responses']
244
+ print(response[0])
245
+
246
+ QA_input = {
247
+ 'question': text,
248
+ 'context': response[0]
249
+ }
250
+
251
+ # Tokenize input
252
+ encoded_input = tokenize_function(QA_input, roberta_tokenizer)
253
+
254
+ # Prepare batch samples
255
+ batch_samples = data_collator([encoded_input], roberta_tokenizer)
256
+
257
+ # Model prediction
258
+ roberta_model.eval()
259
+ with torch.no_grad():
260
+ inputs = {
261
+ 'input_ids': batch_samples['input_ids'],
262
+ 'attention_mask': batch_samples['attention_mask'],
263
+ 'words_lengths': batch_samples['words_lengths'],
264
+ }
265
+ outputs = roberta_model(**inputs)
266
+
267
+ # Extract answer
268
+ result = extract_answer([encoded_input], outputs, roberta_tokenizer)
269
+ context = response[0]
270
+ return result, context
271
+
272
+ def tokenize_function(example, tokenizer):
273
+ question_word = word_tokenize(example["question"])
274
+ context_word = word_tokenize(example["context"])
275
+
276
+ question_sub_words_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(w)) for w in question_word]
277
+ context_sub_words_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(w)) for w in context_word]
278
+ valid = True
279
+ if len([j for i in question_sub_words_ids + context_sub_words_ids for j in i]) > tokenizer.model_max_length - 1:
280
+ valid = False
281
+
282
+ question_sub_words_ids = [[tokenizer.bos_token_id]] + question_sub_words_ids + [[tokenizer.eos_token_id]]
283
+ context_sub_words_ids = context_sub_words_ids + [[tokenizer.eos_token_id]]
284
+
285
+ input_ids = [j for i in question_sub_words_ids + context_sub_words_ids for j in i]
286
+ if len(input_ids) > tokenizer.model_max_length:
287
+ valid = False
288
+
289
+ words_lengths = [len(item) for item in question_sub_words_ids + context_sub_words_ids]
290
+
291
+ return {
292
+ "input_ids": input_ids,
293
+ "words_lengths": words_lengths,
294
+ "valid": valid
295
+ }
296
+
297
+ def data_collator(samples, tokenizer):
298
+ if len(samples) == 0:
299
+ return {}
300
+
301
+ def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
302
+ size = max(v.size(0) for v in values)
303
+ res = values[0].new(len(values), size).fill_(pad_idx)
304
+
305
+ def copy_tensor(src, dst):
306
+ assert dst.numel() == src.numel()
307
+ if move_eos_to_beginning:
308
+ assert src[-1] == eos_idx
309
+ dst[0] = eos_idx
310
+ dst[1:] = src[:-1]
311
+ else:
312
+ dst.copy_(src)
313
+
314
+ for i, v in enumerate(values):
315
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
316
+ return res
317
+
318
+ input_ids = collate_tokens([torch.tensor(item['input_ids']) for item in samples], pad_idx=tokenizer.pad_token_id)
319
+ attention_mask = torch.zeros_like(input_ids)
320
+ for i in range(len(samples)):
321
+ attention_mask[i][:len(samples[i]['input_ids'])] = 1
322
+ words_lengths = collate_tokens([torch.tensor(item['words_lengths']) for item in samples], pad_idx=0)
323
+
324
+ batch_samples = {
325
+ 'input_ids': input_ids,
326
+ 'attention_mask': attention_mask,
327
+ 'words_lengths': words_lengths,
328
+ }
329
+
330
+ return batch_samples
331
+
332
+ def extract_answer(inputs, outputs, tokenizer):
333
+ plain_result = []
334
+ for sample_input, start_logit, end_logit in zip(inputs, outputs.start_logits, outputs.end_logits):
335
+ sample_words_length = sample_input['words_lengths']
336
+ input_ids = sample_input['input_ids']
337
+ answer_start = sum(sample_words_length[:torch.argmax(start_logit)])
338
+ answer_end = sum(sample_words_length[:torch.argmax(end_logit) + 1])
339
+
340
+ if answer_start <= answer_end:
341
+ answer = tokenizer.convert_tokens_to_string(
342
+ tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
343
+ if answer == tokenizer.bos_token:
344
+ answer = ''
345
+ else:
346
+ answer = ''
347
+
348
+ score_start = torch.max(torch.softmax(start_logit, dim=-1)).cpu().detach().numpy().tolist()
349
+ score_end = torch.max(torch.softmax(end_logit, dim=-1)).cpu().detach().numpy().tolist()
350
+ plain_result.append({
351
+ "answer": answer,
352
+ "score_start": score_start,
353
+ "score_end": score_end
354
+ })
355
+ return plain_result
356
+
357
+ # st.title("Chatbot Roberta")
358
+ # st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi.")
359
+ # text = st.text_input("User: ", key="input")
360
+
361
+ # if 'chat_history' not in st.session_state:
362
+ # st.session_state['chat_history'] = []
363
+
364
+
365
+ # def get_response(text):
366
+ # st.subheader("The Answer is:")
367
+ # st.write(text)
368
+ # answer, context = chatRoberta(text)
369
+ # result = answer[0]['answer']
370
+ # if result == "":
371
+ # return "Xin lỗi, tôi không thể tìm được đáp án phù hợp cho câu hỏi này ... Hãy thử trả lời bằng câu hỏi khác!"
372
+ # return result
373
+
374
+ # if st.button("Chat!"):
375
+ # st.session_state['chat_history'].append(("User", text))
376
+
377
+ # response = get_response(text)
378
+
379
+ # st.subheader("The Response is:")
380
+ # message = st.empty()
381
+ # result = ""
382
+ # for chunk in response:
383
+ # result += chunk
384
+ # message.markdown(result + "❚ ")
385
+ # message.markdown(result)
386
+ # st.session_state['chat_history'].append(("Bot", result))
387
+
388
+ # for i, (sender, message) in enumerate(st.session_state['chat_history']):
389
+ # if sender == "User":
390
+ # st.text_area(f"User:", value=message, height=100, max_chars=None, key=f"user_{i}")
391
+ # else:
392
+ # st.text_area(f"Bot:", value=message, height=100, max_chars=None, key=f"bot_{i}")
393
+
394
+ def get_response(text):
395
+ # Thay thế hàm này bằng model của bạn để lấy câu trả lời từ bot
396
+ # st.subheader("The Answer is:")
397
+ # st.write(text)
398
+ answer, context = chatRoberta(text)
399
+ result = answer[0]['answer']
400
+ if result == "":
401
+ return "Xin lỗi, tôi không thể tìm được đáp án phù hợp cho câu hỏi này ... Hãy thử trả lời bằng câu hỏi khác!"
402
+ return result
403
+
404
+ st.title("General Law Chatbot")
405
+
406
+ # Khởi tạo lịch sử tin nhắn
407
+ if "messages" not in st.session_state:
408
+ st.session_state.messages = []
409
+
410
+ # Hiển thị các tin nhắn từ lịch sử
411
+ for message in st.session_state.messages:
412
+ with st.chat_message(message["role"]):
413
+ st.markdown(message["content"])
414
+
415
+ # Nhận input từ người dùng
416
+ if prompt := st.chat_input("What is up?"):
417
+ # Thêm tin nhắn của người dùng vào lịch sử
418
+ st.session_state.messages.append({"role": "user", "content": prompt})
419
+
420
+ # Hiển thị tin nhắn của người dùng trong giao diện
421
+ with st.chat_message("user"):
422
+ st.markdown(prompt)
423
+
424
+ # Lấy câu trả lời từ bot
425
+ response = get_response(prompt)
426
+
427
+ # Hiển thị câu trả lời của bot trong giao diện
428
+ with st.chat_message("assistant"):
429
+ st.markdown(response)
430
+
431
+ # Thêm câu trả lời của bot vào lịch sử
432
+ st.session_state.messages.append({"role": "assistant", "content": response})