CatoEr commited on
Commit
435431a
1 Parent(s): 3bf30df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -34
app.py CHANGED
@@ -38,47 +38,39 @@ model_race.to(device)
38
  model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
39
 
40
 
41
- def predict(text):
42
- sentences = [
43
- text
44
- ]
45
 
46
- tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True)
47
 
48
- encoded_sentences = tokenizer(
49
- sentences,
50
- padding=True,
51
- truncation=True,
52
- return_tensors='pt',
53
- max_length=128,
54
- )
 
 
 
55
 
56
- input_ids = encoded_sentences["input_ids"].to(device)
57
- attention_mask = encoded_sentences["attention_mask"].to(device)
58
 
59
- model_race.eval()
60
- with torch.no_grad():
61
- outputs = model_race(input_ids, attention_mask)
62
- probs = torch.nn.functional.softmax(outputs, dim=1)
63
- predictions = torch.argmax(outputs, dim=1)
64
- predictions = predictions.cpu().numpy()
65
 
66
- output_string = ""
67
- for i, prob in enumerate(probs[0]):
68
- print(f"{labels[i]}: %{round(prob.item() * 100, 2)}")
69
- output_string += f"{labels[i]}: %{round(prob.item() * 100, 2)}\n"
70
 
71
- print(labels[predictions[0]])
72
- output_string += f"Predicted as: {labels[predictions[0]]}"
 
 
 
 
 
 
 
73
 
74
- return output_string
75
-
76
-
77
- demo = gr.Interface(
78
- fn=predict,
79
- inputs=["text"],
80
- outputs=["text"],
81
- )
82
 
83
  demo.launch()
84
 
 
 
38
  model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
39
 
40
 
41
+ max_textboxes = 10
 
 
 
42
 
 
43
 
44
+ def update_textboxes(k):
45
+ components = []
46
+ if k is None:
47
+ k = 0
48
+ for i in range(max_textboxes):
49
+ if i < k:
50
+ components.append(gr.update(visible=True))
51
+ else:
52
+ components.append(gr.update(visible=False))
53
+ return components
54
 
 
 
55
 
56
+ def clear_textboxes():
57
+ return [gr.update(value='') for _ in range(max_textboxes)]
 
 
 
 
58
 
 
 
 
 
59
 
60
+ with gr.Blocks() as demo:
61
+ with gr.Row():
62
+ with gr.Column(scale=1):
63
+ s = gr.Slider(1, max_textboxes, value=1, step=1, label="How many tweets do you want to enter:")
64
+ textboxes = [gr.Textbox(label=f"Tweet {i + 1}", visible=(i == 0)) for i in range(max_textboxes)]
65
+ s.change(fn=update_textboxes, inputs=s, outputs=textboxes)
66
+ btn = gr.Button("Predict")
67
+ with gr.Column(scale=1):
68
+ output = gr.Textbox(label="Profile of User")
69
 
70
+ btn.click(fn=predict, inputs=textboxes, outputs=output)
71
+ btn_clear = gr.Button("Clear")
72
+ btn_clear.click(fn=clear_textboxes, outputs=textboxes)
 
 
 
 
 
73
 
74
  demo.launch()
75
 
76
+