File size: 3,766 Bytes
cab1df1
ff6b5fc
a259df9
 
 
 
 
 
ff6b5fc
 
a259df9
 
cab1df1
ff6b5fc
cab1df1
a259df9
cab1df1
 
460bccf
 
 
 
3a3e2e6
460bccf
 
a259df9
 
460bccf
 
 
cab1df1
 
a259df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cab1df1
a259df9
 
 
 
 
ff6b5fc
 
 
 
cd7c5fe
ff6b5fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd7c5fe
ff6b5fc
 
 
 
 
 
cd7c5fe
ff6b5fc
59ee00b
3a3e2e6
 
ff6b5fc
3a3e2e6
cab1df1
 
a259df9
 
 
 
460bccf
cab1df1
460bccf
533fd96
3a3e2e6
a259df9
 
 
 
 
460bccf
cab1df1
579e033
533fd96
a259df9
cab1df1
 
a259df9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoModelForVision2Seq,
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
)
import numpy as np
import gradio as gr
import librosa
from gradio.themes import Citrus

# Set the device (GPU or CPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Initialize processor and model
try:
    processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
    model = AutoModelForVision2Seq.from_pretrained(
        "HuggingFaceTB/SmolVLM-Instruct",
        torch_dtype=torch.bfloat16,
        _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
    ).to(DEVICE)
    stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
    stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(DEVICE)
except Exception as e:
    print(f"Error loading model or processor: {str(e)}")
    exit(1)


# Define the function to convert speech to text
def speech_to_text(audio):
    try:
        # Load audio
        audio, rate = librosa.load(audio, sr=16000)
        input_values = stt_processor(
            audio, return_tensors="pt", sampling_rate=16000
        ).input_values.to(DEVICE)
        logits = stt_model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = stt_processor.decode(predicted_ids[0])
        print(f"Detected text: {transcription}")
        return transcription
    except Exception as e:
        return f"Error: Unable to process the audio. {str(e)}"


# Define the function to answer questions
def answer_question(image, question, audio):
    # Convert speech to text if audio is provided
    if audio is not None:
        question = speech_to_text(audio)

    # Check if the image is provided
    if image is None:
        return "Error: Please upload an image."

    # Convert NumPy array to PIL Image
    try:
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
    except Exception as e:
        return f"Error: Unable to process the image. {str(e)}"

    # Ensure question is provided
    if not question.strip():
        return "Error: Please provide a question."

    # Create input message for the model
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question},
            ],
        },
    ]

    # Apply chat template and prepare inputs
    try:
        prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE)
    except Exception as e:
        return f"Error: Failed to prepare inputs. {str(e)}"

    # Generate answer
    try:
        outputs = model.generate(**inputs, max_new_tokens=400)
        answer = processor.decode(outputs[0], skip_special_tokens=True)
        return answer
    except Exception as e:
        return f"Error: Failed to generate answer. {str(e)}"


# Customize the Citrus theme with a specific neutral_hue
custom_citrus = Citrus(neutral_hue="slate")

# Define your Gradio interface
iface = gr.Interface(
    fn=answer_question,
    inputs=[
        gr.Image(type="numpy"),
        gr.Textbox(lines=2, placeholder="Enter your question here..."),
        gr.Audio(
            type="filepath",
            sources="microphone",
            label="Upload a recording or record a question",
        ),
    ],
    outputs="text",
    title="FAAM-demo | Vision Language Model | SmolVLM",
    description="Upload an image and ask a question about it.",
    theme=custom_citrus,
)

# Launch the interface
iface.launch()