Spaces:
Running
Running
add models
Browse files
models/__pycache__/distilbart_cnn_12_6.cpython-310.pyc
ADDED
Binary file (1.33 kB). View file
|
|
models/__pycache__/t5_small_medium_title_generation.cpython-310.pyc
ADDED
Binary file (1.08 kB). View file
|
|
models/distilbart_cnn_12_6.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
+
|
3 |
+
# Load the DistilBART-CNN-12-6 model
|
4 |
+
# loading the model outside of the function makes it faster
|
5 |
+
SUMMARIZATION_MODEL = "sshleifer/distilbart-cnn-12-6"
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
|
7 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZATION_MODEL, device_map="cuda:0")
|
8 |
+
|
9 |
+
def summarize(text, max_len=20):
|
10 |
+
"""
|
11 |
+
Summarizes the given text using the DistilBART-CNN-12-6 model.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
text (str): The text to be summarized.
|
15 |
+
max_length (int, optional): The maximum length of the summary. Defaults to 20.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
str: The summarized text.
|
19 |
+
"""
|
20 |
+
|
21 |
+
inputs = tokenizer(text,
|
22 |
+
return_tensors="pt",
|
23 |
+
max_length=max_len,
|
24 |
+
truncation=True,
|
25 |
+
).input_ids
|
26 |
+
|
27 |
+
# Move the inputs tensor to the same device as the model tensor
|
28 |
+
inputs = inputs.cuda()
|
29 |
+
|
30 |
+
outputs = model.generate(inputs,
|
31 |
+
max_new_tokens=100,
|
32 |
+
num_beams=8,
|
33 |
+
length_penalty=0.2,
|
34 |
+
early_stopping=False
|
35 |
+
)
|
36 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
37 |
+
|
38 |
+
def summarizePipeline(text):
|
39 |
+
from transformers import pipeline
|
40 |
+
|
41 |
+
pipe = pipeline(
|
42 |
+
"summarization",
|
43 |
+
model=model,
|
44 |
+
tokenizer=tokenizer,
|
45 |
+
)
|
46 |
+
|
47 |
+
return pipe(text)[0]["summary_text"]
|
models/flan_t5_xl.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, accelerate
|
2 |
+
from langchain.llms import HuggingFacePipeline
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
|
4 |
+
|
5 |
+
model_id = 'google/flan-t5-large'
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
7 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True)
|
8 |
+
|
9 |
+
pipe = pipeline(
|
10 |
+
"text2text-generation",
|
11 |
+
model=model,
|
12 |
+
tokenizer=tokenizer,
|
13 |
+
max_length=512,
|
14 |
+
)
|
15 |
+
|
16 |
+
local_llm = HuggingFacePipeline(Pipeline=pipe)
|
models/t5_small_medium_title_generation.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def t5model(prompt: str) -> str:
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-small-medium-title-generation")
|
7 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-small-medium-title-generation", device_map="cuda:0", torch_dtype=torch.float16)
|
8 |
+
inputs = tokenizer(
|
9 |
+
["summarize:" + prompt],
|
10 |
+
return_tensors="pt",
|
11 |
+
max_length=1024,
|
12 |
+
truncation=True
|
13 |
+
)
|
14 |
+
|
15 |
+
# Move the inputs tensor to the same device as the model tensor
|
16 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
17 |
+
|
18 |
+
outputs = model.generate(
|
19 |
+
**inputs,
|
20 |
+
num_beams=8,
|
21 |
+
do_sample=True,
|
22 |
+
min_length=8,
|
23 |
+
max_length=15
|
24 |
+
)
|
25 |
+
|
26 |
+
decoded_output = tokenizer.batch_decode(
|
27 |
+
outputs, skip_special_tokens=True
|
28 |
+
)[0]
|
29 |
+
|
30 |
+
return decoded_output
|
requirements.txt
CHANGED
@@ -6,3 +6,4 @@ accelerate
|
|
6 |
langchain
|
7 |
yt-dlp
|
8 |
rich
|
|
|
|
6 |
langchain
|
7 |
yt-dlp
|
8 |
rich
|
9 |
+
gradio
|