EaindraKyaw commited on
Commit
9f166d9
·
verified ·
1 Parent(s): 170b9ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !apt-get install espeak
2
+ import io
3
+ import matplotlib.pyplot as plt
4
+ import requests
5
+ import inflect
6
+ from PIL import Image
7
+
8
+ def load_image_from_url(url):
9
+ return Image.open(requests.get(url, stream=True).raw)
10
+
11
+ def render_results_in_image(in_pil_img, in_results):
12
+ plt.figure(figsize=(16, 10))
13
+ plt.imshow(in_pil_img)
14
+
15
+ ax = plt.gca()
16
+
17
+ for prediction in in_results:
18
+
19
+ x, y = prediction['box']['xmin'], prediction['box']['ymin']
20
+ w = prediction['box']['xmax'] - prediction['box']['xmin']
21
+ h = prediction['box']['ymax'] - prediction['box']['ymin']
22
+
23
+ ax.add_patch(plt.Rectangle((x, y),
24
+ w,
25
+ h,
26
+ fill=False,
27
+ color="green",
28
+ linewidth=2))
29
+ ax.text(
30
+ x,
31
+ y,
32
+ f"{prediction['label']}: {round(prediction['score']*100, 1)}%",
33
+ color='red'
34
+ )
35
+
36
+ plt.axis("off")
37
+
38
+ # Save the modified image to a BytesIO object
39
+ img_buf = io.BytesIO()
40
+ plt.savefig(img_buf, format='png',
41
+ bbox_inches='tight',
42
+ pad_inches=0)
43
+ img_buf.seek(0)
44
+ modified_image = Image.open(img_buf)
45
+
46
+ # Close the plot to prevent it from being displayed
47
+ plt.close()
48
+
49
+ return modified_image
50
+
51
+ def summarize_predictions_natural_language(predictions):
52
+ summary = {}
53
+ p = inflect.engine()
54
+
55
+ for prediction in predictions:
56
+ label = prediction['label']
57
+ if label in summary:
58
+ summary[label] += 1
59
+ else:
60
+ summary[label] = 1
61
+
62
+ result_string = "In this image, there are "
63
+ for i, (label, count) in enumerate(summary.items()):
64
+ count_string = p.number_to_words(count)
65
+ result_string += f"{count_string} {label}"
66
+ if count > 1:
67
+ result_string += "s"
68
+
69
+ result_string += " "
70
+
71
+ if i == len(summary) - 2:
72
+ result_string += "and "
73
+
74
+ # Remove the trailing comma and space
75
+ result_string = result_string.rstrip(', ') + "."
76
+
77
+ return result_string
78
+
79
+
80
+ ##### To ignore warnings #####
81
+ import warnings
82
+ import logging
83
+ from transformers import logging as hf_logging
84
+
85
+ def ignore_warnings():
86
+ # Ignore specific Python warnings
87
+ warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
88
+ warnings.filterwarnings("ignore", message="Could not find image processor class")
89
+ warnings.filterwarnings("ignore", message="The `max_size` parameter is deprecated")
90
+
91
+ # Adjust logging for libraries using the logging module
92
+ logging.basicConfig(level=logging.ERROR)
93
+ hf_logging.set_verbosity_error()
94
+
95
+ ########
96
+
97
+ from transformers import pipeline
98
+ from PIL import Image
99
+ from IPython.display import Audio as IPythonAudio
100
+ import gradio as gr
101
+ import numpy as np
102
+ import io
103
+ import soundfile as sf
104
+
105
+ def processed_image(image):
106
+ # The uploaded image is a PIL image
107
+ od_pipe= pipeline("object-detection", model="facebook/detr-resnet-50")
108
+ pl_out = od_pipe(image)
109
+ processed_image=render_results_in_image(image,pl_out)
110
+ text=summarize_predictions_natural_language(pl_out)
111
+ return processed_image,text
112
+
113
+ iface = gr.Interface(processed_image, # Function to process the image
114
+ inputs=gr.Image(type="pil"), # Image upload input
115
+ outputs=[gr.Image(type="pil"),"text"] # Image output
116
+ )
117
+
118
+ iface.launch()
119
+
120
+ tts_pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-vctk")
121
+ narrated_text=tts_pipe(text)
122
+ from IPython.display import Audio as IPythonAudio
123
+
124
+ IPythonAudio(narrated_text["audio"][0],
125
+ rate=narrated_text["sampling_rate"])