TheoBH commited on
Commit
78fb25d
·
verified ·
1 Parent(s): 43c373a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import diffusers
3
+ import tqdm as notebook_tqdm
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ import cv2
6
+ import math
7
+ import gradio as gr
8
+ import numpy as np
9
+ import os
10
+ import mediapipe as mp
11
+
12
+ from mediapipe.tasks import python
13
+ from mediapipe.tasks.python import vision
14
+ from mediapipe.tasks.python.components import containers
15
+
16
+ from skimage.measure import label, regionprops
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ import cv2
20
+
21
+
22
+ from skimage.measure import label
23
+ from skimage.measure import regionprops
24
+
25
+ from PIL import Image
26
+ import torch
27
+
28
+ import numpy as np
29
+ import cv2
30
+ from PIL import Image, ImageDraw
31
+ import mediapipe as mp
32
+ from transformers import pipeline
33
+ from skimage.measure import label, regionprops
34
+ import gradio as gr
35
+
36
+
37
+ import gradio as gr
38
+ import numpy as np
39
+ import cv2
40
+ from PIL import Image, ImageDraw
41
+ import mediapipe as mp
42
+ from transformers import pipeline
43
+ from skimage.measure import label, regionprops
44
+ import matplotlib.pyplot as plt
45
+
46
+
47
+ def _normalized_to_pixel_coordinates(
48
+ normalized_x: float, normalized_y: float, image_width: int, image_height: int):
49
+ """Converts normalized value pair to pixel coordinates."""
50
+
51
+ # Checks if the float value is between 0 and 1.
52
+ def is_valid_normalized_value(value: float) -> bool:
53
+ return (value > 0 or math.isclose(0, value)) and (value < 1 or math.isclose(1, value))
54
+
55
+ if not (is_valid_normalized_value(normalized_x) and is_valid_normalized_value(normalized_y)):
56
+ # TODO: Draw coordinates even if it's outside of the image bounds.
57
+ return None
58
+ x_px = min(math.floor(normalized_x * image_width), image_width - 1)
59
+ y_px = min(math.floor(normalized_y * image_height), image_height - 1)
60
+ return x_px, y_px
61
+
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
65
+ "stabilityai/stable-diffusion-2-inpainting",
66
+ torch_dtype=torch.float16,
67
+ ).to(device)
68
+
69
+ #from huggingface_hub import login
70
+ #login()
71
+ #pipe2 = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
72
+ #pipe2.to("cuda")
73
+
74
+ BG_COLOR = (192, 192, 192) # gray
75
+ MASK_COLOR = (255, 255, 255) # white
76
+
77
+ RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest
78
+ NormalizedKeypoint = containers.keypoint.NormalizedKeypoint
79
+
80
+ # Create the options that will be used for InteractiveSegmenter
81
+ base_options = python.BaseOptions(model_asset_path='model.tflite')
82
+ options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True)
83
+
84
+
85
+ def create_bounding_box_mask(image):
86
+ image = 1 - image
87
+
88
+ # Find the coordinates of the non-background pixels
89
+ y_indices, x_indices = np.nonzero(image)
90
+ if not y_indices.size or not x_indices.size:
91
+ return None # No areas found, you might return an empty mask or raise an error
92
+
93
+ # Calculate the bounding box coordinates
94
+ x_min, x_max = x_indices.min(), x_indices.max()
95
+ y_min, y_max = y_indices.min(), y_indices.max()
96
+
97
+ # Create a new mask for the bounding box
98
+ bounding_mask = np.zeros_like(image, dtype=np.uint8) # Ensure it's a single-channel mask
99
+ bounding_mask[y_min:y_max+1, x_min:x_max+1] = 1 # Fill the bounding box with white 1
100
+
101
+ return bounding_mask
102
+
103
+
104
+
105
+ def segment_2(image_np, coordinates):
106
+ OVERLAY_COLOR = (255, 105, 180) # Rose
107
+
108
+ # Créer le segmenteur
109
+ with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter:
110
+
111
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_np)
112
+
113
+ # Enlever les parenthèses
114
+ coordinates = coordinates.strip("()")
115
+
116
+ # Séparer les valeurs par la virgule
117
+ valeurs = coordinates.split(',')
118
+
119
+ # Convertir les chaînes de caractères en nombres flottants
120
+ x = float(valeurs[0])
121
+ y = float(valeurs[1])
122
+
123
+ # Récupérer les masques de catégorie pour l'image
124
+ roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT,
125
+ keypoint=NormalizedKeypoint(x, y))
126
+ segmentation_result = segmenter.segment(image, roi)
127
+ category_mask = segmentation_result.category_mask
128
+
129
+ # Trouver la boîte englobante de la région segmentée
130
+ mask = (category_mask.numpy_view().astype(np.uint8)*255)
131
+
132
+ # Trouver la boîte englobante de la région segmentée
133
+ x, y, w, h = cv2.boundingRect(mask)
134
+
135
+ # Convertir l'image BGR en RGB
136
+ image_data = cv2.cvtColor(image.numpy_view(), cv2.COLOR_BGR2RGB)
137
+
138
+ # Créer une image d'incrustation avec la couleur désirée (par exemple, (255, 0, 0) pour le rouge)
139
+ overlay_image = np.zeros(image_data.shape, dtype=np.uint8)
140
+ overlay_image[:] = OVERLAY_COLOR
141
+
142
+ # Créer la condition à partir du tableau category_masks
143
+ alpha = np.stack((category_mask.numpy_view(),) * 3, axis=-1) <= 0.1
144
+
145
+ # Créer un canal alpha à partir de la condition avec l'opacité désirée (par exemple, 0.7 pour 70%)
146
+ alpha = alpha.astype(float) * 0.5 # Réduire l'opacité à 50%
147
+
148
+ # Fusionner l'image originale et l'image d'incrustation en fonction du canal alpha
149
+ output_image = image_data * (1 - alpha) + overlay_image * alpha
150
+ output_image = output_image.astype(np.uint8)
151
+
152
+ # Dessiner un point blanc avec une bordure noire pour indiquer le point d'intérêt
153
+ thickness, radius = 6, -1
154
+ keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height)
155
+ cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius)
156
+ cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius)
157
+
158
+
159
+ image_width, image_height = output_image.shape[:2]
160
+ bounding_mask = create_bounding_box_mask(mask)
161
+ bbox_mask_image = Image.fromarray((bounding_mask * 255).astype(np.uint8))
162
+ bbox_img = bbox_mask_image.convert("RGB")
163
+ bbox_img.resize((image_width, image_height))
164
+
165
+ return output_image,bbox_mask_image
166
+
167
+
168
+ def generate_2(image_file_path, bbox_image, prompt):
169
+
170
+ # Read image
171
+ img = Image.fromarray(image_file_path).convert("RGB")
172
+
173
+ # Generate images using images and prompts
174
+ images = pipe(prompt=prompt,
175
+ image=img,
176
+ mask_image=bbox_image,
177
+ generator=torch.Generator(device="cuda").manual_seed(0),
178
+ num_images_per_prompt=3,
179
+ plms=True).images
180
+
181
+ # Create an image grid
182
+ def image_grid(imgs, rows, cols):
183
+ assert len(imgs) == rows*cols
184
+
185
+ w, h = imgs[0].size
186
+ grid = Image.new('RGB', size=(cols*w, rows*h))
187
+ grid_w, grid_h = grid.size
188
+
189
+ for i, img in enumerate(imgs):
190
+ grid.paste(img, box=(i%cols*w, i//cols*h))
191
+ return grid
192
+
193
+ grid_image = image_grid(images, 1, 3)
194
+ return grid_image
195
+
196
+
197
+ def onclick(evt: gr.SelectData, image):
198
+ if evt:
199
+ x, y = evt.index
200
+ # Normalize the coordinates by 0-1
201
+ normalized_x = round(x / image.shape[1], 2)
202
+ normalized_y = round(y / image.shape[0], 2)
203
+ return normalized_x, normalized_y
204
+ else:
205
+ return None, None
206
+
207
+
208
+
209
+ # Assurez-vous d'importer ou de définir les fonctions segment et generate_2 ici
210
+
211
+ def callback(image, coordinates, prompt):
212
+ # Convertir l'image PIL en chemin de fichier temporaire ou en numpy array si nécessaire
213
+ # Appeler la fonction segment avec les coordonnées et l'image
214
+ segmented_image, bbox_image = segment_2(image, coordinates)
215
+
216
+ # Appeler la fonction generate_2 avec l'image, bbox_image, et le prompt
217
+ grid_image = generate_2(image, bbox_image, prompt)
218
+
219
+ # Retourner les images résultantes pour l'affichage
220
+ return segmented_image, grid_image
221
+
222
+ with gr.Blocks() as demo:
223
+ with gr.Row():
224
+ image_input = gr.Image(type="numpy", label="Upload Image", interactive=True)
225
+ coordinates_output = gr.Textbox(label="Coordinates")
226
+ with gr.Row():
227
+ prompt_input = gr.Textbox(label="What do you want to change?")
228
+ submit_button = gr.Button("Submit")
229
+ with gr.Row():
230
+ segmented_image_output = gr.Image(type="numpy", label="Segmented Image")
231
+ grid_image_output = gr.Image(type="pil", label="Generated Image Grid")
232
+
233
+ image_input.select(onclick, inputs=[image_input], outputs=coordinates_output)
234
+ submit_button.click(fn=callback, inputs=[image_input, coordinates_output, prompt_input], outputs=[segmented_image_output, grid_image_output])
235
+
236
+ demo.launch(debug=True)