from typing import Literal import os import torch from itertools import batched import click from PIL import Image from tqdm import tqdm import json from diffusers import StableDiffusionImg2ImgPipeline PROMPTS = { "rain": "a photo of a city street in the rain", "snow": "a photo of a city street in the snow", "fog": "a photo of a city street in the fog", "night": "a photo of a city street at night", } @click.command() @click.option("--swim_dir", default="datasets/swim_data") @click.option("--output") @click.option("--type") @click.option("--batch_size", default=1) @click.option("--no_night", is_flag=True, default=False) def sdedit( swim_dir: str, output: str, type: Literal["rain", "snow", "fog", "night"], batch_size: int, no_night: bool, ): device = "cuda" model_id_or_path = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id_or_path, torch_dtype=torch.float16 ) pipe = pipe.to(device) os.makedirs(output, exist_ok=True) with open(f"{swim_dir}/val/labels.json") as f: labels = json.load(f) # [ # { # "name": "049306.jpg", # "weather": "clear", # "timeofday": "night", # "source": "bdd100k" # }, # labels format clear_images = [ label["name"] for label in labels if label["weather"] == "clear" and (not no_night or label["timeofday"] != "night") ] for names in tqdm(list(batched(clear_images, batch_size)), desc="Generating"): images = [Image.open(f"{swim_dir}/val/images/{name}") for name in names] prompts = [PROMPTS[type]] * len(images) processed_images = pipe(images, prompts, device=device) for name, processed_image in zip(names, processed_images): processed_image.save(f"{output}/{name}") if __name__ == "__main__": sdedit()