ombhojane commited on
Commit
198ad13
·
verified ·
1 Parent(s): abc3fc2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from transformers import pipeline
5
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
6
+ from diffusers.utils import load_image, make_image_grid
7
+ from PIL import Image
8
+ import requests
9
+ from io import BytesIO
10
+
11
+ # Initialize the depth estimator
12
+ depth_estimator = pipeline("depth-estimation")
13
+
14
+ # Function to load an image from a URL
15
+ def load_image_from_url(url):
16
+ response = requests.get(url)
17
+ img = Image.open(BytesIO(response.content))
18
+ return img
19
+
20
+ # Function to get depth map
21
+ def get_depth_map(image, depth_estimator):
22
+ image = depth_estimator(image)["depth"]
23
+ image = np.array(image)
24
+ image = image[:, :, None]
25
+ image = np.concatenate([image, image, image], axis=2)
26
+ detected_map = torch.from_numpy(image).float() / 255.0
27
+ depth_map = detected_map.permute(2, 0, 1)
28
+ return depth_map
29
+
30
+ # Streamlit UI
31
+ st.title("Image Modification with ControlNet and Stable Diffusion")
32
+
33
+ # User inputs
34
+ image_url = st.text_input("Enter the URL of a farm image:", "")
35
+ prompt = st.text_input("Enter your prompt:", "vineyard agrotourism service on the farm")
36
+
37
+ if st.button("Generate"):
38
+ if image_url:
39
+ # Load the image
40
+ farm_image = load_image_from_url(image_url)
41
+
42
+ # Process image for depth map
43
+ depth_map = get_depth_map(farm_image, depth_estimator).unsqueeze(0).half().to("cuda")
44
+
45
+ # Load the ControlNet model and the StableDiffusionControlNetImg2ImgPipeline
46
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-normal", torch_dtype=torch.float16, use_safetensors=True)
47
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
48
+ "runwayml/stable-diffusion-v1-5",
49
+ controlnet=controlnet,
50
+ torch_dtype=torch.float16,
51
+ use_safetensors=True
52
+ ).to("cuda")
53
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
54
+ pipe.enable_model_cpu_offload()
55
+
56
+ # Generate the image
57
+ output = pipe(
58
+ prompt,
59
+ image=farm_image,
60
+ control_image=depth_map,
61
+ ).images[0]
62
+
63
+ # Convert PIL images to display in Streamlit
64
+ st.image(farm_image, caption="Original Image")
65
+ st.image(output, caption="Generated Image")
66
+ else:
67
+ st.write("Please enter an image URL.")