File size: 1,893 Bytes
ea2f505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87470f9
ea2f505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from diffusers import (
    StableDiffusionControlNetImg2ImgPipeline,
    ControlNetModel,
    DDIMScheduler,
)

from PIL import Image


class QRControlNet:

    def __init__(self, device: str = "cuda"):

        torch_dtype = torch.float16 if device == "cuda" else torch.float32
        controlnet = ControlNetModel.from_pretrained(
            "DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch_dtype
        )

        pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            # safety_checker=None,
            torch_dtype=torch_dtype,
        ).to(device)

        if device == "cuda":
            pipe.enable_xformers_memory_efficient_attention()
            pipe.enable_model_cpu_offload()

        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        self.pipe = pipe

    def generate_image(
        self,
        source_image: Image,
        control_image: Image,
        prompt: str,
        negative_prompt: str,
        img_size=512,
        num_inference_steps: int = 50,
        guidance_scale: int = 20,
        controlnet_conditioning_scale: float = 3.0,
        strength=0.9,
        seed=42,
        **kwargs
    ):

        width = height = img_size
        generator = torch.manual_seed(seed)

        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=source_image,
            control_image=control_image,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=controlnet_conditioning_scale,  # 3.0,
            generator=generator,
            strength=strength,
            num_inference_steps=num_inference_steps,
            **kwargs
        )

        return image.images[0]