|
|
|
import spaces |
|
|
|
|
|
import torch |
|
from PIL import Image |
|
from utils.utils import * |
|
import torch.nn.functional as F |
|
from trol.load_trol import load_trol |
|
from torchvision.transforms.functional import pil_to_tensor |
|
|
|
|
|
import time |
|
import gradio as gr |
|
from threading import Thread |
|
from accelerate import Accelerator |
|
from transformers import TextIteratorStreamer |
|
from torchvision.transforms.functional import pil_to_tensor |
|
|
|
|
|
import subprocess |
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
accel = Accelerator() |
|
|
|
|
|
prompt_type="with_image" |
|
img_path='figures/demo.png' |
|
question="What is the troll doing? Provide the detail in the image and imagine what the event happens." |
|
|
|
|
|
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B') |
|
|
|
|
|
model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B') |
|
|
|
|
|
model_7, tokenizer_7 = load_trol(link='TroL-7B') |
|
|
|
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p): |
|
|
|
|
|
_inputs = model.eval_process(inputs=inputs, |
|
data='demo', |
|
tokenizer=tokenizer, |
|
device=device, |
|
img_token_number=image_token_number) |
|
generation_kwargs = _inputs |
|
generation_kwargs.update({'streamer': streamer}) |
|
generation_kwargs.update({'do_sample': True}) |
|
generation_kwargs.update({'max_new_tokens': new_max_token}) |
|
generation_kwargs.update({'top_p': top_p}) |
|
generation_kwargs.update({'temperature': temperature}) |
|
generation_kwargs.update({'use_cache': True}) |
|
return model.generate(**generation_kwargs) |
|
|
|
@spaces.GPU |
|
def bot_streaming(message, history, link, temperature, new_max_token, top_p): |
|
|
|
|
|
if "1.8B" in link: |
|
model = model_1_8 |
|
tokenizer = tokenizer_1_8 |
|
path = "BK-Lee/TroL-1.8B" |
|
elif "3.8B" in link: |
|
model = model_3_8 |
|
tokenizer = tokenizer_3_8 |
|
path = "BK-Lee/TroL-3.8B" |
|
elif "7B" in link: |
|
model = model_7 |
|
tokenizer = tokenizer_7 |
|
path = "BK-Lee/TroL-7B" |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
try: |
|
model.model.initialize_trol_gating() |
|
model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt"))) |
|
except: |
|
model.language_model.model.initialize_trol_gating() |
|
model.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt"))) |
|
|
|
|
|
for param in model.parameters(): |
|
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower(): |
|
param.data = param.data.to(torch.float16) |
|
|
|
|
|
for param in model.parameters(): |
|
if not param.is_cuda: |
|
param.data = param.to(accel.device) |
|
|
|
try: |
|
|
|
image_token_number = None |
|
if len(message['files']) == 1: |
|
|
|
image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB")) |
|
if "3.8B" not in link: |
|
image_token_number = 1225 |
|
image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0) |
|
inputs = [{'image': image.to(accel.device), 'question': message['text']}] |
|
elif len(message['files']) > 1: |
|
raise Exception("No way!") |
|
else: |
|
inputs = [{'question': message['text']}] |
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) |
|
|
|
|
|
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, |
|
image_token_number=image_token_number, |
|
streamer=streamer, |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=accel.device, |
|
temperature=temperature, |
|
new_max_token=new_max_token, |
|
top_p=top_p)) |
|
thread.start() |
|
|
|
|
|
generated_text = "" |
|
for new_text in streamer: |
|
generated_text += new_text |
|
generated_text |
|
|
|
|
|
response = output_filtering(generated_text, model) |
|
|
|
except: |
|
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version." |
|
|
|
|
|
text = message['text'] |
|
files = message['files'] |
|
print('-----------------------------') |
|
print(f'Link: {link}') |
|
print(f'Text: {text}') |
|
print(f'MM Files: {files}') |
|
print(f'Response: {response}') |
|
print('-----------------------------\n') |
|
|
|
|
|
buffer = "" |
|
for character in response: |
|
buffer += character |
|
time.sleep(0.012) |
|
yield buffer |
|
|
|
demo = gr.ChatInterface(fn=bot_streaming, |
|
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")], |
|
additional_inputs_accordion="Generation Hyperparameters", |
|
theme=gr.themes.Soft(), |
|
title="TroL", |
|
description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. " |
|
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) " |
|
"Note that, we don't support history-based conversation referring to previous dialogue", |
|
stop_btn="Stop Generation", multimodal=True) |
|
demo.launch() |