rayochoajr's picture
Update app.py
e9f7e38 verified
raw
history blame
6.23 kB
import subprocess
import threading
import gradio as gr
import websocket
import uuid
import json
import urllib.request
import urllib.parse
from PIL import Image
import io
# πŸš€ Chapter 1: Install Necessary Packages πŸš€
def install_packages():
packages = [
"gradio",
"websocket-client",
"pillow"
]
for package in packages:
subprocess.check_call(["pip", "install", package])
# Use threading to run the installation in the background
install_thread = threading.Thread(target=install_packages)
install_thread.start()
install_thread.join()
# 🌟 Chapter 2: Generate Client ID 🌟
client_id = str(uuid.uuid4())
# 🌟 Chapter 3: Queue Prompt Function 🌟
def queue_prompt(prompt, server_address):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{server_address}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
# 🌟 Chapter 4: Get Image Function 🌟
def get_image(filename, subfolder, folder_type, server_address):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response:
return response.read()
# 🌟 Chapter 5: Get History Function 🌟
def get_history(prompt_id, server_address):
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response:
return json.loads(response.read())
# 🌟 Chapter 6: Get Images Function 🌟
def get_images(ws, prompt, server_address):
prompt_id = queue_prompt(prompt, server_address)['prompt_id']
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
# 🌟 Chapter 7: Generate Image Function 🌟
def generate_image(text_prompt, seed, server):
prompt_text = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": 8566257,
"steps": 8
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": 512,
"width": 768
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "bad hands"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"save_image_websocket_node": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
prompt = json.loads(prompt_text)
prompt["6"]["inputs"]["text"] = text_prompt
prompt["3"]["inputs"]["seed"] = seed
server_address = "3.14.144.23:8188" if server == "AWS Server" else "192.168.50.136:8188"
ws = websocket.WebSocket()
ws.connect(f"ws://{server_address}/ws?clientId={client_id}")
images = get_images(ws, prompt, server_address)
image = None
for node_id in images:
for image_data in images[node_id]:
image = Image.open(io.BytesIO(image_data))
break
if image:
break
return image
# 🌟 Chapter 8: Cancel Request Function 🌟
def cancel_request():
return "Request Cancelled"
# 🌟 Chapter 9: Gradio Interface 🌟
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with Websockets API")
gr.Markdown("Generate images using a Websockets API and SaveImageWebsocket node.")
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Text Prompt", value="masterpiece best quality man")
seed = gr.Number(label="Seed", value=5)
server = gr.Radio(label="Server", choices=["AWS Server", "Home Server"], value="AWS Server")
generate_button = gr.Button("Generate Image")
cancel_button = gr.Button("Cancel Request")
with gr.Column():
output_image = gr.Image(label="Generated Image")
generate_button.click(fn=generate_image, inputs=[text_prompt, seed, server], outputs=output_image)
cancel_button.click(fn=cancel_request, inputs=[], outputs=[])
demo.launch()