|
import argparse |
|
import base64 |
|
from io import BytesIO |
|
|
|
from PIL import Image |
|
|
|
from handler import EndpointHandler, decode_base64_image |
|
|
|
|
|
def local_predict(prompts, encode_image): |
|
|
|
my_handler = EndpointHandler() |
|
if encode_image: |
|
response = my_handler({"inputs": prompts, "image": encode_image}) |
|
else: |
|
response = my_handler({"inputs": prompts}) |
|
|
|
image = decode_base64_image(response["image"]) |
|
image.save("local_output.png") |
|
|
|
|
|
opt = argparse.ArgumentParser("Diffuser local test") |
|
opt.add_argument("-prompts", "--prompts", default="", type=str, help="Diffuser prompts") |
|
opt.add_argument("-image", "--image", default="", type=str, help="Init image") |
|
if __name__ == '__main__': |
|
args = opt.parse_args() |
|
|
|
encoded_string = "" |
|
if args.image: |
|
with open(args.image, "rb") as image_file: |
|
encoded_string = base64.b64encode(image_file.read()).decode() |
|
|
|
local_predict(args.prompts, encoded_string) |
|
|