ljyflores commited on
Commit
851657f
·
1 Parent(s): ed8d715

Update app

Browse files
Files changed (1) hide show
  1. app.py +7 -53
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
 
3
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  dataset_example_dictionary = {
6
  "cochrane": [
@@ -33,28 +33,9 @@ def load(dataset_name, model_variant_name):
33
  model=model_dictionary[dataset_name][model_variant_name]
34
  )
35
 
 
36
  def predict(text, pipeline):
37
- return pipeline(text, max_length=768)
38
-
39
- # @st.cache_resource
40
- # def load(dataset_name, model_variant_name):
41
- # tokenizer = AutoTokenizer.from_pretrained(model_dictionary[dataset_name][model_variant_name])
42
- # model = AutoModelForSeq2SeqLM.from_pretrained(model_dictionary[dataset_name][model_variant_name])
43
- # return pipeline("text2text-generation", model="ljyflores/bart_xsum_cochrane_finetune")
44
-
45
- # def encode(text, _tokenizer):
46
- # """This function takes a batch of samples,
47
- # and tokenizes them into IDs for the model."""
48
- # # Tokenize the Findings (the input)
49
- # model_inputs = _tokenizer(
50
- # [text], padding=True, truncation=True, return_tensors="pt"
51
- # )
52
- # return model_inputs
53
-
54
- # def predict(text, model, tokenizer):
55
- # model_inputs = encode(text, tokenizer)
56
- # model_outputs = model.generate(**model_inputs, max_length=768).detach()
57
- # return tokenizer.batch_decode(model_outputs)
58
 
59
  def clean(s):
60
  return s.replace("<s>","").replace("</s>","")
@@ -77,38 +58,11 @@ st.text_area("Text to Simplify:", key="text", height=275)
77
 
78
  # Load model and run inference
79
  if st.button("Simplify!"):
80
- # # Number 1
81
- # # tokenizer_baseline, model_baseline = load(dataset_option, "baseline")
82
- # # model_outputs_baseline = predict(st.session_state.text, model_baseline, tokenizer_baseline)[0]
83
-
84
- # pipeline_baseline = load(dataset_option, "baseline")
85
- # # model_outputs_baseline = predict(st.session_state.text, pipeline_baseline)[0]["generated_text"]
86
 
87
- # # pipeline_baseline = pipeline(
88
- # # "text2text-generation",
89
- # # model=model_dictionary[dataset_option]["baseline"]
90
- # # )
91
- # model_outputs_baseline = pipeline_baseline(
92
- # st.session_state.text,
93
- # max_length=768,
94
- # do_sample=False
95
- # )
96
- # st.write(f"Baseline: {clean(model_outputs_baseline)}")
97
 
98
- # # Number 2
99
- # tokenizer_ul, model_ul = load(dataset_option, "ul")
100
- # model_outputs_ul = predict(st.session_state.text, model_ul, tokenizer_ul)[0]
101
-
102
  pipeline_ul = load(dataset_option, "ul")
103
- # model_outputs_ul = predict(st.session_state.text, pipeline_ul)[0]["generated_text"]
104
-
105
- # pipeline_ul = pipeline(
106
- # "text2text-generation",
107
- # model=model_dictionary[dataset_option]["ul"]
108
- # )
109
- model_outputs_ul = pipeline_ul(
110
- st.session_state.text,
111
- max_length=768,
112
- do_sample=False
113
- )
114
  st.write(f"Unlikelihood Learning: {clean(model_outputs_ul)}")
 
1
  import streamlit as st
2
 
3
+ from transformers import pipeline
4
 
5
  dataset_example_dictionary = {
6
  "cochrane": [
 
33
  model=model_dictionary[dataset_name][model_variant_name]
34
  )
35
 
36
+ @st.cache_data()
37
  def predict(text, pipeline):
38
+ return pipeline(text, max_length=768, do_sample=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def clean(s):
41
  return s.replace("<s>","").replace("</s>","")
 
58
 
59
  # Load model and run inference
60
  if st.button("Simplify!"):
 
 
 
 
 
 
61
 
62
+ pipeline_baseline = load(dataset_option, "baseline")
63
+ model_outputs_baseline = predict(st.session_state.text, pipeline_baseline)[0]["generated_text"]
64
+ st.write(f"Baseline: {clean(model_outputs_baseline)}")
 
 
 
 
 
 
 
65
 
 
 
 
 
66
  pipeline_ul = load(dataset_option, "ul")
67
+ model_outputs_ul = predict(st.session_state.text, pipeline_ul)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
68
  st.write(f"Unlikelihood Learning: {clean(model_outputs_ul)}")