import os import sys import torch import gradio as gr from PIL import Image import numpy as np from omegaconf import OmegaConf import subprocess from tqdm import tqdm import requests def download_file(url, filename): response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) block_size = 1024 with open(filename, 'wb') as file, tqdm( desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as progress_bar: for data in response.iter_content(block_size): size = file.write(data) progress_bar.update(size) def setup_environment(): if not os.path.exists("CCSR"): print("Cloning CCSR repository...") subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git", "."]) sys.path.append(os.getcwd()) os.makedirs("weights", exist_ok=True) if not os.path.exists("weights/real-world_ccsr.ckpt"): print("Downloading model checkpoint...") download_file( "https://huggingface.co./camenduru/CCSR/resolve/main/real-world_ccsr.ckpt", "weights/real-world_ccsr.ckpt" ) else: print("Model checkpoint already exists. Skipping download.") setup_environment() # Importing from the current directory from ldm.xformers_state import disable_xformers from model.q_sampler import SpacedSampler from model.ccsr_stage1 import ControlLDM from utils.common import instantiate_from_config, load_state_dict config = OmegaConf.load("configs/model/ccsr_stage2.yaml") model = instantiate_from_config(config) ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu") load_state_dict(model, ckpt, strict=True) model.freeze() model.to("cuda") @torch.no_grad() def process(image, steps, t_max, t_min, color_fix_type): image = Image.open(image).convert("RGB") image = image.resize((256, 256), Image.LANCZOS) image = np.array(image) sampler = SpacedSampler(model, var_type="fixed_small") control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1) control = einops.rearrange(control, "n h w c -> n c h w").contiguous() model.control_scales = [1.0] * 13 height, width = control.size(-2), control.size(-1) shape = (1, 4, height // 8, width // 8) x_T = torch.randn(shape, device=model.device, dtype=torch.float32) samples = sampler.sample_ccsr( steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control, positive_prompt="", negative_prompt="", x_T=x_T, cfg_scale=1.0, color_fix_type=color_fix_type ) x_samples = samples.clamp(0, 1) x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) return Image.fromarray(x_samples[0]) interface = gr.Interface( fn=process, inputs=[ gr.Image(type="filepath"), gr.Slider(minimum=1, maximum=100, step=1, value=45), gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667), gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333), gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"), ], outputs=gr.Image(type="pil"), title="CCSR: Continuous Contrastive Super-Resolution", ) interface.launch()