ccsr-upscaler / app.py
owiedotch's picture
Update app.py
182f0d5 verified
raw
history blame
No virus
3.52 kB
import os
import sys
import torch
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
import subprocess
from tqdm import tqdm
import requests
# Assuming spaces is a valid module
import spaces
def download_file(url, filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = file.write(data)
progress_bar.update(size)
def setup_environment():
if not os.path.exists("CCSR"):
print("Cloning CCSR repository...")
subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"])
os.chdir("CCSR")
sys.path.append(os.getcwd())
os.makedirs("weights", exist_ok=True)
if not os.path.exists("weights/real-world_ccsr.ckpt"):
print("Downloading model checkpoint...")
download_file(
"https://huggingface.co./camenduru/CCSR/resolve/main/real-world_ccsr.ckpt",
"weights/real-world_ccsr.ckpt"
)
else:
print("Model checkpoint already exists. Skipping download.")
setup_environment()
# Importing from the CCSR folder
from CCSR.ldm.xformers_state import disable_xformers
from CCSR.model.q_sampler import SpacedSampler
from CCSR.model.ccsr_stage1 import ControlLDM
from CCSR.utils.common import instantiate_from_config, load_state_dict
config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
model = instantiate_from_config(config)
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
load_state_dict(model, ckpt, strict=True)
model.freeze()
model.to("cuda")
@spaces.GPU # Decorate the inference function with @spaces.GPU
@torch.no_grad()
def process(image, steps, t_max, t_min, color_fix_type):
image = Image.open(image).convert("RGB")
image = image.resize((256, 256), Image.LANCZOS)
image = np.array(image)
sampler = SpacedSampler(model, var_type="fixed_small")
control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
model.control_scales = [1.0] * 13
height, width = control.size(-2), control.size(-1)
shape = (1, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
samples = sampler.sample_ccsr(
steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0, color_fix_type=color_fix_type
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
return Image.fromarray(x_samples[0])
interface = gr.Interface(
fn=process,
inputs=[
gr.Image(type="filepath"),
gr.Slider(minimum=1, maximum=100, step=1, value=45),
gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667),
gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333),
gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"),
],
outputs=gr.Image(type="pil"),
title="CCSR: Continuous Contrastive Super-Resolution",
)
interface.launch()