ibraheemmoosa commited on
Commit
a2dbd3d
·
verified ·
1 Parent(s): cae2767

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -1
app.py CHANGED
@@ -1,8 +1,44 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def predict(source, translation1, translation2):
4
  model_input = "Source: {} Translation 0: {} Translation 1: {}".format(source, translation1, translation2)
5
- return {'Translation 1': 0.70, 'Translation 2': 0.30}
 
 
 
 
 
6
 
7
  source_textbox = gr.Textbox(label="Source", info="Source Sentence", value="Le chat est sur la tapis.")
8
  translation1_textbox = gr.Textbox(label="Translation 1", info="Translation 1", value="The cat is on the bed.")
 
1
  import gradio as gr
2
+ from transformers import AutoModel, PretrainedConfig, PreTrainedModel, MT5EncoderModel
3
+
4
+ class MTRankerConfig(PretrainedConfig):
5
+
6
+ def __init__(self, backbone='google/mt5-base', **kwargs):
7
+ self.backbone = backbone
8
+ super().__init__(**kwargs)
9
+
10
+
11
+
12
+ class MTRanker(PreTrainedModel):
13
+ config_class = MTRankerConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.encoder = MT5EncoderModel.from_pretrained(config.backbone)
18
+ self.num_classes = 2
19
+ self.classifier = torch.nn.Linear(self.encoder.config.hidden_size, self.num_classes)
20
+
21
+ def forward(self, input_ids, attention_mask):
22
+ encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
23
+ seq_lengths = torch.sum(attention_mask, keepdim=True, dim=1)
24
+ pooled_hidden_state = torch.sum(encoder_output * attention_mask.unsqueeze(-1).expand(-1, -1, self.encoder.config.hidden_size), dim=1)
25
+ pooled_hidden_state /= seq_lengths
26
+ prediction_logit = self.classifier(pooled_hidden_state)
27
+ return prediction_logit
28
+
29
+
30
+ config = MTRankerConfig(backbone='google/mt5-base')
31
+ tokenizer = AutoTokenizer.from_pretrained(config.backbone)
32
+ model = MTRanker.from_pretrained('ibraheemmoosa/mt-ranker-base')
33
 
34
  def predict(source, translation1, translation2):
35
  model_input = "Source: {} Translation 0: {} Translation 1: {}".format(source, translation1, translation2)
36
+ inputs = tokenizer([model_input], max_length=512, padding='max_length', truncation=True, return_tensors='pt')
37
+ with autocast(dtype=torch.bfloat16):
38
+ logits = model(inputs.input_ids, inputs.attention_mask)
39
+ output_scores = torch.softmax(logits, dim=1)
40
+ output_scores = output_scores[0]
41
+ return {'Translation 1': output_scores[0], 'Translation 2': output_scores[1]}
42
 
43
  source_textbox = gr.Textbox(label="Source", info="Source Sentence", value="Le chat est sur la tapis.")
44
  translation1_textbox = gr.Textbox(label="Translation 1", info="Translation 1", value="The cat is on the bed.")