Spaces:
Sleeping
Sleeping
minhdang14902
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -359,14 +359,11 @@ def extract_answer(inputs, outputs, tokenizer):
|
|
359 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
360 |
|
361 |
INPUT_MAX_LEN = 128 # Adjusted input length
|
362 |
-
OUTPUT_MAX_LEN =
|
363 |
|
364 |
-
|
365 |
-
|
366 |
-
MODEL_NAME = "VietAI/vit5-base"
|
367 |
-
return MODEL_NAME
|
368 |
|
369 |
-
MODEL_NAME = download_model_name()
|
370 |
|
371 |
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
|
372 |
|
@@ -403,16 +400,14 @@ class T5Model(pl.LightningModule):
|
|
403 |
return AdamW(self.parameters(), lr=0.0001)
|
404 |
|
405 |
|
406 |
-
@st.cache_data
|
407 |
-
def load_t5():
|
408 |
-
train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
|
409 |
-
train_model.freeze()
|
410 |
-
return train_model
|
411 |
|
412 |
-
train_model =
|
|
|
|
|
413 |
|
414 |
|
415 |
def generate_question(question):
|
|
|
416 |
inputs_encoding = tokenizer(
|
417 |
question,
|
418 |
add_special_tokens=True,
|
@@ -423,6 +418,7 @@ def generate_question(question):
|
|
423 |
return_tensors="pt"
|
424 |
).to(DEVICE)
|
425 |
|
|
|
426 |
generate_ids = train_model.model.generate(
|
427 |
input_ids=inputs_encoding["input_ids"],
|
428 |
attention_mask=inputs_encoding["attention_mask"],
|
@@ -432,7 +428,8 @@ def generate_question(question):
|
|
432 |
no_repeat_ngram_size=2,
|
433 |
early_stopping=True,
|
434 |
)
|
435 |
-
|
|
|
436 |
preds = [
|
437 |
tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
438 |
for gen_id in generate_ids
|
|
|
359 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
360 |
|
361 |
INPUT_MAX_LEN = 128 # Adjusted input length
|
362 |
+
OUTPUT_MAX_LEN = 256 # Adjusted output length
|
363 |
|
364 |
+
MODEL_NAME = "VietAI/vit5-base"
|
365 |
+
|
|
|
|
|
366 |
|
|
|
367 |
|
368 |
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
|
369 |
|
|
|
400 |
return AdamW(self.parameters(), lr=0.0001)
|
401 |
|
402 |
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
+
train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
|
405 |
+
train_model.freeze()
|
406 |
+
|
407 |
|
408 |
|
409 |
def generate_question(question):
|
410 |
+
print("tokenizer")
|
411 |
inputs_encoding = tokenizer(
|
412 |
question,
|
413 |
add_special_tokens=True,
|
|
|
418 |
return_tensors="pt"
|
419 |
).to(DEVICE)
|
420 |
|
421 |
+
print("generate id")
|
422 |
generate_ids = train_model.model.generate(
|
423 |
input_ids=inputs_encoding["input_ids"],
|
424 |
attention_mask=inputs_encoding["attention_mask"],
|
|
|
428 |
no_repeat_ngram_size=2,
|
429 |
early_stopping=True,
|
430 |
)
|
431 |
+
|
432 |
+
print("decode")
|
433 |
preds = [
|
434 |
tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
435 |
for gen_id in generate_ids
|