asutermo commited on
Commit
8421d2f
·
1 Parent(s): 43dcb58

build cog for tryon/tryoff

Browse files
Files changed (2) hide show
  1. cog.yaml +39 -0
  2. predict.py +98 -0
cog.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://cog.run/yaml
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ gpu: true
7
+
8
+ # a list of ubuntu apt packages to install
9
+ system_packages:
10
+ - "libgl1-mesa-glx"
11
+ - "libglib2.0-0"
12
+
13
+ # python version in the form '3.11' or '3.11.4'
14
+ python_version: "3.11"
15
+
16
+ # a list of packages in the format <package-name>==<version>
17
+ python_packages:
18
+ - torch==2.4.0
19
+ - transformers==4.43.3
20
+ - datasets==2.20.0
21
+ - accelerate==0.26.1
22
+ - jupyter==1.0.0
23
+ - numpy==1.26.3
24
+ - pillow==10.2.0
25
+ - peft==0.13.2
26
+ - diffusers>=0.32.0
27
+ - timm==0.9.16
28
+ - torchvision==0.19.0
29
+ - tqdm==4.66.5
30
+ - numpy==1.26.4
31
+ - sentencepiece
32
+ - protobuf
33
+
34
+ # commands run after the environment is setup
35
+ run:
36
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
37
+
38
+ # predict.py defines how predictions are run on your model
39
+ predict: "predict.py:Predictor"
predict.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import time
4
+
5
+ from cog import BasePredictor, Input, Path, Secret
6
+ from diffusers.utils import load_image, check_min_version
7
+ from diffusers import FluxFillPipeline
8
+ from diffusers import FluxTransformer2DModel
9
+ import numpy as np
10
+ import torch
11
+ from torchvision import transforms
12
+
13
+ class Predictor(BasePredictor):
14
+ def setup(self) -> None:
15
+ """Load part of the model into memory to make running multiple predictions efficient"""
16
+ self.dtype = torch.bloat16
17
+ self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
18
+ torch_dtype=self.dtype)
19
+ self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
20
+ torch_dtype=self.dtype)
21
+
22
+ def predict(self,
23
+ hf_token: Secret(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
24
+ image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
25
+ mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
26
+ try_on: bool = Input(True, description="Try on or try off"),
27
+ garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
28
+ num_steps: int = Input(50, description="Number of steps to run the model for"),
29
+ guidance_scale: float = Input(30, description="Guidance scale for the model"),
30
+ seed: int = Input(0, description="Seed for the model"),
31
+ width: int = Input(576, description="Width of the output image"),
32
+ height: int = Input(768, description="Height of the output image")):
33
+
34
+
35
+ size = (width, height)
36
+ if try_on:
37
+ self.transformer = self.try_on_transformer
38
+ else:
39
+ self.transformer = self.try_off_transformer
40
+
41
+ self.pipe = FluxFillPipeline.from_pretrained(
42
+ "black-forest-labs/FLUX.1-dev",
43
+ transformer=self.transformer,
44
+ torch_dtype=self.dtype,
45
+ token=hf_token
46
+ ).to("cuda")
47
+
48
+ self.pipe.transformer.to(self.dtype)
49
+
50
+ transform = transforms.Compose([
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.5], [0.5]) # For RGB images
53
+ ])
54
+ mask_transform = transforms.Compose([
55
+ transforms.ToTensor()
56
+ ])
57
+
58
+ i = load_image(image).convert("RGB").resize(size)
59
+ m = load_image(mask).convert("RGB").resize(size)
60
+ g = load_image(garment).convert("RGB").resize(size)
61
+
62
+ # Transform images using the new preprocessing
63
+ image_tensor = transform(i)
64
+ mask_tensor = mask_transform(m)[:1] # Take only first channel
65
+ garment_tensor = transform(g)
66
+
67
+ # Create concatenated images
68
+ inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
69
+ garment_mask = torch.zeros_like(mask_tensor)
70
+
71
+ if try_on:
72
+ extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
73
+ else:
74
+ extended_mask = torch.cat([1 - garment_mask, mask_tensor], dim=2)
75
+
76
+ prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
77
+ f"[IMAGE1] Detailed product shot of a clothing" \
78
+ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
79
+
80
+ generator = torch.Generator(device="cuda").manual_seed(seed)
81
+ result = self.pipe(
82
+ height=size[1],
83
+ width=size[0] * 2,
84
+ image=inpaint_image,
85
+ mask_image=extended_mask,
86
+ num_inference_steps=num_steps,
87
+ generator=generator,
88
+ max_sequence_length=512,
89
+ guidance_scale=guidance_scale,
90
+ prompt=prompt,
91
+ ).images[0]
92
+
93
+ # Split and save results
94
+ width = size[0]
95
+ garment_result = result.crop((0, 0, width, size[1]))
96
+ try_result = result.crop((width, 0, width * 2, size[1]))
97
+
98
+ return garment_result, try_result