import gradio as gr import torch import os from transformers import AutoTokenizer, T5ForConditionalGeneration model_id = 'ksabeh/gavi' max_input_length = 512 max_target_length = 10 auth_token = os.environ.get('TOKEN') model = T5ForConditionalGeneration.from_pretrained(model_id, use_auth_token=auth_token) tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token) def predict(title, category): input = f"{title} {category} " model_input = tokenizer(input, max_length=max_input_length, truncation=True, padding="max_length") model_input = {k:torch.unsqueeze(torch.tensor(v),dim=0) for k,v in model_input.items()} predictions = model.generate(**model_input, num_beams=8, do_sample=True, max_length=10) return tokenizer.batch_decode(predictions, skip_special_tokens=True)[0] iface = gr.Interface( predict, inputs=["text", "text"], outputs=['text'], title="GAVI", examples=[["Arriba Salsa Garlic and Cilantro, 16 oz", "Food"], ["MV Verholen Black GPS Ball Mount for BMW K1200S K1200R K1300S K1300R Black GPS Ball Mount VER-4901-10181", "Toys"]] ) iface.launch()