|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
class _MLPVectorProjector(nn.Module): |
|
def __init__( |
|
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int |
|
): |
|
super(_MLPVectorProjector, self).__init__() |
|
self.mlps = nn.ModuleList() |
|
for _ in range(width): |
|
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] |
|
for _ in range(1, num_layers): |
|
mlp.append(nn.GELU()) |
|
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) |
|
self.mlps.append(nn.Sequential(*mlp)) |
|
|
|
def forward(self, x): |
|
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) |
|
|
|
model_name = "microsoft/phi-2" |
|
|
|
phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
def textMode(text, count): |
|
count = int(count) |
|
inputs = tokenizer(text, return_tensors="pt", return_attention_mask=False) |
|
prediction = tokenizer.batch_decode( |
|
phi2_text.generate( |
|
**inputs, |
|
max_new_tokens=count, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
) |
|
return prediction[0].rstrip('<|endoftext|>').rstrip("\n") |
|
|
|
|
|
|
|
def imageMode(image, question): |
|
return "In progress" |
|
|
|
def audioMode(audio): |
|
return "In progress" |
|
|
|
|
|
interface_title = "TSAI-ERA-V1 - Capstone - Multimodal GPT Demo" |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown(f"## **{interface_title}**") |
|
gr.Markdown("Choose text mode/image mode/audio mode for generation") |
|
with gr.Tab("Text mode"): |
|
text_input = gr.Textbox(placeholder="Enter a prompt", label="Input") |
|
text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count") |
|
text_button = gr.Button("Submit") |
|
text_output = gr.Textbox(label="Chat GPT like text") |
|
with gr.Tab("Image mode"): |
|
with gr.Row(): |
|
image_input = gr.Image() |
|
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt") |
|
image_button = gr.Button("Submit") |
|
image_text_output = gr.Textbox(label="Answer") |
|
|
|
with gr.Tab("Audio mode"): |
|
audio_input = gr.Audio() |
|
audio_button = gr.Button("Submit") |
|
audio_text_output = gr.Textbox(label="Chat GPT like text") |
|
|
|
|
|
text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output) |
|
image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output) |
|
audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output) |
|
|
|
demo.launch() |