Spaces:
Running
on
Zero
Running
on
Zero
nearly working
Browse files- cog.yaml +2 -2
- predict.py +19 -20
cog.yaml
CHANGED
@@ -18,9 +18,9 @@ build:
|
|
18 |
- torch==2.4.0
|
19 |
- transformers==4.43.3
|
20 |
- datasets==2.20.0
|
21 |
-
- accelerate==
|
22 |
- jupyter==1.0.0
|
23 |
-
- numpy==1.26.
|
24 |
- pillow==10.2.0
|
25 |
- peft==0.13.2
|
26 |
- diffusers>=0.32.0
|
|
|
18 |
- torch==2.4.0
|
19 |
- transformers==4.43.3
|
20 |
- datasets==2.20.0
|
21 |
+
- accelerate==1.3.0
|
22 |
- jupyter==1.0.0
|
23 |
+
- numpy==1.26.4
|
24 |
- pillow==10.2.0
|
25 |
- peft==0.13.2
|
26 |
- diffusers>=0.32.0
|
predict.py
CHANGED
@@ -1,38 +1,37 @@
|
|
1 |
-
import
|
2 |
-
import subprocess
|
3 |
-
import time
|
4 |
|
5 |
from cog import BasePredictor, Input, Path, Secret
|
6 |
-
from diffusers.utils import load_image
|
7 |
from diffusers import FluxFillPipeline
|
8 |
from diffusers import FluxTransformer2DModel
|
9 |
-
import numpy as np
|
10 |
import torch
|
11 |
from torchvision import transforms
|
12 |
|
13 |
class Predictor(BasePredictor):
|
14 |
def setup(self) -> None:
|
15 |
"""Load part of the model into memory to make running multiple predictions efficient"""
|
16 |
-
self.dtype = torch.bloat16
|
17 |
self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
|
18 |
-
torch_dtype=
|
19 |
self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
|
20 |
-
torch_dtype=
|
21 |
|
22 |
def predict(self,
|
23 |
hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
|
24 |
image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
|
25 |
-
mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png
|
26 |
try_on: bool = Input(True, description="Try on or try off"),
|
27 |
garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
|
28 |
num_steps: int = Input(50, description="Number of steps to run the model for"),
|
29 |
guidance_scale: float = Input(30, description="Guidance scale for the model"),
|
30 |
seed: int = Input(0, description="Seed for the model"),
|
31 |
width: int = Input(576, description="Width of the output image"),
|
32 |
-
height: int = Input(768, description="Height of the output image")):
|
33 |
|
34 |
-
|
35 |
size = (width, height)
|
|
|
|
|
|
|
|
|
36 |
if try_on:
|
37 |
self.transformer = self.try_on_transformer
|
38 |
else:
|
@@ -41,11 +40,11 @@ class Predictor(BasePredictor):
|
|
41 |
self.pipe = FluxFillPipeline.from_pretrained(
|
42 |
"black-forest-labs/FLUX.1-dev",
|
43 |
transformer=self.transformer,
|
44 |
-
torch_dtype=
|
45 |
-
token=hf_token
|
46 |
).to("cuda")
|
47 |
|
48 |
-
self.pipe.transformer.to(
|
49 |
|
50 |
transform = transforms.Compose([
|
51 |
transforms.ToTensor(),
|
@@ -55,10 +54,6 @@ class Predictor(BasePredictor):
|
|
55 |
transforms.ToTensor()
|
56 |
])
|
57 |
|
58 |
-
i = load_image(image).convert("RGB").resize(size)
|
59 |
-
m = load_image(mask).convert("RGB").resize(size)
|
60 |
-
g = load_image(garment).convert("RGB").resize(size)
|
61 |
-
|
62 |
# Transform images using the new preprocessing
|
63 |
image_tensor = transform(i)
|
64 |
mask_tensor = mask_transform(m)[:1] # Take only first channel
|
@@ -94,5 +89,9 @@ class Predictor(BasePredictor):
|
|
94 |
width = size[0]
|
95 |
garment_result = result.crop((0, 0, width, size[1]))
|
96 |
try_result = result.crop((width, 0, width * 2, size[1]))
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
|
|
|
|
2 |
|
3 |
from cog import BasePredictor, Input, Path, Secret
|
4 |
+
from diffusers.utils import load_image
|
5 |
from diffusers import FluxFillPipeline
|
6 |
from diffusers import FluxTransformer2DModel
|
|
|
7 |
import torch
|
8 |
from torchvision import transforms
|
9 |
|
10 |
class Predictor(BasePredictor):
|
11 |
def setup(self) -> None:
|
12 |
"""Load part of the model into memory to make running multiple predictions efficient"""
|
|
|
13 |
self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
|
14 |
+
torch_dtype=torch.bfloat16)
|
15 |
self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
|
16 |
+
torch_dtype=torch.bfloat16)
|
17 |
|
18 |
def predict(self,
|
19 |
hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
|
20 |
image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
|
21 |
+
mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png"),
|
22 |
try_on: bool = Input(True, description="Try on or try off"),
|
23 |
garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
|
24 |
num_steps: int = Input(50, description="Number of steps to run the model for"),
|
25 |
guidance_scale: float = Input(30, description="Guidance scale for the model"),
|
26 |
seed: int = Input(0, description="Seed for the model"),
|
27 |
width: int = Input(576, description="Width of the output image"),
|
28 |
+
height: int = Input(768, description="Height of the output image")) -> List[Path]:
|
29 |
|
|
|
30 |
size = (width, height)
|
31 |
+
i = load_image(str(image)).convert("RGB").resize(size)
|
32 |
+
m = load_image(str(mask)).convert("RGB").resize(size)
|
33 |
+
g = load_image(str(garment)).convert("RGB").resize(size)
|
34 |
+
|
35 |
if try_on:
|
36 |
self.transformer = self.try_on_transformer
|
37 |
else:
|
|
|
40 |
self.pipe = FluxFillPipeline.from_pretrained(
|
41 |
"black-forest-labs/FLUX.1-dev",
|
42 |
transformer=self.transformer,
|
43 |
+
torch_dtype=torch.bfloat16,
|
44 |
+
token=hf_token.get_secret_value()
|
45 |
).to("cuda")
|
46 |
|
47 |
+
self.pipe.transformer.to(torch.bfloat16)
|
48 |
|
49 |
transform = transforms.Compose([
|
50 |
transforms.ToTensor(),
|
|
|
54 |
transforms.ToTensor()
|
55 |
])
|
56 |
|
|
|
|
|
|
|
|
|
57 |
# Transform images using the new preprocessing
|
58 |
image_tensor = transform(i)
|
59 |
mask_tensor = mask_transform(m)[:1] # Take only first channel
|
|
|
89 |
width = size[0]
|
90 |
garment_result = result.crop((0, 0, width, size[1]))
|
91 |
try_result = result.crop((width, 0, width * 2, size[1]))
|
92 |
+
out_path = "/tmp/try.png"
|
93 |
+
try_result.save(out_path)
|
94 |
+
garm_out_path = "/tmp/garment.png"
|
95 |
+
garment_result.save(garm_out_path)
|
96 |
+
return [Path(out_path), Path(garm_out_path)]
|
97 |
+
|