emircanerol commited on
Commit
a17b609
1 Parent(s): 1ca08a7

Add application file

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from peft import PeftModel, PeftConfig
4
+ from transformers import AutoModelForTokenClassification
5
+
6
+ def test_mask(model, sample):
7
+ """
8
+ Masks the padded tokens in the input.
9
+ Args:
10
+ data (list): List of strings.
11
+ Returns:
12
+ dataset (list): List of dictionaries.
13
+ """
14
+
15
+ tokens = dict()
16
+
17
+ input_tokens = [i + 3 for i in sample.encode('utf-8')]
18
+ input_tokens.append(0) # eos token
19
+ tokens['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64, device=model.device)
20
+
21
+ # Create attention mask
22
+ tokens['attention_mask'] = torch.ones_like(tokens['input_ids'], dtype=torch.int64, device=model.device)
23
+
24
+ return tokens
25
+
26
+ def rewrite(model, data):
27
+ """
28
+ Rewrites the input text with the model.
29
+ Args:
30
+ model (torch.nn.Module): Model.
31
+ data (dict): Dictionary containing 'input_ids' and 'attention_mask'.
32
+ Returns:
33
+ output (str): Rewritten text.
34
+ """
35
+
36
+ with torch.no_grad():
37
+ pred = torch.argmax(model(**data).logits, dim=2).squeeze(0)
38
+
39
+ output = list() # save the indices of the characters as list of integers
40
+
41
+ # Conversion table for Turkish characters {100: [300, 350], ...}
42
+ en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))}
43
+
44
+ for inp, lab in zip((data['input_ids'].squeeze(0) - 3).tolist(), pred.tolist()):
45
+ if lab and inp in en2tr:
46
+ # if the model predicts a diacritic, replace it with the corresponding Turkish character
47
+ output.extend(en2tr[inp])
48
+ elif inp >= 0: output.append(inp)
49
+ return bytes(output).decode()
50
+
51
+ def try_it(text):
52
+ sample = test_mask(model, text)
53
+ return rewrite(model, sample)
54
+
55
+
56
+ if __name__ == '__main__':
57
+ config = PeftConfig.from_pretrained("bite-the-byte/byt5-small-deASCIIfy-TR")
58
+ model = AutoModelForTokenClassification.from_pretrained("google/byt5-small")
59
+ model = PeftModel.from_pretrained(model, "bite-the-byte/byt5-small-deASCIIfy-TR")
60
+
61
+ diacritize_app = gr.Interface(fn=try_it, inputs="text", outputs="text")
62
+ diacritize_app.launch(share=True)