Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
3f5f533
·
verified ·
1 Parent(s): 601528d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -104
app.py CHANGED
@@ -34,39 +34,26 @@ from torchvision.transforms.functional import to_pil_image
34
 
35
  app = Flask(__name__)
36
 
37
- # Chemins de base pour les modèles
38
  base_path = 'yisol/IDM-VTON'
39
 
40
- # Chargement des modèles
41
- unet = UNet2DConditionModel.from_pretrained(
42
- base_path,
43
- subfolder="unet",
44
- torch_dtype=torch.float16,
45
- force_download=False
46
- )
47
- tokenizer_one = AutoTokenizer.from_pretrained(
48
- base_path,
49
- subfolder="tokenizer",
50
- use_fast=False,
51
- force_download=False
52
- )
53
- tokenizer_two = AutoTokenizer.from_pretrained(
54
- base_path,
55
- subfolder="tokenizer_2",
56
- use_fast=False,
57
- force_download=False
58
- )
59
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
60
- text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16)
61
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16)
62
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16)
63
- vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
64
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
65
 
66
- parsing_model = Parsing(0)
67
- openpose_model = OpenPose(0)
68
 
69
- # Préparation du pipeline Tryon
70
  pipe = TryonPipeline.from_pretrained(
71
  base_path,
72
  unet=unet,
@@ -79,12 +66,11 @@ pipe = TryonPipeline.from_pretrained(
79
  scheduler=noise_scheduler,
80
  image_encoder=image_encoder,
81
  torch_dtype=torch.float16,
82
- force_download=False
83
- )
84
  pipe.unet_encoder = UNet_Encoder
85
 
86
- # Utilisation des transformations d'images
87
- tensor_transfrom = transforms.Compose([
88
  transforms.ToTensor(),
89
  transforms.Normalize([0.5], [0.5]),
90
  ])
@@ -96,13 +82,11 @@ def pil_to_binary_mask(pil_image, threshold=0):
96
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
97
  mask[binary_mask] = 1
98
  return Image.fromarray((mask * 255).astype(np.uint8))
99
-
100
-
101
 
102
  def get_image_from_url(url):
103
  try:
104
  response = requests.get(url)
105
- response.raise_for_status() # Vérifie les erreurs HTTP
106
  img = Image.open(BytesIO(response.content))
107
  return img
108
  except Exception as e:
@@ -133,12 +117,7 @@ def save_image(img):
133
  return unique_name
134
 
135
  @spaces.GPU
136
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
137
- device = "cuda"
138
- openpose_model.preprocessor.body_estimation.model.to(device)
139
- pipe.to(device)
140
- pipe.unet_encoder.to(device)
141
-
142
  garm_img = garm_img.convert("RGB").resize((768, 1024))
143
  human_img_orig = dict["background"].convert("RGB")
144
 
@@ -159,11 +138,12 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
159
  if is_checked:
160
  keypoints = openpose_model(human_img.resize((384, 512)))
161
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
162
- mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
163
  mask = mask.resize((768, 1024))
164
  else:
165
- mask = dict['layers'][0].convert("RGB").resize((768, 1024))#pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
166
- mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
 
167
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
168
 
169
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
@@ -178,66 +158,58 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
178
  with torch.cuda.amp.autocast():
179
  prompt = "model is wearing " + garment_des
180
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
181
- with torch.inference_mode():
182
- (
183
- prompt_embeds,
184
- negative_prompt_embeds,
185
- pooled_prompt_embeds,
186
- negative_pooled_prompt_embeds,
187
- ) = pipe.encode_prompt(
188
- prompt,
189
- num_images_per_prompt=1,
190
- do_classifier_free_guidance=True,
191
- negative_prompt=negative_prompt,
192
- )
193
-
194
- prompt = "a photo of " + garment_des
195
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality , change color"
196
- if not isinstance(prompt, list):
197
- prompt = [prompt] * 1
198
- if not isinstance(negative_prompt, list):
199
- negative_prompt = [negative_prompt] * 1
200
- with torch.inference_mode():
201
- (
202
- prompt_embeds_c,
203
- _,
204
- _,
205
- _,
206
- ) = pipe.encode_prompt(
207
- prompt,
208
- num_images_per_prompt=1,
209
- do_classifier_free_guidance=False,
210
- negative_prompt=negative_prompt,
211
- )
212
-
213
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
214
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
215
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
216
- images = pipe(
217
- prompt_embeds=prompt_embeds.to(device, torch.float16),
218
- negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
219
- pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
220
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
221
- num_inference_steps=denoise_steps,
222
- generator=generator,
223
- strength=1.5,
224
- pose_img=pose_img.to(device, torch.float16),
225
- text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
226
- cloth=garm_tensor.to(device, torch.float16),
227
- mask_image=mask,
228
- image=human_img,
229
- height=1024,
230
- width=768,
231
- ip_adapter_image=garm_img.resize((768, 1024)),
232
- guidance_scale=1.5,
233
- )[0]
234
-
235
- if is_checked_crop:
236
- out_img = images[0].resize(crop_size)
237
- human_img_orig.paste(out_img, (int(left), int(top)))
238
- return human_img_orig, mask_gray
239
- else:
240
- return images[0], mask_gray , mask
241
 
242
 
243
  @app.route('/tryon-v2', methods=['POST'])
 
34
 
35
  app = Flask(__name__)
36
 
37
+ # Base paths for models
38
  base_path = 'yisol/IDM-VTON'
39
 
40
+ # Load models
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16).to(device)
44
+ tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False)
45
+ tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
47
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
48
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16).to(device)
49
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16).to(device)
50
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16).to(device)
51
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16).to(device)
52
 
