|
import os |
|
import torch |
|
import torch.nn as nn |
|
import pandas as pd |
|
from PIL import Image |
|
from torchvision import transforms |
|
from transformers import BertTokenizer, AutoModel |
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
from sklearn.model_selection import train_test_split |
|
from typing import List |
|
from dataclasses import dataclass |
|
import gradio as gr |
|
import torch, re |
|
import numpy as np |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration, ViTImageProcessor, BertTokenizer, BlipProcessor, BlipForQuestionAnswering, AutoProcessor, AutoModelForCausalLM, DonutProcessor, VisionEncoderDecoderModel, Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoModelForSeq2SeqLM |
|
|
|
import librosa |
|
from PIL import Image |
|
from torch.nn.utils import rnn |
|
from gtts import gTTS |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class LabelClassifier(nn.Module): |
|
def __init__(self): |
|
super(LabelClassifier, self).__init__() |
|
self.text_encoder = AutoModel.from_pretrained('bert-base-uncased') |
|
self.image_encoder = AutoModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224') |
|
self.intermediate_dim = 128 |
|
self.fusion = nn.Sequential( |
|
nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, self.intermediate_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
) |
|
self.classifier = nn.Linear(self.intermediate_dim, 6) |
|
|
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
def forward(self, |
|
input_ids: torch.LongTensor,pixel_values: torch.FloatTensor, attention_mask: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, labels: torch.LongTensor = None): |
|
|
|
encoded_text = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
encoded_image = self.image_encoder(pixel_values=pixel_values) |
|
|
|
|
|
|
|
|
|
fused_state = self.fusion(torch.cat((encoded_text['pooler_output'], encoded_image['pooler_output']), dim=1)) |
|
|
|
|
|
|
|
logits = self.classifier(fused_state) |
|
|
|
out = {"logits": logits} |
|
|
|
if labels is not None: |
|
loss = self.criterion(logits, labels) |
|
out["loss"] = loss |
|
|
|
|
|
return out |
|
|
|
model = LabelClassifier().to(device) |
|
model.load_state_dict(torch.load('classifier.pth', map_location=torch.device('cpu'))) |
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
processor = ViTImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def m1(que, image): |
|
processor3 = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") |
|
model3 = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") |
|
|
|
inputs = processor3(image, que, return_tensors="pt") |
|
|
|
out = model3.generate(**inputs) |
|
return processor3.decode(out[0], skip_special_tokens=True) |
|
|
|
def m2(que, image): |
|
processor3 = AutoProcessor.from_pretrained("microsoft/git-large-textvqa") |
|
model3 = AutoModelForCausalLM.from_pretrained("microsoft/git-large-textvqa") |
|
|
|
pixel_values = processor3(images=image, return_tensors="pt").pixel_values |
|
|
|
input_ids = processor3(text=que, add_special_tokens=False).input_ids |
|
input_ids = [processor3.tokenizer.cls_token_id] + input_ids |
|
input_ids = torch.tensor(input_ids).unsqueeze(0) |
|
|
|
generated_ids = model3.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) |
|
return processor3.batch_decode(generated_ids, skip_special_tokens=True)[0].split('?', 1)[-1].strip() |
|
|
|
|
|
def m3(que, image): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model3 = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large") |
|
processor3 = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large") |
|
|
|
|
|
inputs = processor3(images=image, text=que, return_tensors="pt") |
|
|
|
predictions = model3.generate(**inputs) |
|
return processor3.decode(predictions[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
def m4(que, image): |
|
processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v1') |
|
model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v1') |
|
|
|
inputs = processor3(images=image, text=que, return_tensors="pt") |
|
predictions = model3.generate(**inputs, max_new_tokens=512) |
|
return processor3.decode(predictions[0], skip_special_tokens=True) |
|
|
|
|
|
def m5(que, image): |
|
|
|
model3 = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ocrvqa-large") |
|
processor3 = Pix2StructProcessor.from_pretrained("google/pix2struct-ocrvqa-large") |
|
|
|
|
|
inputs = processor3(images=image, text=que, return_tensors="pt") |
|
predictions = model3.generate(**inputs) |
|
return processor3.decode(predictions[0], skip_special_tokens=True) |
|
|
|
def m6(que, image): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v1') |
|
model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v1') |
|
|
|
inputs = processor3(images=image, text=que, return_tensors="pt") |
|
predictions = model3.generate(**inputs, max_new_tokens=512) |
|
return processor3.decode(predictions[0], skip_special_tokens=True) |
|
|
|
|
|
def predict_answer(category, que, image): |
|
if category == 0: |
|
return m1(que, image) |
|
elif category == 1: |
|
return m2(que, image) |
|
elif category == 2: |
|
return m3(que, image) |
|
elif category == 3: |
|
return m4(que, image) |
|
elif category == 4: |
|
return m5(que, image) |
|
else: |
|
return m6(que, image) |
|
|
|
|
|
|
|
def transcribe_audio(audio): |
|
|
|
processor2 = WhisperProcessor.from_pretrained("openai/whisper-large-v3",language='en') |
|
model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") |
|
|
|
sampling_rate = audio[0] |
|
audio_data = audio[1] |
|
|
|
|
|
audio_data_float = np.array(audio_data).astype(np.float32) |
|
resampled_audio_data = librosa.resample(audio_data_float, orig_sr=sampling_rate, target_sr=16000) |
|
|
|
|
|
|
|
input_features = processor2( |
|
resampled_audio_data, sampling_rate=16000, return_tensors="pt" |
|
).input_features |
|
|
|
|
|
predicted_ids = model2.generate(input_features) |
|
|
|
|
|
transcription = processor2.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
return transcription |
|
|
|
|
|
def predict_category(que, input_image): |
|
|
|
|
|
|
|
encoded_text = tokenizer( |
|
text=que, |
|
padding='longest', |
|
max_length=24, |
|
truncation=True, |
|
return_tensors='pt', |
|
return_token_type_ids=True, |
|
return_attention_mask=True, |
|
) |
|
|
|
encoded_image = processor(input_image, return_tensors='pt').to(device) |
|
|
|
dict = { |
|
'input_ids': encoded_text['input_ids'].to(device), |
|
'token_type_ids': encoded_text['token_type_ids'].to(device), |
|
'attention_mask': encoded_text['attention_mask'].to(device), |
|
'pixel_values': encoded_image['pixel_values'].to(device) |
|
} |
|
|
|
output = model(input_ids=dict['input_ids'],token_type_ids=dict['token_type_ids'],attention_mask=dict['attention_mask'],pixel_values=dict['pixel_values']) |
|
|
|
preds = output["logits"].argmax(axis=-1).cpu().numpy() |
|
|
|
return preds[0] |
|
|
|
|
|
def combine(audio, input_image, text_question=""): |
|
if audio: |
|
que = transcribe_audio(audio) |
|
else: |
|
que = text_question |
|
|
|
image = Image.fromarray(input_image).convert('RGB') |
|
category = predict_category(que, image) |
|
answer = predict_answer(category, que, image) |
|
|
|
tts = gTTS(answer) |
|
tts.save('answer.mp3') |
|
|
|
return que, answer, 'answer.mp3', category |
|
|
|
|
|
model_interface = gr.Interface(fn=combine, |
|
inputs=[gr.Microphone(label="Ask your question"), |
|
gr.Image(label="Upload the image"), |
|
gr.Textbox(label="Text Question")], |
|
outputs=[gr.Text(label="Transcribed Question"), |
|
gr.Text(label="Answer"), |
|
gr.Audio(label="Audio Answer"), |
|
gr.Text(label="Category")]) |
|
|
|
|
|
model_interface.launch(debug=True) |