import os import sys import torch import gradio as gr from PIL import Image import numpy as np from omegaconf import OmegaConf import subprocess from tqdm import tqdm import requests import shutil def download_file(url, filename): response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) block_size = 1024 with open(filename, 'wb') as file, tqdm( desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as progress_bar: for data in response.iter_content(block_size): size = file.write(data) progress_bar.update(size) def setup_environment(): if not os.path.exists("CCSR_temp"): print("Cloning CCSR repository...") subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git", "CCSR_temp"]) else: print("CCSR repository already cloned. Skipping clone step.") # Move necessary files/directories from CCSR_temp to current directory for item in os.listdir("CCSR_temp"): s = os.path.join("CCSR_temp", item) d = os.path.join(os.getcwd(), item) if os.path.isdir(s): if os.path.exists(d): shutil.rmtree(d) shutil.move(s, d) else: if os.path.exists(d): os.remove(d) shutil.move(s, d) # Clean up the temporary directory shutil.rmtree("CCSR_temp") sys.path.append(os.getcwd()) os.makedirs("weights", exist_ok=True) if not os.path.exists("weights/real-world_ccsr.ckpt"): print("Downloading model checkpoint...") download_file( "https://huggingface.co./camenduru/CCSR/resolve/main/real-world_ccsr.ckpt", "weights/real-world_ccsr.ckpt" ) else: print("Model checkpoint already exists. Skipping download.") setup_environment() # Importing from the current directory from ldm.xformers_state import disable_xformers from model.q_sampler import SpacedSampler from model.ccsr_stage1 import ControlLDM from utils.common import instantiate_from_config, load_state_dict config = OmegaConf.load("configs/model/ccsr_stage2.yaml") model = instantiate_from_config(config) ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu") load_state_dict(model, ckpt, strict=True) model.freeze() model.to("cuda") @torch.no_grad() def process(image, steps, t_max, t_min, color_fix_type): image = Image.open(image).convert("RGB") image = image.resize((256, 256), Image.LANCZOS) image = np.array(image) sampler = SpacedSampler(model, var_type="fixed_small") control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1) control = einops.rearrange(control, "n h w c -> n c h w").contiguous() model.control_scales = [1.0] * 13 height, width = control.size(-2), control.size(-1) shape = (1, 4, height // 8, width // 8) x_T = torch.randn(shape, device=model.device, dtype=torch.float32) samples = sampler.sample_ccsr( steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control, positive_prompt="", negative_prompt="", x_T=x_T, cfg_scale=1.0, color_fix_type=color_fix_type ) x_samples = samples.clamp(0, 1) x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) return Image.fromarray(x_samples[0]) interface = gr.Interface( fn=process, inputs=[ gr.Image(type="filepath"), gr.Slider(minimum=1, maximum=100, step=1, value=45), gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667), gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333), gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"), ], outputs=gr.Image(type="pil"), title="CCSR: Continuous Contrastive Super-Resolution", ) interface.launch()