import os import urllib.request import gradio as gr from transformers import T5Tokenizer, T5ForConditionalGeneration import huggingface_hub import re from transformers import AutoTokenizer, AutoModelForCausalLM import torch import time import transformers import requests import globals from utility import * """set up""" huggingface_hub.login(token=globals.HF_TOKEN) gemma_tokenizer = AutoTokenizer.from_pretrained(globals.gemma_2b_URL) gemma_model = AutoModelForCausalLM.from_pretrained(globals.gemma_2b_URL) falcon_tokenizer = AutoTokenizer.from_pretrained(globals.falcon_7b_URL, trust_remote_code=True, device_map=globals.device_map, offload_folder="offload") falcon_model = AutoModelForCausalLM.from_pretrained(globals.falcon_7b_URL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map=globals.device_map, offload_folder="offload") def get_model(model_typ): if model_typ not in ["gemma", "falcon", "falcon_api", "simplet5_base", "simplet5_large"]: raise ValueError('Invalid model type. Choose "gemma", "falcon", "falcon_api","simplet5_base", "simplet5_large".') if model_typ=="gemma": tokenizer = gemma_tokenizer model = gemma_model prefix = globals.gemma_PREFIX elif model_typ=="falcon_api": prefix = globals.falcon_PREFIX model=None tokenizer = None elif model_typ=="falcon": tokenizer = falcon_tokenizer model = falcon_model prefix = globals.falcon_PREFIX elif model_typ in ["simplet5_base","simplet5_large"]: prefix = globals.simplet5_PREFIX URL = globals.simplet5_base_URL if model_typ=="simplet5_base" else globals.simplet5_large_URL T5_MODEL_PATH = f"https://huggingface.co./{URL}/resolve/main/{globals.T5_FILE_NAME}" fetch_model(T5_MODEL_PATH, globals.T5_FILE_NAME) tokenizer = T5Tokenizer.from_pretrained(URL) model = T5ForConditionalGeneration.from_pretrained(URL) return model, tokenizer, prefix def single_query(model_typ="gemma",prompt="She has a heart of gold", max_length=256, api_token=""): model, tokenizer, prefix = get_model(model_typ) if api_token=="" and model_typ=="falcon_api": return "Warning: Aborted, Access token needed to access HuggingFace FalconAPI" start_time = time.time() input = prefix.replace("{fig}", prompt) print(f"Input to model: \n{input}") if model_typ == "simplet5_base" or model_typ == "simplet5_large": inputs = tokenizer(input, return_tensors="pt") outputs = model.generate( inputs["input_ids"], temperature=0.7, max_length=max_length, num_beams=5, top_k=10, do_sample=True, num_return_sequences=1, early_stopping=True ) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) elif model_typ=="gemma": inputs = tokenizer(input, return_tensors="pt") generate_ids = model.generate(inputs.input_ids, max_length=max_length) output= tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] print(f"Model original output:{output}\n") answer = post_process(output,input) # pattern = r"\*\*Literal Meaning:\*\*\s*(.*?)(?:\n\n|$)" # match = re.search(pattern, output, re.DOTALL) # if match: # answer = match.group(1).strip() # else: # answer = output elif model_typ=="falcon": falcon_pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, ) sequences = falcon_pipeline( prompt, max_length=max_length, do_sample=False, # processing time too long, disable sampling for deterministic output num_return_sequences=1, eos_token_id=falcon_tokenizer.eos_token_id, ) for seq in sequences: print(f"Result: \n{seq['generated_text']}") elif model_typ=="falcon_api": API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct" headers = {"Authorization": f"Bearer {api_token}"} payload = { "inputs": input, "parameters": { "temperature": 0.7, "max_length": max_length, "num_return_sequences": 1 } } output = api_query(API_URL=API_URL,headers=headers,payload=payload) answer = output[0]["generated_text"] answer = post_process(answer,input) else: raise ValueError('Invalid model type. Choose "gemma", "falcon", "falcon_api","simplet5_base", "simplet5_large".') print(f"Time taken: {time.time()-start_time:.2f} seconds") print(f"processed model output: {answer}") return answer model_types = ["gemma", "falcon", "falcon_api", "simplet5_base", "simplet5_large"] single_gradio = gr.Interface( fn=single_query, inputs=[ gr.Dropdown(choices=model_types, label="Select Model Type"), gr.Textbox(lines=2, placeholder="Enter a sentence...", label="Input Sentence"), gr.Slider(minimum=50, maximum=512, step=10, value=256, label="Max Length"), gr.Textbox(lines=1, placeholder="Enter your API token", label="HuggingFace Token",value=""), ], outputs="text", theme=gr.themes.Soft(), title=globals.TITLE, description="Select a model type from the dropdown and input a sentence to get the paraphrased literal meaning", examples=globals.EXAMPLE ) if __name__ == '__main__': single_gradio.launch()