reynardryanda commited on
Commit
04e4555
·
1 Parent(s): 805fdc9

another test

Browse files
Files changed (2) hide show
  1. app.py +15 -22
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,26 +1,19 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
4
 
5
- model_name = "cross-encoder/multi-nli-xlm-r-100"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
 
9
- def generate_prediction(input_text):
10
- input_ids = tokenizer.encode(input_text, truncation=True, padding=True, return_tensors='pt')
11
- outputs = model(input_ids)
12
- predicted_label = torch.argmax(outputs.logits)
13
- label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
14
- predicted_label_text = label_map[predicted_label.item()]
15
- return predicted_label_text
 
 
 
16
 
17
- input_text = gr.inputs.Textbox(label="Input text")
18
- output_text = gr.outputs.Textbox(label="Output text")
19
-
20
- gr.Interface(
21
- generate_prediction,
22
- inputs=input_text,
23
- outputs=output_text,
24
- title="Text Classifier",
25
- description="A Hugging Face cross-encoder model for text classification.",
26
- ).launch()
 
1
  import gradio as gr
2
+ from sentence_transformers import SentenceTransformer, util
 
3
 
4
+ model_name = "cross-encoder/ms-marco-TinyBERT-L-6"
5
+ model = SentenceTransformer(model_name)
 
6
 
7
+ def classify_text(input_text):
8
+ premise = "The cat is on the mat."
9
+ input_embedding = model.encode([input_text, premise], convert_to_tensor=True)
10
+ similarity_score = util.pytorch_cos_sim(input_embedding[0], input_embedding[1])[0][0]
11
+ if similarity_score > 0.7:
12
+ return "entailment"
13
+ elif similarity_score < 0.3:
14
+ return "contradiction"
15
+ else:
16
+ return "neutral"
17
 
18
+ iface = gr.Interface(fn=classify_text, inputs="text", outputs="text", title="Cross-Encoder with SentenceTransformer")
19
+ iface.launch()
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1 @@
1
- transformers
2
- torch
 
1
+ sentence_transformers