File size: 6,613 Bytes
eacf0bd 908a9d5 eacf0bd f019fdd eacf0bd a29da54 908a9d5 eacf0bd a29da54 908a9d5 eacf0bd a29da54 eacf0bd 908a9d5 eacf0bd c9961ab a29da54 3fb84e5 a29da54 3fb84e5 a29da54 3fb84e5 c9961ab 3fb84e5 c9961ab 908a9d5 c75a3b7 3fb84e5 eacf0bd 8a4d05f eacf0bd 8a4d05f eacf0bd 908a9d5 eacf0bd 908a9d5 eacf0bd 9edaf8c c9961ab eacf0bd 9edaf8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# A100 Zero GPU
import spaces
# TroL Package
import torch
from PIL import Image
from utils.utils import *
import torch.nn.functional as F
from trol.load_trol import load_trol
from torchvision.transforms.functional import pil_to_tensor
# Gradio Package
import time
import gradio as gr
from threading import Thread
from accelerate import Accelerator
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# accel
accel = Accelerator()
# User prompt
prompt_type="with_image" # Select one option "text_only", "with_image"
img_path='figures/demo.png'
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
# loading model
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
# loading model
model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
# loading model
model_7, tokenizer_7 = load_trol(link='TroL-7B')
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
# propagation
_inputs = model.eval_process(inputs=inputs,
data='demo',
tokenizer=tokenizer,
device=device,
img_token_number=image_token_number)
generation_kwargs = _inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': new_max_token})
generation_kwargs.update({'top_p': top_p})
generation_kwargs.update({'temperature': temperature})
generation_kwargs.update({'use_cache': True})
return model.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
# model selection
if "1.8B" in link:
model = model_1_8
tokenizer = tokenizer_1_8
path = "BK-Lee/TroL-1.8B"
elif "3.8B" in link:
model = model_3_8
tokenizer = tokenizer_3_8
path = "BK-Lee/TroL-3.8B"
elif "7B" in link:
model = model_7
tokenizer = tokenizer_7
path = "BK-Lee/TroL-7B"
# trol gating load
from huggingface_hub import hf_hub_download
try:
model.model.initialize_trol_gating()
model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
except:
model.language_model.model.initialize_trol_gating()
model.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
# X -> float16 conversion
for param in model.parameters():
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
param.data = param.data.to(torch.float16)
# cpu -> gpu
for param in model.parameters():
if not param.is_cuda:
param.data = param.to(accel.device)
try:
# prompt type -> input prompt
image_token_number = None
if len(message['files']) == 1:
# Image Load
image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
if "3.8B" not in link:
image_token_number = 1225
image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
inputs = [{'image': image.to(accel.device), 'question': message['text']}]
elif len(message['files']) > 1:
raise Exception("No way!")
else:
inputs = [{'question': message['text']}]
# Text Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
image_token_number=image_token_number,
streamer=streamer,
model=model,
tokenizer=tokenizer,
device=accel.device,
temperature=temperature,
new_max_token=new_max_token,
top_p=top_p))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = output_filtering(generated_text, model)
except:
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
# private log print
text = message['text']
files = message['files']
print('-----------------------------')
print(f'Link: {link}')
print(f'Text: {text}')
print(f'MM Files: {files}')
print(f'Response: {response}')
print('-----------------------------\n')
buffer = ""
for character in response:
buffer += character
time.sleep(0.012)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
additional_inputs_accordion="Generation Hyperparameters",
theme=gr.themes.Soft(),
title="TroL",
description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. "
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) "
"Note that, we don't support history-based conversation referring to previous dialogue",
stop_btn="Stop Generation", multimodal=True)
demo.launch() |