minhdang14902 commited on
Commit
7c8c302
·
verified ·
1 Parent(s): 9718bbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -359,14 +359,11 @@ def extract_answer(inputs, outputs, tokenizer):
359
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
360
 
361
  INPUT_MAX_LEN = 128 # Adjusted input length
362
- OUTPUT_MAX_LEN = 512 # Adjusted output length
363
 
364
- @st.cache_data
365
- def download_model_name():
366
- MODEL_NAME = "VietAI/vit5-base"
367
- return MODEL_NAME
368
 
369
- MODEL_NAME = download_model_name()
370
 
371
  tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
372
 
@@ -403,16 +400,14 @@ class T5Model(pl.LightningModule):
403
  return AdamW(self.parameters(), lr=0.0001)
404
 
405
 
406
- @st.cache_data
407
- def load_t5():
408
- train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
409
- train_model.freeze()
410
- return train_model
411
 
412
- train_model = load_t5()
 
 
413
 
414
 
415
  def generate_question(question):
 
416
  inputs_encoding = tokenizer(
417
  question,
418
  add_special_tokens=True,
@@ -423,6 +418,7 @@ def generate_question(question):
423
  return_tensors="pt"
424
  ).to(DEVICE)
425
 
 
426
  generate_ids = train_model.model.generate(
427
  input_ids=inputs_encoding["input_ids"],
428
  attention_mask=inputs_encoding["attention_mask"],
@@ -432,7 +428,8 @@ def generate_question(question):
432
  no_repeat_ngram_size=2,
433
  early_stopping=True,
434
  )
435
-
 
436
  preds = [
437
  tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
438
  for gen_id in generate_ids
 
359
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
360
 
361
  INPUT_MAX_LEN = 128 # Adjusted input length
362
+ OUTPUT_MAX_LEN = 256 # Adjusted output length
363
 
364
+ MODEL_NAME = "VietAI/vit5-base"
365
+
 
 
366
 
 
367
 
368
  tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
369
 
 
400
  return AdamW(self.parameters(), lr=0.0001)
401
 
402
 
 
 
 
 
 
403
 
404
+ train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
405
+ train_model.freeze()
406
+
407
 
408
 
409
  def generate_question(question):
410
+ print("tokenizer")
411
  inputs_encoding = tokenizer(
412
  question,
413
  add_special_tokens=True,
 
418
  return_tensors="pt"
419
  ).to(DEVICE)
420
 
421
+ print("generate id")
422
  generate_ids = train_model.model.generate(
423
  input_ids=inputs_encoding["input_ids"],
424
  attention_mask=inputs_encoding["attention_mask"],
 
428
  no_repeat_ngram_size=2,
429
  early_stopping=True,
430
  )
431
+
432
+ print("decode")
433
  preds = [
434
  tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
435
  for gen_id in generate_ids