nyanko7 commited on
Commit
3937dd0
1 Parent(s): 0a5c46d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import gradio as gr
3
+ import requests
4
+ import random
5
+ import io
6
+ import zipfile
7
+ from PIL import Image
8
+ import os
9
+ import numpy as np
10
+ import json
11
+ import boto3
12
+
13
+
14
+ # Create an S3 client
15
+ s3 = boto3.client('s3')
16
+
17
+ def save_to_s3(image_data, payload, file_name):
18
+ # Define the bucket and the path
19
+ bucket_name = 'dataset-novelai'
20
+ folder_name = datetime.datetime.now().strftime("%Y-%m-%d")
21
+ image_key = f'gradio/{folder_name}/{file_name}.webp'
22
+ payload_key = f'gradio/{folder_name}/{file_name}.json'
23
+
24
+ # Save the image
25
+ image_data.seek(0) # Go to the start of the BytesIO object
26
+ s3.upload_fileobj(image_data, bucket_name, image_key, ExtraArgs={'ContentType': 'image/webp'})
27
+
28
+ # Save the payload
29
+ payload_data = io.BytesIO(payload.encode('utf-8'))
30
+ s3.upload_fileobj(payload_data, bucket_name, payload_key, ExtraArgs={'ContentType': 'application/json'})
31
+
32
+
33
+ # Function to handle the NovelAI API request
34
+ def generate_novelai_image(input_text, quality_tags, seed, negative_prompt, scale, sampler):
35
+ jwt_token = os.environ.get('NAI_API_KEY')
36
+
37
+ # Check if quality tags are provided and append to input
38
+ final_input = input_text
39
+ if quality_tags:
40
+ final_input += ", " + quality_tags
41
+
42
+ # Assign a random seed if seed is -1
43
+ if seed == -1:
44
+ seed = random.randint(0, 2**32 - 1)
45
+
46
+ # Define the API URL
47
+ url = "https://api.novelai.net/ai/generate-image"
48
+
49
+ # Set the headers
50
+ headers = {
51
+ "Authorization": f"Bearer {jwt_token}",
52
+ "Content-Type": "application/json",
53
+ "Origin": "https://novelai.net",
54
+ "Referer": "https://novelai.net/"
55
+ }
56
+
57
+ # Define the payload
58
+ payload = {
59
+ "action": "generate",
60
+ "input": final_input,
61
+ "model": "nai-diffusion-3",
62
+ "parameters": {
63
+ "width": 832,
64
+ "height": 1216,
65
+ "scale": scale,
66
+ "sampler": sampler,
67
+ "steps": 28,
68
+ "n_samples": 1,
69
+ "ucPreset": 0,
70
+ "add_original_image": False,
71
+ "cfg_rescale": 0,
72
+ "controlnet_strength": 1,
73
+ "dynamic_thresholding": False,
74
+ "legacy": False,
75
+ "negative_prompt": negative_prompt,
76
+ "noise_schedule": "native",
77
+ "qualityToggle": True,
78
+ "scale": 5,
79
+ "seed": seed,
80
+ "sm": False,
81
+ "sm_dyn": False,
82
+ "steps": 28,
83
+ "ucPreset": 0,
84
+ "uncond_scale": 1,
85
+ "width": 832
86
+ }
87
+ }
88
+
89
+ # Send the POST request
90
+ response = requests.post(url, json=payload, headers=headers)
91
+
92
+ # Process the response
93
+ if response.headers.get('Content-Type') == 'application/x-zip-compressed':
94
+ zipfile_in_memory = io.BytesIO(response.content)
95
+ with zipfile.ZipFile(zipfile_in_memory, 'r') as zip_ref:
96
+ file_names = zip_ref.namelist()
97
+ if file_names:
98
+ with zip_ref.open(file_names[0]) as file:
99
+ image = Image.open(file)
100
+
101
+ # Prepare to save the image to S3
102
+ buffered = io.BytesIO()
103
+ image.save(buffered, format="WEBP", quality=98)
104
+ file_name = str(int(datetime.datetime.now().timestamp()))
105
+ save_to_s3(buffered, json.dumps(payload, indent=4), file_name)
106
+
107
+ return np.array(image), json.dumps(payload, indent=4)
108
+
109
+ else:
110
+ return "No images found in the zip file.", json.dumps(payload, indent=4)
111
+ else:
112
+ return "The response is not a zip file.", json.dumps(payload, indent=4)
113
+
114
+ # Create Gradio interface
115
+ iface = gr.Interface(
116
+ fn=generate_novelai_image,
117
+ inputs=[
118
+ gr.Textbox(label="Input Text"),
119
+ gr.Textbox(label="Quality Tags", value="best quality, amazing quality, very aesthetic, absurdres"),
120
+ gr.Slider(minimum=-1, maximum=2**32 - 1, step=1, value=-1, label="Seed"),
121
+ gr.Textbox(label="Negative Prompt", value="nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"),
122
+ gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Scale"),
123
+ gr.Dropdown(
124
+ choices=[
125
+ "k_euler", "k_euler_ancestral", "k_dpmpp_2s_ancestral",
126
+ "k_dpmpp_2m", "k_dpmpp_sde", "ddim_v3"
127
+ ],
128
+ value="k_euler",
129
+ label="Sampler"
130
+ )
131
+ ],
132
+ outputs=[
133
+ "image",
134
+ gr.Textbox(label="Submitted Payload")
135
+ ]
136
+ )
137
+
138
+ try:
139
+ iface.launch(share=True)
140
+ except RuntimeError: # use in HF spaces
141
+ iface.launch()