codeblacks commited on
Commit
c6cd033
1 Parent(s): 10da7cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -1,14 +1,14 @@
1
- from transformers import LongformerTokenizer, LongformerModel
2
  import torch
3
  import gradio as gr
4
 
5
- # Load the pre-trained Longformer model and tokenizer
6
- tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
7
- model = LongformerModel.from_pretrained('allenai/longformer-base-4096')
8
 
9
- def get_longformer_embeddings(sentences):
10
  # Tokenize input sentences
11
- inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True, max_length=2048)
12
  # Get embeddings
13
  with torch.no_grad():
14
  outputs = model(**inputs)
@@ -17,11 +17,11 @@ def get_longformer_embeddings(sentences):
17
 
18
  # Define the Gradio interface
19
  interface = gr.Interface(
20
- fn=get_longformer_embeddings, # Function to call
21
  inputs=gr.Textbox(lines=2, placeholder="Enter sentences here, one per line"), # Input component
22
  outputs=gr.JSON(), # Output component
23
- title="Sentence Embeddings with Longformer", # Interface title
24
- description="Enter sentences to get their embeddings with Longformer (up to 2048 tokens)." # Description
25
  )
26
 
27
  # Launch the interface
 
1
+ from transformers import AutoTokenizer, AutoModel
2
  import torch
3
  import gradio as gr
4
 
5
+ # Load the pre-trained paraphrase-mpnet-base-v2 model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-mpnet-base-v2')
7
+ model = AutoModel.from_pretrained('sentence-transformers/paraphrase-mpnet-base-v2')
8
 
9
+ def get_mpnet_embeddings(sentences):
10
  # Tokenize input sentences
11
+ inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True, max_length=512)
12
  # Get embeddings
13
  with torch.no_grad():
14
  outputs = model(**inputs)
 
17
 
18
  # Define the Gradio interface
19
  interface = gr.Interface(
20
+ fn=get_mpnet_embeddings, # Function to call
21
  inputs=gr.Textbox(lines=2, placeholder="Enter sentences here, one per line"), # Input component
22
  outputs=gr.JSON(), # Output component
23
+ title="Sentence Embeddings with MPNet", # Interface title
24
+ description="Enter sentences to get their embeddings with paraphrase-mpnet-base-v2 (up to 512 tokens)." # Description
25
  )
26
 
27
  # Launch the interface