asutermo commited on
Commit
bf8ca35
·
1 Parent(s): da1d0d5

nearly working

Browse files
Files changed (2) hide show
  1. cog.yaml +2 -2
  2. predict.py +19 -20
cog.yaml CHANGED
@@ -18,9 +18,9 @@ build:
18
  - torch==2.4.0
19
  - transformers==4.43.3
20
  - datasets==2.20.0
21
- - accelerate==0.26.1
22
  - jupyter==1.0.0
23
- - numpy==1.26.3
24
  - pillow==10.2.0
25
  - peft==0.13.2
26
  - diffusers>=0.32.0
 
18
  - torch==2.4.0
19
  - transformers==4.43.3
20
  - datasets==2.20.0
21
+ - accelerate==1.3.0
22
  - jupyter==1.0.0
23
+ - numpy==1.26.4
24
  - pillow==10.2.0
25
  - peft==0.13.2
26
  - diffusers>=0.32.0
predict.py CHANGED
@@ -1,38 +1,37 @@
1
- import os
2
- import subprocess
3
- import time
4
 
5
  from cog import BasePredictor, Input, Path, Secret
6
- from diffusers.utils import load_image, check_min_version
7
  from diffusers import FluxFillPipeline
8
  from diffusers import FluxTransformer2DModel
9
- import numpy as np
10
  import torch
11
  from torchvision import transforms
12
 
13
  class Predictor(BasePredictor):
14
  def setup(self) -> None:
15
  """Load part of the model into memory to make running multiple predictions efficient"""
16
- self.dtype = torch.bloat16
17
  self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
18
- torch_dtype=self.dtype)
19
  self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
20
- torch_dtype=self.dtype)
21
 
22
  def predict(self,
23
  hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
24
  image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
25
- mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
26
  try_on: bool = Input(True, description="Try on or try off"),
27
  garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
28
  num_steps: int = Input(50, description="Number of steps to run the model for"),
29
  guidance_scale: float = Input(30, description="Guidance scale for the model"),
30
  seed: int = Input(0, description="Seed for the model"),
31
  width: int = Input(576, description="Width of the output image"),
32
- height: int = Input(768, description="Height of the output image")):
33
 
34
-
35
  size = (width, height)
 
 
 
 
36
  if try_on:
37
  self.transformer = self.try_on_transformer
38
  else:
@@ -41,11 +40,11 @@ class Predictor(BasePredictor):
41
  self.pipe = FluxFillPipeline.from_pretrained(
42
  "black-forest-labs/FLUX.1-dev",
43
  transformer=self.transformer,
44
- torch_dtype=self.dtype,
45
- token=hf_token
46
  ).to("cuda")
47
 
48
- self.pipe.transformer.to(self.dtype)
49
 
50
  transform = transforms.Compose([
51
  transforms.ToTensor(),
@@ -55,10 +54,6 @@ class Predictor(BasePredictor):
55
  transforms.ToTensor()
56
  ])
57
 
58
- i = load_image(image).convert("RGB").resize(size)
59
- m = load_image(mask).convert("RGB").resize(size)
60
- g = load_image(garment).convert("RGB").resize(size)
61
-
62
  # Transform images using the new preprocessing
63
  image_tensor = transform(i)
64
  mask_tensor = mask_transform(m)[:1] # Take only first channel
@@ -94,5 +89,9 @@ class Predictor(BasePredictor):
94
  width = size[0]
95
  garment_result = result.crop((0, 0, width, size[1]))
96
  try_result = result.crop((width, 0, width * 2, size[1]))
97
-
98
- return garment_result, try_result
 
 
 
 
 
1
+ from typing import List
 
 
2
 
3
  from cog import BasePredictor, Input, Path, Secret
4
+ from diffusers.utils import load_image
5
  from diffusers import FluxFillPipeline
6
  from diffusers import FluxTransformer2DModel
 
7
  import torch
8
  from torchvision import transforms
9
 
10
  class Predictor(BasePredictor):
11
  def setup(self) -> None:
12
  """Load part of the model into memory to make running multiple predictions efficient"""
 
13
  self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
14
+ torch_dtype=torch.bfloat16)
15
  self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
16
+ torch_dtype=torch.bfloat16)
17
 
18
  def predict(self,
19
  hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
20
  image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
21
+ mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png"),
22
  try_on: bool = Input(True, description="Try on or try off"),
23
  garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
24
  num_steps: int = Input(50, description="Number of steps to run the model for"),
25
  guidance_scale: float = Input(30, description="Guidance scale for the model"),
26
  seed: int = Input(0, description="Seed for the model"),
27
  width: int = Input(576, description="Width of the output image"),
28
+ height: int = Input(768, description="Height of the output image")) -> List[Path]:
29
 
 
30
  size = (width, height)
31
+ i = load_image(str(image)).convert("RGB").resize(size)
32
+ m = load_image(str(mask)).convert("RGB").resize(size)
33
+ g = load_image(str(garment)).convert("RGB").resize(size)
34
+
35
  if try_on:
36
  self.transformer = self.try_on_transformer
37
  else:
 
40
  self.pipe = FluxFillPipeline.from_pretrained(
41
  "black-forest-labs/FLUX.1-dev",
42
  transformer=self.transformer,
43
+ torch_dtype=torch.bfloat16,
44
+ token=hf_token.get_secret_value()
45
  ).to("cuda")
46
 
47
+ self.pipe.transformer.to(torch.bfloat16)
48
 
49
  transform = transforms.Compose([
50
  transforms.ToTensor(),
 
54
  transforms.ToTensor()
55
  ])
56
 
 
 
 
 
57
  # Transform images using the new preprocessing
58
  image_tensor = transform(i)
59
  mask_tensor = mask_transform(m)[:1] # Take only first channel
 
89
  width = size[0]
90
  garment_result = result.crop((0, 0, width, size[1]))
91
  try_result = result.crop((width, 0, width * 2, size[1]))
92
+ out_path = "/tmp/try.png"
93
+ try_result.save(out_path)
94
+ garm_out_path = "/tmp/garment.png"
95
+ garment_result.save(garm_out_path)
96
+ return [Path(out_path), Path(garm_out_path)]
97
+