|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from torchvision import transforms |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image |
|
|
|
|
|
class _MLPVectorProjector(nn.Module): |
|
def __init__( |
|
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int |
|
): |
|
super(_MLPVectorProjector, self).__init__() |
|
self.mlps = nn.ModuleList() |
|
for _ in range(width): |
|
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] |
|
for _ in range(1, num_layers): |
|
mlp.append(nn.GELU()) |
|
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) |
|
self.mlps.append(nn.Sequential(*mlp)) |
|
|
|
def forward(self, x): |
|
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) |
|
|
|
|
|
|
|
model_name = "microsoft/phi-2" |
|
|
|
with torch.no_grad(): |
|
phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",torch_dtype=torch.float16) |
|
|
|
tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
model_name_audio = "openai/whisper-small" |
|
pipe = pipeline(task="automatic-speech-recognition", model=model_name_audio, |
|
chunk_length_s=30, device="cpu",) |
|
|
|
|
|
|
|
model_id_clip = "openai/clip-vit-base-patch16" |
|
model_clip = CLIPModel.from_pretrained(model_id_clip).to("cpu") |
|
processor_clip = CLIPProcessor.from_pretrained(model_id_clip) |
|
|
|
print('--------------Loaded CLIP----------------------') |
|
|
|
|
|
def preprocess_image(image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
image = transforms.Resize((224, 224))(image) |
|
image = transforms.ToTensor()(image) |
|
return image.unsqueeze(0) |
|
|
|
|
|
def encode_image(image_path): |
|
image = preprocess_image(image_path).to("cpu") |
|
|
|
dummy_text = "" |
|
inputs = processor_clip(text=dummy_text, images=image, return_tensors="pt", padding=True) |
|
outputs = model_clip(**inputs) |
|
img_embedding = outputs.image_embeds |
|
return img_embedding |
|
|
|
|
|
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cpu") |
|
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth', map_location=torch.device('cpu'))) |
|
|
|
print('--------------Loaded proj head----------------------') |
|
|
|
|
|
with torch.no_grad(): |
|
phi2_finetuned = AutoModelForCausalLM.from_pretrained( |
|
"phi2_adaptor_fineTuned", trust_remote_code=True).to("cpu") |
|
|
|
print('--------------Loaded fine tuned phi2 model----------------------') |
|
|
|
|
|
def example_inference(input_text, count, image, img_qn, audio): |
|
pred_text = textMode(input_text, count) |
|
pred_text_image = imageMode(image, img_qn) |
|
pred_text_audio = audioMode(audio) |
|
return pred_text, pred_text_image, pred_text_audio |
|
|
|
|
|
|
|
def textMode(text, count): |
|
count = int(count) |
|
text = "Question: " + text + "Answer: " |
|
inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False) |
|
prediction = tokenizer_text.batch_decode( |
|
phi2_finetuned.generate( |
|
**inputs, |
|
max_new_tokens=count, |
|
bos_token_id=tokenizer_text.bos_token_id, |
|
eos_token_id=tokenizer_text.eos_token_id, |
|
pad_token_id=tokenizer_text.pad_token_id |
|
) |
|
) |
|
return prediction[0].rstrip('<|endoftext|>').rstrip("\n") |
|
|
|
|
|
|
|
def imageMode(image, question): |
|
image_embedding = encode_image(image) |
|
print('-------Image embedding from clip obtained-----------') |
|
imgToTextEmb = img_proj_head(image_embedding).unsqueeze(0) |
|
print('-------text embedding from projection obtained-----------') |
|
question = "Question: " + question + "Answer: " |
|
Qtokens = torch.tensor(tokenizer_text.encode(question, add_special_tokens=True)).unsqueeze(0) |
|
Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens) |
|
print('-------question embedding from phi2 obtained-----------') |
|
inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2) |
|
|
|
prediction = tokenizer_text.batch_decode( |
|
phi2_finetuned.generate( |
|
inputs_embeds=inputs, |
|
max_new_tokens=50, |
|
bos_token_id=tokenizer_text.bos_token_id, |
|
eos_token_id=tokenizer_text.eos_token_id, |
|
pad_token_id=tokenizer_text.pad_token_id |
|
) |
|
) |
|
text_pred = prediction[0].strip('<|endoftext|>').rstrip("\n") |
|
return text_pred |
|
|
|
def audioMode(audio): |
|
if audio is None: |
|
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") |
|
|
|
print('---------type of audio--------------') |
|
print(type(audio)) |
|
print(audio) |
|
text = pipe(audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] |
|
pred_text = textMode(text, 50) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return pred_text |
|
|
|
|
|
interface_title = "TSAI-ERA-V1 - Capstone - Multimodal GPT Demo" |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown(f"## **{interface_title}**") |
|
gr.Markdown("Choose text mode/image mode/audio mode for generation") |
|
with gr.Tab("Text mode"): |
|
text_input = gr.Textbox(placeholder="Enter a prompt", label="Input") |
|
text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count") |
|
text_button = gr.Button("Submit") |
|
text_output = gr.Textbox(label="Chat GPT like text") |
|
with gr.Tab("Image mode"): |
|
with gr.Row(): |
|
image_input = gr.Image(type="filepath") |
|
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt") |
|
image_button = gr.Button("Submit") |
|
image_text_output = gr.Textbox(label="Answer") |
|
|
|
with gr.Tab("Audio mode"): |
|
audio_input = gr.Audio(type="filepath") |
|
audio_button = gr.Button("Submit") |
|
audio_text_output = gr.Textbox(label="Chat GPT like text") |
|
|
|
|
|
text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output) |
|
image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output) |
|
audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["What is a large language model?","50","zebras.png","Are the zebras walking or standing still in the image?","WtIsML.m4a"] |
|
], |
|
inputs=[text_input, text_input_count, image_input, image_text_input, audio_input], |
|
outputs=[text_output, image_text_output, audio_text_output], |
|
fn=example_inference, |
|
) |
|
|
|
demo.launch() |