|
from flask import Flask, request, jsonify, Response, stream_with_context |
|
import requests |
|
import json |
|
import time |
|
import random |
|
import logging |
|
import sys |
|
import re |
|
from logging.handlers import TimedRotatingFileHandler |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
class RequestFormatter(logging.Formatter): |
|
def format(self, record): |
|
if request.method in ['POST', 'GET']: |
|
record.url = request.url |
|
record.remote_addr = request.remote_addr |
|
record.token = request.headers.get('Authorization', 'No Token') |
|
return super().format(record) |
|
return None |
|
|
|
formatter = RequestFormatter( |
|
'%(remote_addr)s - - [%(asctime)s] - Token: %(token)s - %(message)s', |
|
datefmt='%d/%b/%Y %H:%M:%S' |
|
) |
|
|
|
handler = TimedRotatingFileHandler('app.log', when="midnight", interval=1, backupCount=30) |
|
handler.setFormatter(formatter) |
|
handler.setLevel(logging.INFO) |
|
|
|
app.logger.addHandler(handler) |
|
app.logger.setLevel(logging.INFO) |
|
|
|
|
|
MODEL_MAPPING = { |
|
"flux.1-schnell": { |
|
"provider": "black-forest-labs", |
|
"mapping": "Pro/black-forest-labs/FLUX.1-schnell" |
|
}, |
|
"flux.1-dev": { |
|
"provider": "black-forest-labs", |
|
"mapping": "black-forest-labs/FLUX.1-dev" |
|
}, |
|
"stable-diffusion-3-5-large": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-3-5-large" |
|
}, |
|
"stable-diffusion-2-1": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-2-1" |
|
}, |
|
"stable-diffusion-3-medium": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-3-medium" |
|
}, |
|
"stable-diffusion-xl-base-1.0": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-xl-base-1.0" |
|
} |
|
} |
|
|
|
|
|
def getAuthCookie(req): |
|
auth_cookie = req.headers.get('Authorization') |
|
if auth_cookie and auth_cookie.startswith('Bearer '): |
|
return auth_cookie |
|
return None |
|
|
|
@app.route('/') |
|
def index(): |
|
usage = """ |
|
<html> |
|
<head> |
|
<title>Text-to-Image API with SiliconFlow</title> |
|
<style> |
|
body { font-family: Arial, sans-serif; line-height: 1.6; padding: 20px; max-width: 800px; margin: 0 auto; } |
|
h1 { color: #333; } |
|
h2 { color: #666; } |
|
pre { background-color: #f4f4f4; padding: 10px; border-radius: 5px; } |
|
code { font-family: Consolas, monospace; } |
|
</style> |
|
</head> |
|
<body> |
|
<h1>Welcome to the Text-to-Image API with SiliconFlow!</h1> |
|
|
|
<h2>Usage:</h2> |
|
<ol> |
|
<li>Send a POST request to <code>/ai/v1/chat/completions</code></li> |
|
<li>Include your prompt in the 'content' field of the last message</li> |
|
<li>Optional parameters: |
|
<ul> |
|
<li><code>-s <ratio></code>: Set image size ratio (e.g., -s 1:1, -s 16:9)</li> |
|
<li><code>-o</code>: Use original prompt without enhancement</li> |
|
</ul> |
|
</li> |
|
</ol> |
|
|
|
<h2>Example Request:</h2> |
|
<pre><code> |
|
{ |
|
"model": "flux", |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": "A beautiful landscape -s 16:9" |
|
} |
|
] |
|
} |
|
</code></pre> |
|
|
|
<p>For more details, please refer to the API documentation.</p> |
|
</body> |
|
</html> |
|
""" |
|
return usage, 200 |
|
|
|
@app.route('/ai/v1/models', methods=['GET']) |
|
def get_models(): |
|
try: |
|
|
|
auth_cookie = getAuthCookie(request) |
|
if not auth_cookie: |
|
app.logger.info(f'GET /ai/v1/models - 401 Unauthorized') |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
|
|
models_list = [ |
|
{ |
|
"id": model_id, |
|
"object": "model", |
|
"created": int(time.time()), |
|
"owned_by": info["provider"], |
|
"permission": [], |
|
"root": model_id, |
|
"parent": None |
|
} |
|
for model_id, info in MODEL_MAPPING.items() |
|
] |
|
|
|
|
|
app.logger.info(f'GET /ai/v1/models - 200 OK') |
|
|
|
return jsonify({ |
|
"object": "list", |
|
"data": models_list |
|
}) |
|
|
|
except Exception as error: |
|
app.logger.error(f"Error: {str(error)}") |
|
return jsonify({"error": "Authentication failed", "details": str(error)}), 401 |
|
|
|
@app.route('/ai/v1/chat/completions', methods=['POST']) |
|
def handle_request(): |
|
try: |
|
body = request.json |
|
model = body.get('model') |
|
messages = body.get('messages') |
|
stream = body.get('stream', False) |
|
if not model or not messages or len(messages) == 0: |
|
app.logger.info(f"POST /ai/v1/chat/completions - Status: 400 - Bad Request - Missing required fields") |
|
return jsonify({"error": "Bad Request: Missing required fields"}), 400 |
|
|
|
|
|
if model in MODEL_MAPPING: |
|
mapped_model = MODEL_MAPPING[model]['mapping'] |
|
else: |
|
app.logger.info(f"POST /ai/v1/chat/completions - Status: 400 - Bad Request - Model '{model}' not found") |
|
return jsonify({"error": f"Model '{model}' not found"}), 400 |
|
|
|
prompt = messages[-1]['content'] |
|
image_size, clean_prompt, use_original, size_param = extract_params_from_prompt(prompt) |
|
|
|
auth_header = request.headers.get('Authorization') |
|
random_token = get_random_token(auth_header) |
|
if not random_token: |
|
app.logger.info(f"POST /ai/v1/chat/completions - Status: 401 - Unauthorized - Invalid or missing Authorization header") |
|
return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401 |
|
|
|
if use_original: |
|
enhanced_prompt = clean_prompt |
|
else: |
|
enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token) |
|
|
|
new_url = 'https://api.siliconflow.cn/v1/images/generations' |
|
new_request_body = { |
|
"model": mapped_model, |
|
"prompt": enhanced_prompt, |
|
"negative_prompt": "", |
|
"image_size": image_size, |
|
"batch_size": 1, |
|
"seed": random.randint(0, 4999999999), |
|
"num_inference_steps": 20, |
|
"guidance_scale": 7.5 |
|
} |
|
|
|
headers = { |
|
'Authorization': f'Bearer {random_token}', |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60) |
|
response.raise_for_status() |
|
response_body = response.json() |
|
|
|
if 'data' in response_body and response_body['data'] and 'url' in response_body['data'][0]: |
|
image_url = response_body['data'][0]['url'] |
|
else: |
|
raise ValueError("Unexpected response structure from image generation API") |
|
|
|
unique_id = str(int(time.time() * 1000)) |
|
current_timestamp = int(time.time()) |
|
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9)) |
|
|
|
image_data = {'data': [{'url': image_url}]} |
|
|
|
|
|
params = [] |
|
if size_param != "16:9": |
|
params.append(f"-s {size_param}") |
|
if use_original: |
|
params.append("-o") |
|
params_str = " ".join(params) if params else "no params" |
|
|
|
app.logger.info(f'POST /ai/v1/chat/completions - Status: 200 - Token: {random_token} - Model: {mapped_model} - Params: {params_str} - Image URL: {image_url}') |
|
|
|
if stream: |
|
return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original) |
|
else: |
|
return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original) |
|
except Exception as e: |
|
app.logger.error(f"Error: {str(e)}") |
|
return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500 |
|
|
|
def extract_params_from_prompt(prompt): |
|
size_match = re.search(r'-s\s+(\S+)', prompt) |
|
original_match = re.search(r'-o', prompt) |
|
|
|
if size_match: |
|
size = size_match.group(1) |
|
clean_prompt = re.sub(r'-s\s+\S+', '', prompt).strip() |
|
else: |
|
size = "16:9" |
|
clean_prompt = prompt |
|
|
|
use_original = bool(original_match) |
|
if use_original: |
|
clean_prompt = re.sub(r'-o', '', clean_prompt).strip() |
|
|
|
image_size = RATIO_MAP.get(size, RATIO_MAP["16:9"]) |
|
return image_size, clean_prompt, use_original, size |
|
|
|
def get_random_token(auth_header): |
|
if not auth_header: |
|
return None |
|
if auth_header.startswith('Bearer '): |
|
auth_header = auth_header[7:] |
|
tokens = [token.strip() for token in auth_header.split(',') if token.strip()] |
|
if not tokens: |
|
return None |
|
return random.choice(tokens) |
|
|
|
def translate_and_enhance_prompt(prompt, auth_token): |
|
translate_url = 'https://api.siliconflow.cn/v1/chat/completions' |
|
translate_body = { |
|
'model': 'Qwen/Qwen2-72B-Instruct', |
|
'messages': [ |
|
{'role': 'system', 'content': SYSTEM_ASSISTANT}, |
|
{'role': 'user', 'content': prompt} |
|
] |
|
} |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Authorization': f'Bearer {auth_token}' |
|
} |
|
|
|
response = requests.post(translate_url, headers=headers, json=translate_body, timeout=30) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result['choices'][0]['message']['content'] |
|
|
|
SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。 |
|
提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。 |
|
|
|
为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。 |
|
|
|
提示包括三个部分:**前缀**(质量标签+风格词+效果器)+ **主题**(图像的主要焦点)+ **场景**(背景、环境)。 |
|
|
|
* 前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。 |
|
|
|
* 主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。 |
|
|
|
* 场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作: |
|
|
|
1. 我会发送给您一个图像场景。需要你生成详细的图像描述 |
|
2. 图像描述必须是英文,输出为Positive Prompt。 |
|
|
|
示例: |
|
|
|
我发送:二战时期的护士。 |
|
您回复只回复: |
|
A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment. |
|
""" |
|
|
|
RATIO_MAP = { |
|
"1:1": "1024x1024", |
|
"1:2": "1024x2048", |
|
"3:2": "1536x1024", |
|
"4:3": "1536x2048", |
|
"16:9": "2048x1152", |
|
"9:16": "1152x2048" |
|
} |
|
|
|
def stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original): |
|
return Response(stream_with_context(generate_stream(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original)), content_type='text/event-stream') |
|
|
|
def generate_stream(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original): |
|
chunks = [ |
|
f"原始提示词:\n{original_prompt}\n", |
|
] |
|
|
|
if not use_original: |
|
chunks.append(f"翻译后的提示词:\n{translated_prompt}\n") |
|
|
|
chunks.extend([ |
|
f"图像规格:{size}\n", |
|
"正在根据提示词生成图像...\n", |
|
"图像正在处理中...\n", |
|
"即将完成...\n", |
|
f"生成成功!\n图像生成完毕,以下是结果:\n\n![生成的图像]({image_data['data'][0]['url']})" |
|
]) |
|
|
|
for i, chunk in enumerate(chunks): |
|
json_chunk = json.dumps({ |
|
"id": unique_id, |
|
"object": "chat.completion.chunk", |
|
"created": created, |
|
"model": model, |
|
"system_fingerprint": system_fingerprint, |
|
"choices": [{ |
|
"index": 0, |
|
"delta": {"content": chunk}, |
|
"logprobs": None, |
|
"finish_reason": None |
|
}] |
|
}) |
|
yield f"data: {json_chunk}\n\n" |
|
time.sleep(0.5) |
|
|
|
final_chunk = json.dumps({ |
|
"id": unique_id, |
|
"object": "chat.completion.chunk", |
|
"created": created, |
|
"model": model, |
|
"system_fingerprint": system_fingerprint, |
|
"choices": [{ |
|
"index": 0, |
|
"delta": {}, |
|
"logprobs": None, |
|
"finish_reason": "stop" |
|
}] |
|
}) |
|
yield f"data: {final_chunk}\n\n" |
|
|
|
def non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original): |
|
content = f"原始提示词:{original_prompt}\n" |
|
|
|
if not use_original: |
|
content += f"翻译后的提示词:{translated_prompt}\n" |
|
|
|
content += ( |
|
f"图像规格:{size}\n" |
|
f"图像生成成功!\n" |
|
f"以下是结果:\n\n" |
|
f"![生成的图像]({image_data['data'][0]['url']})" |
|
) |
|
|
|
response = { |
|
'id': unique_id, |
|
'object': "chat.completion", |
|
'created': created, |
|
'model': model, |
|
'system_fingerprint': system_fingerprint, |
|
'choices': [{ |
|
'index': 0, |
|
'message': { |
|
'role': "assistant", |
|
'content': content |
|
}, |
|
'finish_reason': "stop" |
|
}], |
|
'usage': { |
|
'prompt_tokens': len(original_prompt), |
|
'completion_tokens': len(content), |
|
'total_tokens': len(original_prompt) + len(content) |
|
} |
|
} |
|
|
|
return jsonify(response) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=8000) |
|
|