minhdang14902 commited on
Commit
8e66f46
·
verified ·
1 Parent(s): d65d545

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -361,7 +361,11 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
361
  INPUT_MAX_LEN = 128 # Adjusted input length
362
  OUTPUT_MAX_LEN = 256 # Adjusted output length
363
 
364
- MODEL_NAME = "VietAI/vit5-base"
 
 
 
 
365
 
366
  tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
367
 
@@ -397,8 +401,13 @@ class T5Model(pl.LightningModule):
397
  def configure_optimizers(self):
398
  return AdamW(self.parameters(), lr=0.0001)
399
 
400
- train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
401
- train_model.freeze()
 
 
 
 
 
402
 
403
 
404
  def generate_question(question):
@@ -428,6 +437,7 @@ def generate_question(question):
428
  ]
429
 
430
  response = " ".join(preds[0].split())
 
431
  return response
432
 
433
  # st.title("Chatbot Roberta")
 
361
  INPUT_MAX_LEN = 128 # Adjusted input length
362
  OUTPUT_MAX_LEN = 256 # Adjusted output length
363
 
364
+ @st.cache_data
365
+ def download_model_name():
366
+ MODEL_NAME = "VietAI/vit5-base"
367
+
368
+ download_model_name()
369
 
370
  tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
371
 
 
401
  def configure_optimizers(self):
402
  return AdamW(self.parameters(), lr=0.0001)
403
 
404
+
405
+ @st.cache_data
406
+ def load_t5():
407
+ train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
408
+ train_model.freeze()
409
+
410
+ load_t5()
411
 
412
 
413
  def generate_question(question):
 
437
  ]
438
 
439
  response = " ".join(preds[0].split())
440
+ print('T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5T5')
441
  return response
442
 
443
  # st.title("Chatbot Roberta")