KasKniesmeijer commited on
Commit
a259df9
·
1 Parent(s): 7373a84

improved gradio interface

Browse files
Files changed (4) hide show
  1. app.py +47 -8
  2. index.html +0 -25
  3. requirements.txt +4 -1
  4. src/main.js +0 -83
app.py CHANGED
@@ -1,12 +1,19 @@
1
  import torch
2
  from PIL import Image
3
- from transformers import AutoProcessor, AutoModelForVision2Seq
4
- from transformers.image_utils import load_image
 
 
 
 
5
  import numpy as np
6
  import gradio as gr
 
 
7
 
8
  # Set the device (GPU or CPU)
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
  # Initialize processor and model
12
  try:
@@ -16,13 +23,36 @@ try:
16
  torch_dtype=torch.bfloat16,
17
  _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
18
  ).to(DEVICE)
 
 
19
  except Exception as e:
20
  print(f"Error loading model or processor: {str(e)}")
21
  exit(1)
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Define the function to answer questions
25
- def answer_question(image, question):
 
 
 
 
26
  # Check if the image is provided
27
  if image is None:
28
  return "Error: Please upload an image."
@@ -65,17 +95,26 @@ def answer_question(image, question):
65
  return f"Error: Failed to generate answer. {str(e)}"
66
 
67
 
68
- # Create Gradio interface
 
 
 
69
  iface = gr.Interface(
70
  fn=answer_question,
71
  inputs=[
72
- gr.Image(type="numpy"),
73
  gr.Textbox(lines=2, placeholder="Enter your question here..."),
 
 
 
 
 
74
  ],
75
  outputs="text",
76
  title="FAAM-demo | Vision Language Model | SmolVLM",
77
- description="Upload an image and ask a question about it.",
 
78
  )
79
 
80
- if __name__ == "__main__":
81
- iface.launch()
 
1
  import torch
2
  from PIL import Image
3
+ from transformers import (
4
+ AutoProcessor,
5
+ AutoModelForVision2Seq,
6
+ Wav2Vec2ForCTC,
7
+ Wav2Vec2Processor,
8
+ )
9
  import numpy as np
10
  import gradio as gr
11
+ import librosa
12
+ from gradio.themes import Citrus
13
 
14
  # Set the device (GPU or CPU)
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Using device: {DEVICE}")
17
 
18
  # Initialize processor and model
19
  try:
 
23
  torch_dtype=torch.bfloat16,
24
  _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
25
  ).to(DEVICE)
26
+ stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
27
+ stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(DEVICE)
28
  except Exception as e:
29
  print(f"Error loading model or processor: {str(e)}")
30
  exit(1)
31
 
32
 
33
+ # Define the function to convert speech to text
34
+ def speech_to_text(audio):
35
+ try:
36
+ # Load audio
37
+ audio, rate = librosa.load(audio, sr=16000)
38
+ input_values = stt_processor(
39
+ audio, return_tensors="pt", sampling_rate=16000
40
+ ).input_values.to(DEVICE)
41
+ logits = stt_model(input_values).logits
42
+ predicted_ids = torch.argmax(logits, dim=-1)
43
+ transcription = stt_processor.decode(predicted_ids[0])
44
+ print(f"Detected text: {transcription}")
45
+ return transcription
46
+ except Exception as e:
47
+ return f"Error: Unable to process the audio. {str(e)}"
48
+
49
+
50
  # Define the function to answer questions
51
+ def answer_question(image, question, audio):
52
+ # Convert speech to text if audio is provided
53
+ if audio is not None:
54
+ question = speech_to_text(audio)
55
+
56
  # Check if the image is provided
57
  if image is None:
58
  return "Error: Please upload an image."
 
95
  return f"Error: Failed to generate answer. {str(e)}"
96
 
97
 
