File size: 1,921 Bytes
cf791ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Literal
import os
import torch
from itertools import batched
import click
from PIL import Image
from tqdm import tqdm
import json
from diffusers import StableDiffusionImg2ImgPipeline
PROMPTS = {
"rain": "a photo of a city street in the rain",
"snow": "a photo of a city street in the snow",
"fog": "a photo of a city street in the fog",
"night": "a photo of a city street at night",
}
@click.command()
@click.option("--swim_dir", default="datasets/swim_data")
@click.option("--output")
@click.option("--type")
@click.option("--batch_size", default=1)
@click.option("--no_night", is_flag=True, default=False)
def sdedit(
swim_dir: str,
output: str,
type: Literal["rain", "snow", "fog", "night"],
batch_size: int,
no_night: bool,
):
device = "cuda"
model_id_or_path = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id_or_path, torch_dtype=torch.float16
)
pipe = pipe.to(device)
os.makedirs(output, exist_ok=True)
with open(f"{swim_dir}/val/labels.json") as f:
labels = json.load(f)
# [
# {
# "name": "049306.jpg",
# "weather": "clear",
# "timeofday": "night",
# "source": "bdd100k"
# },
# labels format
clear_images = [
label["name"]
for label in labels
if label["weather"] == "clear"
and (not no_night or label["timeofday"] != "night")
]
for names in tqdm(list(batched(clear_images, batch_size)), desc="Generating"):
images = [Image.open(f"{swim_dir}/val/images/{name}") for name in names]
prompts = [PROMPTS[type]] * len(images)
processed_images = pipe(images, prompts, device=device)
for name, processed_image in zip(names, processed_images):
processed_image.save(f"{output}/{name}")
if __name__ == "__main__":
sdedit()
|