Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
83d675a
·
verified ·
1 Parent(s): 406bf7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -84
app.py CHANGED
@@ -1,95 +1,246 @@
 
1
  from flask import Flask, request, jsonify
 
 
 
2
  import torch
3
- from diffusers import UNet2DConditionModel # Modification ici
 
 
4
  from transformers import (
5
- AutoTokenizer,
6
- CLIPTextModel,
7
- CLIPTextModelWithProjection,
8
- CLIPVisionModelWithProjection
 
9
  )
10
- from PIL import Image
11
- import base64
12
- from io import BytesIO
 
 
 
 
 
13
 
14
  app = Flask(__name__)
15
 
16
- # Global variables for models to load them once at startup
17
- unet = None
18
- tokenizer_one = None
19
- tokenizer_two = None
20
- noise_scheduler = None
21
- text_encoder_one = None
22
- text_encoder_two = None
23
- image_encoder = None
24
- vae = None
25
- UNet_Encoder = None
26
-
27
- # Load models once at startup
28
- def load_models():
29
- global unet, tokenizer_one, tokenizer_two, noise_scheduler
30
- global text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
31
 
32
- if unet is None:
33
- # Load models only when required to reduce memory usage
34
- unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4")
35
-
36
- if tokenizer_one is None:
37
- tokenizer_one = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
38
-
39
- if tokenizer_two is None:
40
- tokenizer_two = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14-336")
41
 
42
- if noise_scheduler is None:
43
- noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4")
44
-
45
- if text_encoder_one is None:
46
- text_encoder_one = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
47
-
48
- if text_encoder_two is None:
49
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14-336")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- if image_encoder is None:
52
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
53
-
54
- if vae is None:
55
- vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-v1-4")
56
-
57
- if UNet_Encoder is None:
58
- UNet_Encoder = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4")
59
-
60
- # Helper function to process base64 image
61
- def decode_image(image_base64):
62
- image_data = base64.b64decode(image_base64)
63
- image = Image.open(BytesIO(image_data)).convert("RGB")
64
- return image
65
-
66
- # Helper function to encode image to base64
67
- def encode_image(image):
68
- buffered = BytesIO()
69
- image.save(buffered, format="PNG")
70
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
71
-
72
- # Route for image processing
73
- @app.route('/process_image', methods=['POST'])
74
- def process_image():
75
- data = request.json
76
-
77
- # Load the models (this will only happen once)
78
- load_models()
79
-
80
- # Extract the image from the request
81
- image_base64 = data.get('image_base64')
82
- if not image_base64:
83
- return jsonify({"error": "No image provided"}), 400
84
-
85
- image = decode_image(image_base64)
86
-
87
- # Perform inference with the models (example, modify as needed)
88
- processed_image = image # Placeholder for actual image processing
89
-
90
- # Return the processed image as base64
91
- processed_image_base64 = encode_image(processed_image)
92
- return jsonify({"processed_image": processed_image_base64})
93
-
94
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  app.run(host='0.0.0.0', port=7860)
 
1
+ import os
2
  from flask import Flask, request, jsonify
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import base64
6
  import torch
7
+ import requests
8
+ import numpy as np
9
+ import uuid
10
  from transformers import (
11
+ CLIPImageProcessor,
12
+ CLIPVisionModelWithProjection,
13
+ CLIPTextModel,
14
+ CLIPTextModelWithProjection,
15
+ AutoTokenizer
16
  )
17
+ from diffusers import DDPMScheduler, AutoencoderKL
18
+ from utils_mask import get_mask_location
19
+ from torchvision import transforms
20
+ import apply_net
21
+ from preprocess.humanparsing.run_parsing import Parsing
22
+ from preprocess.openpose.run_openpose import OpenPose
23
+ from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
24
+ from torchvision.transforms.functional import to_pil_image
25
 
26
  app = Flask(__name__)
27
 
28
+ # Variables globales pour stocker les modèles
29
+ models_loaded = False
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def load_models():
32
+ global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two
33
+ global image_encoder, vae, UNet_Encoder, parsing_model, openpose_model, pipe
 
 
 
 
 
 
34
 
