qninhdt's picture
cc
cf791ca
raw
history blame
1.92 kB
from typing import Literal
import os
import torch
from itertools import batched
import click
from PIL import Image
from tqdm import tqdm
import json
from diffusers import StableDiffusionImg2ImgPipeline
PROMPTS = {
"rain": "a photo of a city street in the rain",
"snow": "a photo of a city street in the snow",
"fog": "a photo of a city street in the fog",
"night": "a photo of a city street at night",
}
@click.command()
@click.option("--swim_dir", default="datasets/swim_data")
@click.option("--output")
@click.option("--type")
@click.option("--batch_size", default=1)
@click.option("--no_night", is_flag=True, default=False)
def sdedit(
swim_dir: str,
output: str,
type: Literal["rain", "snow", "fog", "night"],
batch_size: int,
no_night: bool,
):
device = "cuda"
model_id_or_path = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id_or_path, torch_dtype=torch.float16
)
pipe = pipe.to(device)
os.makedirs(output, exist_ok=True)
with open(f"{swim_dir}/val/labels.json") as f:
labels = json.load(f)
# [
# {
# "name": "049306.jpg",
# "weather": "clear",
# "timeofday": "night",
# "source": "bdd100k"
# },
# labels format
clear_images = [
label["name"]
for label in labels
if label["weather"] == "clear"
and (not no_night or label["timeofday"] != "night")
]
for names in tqdm(list(batched(clear_images, batch_size)), desc="Generating"):
images = [Image.open(f"{swim_dir}/val/images/{name}") for name in names]
prompts = [PROMPTS[type]] * len(images)
processed_images = pipe(images, prompts, device=device)
for name, processed_image in zip(names, processed_images):
processed_image.save(f"{output}/{name}")
if __name__ == "__main__":
sdedit()