Update app.py
Browse files
app.py
CHANGED
@@ -8,29 +8,29 @@ from transformers import (
|
|
8 |
)
|
9 |
|
10 |
# Function to load VQA pipeline
|
11 |
-
@st.
|
12 |
def load_vqa_pipeline():
|
13 |
return pipeline(task="visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
|
14 |
|
15 |
# Function to load BERT-based pipeline
|
16 |
-
@st.
|
17 |
def load_bbu_pipeline():
|
18 |
return pipeline(task="fill-mask", model="bert-base-uncased")
|
19 |
|
20 |
# Function to load Blenderbot model
|
21 |
-
@st.
|
22 |
def load_blenderbot_model():
|
23 |
model_name = "facebook/blenderbot-400M-distill"
|
24 |
tokenizer = BlenderbotTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
|
25 |
return BlenderbotForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=model_name)
|
26 |
|
27 |
# Function to load GPT-2 pipeline
|
28 |
-
@st.
|
29 |
def load_gpt2_pipeline():
|
30 |
return pipeline(task="text-generation", model="gpt2")
|
31 |
|
32 |
# Function to load BERTopic models
|
33 |
-
@st.
|
34 |
def load_topic_models():
|
35 |
topic_model_1 = BERTopic.load(path="davanstrien/chat_topics")
|
36 |
topic_model_2 = BERTopic.load(path="MaartenGr/BERTopic_ArXiv")
|
|
|
8 |
)
|
9 |
|
10 |
# Function to load VQA pipeline
|
11 |
+
@st.cache(allow_output_mutation=True)
|
12 |
def load_vqa_pipeline():
|
13 |
return pipeline(task="visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
|
14 |
|
15 |
# Function to load BERT-based pipeline
|
16 |
+
@st.cache(allow_output_mutation=True)
|
17 |
def load_bbu_pipeline():
|
18 |
return pipeline(task="fill-mask", model="bert-base-uncased")
|
19 |
|
20 |
# Function to load Blenderbot model
|
21 |
+
@st.cache(allow_output_mutation=True)
|
22 |
def load_blenderbot_model():
|
23 |
model_name = "facebook/blenderbot-400M-distill"
|
24 |
tokenizer = BlenderbotTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
|
25 |
return BlenderbotForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=model_name)
|
26 |
|
27 |
# Function to load GPT-2 pipeline
|
28 |
+
@st.cache(allow_output_mutation=True)
|
29 |
def load_gpt2_pipeline():
|
30 |
return pipeline(task="text-generation", model="gpt2")
|
31 |
|
32 |
# Function to load BERTopic models
|
33 |
+
@st.cache(allow_output_mutation=True)
|
34 |
def load_topic_models():
|
35 |
topic_model_1 = BERTopic.load(path="davanstrien/chat_topics")
|
36 |
topic_model_2 = BERTopic.load(path="MaartenGr/BERTopic_ArXiv")
|