api / modules /app.py
Muhammad Saqib
Update modules/app.py
daf8462 verified
raw
history blame
9.35 kB
from fastapi import FastAPI, HTTPException, Request, Depends
import aiohttp
import base64
import io
import os
import random
import string
import json
app = FastAPI()
whisper_origin = os.getenv("WHISPER_ORIGIN")
whisper_base_url = os.getenv("WHISPER_BASE_URL")
img2location_name = os.getenv("IMG2LOCATION_NAME")
img2location_origin = os.getenv("IMG2LOCATION_ORIGIN")
img2location_base_url = os.getenv("IMG2LOCATION_BASE_URL")
pixart_sigma_base_url = os.getenv("PIXART_SIGMA_BASE_URL")
allowed_user_agent = os.getenv("ALLOWED_USER_AGENT")
async def verify_user_agent(request: Request):
user_agent = request.headers.get('User-Agent', None)
if user_agent != allowed_user_agent:
raise HTTPException(status_code=403, detail="Access denied")
return True
def generate_hash(length=12):
# Characters that can appear in the hash
characters = string.ascii_lowercase + string.digits
# Generate a random string of the specified length
hash_string = ''.join(random.choice(characters) for _ in range(length))
return hash_string
@app.get("/")
async def read_root():
return {"message": "Saqib's API"}
@app.post("/whisper", dependencies=[Depends(verify_user_agent)])
async def whisper(request: Request):
data = await request.json() # Extracting JSON data from request
if "audio_url" not in data:
raise HTTPException(status_code=400, detail="audio_url not found in request")
url = data["audio_url"]
headers = {
'Accept': 'application/json, text/plain, */*',
'Accept-Language': 'en-US,en;q=0.9',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'DNT': '1',
'Origin': whisper_origin,
'Pragma': 'no-cache',
'Referer': f'{whisper_origin}/',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-site',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0',
'sec-ch-ua': '"Chromium";v="124", "Microsoft Edge";v="124", "Not-A.Brand";v="99"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Windows"',
}
# Async HTTP request to get the audio file
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status != 200:
return f"Failed to download audio: {resp.status}"
audio_data = await resp.read()
# Encode the audio data to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
json_data = '{"audio": "' + audio_base64 + '"}'
# Post request to the API
async with session.post(f'{whisper_base_url}/v1/inference/openai/whisper-large', headers=headers, data=json_data) as post_resp:
if post_resp.status != 200:
return f"API request failed: {post_resp.status}"
return await post_resp.json()
@app.post("/img2location", dependencies=[Depends(verify_user_agent)])
async def img2location(request: Request):
request_json = await request.json()
image_url = request_json.get("image_url", None)
if not image_url:
raise HTTPException(status_code=400, detail="image_url not found in request")
def extract_coordinates(text):
# Split the text into lines
lines = text.split('\n')
# Iterate through each line to find the one that contains the coordinates
for line in lines:
try:
if line.startswith("Coordinates:"):
# Remove the label and split by comma
coords = line.replace("Coordinates:", "").strip()
lat, lon = coords.split(',')
# Further split by space to isolate numerical values
latitude = float(lat.split('°')[0].strip())
longitude = float(lon.split('°')[0].strip())
return latitude, longitude
except Exception as e:
print("Error:", e)
return None
# Return None if no coordinates are found
return None
headers = {
'accept': '*/*',
'accept-language': 'en-US,en;q=0.9',
'cache-control': 'no-cache',
'dnt': '1',
'origin': img2location_origin,
'pragma': 'no-cache',
'priority': 'u=1, i',
'referer': f'{img2location_origin}/',
'sec-ch-ua': '"Chromium";v="124", "Microsoft Edge";v="124", "Not-A.Brand";v="99"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Windows"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'cross-site',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0',
}
async with aiohttp.ClientSession() as session:
# Fetch the image from the URL
async with session.get(image_url) as img_response:
if img_response.status != 200:
return f"Failed to fetch image: HTTP {img_response.status}"
image_data = await img_response.read()
# Using BytesIO to handle the byte content
data = aiohttp.FormData()
data.add_field('image', io.BytesIO(image_data), filename="image.png", content_type='image/png')
# Sending the POST request
async with session.post(img2location_base_url, headers=headers, data=data) as response:
if response.status != 200:
return f"Failed to upload image: HTTP {response.status}"
json_response = await response.json()
if 'message' in json_response:
json_response['message'] = json_response['message'].strip()
coordinates = extract_coordinates(json_response['message'])
if coordinates:
latitude, longitude = coordinates
google_maps = f"https://www.google.com/maps/search/?api=1&query={latitude},{longitude}"
json_response['message'] += f"\n\nView on Google Maps: {google_maps}"
if img2location_name in json_response['message'].lower():
raise HTTPException(status_code=400, detail="We are not allowed to process this image. Please try another one.")
return json_response
raise ValueError(f"Unexpected response: {json_response}")
@app.post("/pixart-sigma", dependencies=[Depends(verify_user_agent)])
async def pixart_sigma(request: Request):
request_json = await request.json()
prompt = request_json.get("prompt", None)
negative_prompt = request_json.get("negative_prompt", "")
style = request_json.get("style", "(No style)")
use_negative_prompt = request_json.get("use_negative_prompt", True)
num_imgs = request_json.get("num_imgs", 1)
seed = request_json.get("seed", 0)
width = request_json.get("width", 1024)
height = request_json.get("height", 1024)
schedule = request_json.get("schedule", "DPM-Solver")
dpms_guidance_scale = request_json.get("dpms_guidance_scale", 4.5)
sas_guidance_scale = request_json.get("sas_guidance_scale", 3)
dpms_inference_steps = request_json.get("dpms_inference_steps", 14)
sas_inference_steps = request_json.get("sas_inference_steps", 25)
randomize_seed = request_json.get("randomize_seed", True)
hash = generate_hash()
headers = {
'accept': '*/*'
}
params = {
'__theme': 'light',
}
json_data = {
'data': [
prompt,
negative_prompt,
style,
use_negative_prompt,
num_imgs,
seed,
width,
height,
schedule,
dpms_guidance_scale,
sas_guidance_scale,
dpms_inference_steps,
sas_inference_steps,
True,
],
'event_data': None,
'fn_index': 3,
'trigger_id': 7,
'session_hash': hash,
}
async with aiohttp.ClientSession() as session:
async with session.post(f'{pixart_sigma_base_url}/queue/join', params=params, headers=headers, json=json_data) as response:
print(response.text)
params = {
'session_hash': hash,
}
async with session.get(f'{pixart_sigma_base_url}/queue/data', params=params, headers=headers) as response:
async for line in response.content:
try:
if line:
line = line.decode('utf-8')
line = line.replace('data: ', '')
line_json = json.loads(line)
if line_json["msg"] == "process_completed":
print(line_json)
image_url = line_json["output"]["data"][0][0]["image"]["url"]
return {"image_url": image_url}
except:
pass
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)