Spaces:
Running
on
Zero
Running
on
Zero
build cog for tryon/tryoff
Browse files- cog.yaml +39 -0
- predict.py +98 -0
cog.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://cog.run/yaml
|
3 |
+
|
4 |
+
build:
|
5 |
+
# set to true if your model requires a GPU
|
6 |
+
gpu: true
|
7 |
+
|
8 |
+
# a list of ubuntu apt packages to install
|
9 |
+
system_packages:
|
10 |
+
- "libgl1-mesa-glx"
|
11 |
+
- "libglib2.0-0"
|
12 |
+
|
13 |
+
# python version in the form '3.11' or '3.11.4'
|
14 |
+
python_version: "3.11"
|
15 |
+
|
16 |
+
# a list of packages in the format <package-name>==<version>
|
17 |
+
python_packages:
|
18 |
+
- torch==2.4.0
|
19 |
+
- transformers==4.43.3
|
20 |
+
- datasets==2.20.0
|
21 |
+
- accelerate==0.26.1
|
22 |
+
- jupyter==1.0.0
|
23 |
+
- numpy==1.26.3
|
24 |
+
- pillow==10.2.0
|
25 |
+
- peft==0.13.2
|
26 |
+
- diffusers>=0.32.0
|
27 |
+
- timm==0.9.16
|
28 |
+
- torchvision==0.19.0
|
29 |
+
- tqdm==4.66.5
|
30 |
+
- numpy==1.26.4
|
31 |
+
- sentencepiece
|
32 |
+
- protobuf
|
33 |
+
|
34 |
+
# commands run after the environment is setup
|
35 |
+
run:
|
36 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
37 |
+
|
38 |
+
# predict.py defines how predictions are run on your model
|
39 |
+
predict: "predict.py:Predictor"
|
predict.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import time
|
4 |
+
|
5 |
+
from cog import BasePredictor, Input, Path, Secret
|
6 |
+
from diffusers.utils import load_image, check_min_version
|
7 |
+
from diffusers import FluxFillPipeline
|
8 |
+
from diffusers import FluxTransformer2DModel
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
class Predictor(BasePredictor):
|
14 |
+
def setup(self) -> None:
|
15 |
+
"""Load part of the model into memory to make running multiple predictions efficient"""
|
16 |
+
self.dtype = torch.bloat16
|
17 |
+
self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
|
18 |
+
torch_dtype=self.dtype)
|
19 |
+
self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
|
20 |
+
torch_dtype=self.dtype)
|
21 |
+
|
22 |
+
def predict(self,
|
23 |
+
hf_token: Secret(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
|
24 |
+
image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
|
25 |
+
mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
|
26 |
+
try_on: bool = Input(True, description="Try on or try off"),
|
27 |
+
garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
|
28 |
+
num_steps: int = Input(50, description="Number of steps to run the model for"),
|
29 |
+
guidance_scale: float = Input(30, description="Guidance scale for the model"),
|
30 |
+
seed: int = Input(0, description="Seed for the model"),
|
31 |
+
width: int = Input(576, description="Width of the output image"),
|
32 |
+
height: int = Input(768, description="Height of the output image")):
|
33 |
+
|
34 |
+
|
35 |
+
size = (width, height)
|
36 |
+
if try_on:
|
37 |
+
self.transformer = self.try_on_transformer
|
38 |
+
else:
|
39 |
+
self.transformer = self.try_off_transformer
|
40 |
+
|
41 |
+
self.pipe = FluxFillPipeline.from_pretrained(
|
42 |
+
"black-forest-labs/FLUX.1-dev",
|
43 |
+
transformer=self.transformer,
|
44 |
+
torch_dtype=self.dtype,
|
45 |
+
token=hf_token
|
46 |
+
).to("cuda")
|
47 |
+
|
48 |
+
self.pipe.transformer.to(self.dtype)
|
49 |
+
|
50 |
+
transform = transforms.Compose([
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize([0.5], [0.5]) # For RGB images
|
53 |
+
])
|
54 |
+
mask_transform = transforms.Compose([
|
55 |
+
transforms.ToTensor()
|
56 |
+
])
|
57 |
+
|
58 |
+
i = load_image(image).convert("RGB").resize(size)
|
59 |
+
m = load_image(mask).convert("RGB").resize(size)
|
60 |
+
g = load_image(garment).convert("RGB").resize(size)
|
61 |
+
|
62 |
+
# Transform images using the new preprocessing
|
63 |
+
image_tensor = transform(i)
|
64 |
+
mask_tensor = mask_transform(m)[:1] # Take only first channel
|
65 |
+
garment_tensor = transform(g)
|
66 |
+
|
67 |
+
# Create concatenated images
|
68 |
+
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
|
69 |
+
garment_mask = torch.zeros_like(mask_tensor)
|
70 |
+
|
71 |
+
if try_on:
|
72 |
+
extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
|
73 |
+
else:
|
74 |
+
extended_mask = torch.cat([1 - garment_mask, mask_tensor], dim=2)
|
75 |
+
|
76 |
+
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
|
77 |
+
f"[IMAGE1] Detailed product shot of a clothing" \
|
78 |
+
f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
|
79 |
+
|
80 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
81 |
+
result = self.pipe(
|
82 |
+
height=size[1],
|
83 |
+
width=size[0] * 2,
|
84 |
+
image=inpaint_image,
|
85 |
+
mask_image=extended_mask,
|
86 |
+
num_inference_steps=num_steps,
|
87 |
+
generator=generator,
|
88 |
+
max_sequence_length=512,
|
89 |
+
guidance_scale=guidance_scale,
|
90 |
+
prompt=prompt,
|
91 |
+
).images[0]
|
92 |
+
|
93 |
+
# Split and save results
|
94 |
+
width = size[0]
|
95 |
+
garment_result = result.crop((0, 0, width, size[1]))
|
96 |
+
try_result = result.crop((width, 0, width * 2, size[1]))
|
97 |
+
|
98 |
+
return garment_result, try_result
|