98
+ # Customize the Citrus theme with a specific neutral_hue
99
+ custom_citrus = Citrus(neutral_hue="slate")
100
+
101
+ # Define your Gradio interface
102
  iface = gr.Interface(
103
  fn=answer_question,
104
  inputs=[
105
+ gr.Image(type="numpy", value="faam_to_the_future.jpg"),
106
  gr.Textbox(lines=2, placeholder="Enter your question here..."),
107
+ gr.Audio(
108
+ type="filepath",
109
+ sources="microphone",
110
+ label="Upload a recording or record a question",
111
+ ),
112
  ],
113
  outputs="text",
114
  title="FAAM-demo | Vision Language Model | SmolVLM",
115
+ description="Welcome to the FAAM-demo!",
116
+ theme=custom_citrus,
117
  )
118
 
119
+ # Launch the interface
120
+ iface.launch()
index.html DELETED
@@ -1,25 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8">
6
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
- <title>SmolVLM WebGPU</title>
8
- <link rel="stylesheet" href="styles.css">
9
- </head>
10
-
11
- <body>
12
- <h1>SmolVLM - Vision-Language Model</h1>
13
- <div id="app">
14
- <canvas id="webgpu-canvas"></canvas>
15
- <div id="controls">
16
- <input type="file" id="image-upload" accept="image/*">
17
- <input type="text" id="question" placeholder="Ask a question about the image">
18
- <button id="submit-btn">Submit</button>
19
- </div>
20
- <div id="answer">Answer will appear here</div>
21
- </div>
22
- <script type="module" src="./src/main.js"></script>
23
- </body>
24
-
25
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  torch
2
  transformers
3
- gradio
 
 
 
 
1
  torch
2
  transformers
3
+ gradio
4
+ pillow
5
+ numpy
6
+ librosa
src/main.js CHANGED
@@ -1,83 +0,0 @@
1
- async function initializeWebGPU() {
2
- const canvas = document.getElementById("webgpu-canvas");
3
-
4
- if (!navigator.gpu) {
5
- document.body.innerHTML = "<p>Your browser does not support WebGPU.</p>";
6
- return;
7
- }
8
-
9
- console.log("WebGPU is supported.");
10
-
11
- const adapter = await navigator.gpu.requestAdapter();
12
- if (!adapter) {
13
- console.error("Failed to get GPU adapter.");
14
- return;
15
- }
16
- console.log("GPU adapter obtained.");
17
-
18
- const device = await adapter.requestDevice();
19
- if (!device) {
20
- console.error("Failed to get GPU device.");
21
- return;
22
- }
23
- console.log("GPU device obtained.");
24
-
25
- const context = canvas.getContext("webgpu");
26
- if (!context) {
27
- console.error("Failed to get WebGPU context.");
28
- return;
29
- }
30
- console.log("WebGPU context obtained.");
31
-
32
- context.configure({
33
- device: device,
34
- format: navigator.gpu.getPreferredCanvasFormat(),
35
- alphaMode: "opaque",
36
- });
37
-
38
- console.log("WebGPU initialized and canvas configured.");
39
- }
40
-
41
- // Call the initializeWebGPU function to ensure it runs
42
- initializeWebGPU();
43
-
44
- async function submitQuestion(imageFile, question) {
45
- const formData = new FormData();
46
- formData.append("image", imageFile);
47
- formData.append("text", question);
48
-
49
- try {
50
- const response = await fetch("/predict", {
51
- method: "POST",
52
- body: formData,
53
- });
54
-
55
- if (!response.ok) {
56
- const errorText = await response.text();
57
- console.error("Failed to get a response:", response.status, response.statusText, errorText);
58
- return `Error: Unable to fetch the answer. Status: ${response.status}, ${response.statusText}`;
59
- }
60
-
61
- const result = await response.json();
62
- return result.data[0];
63
- } catch (error) {
64
- console.error("Fetch error:", error);
65
- return `Error: Unable to fetch the answer. ${error.message}`;
66
- }
67
- }
68
-
69
- // Handle user interactions
70
- document.getElementById("submit-btn").addEventListener("click", async () => {
71
- const imageFile = document.getElementById("image-upload").files[0];
72
- if (!imageFile) {
73
- alert("Please upload an image.");
74
- return;
75
- }
76
- const question = document.getElementById("question").value;
77
-
78
- const answer = await submitQuestion(imageFile, question);
79
- document.getElementById("answer").innerText = `Answer: ${answer}`;
80
- });
81
-
82
- // Initialize WebGPU when the page loads
83
- initializeWebGPU();