Intradiction commited on
Commit
3a5b288
1 Parent(s): 49d3c2c

Update SA LORA pipe with newly trained model

Browse files

Now uses bert-base-uncased trained on 3k samples, better scores than previous

Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModel
3
  from peft.auto import AutoPeftModelForSequenceClassification
4
  from tensorboard.backend.event_processing import event_accumulator
5
  from peft import PeftModel
6
  import plotly.express as px
7
  import pandas as pd
8
 
9
- tokenizer1 = AutoTokenizer.from_pretrained("albert-base-v2")
10
 
11
- loraModel = AutoPeftModelForSequenceClassification.from_pretrained("Intradiction/text_classification_WithLORA")
12
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
13
 
 
 
 
14
  tokenizer2 = AutoTokenizer.from_pretrained("microsoft/deberta-v3-xsmall")
15
  # base_model = AutoModel.from_pretrained("microsoft/deberta-v3-xsmall")
16
  # peft_model_id = "rajevan123/STS-Lora-Fine-Tuning-Capstone-Deberta-small"
@@ -19,16 +19,22 @@ tokenizer2 = AutoTokenizer.from_pretrained("microsoft/deberta-v3-xsmall")
19
 
20
 
21
  # Handle calls to DistilBERT------------------------------------------
 
 
 
 
 
 
22
  distilBERTUntrained_pipe = pipeline("sentiment-analysis", model="bert-base-uncased")
23
  distilBERTnoLORA_pipe = pipeline(model="Intradiction/text_classification_NoLORA")
24
- distilBERTwithLORA_pipe = pipeline("sentiment-analysis", model=loraModel, tokenizer=tokenizer)
25
 
26
  #text class models
27
  def distilBERTnoLORA_fn(text):
28
  return distilBERTnoLORA_pipe(text)
29
 
30
  def distilBERTwithLORA_fn(text):
31
- return distilBERTwithLORA_pipe(text)
32
 
33
  def distilBERTUntrained_fn(text):
34
  return distilBERTUntrained_pipe(text)
@@ -425,7 +431,7 @@ with gr.Blocks(
425
  btnSTSStats.click(fn=displayMetricStatsTextSTSNoLora, outputs=STSNoLoraStats)
426
  btnSTSStats.click(fn=displayMetricStatsTextSTSLora, outputs=STSLoraStats)
427
 
428
- with gr.Tab("More informatioen"):
429
  gr.Markdown("stuff to add")
430
 
431
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer, AutoModel, BertForSequenceClassification
3
  from peft.auto import AutoPeftModelForSequenceClassification
4
  from tensorboard.backend.event_processing import event_accumulator
5
  from peft import PeftModel
6
  import plotly.express as px
7
  import pandas as pd
8
 
 
9
 
 
 
10
 
11
+ loraModel = AutoPeftModelForSequenceClassification.from_pretrained("Intradiction/text_classification_WithLORA")
12
+ #tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
13
+ tokenizer1 = AutoTokenizer.from_pretrained("albert-base-v2")
14
  tokenizer2 = AutoTokenizer.from_pretrained("microsoft/deberta-v3-xsmall")
15
  # base_model = AutoModel.from_pretrained("microsoft/deberta-v3-xsmall")
16
  # peft_model_id = "rajevan123/STS-Lora-Fine-Tuning-Capstone-Deberta-small"
 
19
 
20
 
21
  # Handle calls to DistilBERT------------------------------------------
22
+ base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
23
+ peft_model_id = "Intradiction/BERT-SA-LORA"
24
+ model = PeftModel.from_pretrained(model=base_model, model_id=peft_model_id)
25
+ sa_merged_model = model.merge_and_unload()
26
+ bbu_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
27
+
28
  distilBERTUntrained_pipe = pipeline("sentiment-analysis", model="bert-base-uncased")
29
  distilBERTnoLORA_pipe = pipeline(model="Intradiction/text_classification_NoLORA")
30
+ SentimentAnalysis_LORA_pipe = pipeline("sentiment-analysis", model=sa_merged_model, tokenizer=bbu_tokenizer)
31
 
32
  #text class models
33
  def distilBERTnoLORA_fn(text):
34
  return distilBERTnoLORA_pipe(text)
35
 
36
  def distilBERTwithLORA_fn(text):
37
+ return SentimentAnalysis_LORA_pipe(text)
38
 
39
  def distilBERTUntrained_fn(text):
40
  return distilBERTUntrained_pipe(text)
 
431
  btnSTSStats.click(fn=displayMetricStatsTextSTSNoLora, outputs=STSNoLoraStats)
432
  btnSTSStats.click(fn=displayMetricStatsTextSTSLora, outputs=STSLoraStats)
433
 
434
+ with gr.Tab("More information"):
435
  gr.Markdown("stuff to add")
436
 
437