Spaces:
Sleeping
Sleeping
File size: 5,442 Bytes
4563ea0 d142770 4563ea0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import urllib.request
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
import huggingface_hub
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
import transformers
import requests
import globals
from utility import *
"""set up"""
huggingface_hub.login(token=globals.HF_TOKEN)
gemma_tokenizer = AutoTokenizer.from_pretrained(globals.gemma_2b_URL)
gemma_model = AutoModelForCausalLM.from_pretrained(globals.gemma_2b_URL)
falcon_tokenizer = AutoTokenizer.from_pretrained(globals.falcon_7b_URL, trust_remote_code=True, device_map=globals.device_map, offload_folder="offload")
falcon_model = AutoModelForCausalLM.from_pretrained(globals.falcon_7b_URL, trust_remote_code=True,
torch_dtype=torch.bfloat16, device_map=globals.device_map, offload_folder="offload")
def get_model(model_typ):
if model_typ not in ["gemma", "falcon", "falcon_api", "simplet5_base", "simplet5_large"]:
raise ValueError('Invalid model type. Choose "gemma", "falcon", "falcon_api","simplet5_base", "simplet5_large".')
if model_typ=="gemma":
tokenizer = gemma_tokenizer
model = gemma_model
prefix = globals.gemma_PREFIX
elif model_typ=="falcon_api":
prefix = globals.falcon_PREFIX
model=None
tokenizer = None
elif model_typ=="falcon":
tokenizer = falcon_tokenizer
model = falcon_model
prefix = globals.falcon_PREFIX
elif model_typ in ["simplet5_base","simplet5_large"]:
prefix = globals.simplet5_PREFIX
URL = globals.simplet5_base_URL if model_typ=="simplet5_base" else globals.simplet5_large_URL
T5_MODEL_PATH = f"https://huggingface.co./{URL}/resolve/main/{globals.T5_FILE_NAME}"
fetch_model(T5_MODEL_PATH, globals.T5_FILE_NAME)
tokenizer = T5Tokenizer.from_pretrained(URL)
model = T5ForConditionalGeneration.from_pretrained(URL)
return model, tokenizer, prefix
def single_query(model_typ="gemma",prompt="She has a heart of gold",
max_length=256,
api_token=""):
model, tokenizer, prefix = get_model(model_typ)
if api_token=="" and model_typ=="falcon_api":
return "Warning: Aborted, Access token needed to access HuggingFace FalconAPI"
start_time = time.time()
input = prefix.replace("{fig}", prompt)
print(f"Input to model: \n{input}")
if model_typ == "simplet5_base" or model_typ == "simplet5_large":
inputs = tokenizer(input, return_tensors="pt")
outputs = model.generate(
inputs["input_ids"],
temperature=0.7,
max_length=max_length,
num_beams=5,
top_k=10,
do_sample=True,
num_return_sequences=1,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
elif model_typ=="gemma":
inputs = tokenizer(input, return_tensors="pt")
generate_ids = model.generate(inputs.input_ids, max_length=max_length)
output= tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(f"Model original output:{output}\n")
answer = post_process(output,input)
# pattern = r"\*\*Literal Meaning:\*\*\s*(.*?)(?:\n\n|$)"
# match = re.search(pattern, output, re.DOTALL)
# if match:
# answer = match.group(1).strip()
# else:
# answer = output
elif model_typ=="falcon":
falcon_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
sequences = falcon_pipeline(
prompt,
max_length=max_length,
do_sample=False, # processing time too long, disable sampling for deterministic output
num_return_sequences=1,
eos_token_id=falcon_tokenizer.eos_token_id,
)
for seq in sequences:
print(f"Result: \n{seq['generated_text']}")
elif model_typ=="falcon_api":
API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
headers = {"Authorization": f"Bearer {api_token}"}
payload = {
"inputs": input,
"parameters": {
"temperature": 0.7,
"max_length": max_length,
"num_return_sequences": 1
}
}
output = api_query(API_URL=API_URL,headers=headers,payload=payload)
answer = output[0]["generated_text"]
answer = post_process(answer,input)
else:
raise ValueError('Invalid model type. Choose "gemma", "falcon", "falcon_api","simplet5_base", "simplet5_large".')
print(f"Time taken: {time.time()-start_time:.2f} seconds")
print(f"processed model output: {answer}")
return answer
model_types = ["gemma", "falcon", "falcon_api", "simplet5_base", "simplet5_large"]
single_gradio = gr.Interface(
fn=single_query,
inputs=[
gr.Dropdown(choices=model_types, label="Select Model Type"),
gr.Textbox(lines=2, placeholder="Enter a sentence...", label="Input Sentence"),
gr.Slider(minimum=50, maximum=512, step=10, value=256, label="Max Length"),
gr.Textbox(lines=1, placeholder="Enter your API token", label="HuggingFace Token",value=""),
],
outputs="text",
theme=gr.themes.Soft(),
title=globals.TITLE,
description="Select a model type from the dropdown and input a sentence to get the paraphrased literal meaning",
examples=globals.EXAMPLE
)
if __name__ == '__main__':
single_gradio.launch() |