Rojban commited on
Commit
621523a
Β·
1 Parent(s): 516254f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -1,6 +1,42 @@
1
  import gradio as gr
2
- import os
 
 
3
 
 
4
  token = os.getenv("token")
 
5
 
6
- gr.load("models/Rojban/AutoTrain_Dreambooth3", hf_token=token).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
4
+ import torch
5
 
6
+ # Load your custom Stable Diffusion model
7
  token = os.getenv("token")
8
+ model = gr.load("models/Rojban/AutoTrain_Dreambooth3", hf_token=token)
9
 
10
+ # Define the image generation function
11
+ def generate_image(prompt, seed=42):
12
+ model_name = "stabilityai/stable-diffusion-xl-base-1.0"
13
+ pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
14
+ pipe.to("cuda")
15
+
16
+ refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
17
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
18
+ torch_dtype=torch.float16,
19
+ )
20
+ refiner.to("cuda")
21
+
22
+ generator = torch.Generator("cuda").manual_seed(seed)
23
+ image = pipe(prompt=prompt, generator=generator).images[0]
24
+ image = refiner(prompt=prompt, generator=generator, image=image).images[0]
25
+
26
+ # Save and return the image
27
+ image_path = "generated_image.png"
28
+ image.save(image_path)
29
+ return image_path
30
+
31
+ # Create the Gradio interface
32
+ interface = gr.Interface(
33
+ fn=generate_image,
34
+ inputs=[gr.Textbox(label="Prompt"), gr.Number(label="Seed", default=42)],
35
+ outputs=gr.Image(type="file"),
36
+ title="Custom Stable Diffusion Model",
37
+ description="Generate images using a custom Stable Diffusion model."
38
+ )
39
+
40
+ # Launch the app
41
+ if __name__ == "__main__":
42
+ interface.launch()