File size: 1,967 Bytes
6bae932
 
 
 
 
 
 
 
6a0af53
6bae932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0af53
6bae932
 
 
 
 
 
 
4598c6c
1870c21
8e97e43
af2cfcd
6bae932
 
 
 
 
 
 
 
7963de3
6bae932
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# load both base & refiner
from io import BytesIO

import torch
from diffusers import DiffusionPipeline
from fastapi import APIRouter
from fastapi.responses import StreamingResponse

from cache.local_cache import ttl_cache
from config import settings

router = APIRouter()

base = DiffusionPipeline.from_pretrained(
    settings.base_sd_model, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)

base.to("cuda")
# base.enable_model_cpu_offload()
base.enable_attention_slicing()
refiner = DiffusionPipeline.from_pretrained(
    settings.refiner_sd_model,
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)
refiner.to("cuda")
# refiner.enable_model_cpu_offload()
refiner.enable_attention_slicing()


@router.get("/generate")
@ttl_cache(key_name='prompt', ttl_secs=20)
async def generate(prompt: str):
    """
    generate image
    """
    # Define how many steps and what % of steps to be run on each experts (80/20) here
    n_steps = 40
    high_noise_frac = 0.8
    negative = "disfigured, ugly, bad, immature, cartoon, anime, 3d, painting, b&w, sketch, blurry, deformed, bad anatomy, poorly drawn face, mutation, multiple people."

    prompt = f"single image. single model. {prompt}. zoomed in. full-body. real person. realistic. 4k. best quality."
    print(prompt)

    # run both experts
    image = base(
        prompt=prompt,
        negative_prompt=negative,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
    ).images[0]
    final_image = refiner(
        prompt=prompt,
        negative_prompt=negative,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
    ).images[0]

    memory_stream = BytesIO()
    final_image.save(memory_stream, format="PNG")
    memory_stream.seek(0)
    return StreamingResponse(memory_stream, media_type="image/png")