IBounhas commited on
Commit
d17a9c4
·
1 Parent(s): d445114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -65
app.py CHANGED
@@ -1,54 +1,54 @@
1
- import gradio as gr
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
- import torch
4
- from sentence_transformers import SentenceTransformer, models
5
- param_max_length=256
6
 
7
- # Define a function that takes a text input and returns the result
8
- def analyze_text(input):
9
- # Your processing or model inference code here
10
- result = predict_similarity(input)
11
- return result
12
 
13
- param_model_name="CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth"
14
 
15
- tokenizer = AutoTokenizer.from_pretrained(param_model_name)
16
 
17
- class BertForSTS(torch.nn.Module):
18
 
19
- def __init__(self):
20
- super(BertForSTS, self).__init__()
21
- #self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
22
- #self.bert = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth")
23
- self.bert = models.Transformer(param_model_name, max_seq_length=param_max_length)
24
 
25
 
26
- dimension= self.bert.get_word_embedding_dimension()
27
- #print(dimension)
28
- self.pooling_layer = models.Pooling(dimension)
29
- self.dropout = torch.nn.Dropout(0.1)
30
 
31
- # relu activation function
32
- self.relu = torch.nn.ReLU()
33
 
34
- # dense layer 1
35
- self.fc1 = torch.nn.Linear(dimension,512)
36
 
37
- # dense layer 2 (Output layer)
38
- self.fc2 = torch.nn.Linear(512,512)
39
- #self.pooling_layer = models.Pooling(self.bert.config.hidden_size)
40
- self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1])
41
- #self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1, self.relu, self.dropout,self.fc2])
42
- def forward(self, input_data):
43
- #print(input_data)
44
- x=self.bert(input_data)
45
- x=self.pooling_layer(x)
46
- x=self.fc1(x['sentence_embedding'])
47
- x = self.relu(x)
48
- x = self.dropout(x)
49
- #x = self.fc2(x)
50
 
51
- return x
52
 
53
  import requests
54
 
@@ -57,30 +57,31 @@ response = requests.get(file_url)
57
 
58
  with open("model.pt", "wb") as f:
59
  f.write(response.content)
 
60
  f.close()
61
 
62
- model_load_path = "model.pt"
63
- model = BertForSTS()
64
- model.load_state_dict(torch.load(model_load_path))
65
- model.to(device)
66
-
67
- def predict_similarity(sentence_pair):
68
- test_input = tokenizer(sentence_pair, padding='max_length', max_length = param_max_length, truncation=True, return_tensors="pt").to(device)
69
- test_input['input_ids'] = test_input['input_ids']
70
- print(test_input['input_ids'])
71
- test_input['attention_mask'] = test_input['attention_mask']
72
- del test_input['token_type_ids']
73
- output = model(test_input)
74
- sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()*2-1
75
-
76
- return sim
77
-
78
- # Create a Gradio interface with a text input zone
79
- iface = gr.Interface(
80
- fn=analyze_text, # The function to be called with user input
81
- inputs=[gr.Textbox(), gr.Textbox()],
82
- outputs="text" # Display the result as text
83
- )
84
-
85
- # # Launch the Gradio interface
86
- iface.launch()
 
1
+ # import gradio as gr
2
+ # from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ # import torch
4
+ # from sentence_transformers import SentenceTransformer, models
5
+ # param_max_length=256
6
 
7
+ # # Define a function that takes a text input and returns the result
8
+ # def analyze_text(input):
9
+ # # Your processing or model inference code here
10
+ # result = predict_similarity(input)
11
+ # return result
12
 
13
+ # param_model_name="CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth"
14
 
15
+ # tokenizer = AutoTokenizer.from_pretrained(param_model_name)
16
 
17
+ # class BertForSTS(torch.nn.Module):
18
 
19
+ # def __init__(self):
20
+ # super(BertForSTS, self).__init__()
21
+ # #self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
22
+ # #self.bert = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth")
23
+ # self.bert = models.Transformer(param_model_name, max_seq_length=param_max_length)
24
 
25
 
26
+ # dimension= self.bert.get_word_embedding_dimension()
27
+ # #print(dimension)
28
+ # self.pooling_layer = models.Pooling(dimension)
29
+ # self.dropout = torch.nn.Dropout(0.1)
30
 
31
+ # # relu activation function
32
+ # self.relu = torch.nn.ReLU()
33
 
34
+ # # dense layer 1
35
+ # self.fc1 = torch.nn.Linear(dimension,512)
36
 
37
+ # # dense layer 2 (Output layer)
38
+ # self.fc2 = torch.nn.Linear(512,512)
39
+ # #self.pooling_layer = models.Pooling(self.bert.config.hidden_size)
40
+ # self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1])
41
+ # #self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1, self.relu, self.dropout,self.fc2])
42
+ # def forward(self, input_data):
43
+ # #print(input_data)
44
+ # x=self.bert(input_data)
45
+ # x=self.pooling_layer(x)
46
+ # x=self.fc1(x['sentence_embedding'])
47
+ # x = self.relu(x)
48
+ # x = self.dropout(x)
49
+ # #x = self.fc2(x)
50
 
51
+ # return x
52
 
53
  import requests
54
 
 
57
 
58
  with open("model.pt", "wb") as f:
59
  f.write(response.content)
60
+ print(response.content)
61
  f.close()
62
 
63
+ # model_load_path = "model.pt"
64
+ # model = BertForSTS()
65
+ # model.load_state_dict(torch.load(model_load_path))
66
+ # model.to(device)
67
+
68
+ # def predict_similarity(sentence_pair):
69
+ # test_input = tokenizer(sentence_pair, padding='max_length', max_length = param_max_length, truncation=True, return_tensors="pt").to(device)
70
+ # test_input['input_ids'] = test_input['input_ids']
71
+ # print(test_input['input_ids'])
72
+ # test_input['attention_mask'] = test_input['attention_mask']
73
+ # del test_input['token_type_ids']
74
+ # output = model(test_input)
75
+ # sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()*2-1
76
+
77
+ # return sim
78
+
79
+ # # Create a Gradio interface with a text input zone
80
+ # iface = gr.Interface(
81
+ # fn=analyze_text, # The function to be called with user input
82
+ # inputs=[gr.Textbox(), gr.Textbox()],
83
+ # outputs="text" # Display the result as text
84
+ # )
85
+
86
+ # # # Launch the Gradio interface
87
+ # iface.launch()