diacritizeTR / app.py
emircanerol's picture
Update app.py
43a4b35 verified
raw
history blame contribute delete
No virus
2.18 kB
import gradio as gr
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForTokenClassification
def test_mask(model, sample):
"""
Masks the padded tokens in the input.
Args:
data (list): List of strings.
Returns:
dataset (list): List of dictionaries.
"""
tokens = dict()
input_tokens = [i + 3 for i in sample.encode('utf-8')]
input_tokens.append(0) # eos token
tokens['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64, device=model.device)
# Create attention mask
tokens['attention_mask'] = torch.ones_like(tokens['input_ids'], dtype=torch.int64, device=model.device)
return tokens
def rewrite(model, data):
"""
Rewrites the input text with the model.
Args:
model (torch.nn.Module): Model.
data (dict): Dictionary containing 'input_ids' and 'attention_mask'.
Returns:
output (str): Rewritten text.
"""
with torch.no_grad():
pred = torch.argmax(model(**data).logits, dim=2).squeeze(0)
output = list() # save the indices of the characters as list of integers
# Conversion table for Turkish characters {100: [300, 350], ...}
en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))}
for inp, lab in zip((data['input_ids'].squeeze(0) - 3).tolist(), pred.tolist()):
if lab and inp in en2tr:
# if the model predicts a diacritic, replace it with the corresponding Turkish character
output.extend(en2tr[inp])
elif inp >= 0: output.append(inp)
return bytes(output).decode()
def try_it(text):
sample = test_mask(model, text)
return rewrite(model, sample)
if __name__ == '__main__':
config = PeftConfig.from_pretrained("bite-the-byte/byt5-small-deASCIIfy-TR")
model = AutoModelForTokenClassification.from_pretrained("google/byt5-small")
model = PeftModel.from_pretrained(model, "bite-the-byte/byt5-small-deASCIIfy-TR")
diacritize_app = gr.Interface(fn=try_it, inputs="text", outputs="text")
diacritize_app.launch(share=True)