from transformers import AutoTokenizer, AutoModelForCausalLM import torch import nltk nltk.download('punkt') from nltk.tokenize import sent_tokenize import streamlit as st def load_model(model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) return tokenizer, model model_id = "asi/gpt-fr-cased-small" tokenizer_fr, model_fr = load_model(model_id) model_id = "gpt2" tokenizer_en, model_en = load_model(model_id) model_id = "dbmdz/german-gpt2" tokenizer_de, model_de = load_model(model_id) with st.form(key='Form'): text = st.text_area("Enter text here.") option = st.selectbox('Select Language',('English', 'German', 'French')) submitted = st.form_submit_button("Submit") if submitted: text = text.replace('\n', '') with torch.no_grad(): if option == 'German': encodings = tokenizer_de(text, return_tensors="pt") input_ids = encodings.input_ids target_ids = input_ids.clone() loss = model_de(input_ids, labels=target_ids).loss elif option == 'English': encodings = tokenizer_en(text, return_tensors="pt") input_ids = encodings.input_ids target_ids = input_ids.clone() loss = model_en(input_ids, labels=target_ids).loss else: encodings = tokenizer_fr(text, return_tensors="pt") input_ids = encodings.input_ids target_ids = input_ids.clone() loss = model_fr(input_ids, labels=target_ids).loss st.write("Entire Text") st.write("Perplexity: ", round(float(torch.exp(loss)), 2)) for sentence in sent_tokenize(text): st.write("________________________") st.write(sentence) with torch.no_grad(): if option == 'German': encodings = tokenizer_de(sentence, return_tensors="pt") input_ids = encodings.input_ids target_ids = input_ids.clone() loss = model_de(input_ids, labels=target_ids).loss elif option == 'English': encodings = tokenizer_en(sentence, return_tensors="pt") input_ids = encodings.input_ids target_ids = input_ids.clone() loss = model_en(input_ids, labels=target_ids).loss else: encodings = tokenizer_fr(sentence, return_tensors="pt") input_ids = encodings.input_ids target_ids = input_ids.clone() loss = model_fr(input_ids, labels=target_ids).loss st.write("Perplexity: ", round(float(torch.exp(loss)), 2))