File size: 3,081 Bytes
2a918d4
72e077d
 
b088909
 
 
72e077d
b088909
 
 
 
 
72e077d
b088909
 
 
72e077d
b088909
 
 
 
72e077d
 
b088909
72e077d
b088909
 
 
72e077d
b088909
 
72e077d
b088909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e077d
b088909
 
 
 
 
 
2a918d4
b088909
 
 
179c8fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from transformers import pipeline

# Using the latest version of Gradio and Transformers
# We want to expand the interface to include a reverse translation
# We want to use the Helsinki-NLP/opus-mt-tc-big-he-en model for the reverse translation

# A dropdown menu for selecting the model
model_names = ["Helsinki-NLP/opus-mt-en-he", "Helsinki-NLP/opus-mt-tc-big-he-en"]
model_name = gr.inputs.Dropdown(model_names, label="Model")
# Name the dropdown options
model_name.choices = ["English to Hebrew", "Hebrew to English"]

# An output text box displaying the translated text and reverse translated text
translation = gr.outputs.Textbox(label="Translation")
reverse_translation = gr.outputs.Textbox(label="Reverse Translation")

# A function for translating text
def translate(model_name, text):
    # Create a pipeline for translating from English to Hebrew
    pipe = pipeline("translation", model=model_name)

    # Return the translation
    return pipe(text)[0]["translation_text"]

# Create an interface for translating text
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-he")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-he")

def translate(model_name, text):
    # Create a pipeline for translating from English to Hebrew
    #Console out the model name
    print(model_name)
    if model_name == "English to Hebrew":
        forward_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-he")
        forward_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-he")
        reverse_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tc-big-he-en")
        reverse_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tc-big-he-en")
    elif model_name == "Hebrew to English":
        forward_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tc-big-he-en")
        forward_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tc-big-he-en")
        reverse_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-he")
        reverse_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-he")
    else:
        raise ValueError("Invalid model name")  
    
    # Forward translation
    forward_input_ids = forward_tokenizer.encode(text, return_tensors="pt")
    forward_outputs = forward_model.generate(forward_input_ids)
    forward_translation = forward_tokenizer.decode(forward_outputs[0], skip_special_tokens=True)

    # Reverse translation
    reverse_input_ids = reverse_tokenizer.encode(forward_translation, return_tensors="pt")
    reverse_outputs = reverse_model.generate(reverse_input_ids)
    reverse_translation = reverse_tokenizer.decode(reverse_outputs[0], skip_special_tokens=True)
    
    return forward_translation, reverse_translation

iface = gr.Interface(fn=translate, inputs=[model_name, "text"], outputs=[translation, reverse_translation])

# Launch the interface
iface.launch(share=False)