35
+ if not models_loaded:
36
+ base_path = 'yisol/IDM-VTON'
37
+ unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
38
+ unet.requires_grad_(False)
39
+
40
+ tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
41
+ tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
42
+
43
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
44
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
45
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
46
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
47
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
48
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
49
+
50
+ parsing_model = Parsing(0)
51
+ openpose_model = OpenPose(0)
52
+
53
+ UNet_Encoder.requires_grad_(False)
54
+ image_encoder.requires_grad_(False)
55
+ vae.requires_grad_(False)
56
+ unet.requires_grad_(False)
57
+ text_encoder_one.requires_grad_(False)
58
+ text_encoder_two.requires_grad_(False)
59
+
60
+ tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
61
+
62
+ pipe = TryonPipeline.from_pretrained(
63
+ base_path,
64
+ unet=unet,
65
+ vae=vae,
66
+ feature_extractor=CLIPImageProcessor(),
67
+ text_encoder=text_encoder_one,
68
+ text_encoder_2=text_encoder_two,
69
+ tokenizer=tokenizer_one,
70
+ tokenizer_2=tokenizer_two,
71
+ scheduler=noise_scheduler,
72
+ image_encoder=image_encoder,
73
+ torch_dtype=torch.float16,
74
+ force_download=False
75
+ )
76
+ pipe.unet_encoder = UNet_Encoder
77
+
78
+ global models_loaded
79
+ models_loaded = True
80
+
81
+ def pil_to_binary_mask(pil_image, threshold=0):
82
+ np_image = np.array(pil_image.convert("L")) # Convert to grayscale directly
83
+ binary_mask = np_image > threshold
84
+ mask = np.uint8(binary_mask * 255)
85
+ return Image.fromarray(mask)
86
+
87
+ def get_image_from_url(url):
88
+ try:
89
+ response = requests.get(url)
90
+ response.raise_for_status()
91
+ return Image.open(BytesIO(response.content))
92
+ except Exception as e:
93
+ logging.error(f"Error fetching image from URL: {e}")
94
+ raise
95
+
96
+ def decode_image_from_base64(base64_str):
97
+ try:
98
+ img_data = base64.b64decode(base64_str)
99
+ return Image.open(BytesIO(img_data))
100
+ except Exception as e:
101
+ logging.error(f"Error decoding image: {e}")
102
+ raise
103
+
104
+ def encode_image_to_base64(img):
105
+ try:
106
+ buffered = BytesIO()
107
+ img.save(buffered, format="PNG")
108
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
109
+ except Exception as e:
110
+ logging.error(f"Error encoding image: {e}")
111
+ raise
112
+
113
+ def save_image(img):
114
+ unique_name = f"{uuid.uuid4()}.webp"
115
+ img.save(unique_name, format="WEBP", lossless=True)
116
+ return unique_name
117
+
118
+ def clear_gpu_memory():
119
+ torch.cuda.empty_cache()
120
+ torch.cuda.ipc_collect()
121
+
122
+ @spaces.GPU
123
+ def start_tryon(human_dict, garment_image, garment_description, use_auto_mask, use_auto_crop, denoise_steps, seed, category='upper_body'):
124
+ device = "cuda"
125
+ openpose_model.preprocessor.body_estimation.model.to(device)
126
+ pipe.to(device)
127
+ pipe.unet_encoder.to(device)
128
+
129
+ garment_image = garment_image.convert("RGB").resize((768, 1024))
130
+ human_image_orig = human_dict["background"].convert("RGB")
131
+
132
+ if use_auto_crop:
133
+ width, height = human_image_orig.size
134
+ target_width = int(min(width, height * (3 / 4)))
135
+ target_height = int(min(height, width * (4 / 3)))
136
+ left, top = (width - target_width) / 2, (height - target_height) / 2
137
+ right, bottom = (width + target_width) / 2, (height + target_height) / 2
138
+ cropped_img = human_image_orig.crop((left, top, right, bottom)).resize((768, 1024))
139
+ else:
140
+ cropped_img = human_image_orig.resize((768, 1024))
141
+
142
+ if use_auto_mask:
143
+ keypoints = openpose_model(cropped_img.resize((384, 512)))
144
+ model_parse, _ = parsing_model(cropped_img.resize((384, 512)))
145
+ mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
146
+ mask = mask.resize((768, 1024))
147
+ else:
148
+ mask = pil_to_binary_mask(human_dict['layers'][0].convert("RGB").resize((768, 1024)))
149
 
