VocabLine / ai_sentence.py
dayuian's picture
Create ai_sentence.py
c2aeaf4 verified
raw
history blame
827 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_LIST = [
"EleutherAI/pythia-410m",
"EleutherAI/pythia-1b",
"mistralai/Mistral-7B-Instruct"
]
MODEL_CACHE = {}
def load_model(model_name):
if model_name not in MODEL_CACHE:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
MODEL_CACHE[model_name] = (tokenizer, model)
return MODEL_CACHE[model_name]
def generate_sentence(word, model_name):
tokenizer, model = load_model(model_name)
prompt = f"A simple English sentence with the word '{word}':"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=30)
sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sentence