File size: 3,408 Bytes
6381c79
f77813c
6381c79
 
 
 
 
f77813c
182f0d5
 
2c22ca3
182f0d5
6381c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d527cc3
f77813c
d527cc3
f77813c
d527cc3
f77813c
 
6381c79
 
 
 
 
 
 
 
 
 
 
 
d527cc3
33da899
 
 
 
6381c79
 
182f0d5
6381c79
182f0d5
 
 
 
2c22ca3
182f0d5
 
 
 
 
 
 
 
 
6381c79
182f0d5
6381c79
182f0d5
 
 
 
 
 
 
 
6381c79
 
182f0d5
 
 
 
6381c79
 
182f0d5
6381c79
 
182f0d5
6381c79
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import torch
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
import subprocess
from tqdm import tqdm
import requests
import spaces

def download_file(url, filename):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024
    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as progress_bar:
        for data in response.iter_content(block_size):
            size = file.write(data)
            progress_bar.update(size)

def setup_environment():
    if not os.path.exists("CCSR"):
        print("Cloning CCSR repository...")
        subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"])
    
    os.chdir("CCSR")
    sys.path.append(os.getcwd())
    
    os.makedirs("weights", exist_ok=True)
    if not os.path.exists("weights/real-world_ccsr.ckpt"):
        print("Downloading model checkpoint...")
        download_file(
            "https://huggingface.co./camenduru/CCSR/resolve/main/real-world_ccsr.ckpt",
            "weights/real-world_ccsr.ckpt"
        )
    else:
        print("Model checkpoint already exists. Skipping download.")

setup_environment()

# Importing from the CCSR folder
from ldm.xformers_state import disable_xformers
from model.q_sampler import SpacedSampler
from model.ccsr_stage1 import ControlLDM
from utils.common import instantiate_from_config, load_state_dict

config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
model = instantiate_from_config(config)
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
load_state_dict(model, ckpt, strict=True)
model.freeze()
model.to("cuda")

@spaces.GPU
@torch.no_grad()
def process(image, steps, t_max, t_min, color_fix_type):
    image = Image.open(image).convert("RGB")
    image = image.resize((256, 256), Image.LANCZOS)
    image = np.array(image)

    sampler = SpacedSampler(model, var_type="fixed_small")
    control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()

    model.control_scales = [1.0] * 13

    height, width = control.size(-2), control.size(-1)
    shape = (1, 4, height // 8, width // 8)
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)

    samples = sampler.sample_ccsr(
        steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control,
        positive_prompt="", negative_prompt="", x_T=x_T,
        cfg_scale=1.0, color_fix_type=color_fix_type
    )

    x_samples = samples.clamp(0, 1)
    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)

    return Image.fromarray(x_samples[0])

interface = gr.Interface(
    fn=process,
    inputs=[
        gr.Image(type="filepath"),
        gr.Slider(minimum=1, maximum=100, step=1, value=45),
        gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667),
        gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333),
        gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"),
    ],
    outputs=gr.Image(type="pil"),
    title="CCSR: Continuous Contrastive Super-Resolution",
)

interface.launch()