53
+ parsing_model = Parsing(0).to(device)
54
+ openpose_model = OpenPose(0).to(device)
55
 
56
+ # Prepare Tryon pipeline
57
  pipe = TryonPipeline.from_pretrained(
58
  base_path,
59
  unet=unet,
 
66
  scheduler=noise_scheduler,
67
  image_encoder=image_encoder,
68
  torch_dtype=torch.float16,
69
+ ).to(device)
 
70
  pipe.unet_encoder = UNet_Encoder
71
 
72
+ # Image transformation
73
+ tensor_transform = transforms.Compose([
74
  transforms.ToTensor(),
75
  transforms.Normalize([0.5], [0.5]),
76
  ])
 
82
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
83
  mask[binary_mask] = 1
84
  return Image.fromarray((mask * 255).astype(np.uint8))
 
 
85
 
86
  def get_image_from_url(url):
87
  try:
88
  response = requests.get(url)
89
+ response.raise_for_status()
90
  img = Image.open(BytesIO(response.content))
91
  return img
92
  except Exception as e:
 
117
  return unique_name
118
 
119
  @spaces.GPU
120
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie='upper_body'):
 
 
 
 
 
121
  garm_img = garm_img.convert("RGB").resize((768, 1024))
122
  human_img_orig = dict["background"].convert("RGB")
123
 
 
138
  if is_checked:
139
  keypoints = openpose_model(human_img.resize((384, 512)))
140
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
141
+ mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints)
142
  mask = mask.resize((768, 1024))
143
  else:
144
+ mask = dict['layers'][0].convert("RGB").resize((768, 1024))
145
+
146
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transform(human_img)
147
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
148
 
149
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
 
158
  with torch.cuda.amp.autocast():
159
  prompt = "model is wearing " + garment_des
160
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
161
+ (
162
+ prompt_embeds,
163
+ negative_prompt_embeds,
164
+ pooled_prompt_embeds,
165
+ negative_pooled_prompt_embeds,
166
+ ) = pipe.encode_prompt(
167
+ prompt,
168
+ num_images_per_prompt=1,
169
+ do_classifier_free_guidance=True,
170
+ negative_prompt=negative_prompt,
171
+ )
172
+
173
+ prompt_c = "a photo of " + garment_des
174
+ negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality, change color"
175
+ prompt = [prompt_c] if not isinstance(prompt_c, list) else prompt_c
176
+ negative_prompt = [negative_prompt_c] if not isinstance(negative_prompt_c, list) else negative_prompt_c
177
+
178
+ (
179
+ prompt_embeds_c,
180
+ _,
181
+ _,
182
+ _,
183
+ ) = pipe.encode_prompt(
184
+ prompt,
185
+ num_images_per_prompt=1,
186
+ do_classifier_free_guidance=False,
187
+ negative_prompt=negative_prompt,
188
+ )
189
+
190
+ pose_img_tensor = tensor_transform(pose_img).unsqueeze(0).to(device, torch.float16)
191
+ garm_tensor = tensor_transform(garm_img).unsqueeze(0).to(device, torch.float16)
192
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
193
+
194
+ images = pipe(
195
+ prompt_embeds=prompt_embeds.to(device),
196
+ negative_prompt_embeds=negative_prompt_embeds.to(device),
197
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device),
198
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device),
199
+ num_inference_steps=denoise_steps,
200
+ generator=generator,
201
+ strength=1.5,
202
+ pose_img=pose_img_tensor.to(device),
203
+ text_embeds_cloth=prompt_embeds_c.to(device),
204
+ cloth=garm_tensor.to(device),
205
+ mask_image=mask,
206
+ image=human_img,
207
+ height=1024,
208
+ width=768,
209
+ )
210
+
211
+ final_image = images[0] if isinstance(images, list) else images
212
+ return encode_image_to_base64(final_image)
 
 
 
 
 
 
 
 
213
 
214
 
215
  @app.route('/tryon-v2', methods=['POST'])