Spaces:
Runtime error
Runtime error
import subprocess | |
import threading | |
import gradio as gr | |
import websocket | |
import uuid | |
import json | |
import urllib.request | |
import urllib.parse | |
from PIL import Image | |
import io | |
# π Chapter 1: Install Necessary Packages π | |
def install_packages(): | |
packages = [ | |
"gradio", | |
"websocket-client", | |
"pillow" | |
] | |
for package in packages: | |
subprocess.check_call(["pip", "install", package]) | |
# Use threading to run the installation in the background | |
install_thread = threading.Thread(target=install_packages) | |
install_thread.start() | |
install_thread.join() | |
# π Chapter 2: Generate Client ID π | |
client_id = str(uuid.uuid4()) | |
# π Chapter 3: Queue Prompt Function π | |
def queue_prompt(prompt, server_address): | |
p = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(p).encode('utf-8') | |
req = urllib.request.Request(f"http://{server_address}/prompt", data=data) | |
return json.loads(urllib.request.urlopen(req).read()) | |
# π Chapter 4: Get Image Function π | |
def get_image(filename, subfolder, folder_type, server_address): | |
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
url_values = urllib.parse.urlencode(data) | |
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response: | |
return response.read() | |
# π Chapter 5: Get History Function π | |
def get_history(prompt_id, server_address): | |
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response: | |
return json.loads(response.read()) | |
# π Chapter 6: Get Images Function π | |
def get_images(ws, prompt, server_address): | |
prompt_id = queue_prompt(prompt, server_address)['prompt_id'] | |
output_images = {} | |
current_node = "" | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'executing': | |
data = message['data'] | |
if data['prompt_id'] == prompt_id: | |
if data['node'] is None: | |
break | |
else: | |
current_node = data['node'] | |
else: | |
if current_node == 'save_image_websocket_node': | |
images_output = output_images.get(current_node, []) | |
images_output.append(out[8:]) | |
output_images[current_node] = images_output | |
return output_images | |
# π Chapter 7: Generate Image Function π | |
def generate_image(text_prompt, seed, server): | |
prompt_text = """ | |
{ | |
"3": { | |
"class_type": "KSampler", | |
"inputs": { | |
"cfg": 8, | |
"denoise": 1, | |
"latent_image": [ | |
"5", | |
0 | |
], | |
"model": [ | |
"4", | |
0 | |
], | |
"negative": [ | |
"7", | |
0 | |
], | |
"positive": [ | |
"6", | |
0 | |
], | |
"sampler_name": "euler", | |
"scheduler": "normal", | |
"seed": 8566257, | |
"steps": 8 | |
} | |
}, | |
"4": { | |
"class_type": "CheckpointLoaderSimple", | |
"inputs": { | |
"ckpt_name": "v1-5-pruned-emaonly.ckpt" | |
} | |
}, | |
"5": { | |
"class_type": "EmptyLatentImage", | |
"inputs": { | |
"batch_size": 1, | |
"height": 512, | |
"width": 768 | |
} | |
}, | |
"6": { | |
"class_type": "CLIPTextEncode", | |
"inputs": { | |
"clip": [ | |
"4", | |
1 | |
], | |
"text": "masterpiece best quality girl" | |
} | |
}, | |
"7": { | |
"class_type": "CLIPTextEncode", | |
"inputs": { | |
"clip": [ | |
"4", | |
1 | |
], | |
"text": "bad hands" | |
} | |
}, | |
"8": { | |
"class_type": "VAEDecode", | |
"inputs": { | |
"samples": [ | |
"3", | |
0 | |
], | |
"vae": [ | |
"4", | |
2 | |
] | |
} | |
}, | |
"save_image_websocket_node": { | |
"class_type": "SaveImageWebsocket", | |
"inputs": { | |
"images": [ | |
"8", | |
0 | |
] | |
} | |
} | |
} | |
""" | |
prompt = json.loads(prompt_text) | |
prompt["6"]["inputs"]["text"] = text_prompt | |
prompt["3"]["inputs"]["seed"] = seed | |
server_address = "3.14.144.23:8188" if server == "AWS Server" else "192.168.50.136:8188" | |
ws = websocket.WebSocket() | |
ws.connect(f"ws://{server_address}/ws?clientId={client_id}") | |
images = get_images(ws, prompt, server_address) | |
image = None | |
for node_id in images: | |
for image_data in images[node_id]: | |
image = Image.open(io.BytesIO(image_data)) | |
break | |
if image: | |
break | |
return image | |
# π Chapter 8: Cancel Request Function π | |
def cancel_request(): | |
return "Request Cancelled" | |
# π Chapter 9: Gradio Interface π | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Generation with Websockets API") | |
gr.Markdown("Generate images using a Websockets API and SaveImageWebsocket node.") | |
with gr.Row(): | |
with gr.Column(): | |
text_prompt = gr.Textbox(label="Text Prompt", value="masterpiece best quality man") | |
seed = gr.Number(label="Seed", value=5) | |
server = gr.Radio(label="Server", choices=["AWS Server", "Home Server"], value="AWS Server") | |
generate_button = gr.Button("Generate Image") | |
cancel_button = gr.Button("Cancel Request") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image") | |
generate_button.click(fn=generate_image, inputs=[text_prompt, seed, server], outputs=output_image) | |
cancel_button.click(fn=cancel_request, inputs=[], outputs=[]) | |
demo.launch() | |