|
from typing import Literal |
|
|
|
import os |
|
import torch |
|
from itertools import batched |
|
import click |
|
from PIL import Image |
|
from tqdm import tqdm |
|
import json |
|
|
|
from diffusers import StableDiffusionImg2ImgPipeline |
|
|
|
|
|
PROMPTS = { |
|
"rain": "a photo of a city street in the rain", |
|
"snow": "a photo of a city street in the snow", |
|
"fog": "a photo of a city street in the fog", |
|
"night": "a photo of a city street at night", |
|
} |
|
|
|
|
|
@click.command() |
|
@click.option("--swim_dir", default="datasets/swim_data") |
|
@click.option("--output") |
|
@click.option("--type") |
|
@click.option("--batch_size", default=1) |
|
@click.option("--no_night", is_flag=True, default=False) |
|
def sdedit( |
|
swim_dir: str, |
|
output: str, |
|
type: Literal["rain", "snow", "fog", "night"], |
|
batch_size: int, |
|
no_night: bool, |
|
): |
|
device = "cuda" |
|
model_id_or_path = "runwayml/stable-diffusion-v1-5" |
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
model_id_or_path, torch_dtype=torch.float16 |
|
) |
|
pipe = pipe.to(device) |
|
|
|
os.makedirs(output, exist_ok=True) |
|
|
|
with open(f"{swim_dir}/val/labels.json") as f: |
|
labels = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clear_images = [ |
|
label["name"] |
|
for label in labels |
|
if label["weather"] == "clear" |
|
and (not no_night or label["timeofday"] != "night") |
|
] |
|
|
|
for names in tqdm(list(batched(clear_images, batch_size)), desc="Generating"): |
|
images = [Image.open(f"{swim_dir}/val/images/{name}") for name in names] |
|
prompts = [PROMPTS[type]] * len(images) |
|
processed_images = pipe(images, prompts, device=device) |
|
|
|
for name, processed_image in zip(names, processed_images): |
|
processed_image.save(f"{output}/{name}") |
|
|
|
|
|
if __name__ == "__main__": |
|
sdedit() |
|
|