Spaces:
Sleeping
Sleeping
Initial commit of files
Browse files- app.py +176 -0
- audio_ex3.mp3 +0 -0
- img69.jpg +0 -0
- projection_finetuned.pth +3 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import gradio as gr
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
5 |
+
from torchvision import transforms
|
6 |
+
from transformers import CLIPProcessor, CLIPModel
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class _MLPVectorProjector(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
|
13 |
+
):
|
14 |
+
super(_MLPVectorProjector, self).__init__()
|
15 |
+
self.mlps = nn.ModuleList()
|
16 |
+
for _ in range(width):
|
17 |
+
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
|
18 |
+
for _ in range(1, num_layers):
|
19 |
+
mlp.append(nn.GELU())
|
20 |
+
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
|
21 |
+
self.mlps.append(nn.Sequential(*mlp))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)
|
25 |
+
|
26 |
+
## Text model
|
27 |
+
|
28 |
+
model_name = "microsoft/phi-2"
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",torch_dtype=torch.float16)
|
32 |
+
|
33 |
+
tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
34 |
+
|
35 |
+
## Audio model
|
36 |
+
model_name_audio = "openai/whisper-small"
|
37 |
+
pipe = pipeline(task="automatic-speech-recognition", model=model_name_audio,
|
38 |
+
chunk_length_s=30, device="cpu",)
|
39 |
+
|
40 |
+
## image model
|
41 |
+
#Clip model
|
42 |
+
model_id_clip = "openai/clip-vit-base-patch16"
|
43 |
+
model_clip = CLIPModel.from_pretrained(model_id_clip).to("cpu")
|
44 |
+
processor_clip = CLIPProcessor.from_pretrained(model_id_clip)
|
45 |
+
|
46 |
+
print('--------------Loaded CLIP----------------------')
|
47 |
+
|
48 |
+
# Preprocess the image for clip
|
49 |
+
def preprocess_image(image_path):
|
50 |
+
image = Image.open(image_path).convert("RGB")
|
51 |
+
image = transforms.Resize((224, 224))(image)
|
52 |
+
image = transforms.ToTensor()(image)
|
53 |
+
return image.unsqueeze(0)
|
54 |
+
|
55 |
+
# Get clip encoding
|
56 |
+
def encode_image(image_path):
|
57 |
+
image = preprocess_image(image_path).to("cpu")
|
58 |
+
# Dummy input_ids for text
|
59 |
+
dummy_text = ""
|
60 |
+
inputs = processor_clip(text=dummy_text, images=image, return_tensors="pt", padding=True)
|
61 |
+
outputs = model_clip(**inputs)
|
62 |
+
img_embedding = outputs.image_embeds
|
63 |
+
return img_embedding
|
64 |
+
|
65 |
+
#Get the projection model
|
66 |
+
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cpu")
|
67 |
+
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth', map_location=torch.device('cpu')))
|
68 |
+
|
69 |
+
print('--------------Loaded proj head----------------------')
|
70 |
+
|
71 |
+
#Get the fine-tuned phi-2 model
|
72 |
+
with torch.no_grad():
|
73 |
+
phi2_finetuned = AutoModelForCausalLM.from_pretrained(
|
74 |
+
"phi2_adaptor_fine_tuned", trust_remote_code=True).to("cpu")
|
75 |
+
|
76 |
+
print('--------------Loaded fine tuned phi2 model----------------------')
|
77 |
+
|
78 |
+
|
79 |
+
def example_inference(input_text, count, image, img_qn, audio):
|
80 |
+
pred_text = textMode(input_text, count)
|
81 |
+
pred_text_image = imageMode(image, img_qn)
|
82 |
+
pred_text_audio = audioMode(audio)
|
83 |
+
return pred_text, pred_text_image, pred_text_audio
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
def textMode(text, count):
|
88 |
+
count = int(count)
|
89 |
+
text = "Question: " + text + "Answer: "
|
90 |
+
inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False)
|
91 |
+
prediction = tokenizer_text.batch_decode(
|
92 |
+
phi2_finetuned.generate(
|
93 |
+
**inputs,
|
94 |
+
max_new_tokens=count,
|
95 |
+
bos_token_id=tokenizer_text.bos_token_id,
|
96 |
+
eos_token_id=tokenizer_text.eos_token_id,
|
97 |
+
pad_token_id=tokenizer_text.pad_token_id
|
98 |
+
)
|
99 |
+
)
|
100 |
+
return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def imageMode(image, question):
|
105 |
+
image_embedding = encode_image(image)
|
106 |
+
print('-------Image embedding from clip obtained-----------')
|
107 |
+
imgToTextEmb = img_proj_head(image_embedding).unsqueeze(0)
|
108 |
+
print('-------text embedding from projection obtained-----------')
|
109 |
+
question = "Question: " + question + "Answer: "
|
110 |
+
Qtokens = torch.tensor(tokenizer_text.encode(question, add_special_tokens=True)).unsqueeze(0)
|
111 |
+
Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens)
|
112 |
+
print('-------question embedding from phi2 obtained-----------')
|
113 |
+
inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2)
|
114 |
+
|
115 |
+
prediction = tokenizer_text.batch_decode(
|
116 |
+
phi2_finetuned.generate(
|
117 |
+
inputs_embeds=inputs,
|
118 |
+
max_new_tokens=50,
|
119 |
+
bos_token_id=tokenizer_text.bos_token_id,
|
120 |
+
eos_token_id=tokenizer_text.eos_token_id,
|
121 |
+
pad_token_id=tokenizer_text.pad_token_id
|
122 |
+
)
|
123 |
+
)
|
124 |
+
text_pred = prediction[0].strip('<|endoftext|>').rstrip("\n")
|
125 |
+
return text_pred
|
126 |
+
|
127 |
+
def audioMode(audio):
|
128 |
+
if audio is None:
|
129 |
+
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
|
130 |
+
|
131 |
+
print('---------type of audio--------------')
|
132 |
+
print(type(audio))
|
133 |
+
print(audio)
|
134 |
+
text = pipe(audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
|
135 |
+
pred_text = textMode(text, 50)
|
136 |
+
|
137 |
+
return pred_text
|
138 |
+
|
139 |
+
|
140 |
+
interface_title = "Multimodal GPT Application"
|
141 |
+
with gr.Blocks() as demo:
|
142 |
+
with gr.Row():
|
143 |
+
gr.Markdown(f"## **{interface_title}**")
|
144 |
+
gr.Markdown("Choose text mode/image mode/audio mode for text generation")
|
145 |
+
with gr.Tab("Text mode"):
|
146 |
+
text_input = gr.Textbox(placeholder="Enter a prompt", label="Input")
|
147 |
+
text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count")
|
148 |
+
text_button = gr.Button("Submit")
|
149 |
+
text_output = gr.Textbox(label="Chat GPT like text")
|
150 |
+
with gr.Tab("Image mode"):
|
151 |
+
with gr.Row():
|
152 |
+
image_input = gr.Image(type="filepath")
|
153 |
+
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
|
154 |
+
image_button = gr.Button("Submit")
|
155 |
+
image_text_output = gr.Textbox(label="Answer")
|
156 |
+
|
157 |
+
with gr.Tab("Audio mode"):
|
158 |
+
audio_input = gr.Audio(type="filepath")
|
159 |
+
audio_button = gr.Button("Submit")
|
160 |
+
audio_text_output = gr.Textbox(label="Chat GPT like text")
|
161 |
+
|
162 |
+
|
163 |
+
text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output)
|
164 |
+
image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output)
|
165 |
+
audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output)
|
166 |
+
|
167 |
+
gr.Examples(
|
168 |
+
examples=[
|
169 |
+
["Briefly explain the geographical features of India?","50","img69.jpg","What is the man behind the counter doing?","audio_ex3.mp3"]
|
170 |
+
],
|
171 |
+
inputs=[text_input, text_input_count, image_input, image_text_input, audio_input],
|
172 |
+
outputs=[text_output, image_text_output, audio_text_output],
|
173 |
+
fn=example_inference,
|
174 |
+
)
|
175 |
+
|
176 |
+
demo.launch()
|
audio_ex3.mp3
ADDED
Binary file (207 kB). View file
|
|
img69.jpg
ADDED
projection_finetuned.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fcc412caf78f2a82c1e4668244b319efedf819d27ab970d496c084c52086785
|
3 |
+
size 20973467
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
gradio
|
4 |
+
transformers
|
5 |
+
einops
|
6 |
+
peft
|
7 |
+
torchvision
|