150
+ mask_gray = (1 - transforms.ToTensor()(mask)) * transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(cropped_img)
151
+ mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
152
+
153
+ human_image_arg = _apply_exif_orientation(cropped_img.resize((384, 512)))
154
+ human_image_arg = convert_PIL_to_numpy(human_image_arg, format="BGR")
155
+
156
+ args = apply_net.create_argument_parser().parse_args(
157
+ ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
158
+ pose_image = args.func(args, human_image_arg)
159
+ pose_image = Image.fromarray(pose_image[:, :, ::-1]).resize((768, 1024))
160
+
161
+ with torch.no_grad(), torch.cuda.amp.autocast():
162
+ prompt = "model is wearing " + garment_description
163
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
164
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
165
+ prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt
166
+ )
167
+
168
+ prompt_c = "a photo of " + garment_description
169
+ negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
170
+ prompt_embeds_c, _, _, _ = pipe.encode_prompt(
171
+ prompt_c, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt_c
172
+ )
173
+
174
+ pose_image = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(pose_image).unsqueeze(0).to(device, torch.float16)
175
+ garment_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(garment_image).unsqueeze(0).to(device, torch.float16)
176
+
177
+ images = pipe(
178
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
179
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
180
+ pose_image=pose_image,
181
+ garment_image=garment_tensor,
182
+ mask_image=mask_gray.to(device, torch.float16),
183
+ generator=torch.Generator(device).manual_seed(seed),
184
+ num_inference_steps=denoise_steps
185
+ ).images
186
+
187
+ if images:
188
+ output_image = images[0]
189
+ output_base64 = encode_image_to_base64(output_image)
190
+ mask_image = mask
191
+ mask_base64 = encode_image_to_base64(mask_image)
192
+ return output_image, mask_image
193
+ else:
194
+ raise ValueError("Failed to generate image")
195
+
196
+
197
+ # Route pour récupérer l'image générée
198
+ @app.route('/api/get_image/<image_id>', methods=['GET'])
199
+ def get_image(image_id):
200
+ # Construire le chemin complet de l'image
201
+ image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
202
+
203
+ # Renvoyer l'image
204
+ try:
205
+ return send_file(image_path, mimetype='image/webp')
206
+ except FileNotFoundError:
207
+ return jsonify({'error': 'Image not found'}), 404
208
+
209
+ @app.route('/tryon', methods=['POST'])
210
+ def tryon_handler():
211
+ try:
212
+ data = request.json
213
+ human_image = decode_image_from_base64(data['human_image'])
214
+ garment_image = decode_image_from_base64(data['garment_image'])
215
+ description = data.get('description')
216
+ use_auto_mask = data.get('use_auto_mask', True)
217
+ use_auto_crop = data.get('use_auto_crop', False)
218
+ denoise_steps = int(data.get('denoise_steps', 30))
219
+ seed = int(data.get('seed', 42))
220
+ category = data.get('category', 'upper_body')
221
+
222
+ human_dict = {
223
+ 'background': human_image,
224
+ 'layers': [human_image] if not use_auto_mask else None,
225
+ 'composite': None
226
+ }
227
+ clear_gpu_memory()
228
+
229
+ output_image, mask_image = start_tryon(
230
+ human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, category
231
+ )
232
+
233
+ output_base64 = encode_image_to_base64(output_image)
234
+ mask_base64 = encode_image_to_base64(mask_image)
235
+
236
+ return jsonify({
237
+ 'output_image': output_base64,
238
+ 'mask_image': mask_base64
239
+ })
240
+ except Exception as e:
241
+ logging.error(f"Error in tryon_handler: {e}")
242
+ return jsonify({'error': str(e)}), 500
243
+
244
+ if __name__ == "__main__":
245
+ load_models() # Charge les modèles au démarrage
246
  app.run(host='0.0.0.0', port=7860)