File size: 3,517 Bytes
6381c79
f77813c
6381c79
 
 
 
 
f77813c
182f0d5
 
 
 
 
6381c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77813c
 
 
 
 
 
 
6381c79
 
 
 
 
 
 
 
 
 
 
 
182f0d5
 
 
 
 
6381c79
 
182f0d5
6381c79
182f0d5
 
 
 
 
 
 
 
 
 
 
 
 
 
6381c79
182f0d5
6381c79
182f0d5
 
 
 
 
 
 
 
6381c79
 
182f0d5
 
 
 
6381c79
 
182f0d5
6381c79
 
182f0d5
6381c79
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import torch
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
import subprocess
from tqdm import tqdm
import requests

# Assuming spaces is a valid module
import spaces

def download_file(url, filename):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024
    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as progress_bar:
        for data in response.iter_content(block_size):
            size = file.write(data)
            progress_bar.update(size)

def setup_environment():
    if not os.path.exists("CCSR"):
        print("Cloning CCSR repository...")
        subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"])
    
    os.chdir("CCSR")
    sys.path.append(os.getcwd())
    
    os.makedirs("weights", exist_ok=True)
    if not os.path.exists("weights/real-world_ccsr.ckpt"):
        print("Downloading model checkpoint...")
        download_file(
            "https://huggingface.co./camenduru/CCSR/resolve/main/real-world_ccsr.ckpt",
            "weights/real-world_ccsr.ckpt"
        )
    else:
        print("Model checkpoint already exists. Skipping download.")

setup_environment()

# Importing from the CCSR folder
from CCSR.ldm.xformers_state import disable_xformers
from CCSR.model.q_sampler import SpacedSampler
from CCSR.model.ccsr_stage1 import ControlLDM
from CCSR.utils.common import instantiate_from_config, load_state_dict

config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
model = instantiate_from_config(config)
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
load_state_dict(model, ckpt, strict=True)
model.freeze()
model.to("cuda")

@spaces.GPU  # Decorate the inference function with @spaces.GPU
@torch.no_grad()
def process(image, steps, t_max, t_min, color_fix_type):
    image = Image.open(image).convert("RGB")
    image = image.resize((256, 256), Image.LANCZOS)
    image = np.array(image)

    sampler = SpacedSampler(model, var_type="fixed_small")
    control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()

    model.control_scales = [1.0] * 13

    height, width = control.size(-2), control.size(-1)
    shape = (1, 4, height // 8, width // 8)
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)

    samples = sampler.sample_ccsr(
        steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control,
        positive_prompt="", negative_prompt="", x_T=x_T,
        cfg_scale=1.0, color_fix_type=color_fix_type
    )

    x_samples = samples.clamp(0, 1)
    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)

    return Image.fromarray(x_samples[0])

interface = gr.Interface(
    fn=process,
    inputs=[
        gr.Image(type="filepath"),
        gr.Slider(minimum=1, maximum=100, step=1, value=45),
        gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667),
        gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333),
        gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"),
    ],
    outputs=gr.Image(type="pil"),
    title="CCSR: Continuous Contrastive Super-Resolution",
)

interface.launch()