File size: 1,921 Bytes
cf791ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Literal

import os
import torch
from itertools import batched
import click
from PIL import Image
from tqdm import tqdm
import json

from diffusers import StableDiffusionImg2ImgPipeline


PROMPTS = {
    "rain": "a photo of a city street in the rain",
    "snow": "a photo of a city street in the snow",
    "fog": "a photo of a city street in the fog",
    "night": "a photo of a city street at night",
}


@click.command()
@click.option("--swim_dir", default="datasets/swim_data")
@click.option("--output")
@click.option("--type")
@click.option("--batch_size", default=1)
@click.option("--no_night", is_flag=True, default=False)
def sdedit(
    swim_dir: str,
    output: str,
    type: Literal["rain", "snow", "fog", "night"],
    batch_size: int,
    no_night: bool,
):
    device = "cuda"
    model_id_or_path = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        model_id_or_path, torch_dtype=torch.float16
    )
    pipe = pipe.to(device)

    os.makedirs(output, exist_ok=True)

    with open(f"{swim_dir}/val/labels.json") as f:
        labels = json.load(f)

    # [
    # {
    #     "name": "049306.jpg",
    #     "weather": "clear",
    #     "timeofday": "night",
    #     "source": "bdd100k"
    # },
    # labels format

    clear_images = [
        label["name"]
        for label in labels
        if label["weather"] == "clear"
        and (not no_night or label["timeofday"] != "night")
    ]

    for names in tqdm(list(batched(clear_images, batch_size)), desc="Generating"):
        images = [Image.open(f"{swim_dir}/val/images/{name}") for name in names]
        prompts = [PROMPTS[type]] * len(images)
        processed_images = pipe(images, prompts, device=device)

        for name, processed_image in zip(names, processed_images):
            processed_image.save(f"{output}/{name}")


if __name__ == "__main__":
    sdedit()