Spaces:
Runtime error
Runtime error
RoniFinTech
commited on
Commit
β’
6bae932
1
Parent(s):
6038044
structure
Browse files- config.py +15 -0
- main.py +12 -63
- requirements.txt +1 -2
- {stable_diffusion β routers}/__init__.py +0 -0
- routers/intference/__init__.py +0 -0
- routers/intference/stable_diffusion.py +69 -0
config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class Settings(BaseModel):
|
7 |
+
hf_token: str = os.environ.get("hf_token")
|
8 |
+
base_sd_model: str = os.environ.get("base_sd_model", "stabilityai/stable-diffusion-xl-base-1.0")
|
9 |
+
refiner_sd_model: str = os.environ.get("refiner_sd_model", "stabilityai/stable-diffusion-xl-refiner-1.0")
|
10 |
+
version: str = "0.1.0"
|
11 |
+
url_version: str = "v1"
|
12 |
+
prefix: str = "v1/unik-ml"
|
13 |
+
|
14 |
+
|
15 |
+
settings = Settings()
|
main.py
CHANGED
@@ -1,33 +1,19 @@
|
|
1 |
-
from io import BytesIO
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from diffusers import DiffusionPipeline
|
5 |
from fastapi import FastAPI
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
-
from
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
12 |
-
)
|
13 |
|
14 |
-
|
15 |
-
# base.enable_model_cpu_offload()
|
16 |
-
base.enable_attention_slicing()
|
17 |
-
refiner = DiffusionPipeline.from_pretrained(
|
18 |
-
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
19 |
-
text_encoder_2=base.text_encoder_2,
|
20 |
-
vae=base.vae,
|
21 |
-
torch_dtype=torch.float16,
|
22 |
-
use_safetensors=True,
|
23 |
-
variant="fp16",
|
24 |
-
)
|
25 |
-
refiner.to("cuda")
|
26 |
-
# refiner.enable_model_cpu_offload()
|
27 |
-
refiner.enable_attention_slicing()
|
28 |
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
app.add_middleware(
|
33 |
CORSMiddleware,
|
@@ -43,41 +29,4 @@ async def root():
|
|
43 |
return {"message": "UNIK ML API"}
|
44 |
|
45 |
|
46 |
-
|
47 |
-
async def generate(text: str):
|
48 |
-
"""
|
49 |
-
generate image
|
50 |
-
"""
|
51 |
-
# Define how many steps and what % of steps to be run on each experts (80/20) here
|
52 |
-
n_steps = 40
|
53 |
-
high_noise_frac = 0.8
|
54 |
-
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly. bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, mutated hands and fingers, watermark, watermarked, oversaturated, censored, distorted hands, amputation, missing hands, obese, doubled face, double hands, two women, anime style, cartoon, toon."
|
55 |
-
prompt = "Designs should play with different textures and layering but stick to a monochrome palette. Think leather jackets over mesh tops, or satin draped over matte cotton. in a studio. zoomed-in. single model."
|
56 |
-
|
57 |
-
# run both experts
|
58 |
-
image = base(
|
59 |
-
prompt=prompt,
|
60 |
-
negative_prompt=negative,
|
61 |
-
num_inference_steps=n_steps,
|
62 |
-
denoising_end=high_noise_frac,
|
63 |
-
output_type="latent",
|
64 |
-
).images
|
65 |
-
final_image = refiner(
|
66 |
-
prompt=prompt,
|
67 |
-
negative_prompt=negative,
|
68 |
-
num_inference_steps=n_steps,
|
69 |
-
denoising_start=high_noise_frac,
|
70 |
-
image=image,
|
71 |
-
).images[0]
|
72 |
-
|
73 |
-
# buffer = BytesIO()
|
74 |
-
# final_image.save(buffer, format="PNG")
|
75 |
-
# image_bytes = buffer.getvalue()
|
76 |
-
#
|
77 |
-
# return StreamingResponse(BytesIO(image_bytes), media_type="image/png")
|
78 |
-
#
|
79 |
-
memory_stream = BytesIO()
|
80 |
-
final_image.save(memory_stream, format="PNG")
|
81 |
-
memory_stream.seek(0)
|
82 |
-
return StreamingResponse(memory_stream, media_type="image/png")
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from huggingface_hub import login
|
4 |
|
5 |
+
from config import settings
|
6 |
+
from routers.intference import stable_diffusion
|
|
|
|
|
7 |
|
8 |
+
login(settings.hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
app = FastAPI(
|
11 |
+
title="UNIK ML",
|
12 |
+
version=settings.version,
|
13 |
+
openapi_url=f"{settings.prefix}/openapi.json",
|
14 |
+
docs_url=f"{settings.prefix}/docs",
|
15 |
+
redoc_url=f"{settings.prefix}/redoc",
|
16 |
+
swagger_ui_oauth2_redirect_url=f"{settings.prefix}/docs/oauth2-redirect")
|
17 |
|
18 |
app.add_middleware(
|
19 |
CORSMiddleware,
|
|
|
29 |
return {"message": "UNIK ML API"}
|
30 |
|
31 |
|
32 |
+
app.include_router(stable_diffusion.router, prefix=settings.prefix, tags=["Inference", "sd"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -9,5 +9,4 @@ accelerate==0.21.0
|
|
9 |
diffusers==0.19.3
|
10 |
torchvision==0.15.2
|
11 |
safetensors==0.3.1
|
12 |
-
|
13 |
-
# opencv-python-headless==4.8.0.74
|
|
|
9 |
diffusers==0.19.3
|
10 |
torchvision==0.15.2
|
11 |
safetensors==0.3.1
|
12 |
+
huggingface-hub==0.16.4
|
|
{stable_diffusion β routers}/__init__.py
RENAMED
File without changes
|
routers/intference/__init__.py
ADDED
File without changes
|
routers/intference/stable_diffusion.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# load both base & refiner
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
from fastapi import APIRouter
|
7 |
+
from fastapi.responses import StreamingResponse
|
8 |
+
|
9 |
+
from config import settings
|
10 |
+
|
11 |
+
router = APIRouter()
|
12 |
+
|
13 |
+
base = DiffusionPipeline.from_pretrained(
|
14 |
+
settings.base_sd_model, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
15 |
+
)
|
16 |
+
|
17 |
+
base.to("cuda")
|
18 |
+
# base.enable_model_cpu_offload()
|
19 |
+
base.enable_attention_slicing()
|
20 |
+
refiner = DiffusionPipeline.from_pretrained(
|
21 |
+
settings.refiner_sd_model,
|
22 |
+
text_encoder_2=base.text_encoder_2,
|
23 |
+
vae=base.vae,
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
use_safetensors=True,
|
26 |
+
variant="fp16",
|
27 |
+
)
|
28 |
+
refiner.to("cuda")
|
29 |
+
# refiner.enable_model_cpu_offload()
|
30 |
+
refiner.enable_attention_slicing()
|
31 |
+
|
32 |
+
|
33 |
+
@router.get("/generate")
|
34 |
+
async def generate(prompt: str):
|
35 |
+
"""
|
36 |
+
generate image
|
37 |
+
"""
|
38 |
+
# Define how many steps and what % of steps to be run on each experts (80/20) here
|
39 |
+
n_steps = 40
|
40 |
+
high_noise_frac = 0.8
|
41 |
+
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly. bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, mutated hands and fingers, watermark, watermarked, oversaturated, censored, distorted hands, amputation, missing hands, obese, doubled face, double hands, two women, anime style, cartoon, toon."
|
42 |
+
# prompt = "Designs should play with different textures and layering but stick to a monochrome palette. Think leather jackets over mesh tops, or satin draped over matte cotton. in a studio. zoomed-in. single model."
|
43 |
+
|
44 |
+
# run both experts
|
45 |
+
image = base(
|
46 |
+
prompt=prompt,
|
47 |
+
negative_prompt=negative,
|
48 |
+
num_inference_steps=n_steps,
|
49 |
+
denoising_end=high_noise_frac,
|
50 |
+
output_type="latent",
|
51 |
+
).images
|
52 |
+
final_image = refiner(
|
53 |
+
prompt=prompt,
|
54 |
+
negative_prompt=negative,
|
55 |
+
num_inference_steps=n_steps,
|
56 |
+
denoising_start=high_noise_frac,
|
57 |
+
image=image,
|
58 |
+
).images[0]
|
59 |
+
|
60 |
+
# buffer = BytesIO()
|
61 |
+
# final_image.save(buffer, format="PNG")
|
62 |
+
# image_bytes = buffer.getvalue()
|
63 |
+
#
|
64 |
+
# return StreamingResponse(BytesIO(image_bytes), media_type="image/png")
|
65 |
+
#
|
66 |
+
memory_stream = BytesIO()
|
67 |
+
final_image.save(memory_stream, format="PNG")
|
68 |
+
memory_stream.seek(0)
|
69 |
+
return StreamingResponse(memory_stream, media_type="image/png")
|