Muhammad Saqib commited on
Commit
8ad9edd
1 Parent(s): 08c3120

Update modules/app.py

Browse files
Files changed (1) hide show
  1. modules/app.py +222 -45
modules/app.py CHANGED
@@ -1,51 +1,228 @@
1
- import os
2
- import requests
 
 
 
 
 
3
  import json
4
- from io import BytesIO
5
 
6
- from fastapi import FastAPI
7
- from fastapi.staticfiles import StaticFiles
8
- from fastapi.responses import FileResponse, StreamingResponse
9
 
10
- from modules.inference import infer_t5
11
- from modules.dataset import query_emotion
 
 
 
12
 
13
- # https://huggingface.co/settings/tokens
14
- # https://huggingface.co/spaces/{username}/{space}/settings
15
- API_TOKEN = os.getenv("BIG_GAN_TOKEN")
 
 
 
16
 
17
- app = FastAPI(docs_url=None, redoc_url=None)
18
-
19
- app.mount("/static", StaticFiles(directory="static"), name="static")
20
-
21
-
22
- @app.head("/")
23
  @app.get("/")
24
- def index() -> FileResponse:
25
- return FileResponse(path="static/index.html", media_type="text/html")
26
-
27
-
28
- @app.get("/infer_biggan")
29
- def biggan(input):
30
- output = requests.request(
31
- "POST",
32
- "https://api-inference.huggingface.co/models/osanseviero/BigGAN-deep-128",
33
- headers={"Authorization": f"Bearer {API_TOKEN}"},
34
- data=json.dumps(input),
35
- )
36
-
37
- return StreamingResponse(BytesIO(output.content), media_type="image/png")
38
-
39
-
40
- @app.get("/infer_t5")
41
- def t5(input):
42
- output = infer_t5(input)
43
-
44
- return {"output": output}
45
-
46
-
47
- @app.get("/query_emotion")
48
- def emotion(start, end):
49
- output = query_emotion(int(start), int(end))
50
-
51
- return {"output": output}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request, Depends
2
+ import aiohttp
3
+ import base64
4
+ import io
5
+ import asyncio
6
+ import random
7
+ import string
8
  import json
 
9
 
10
+ app = FastAPI()
 
 
11
 
12
+ async def verify_user_agent(request: Request):
13
+ user_agent = request.headers.get('User-Agent', None)
14
+ if user_agent != "Vercel Edge Functions":
15
+ raise HTTPException(status_code=403, detail="Access denied")
16
+ return True
17
 
18
+ def generate_hash(length=12):
19
+ # Characters that can appear in the hash
20
+ characters = string.ascii_lowercase + string.digits
21
+ # Generate a random string of the specified length
22
+ hash_string = ''.join(random.choice(characters) for _ in range(length))
23
+ return hash_string
24
 
 
 
 
 
 
 
25
  @app.get("/")
26
+ async def read_root():
27
+ return {"message": "Saqib's API"}
28
+
29
+ @app.post("/whisper", dependencies=[Depends(verify_user_agent)])
30
+ async def whisper(request: Request):
31
+ data = await request.json() # Extracting JSON data from request
32
+ if "audio_url" not in data:
33
+ raise HTTPException(status_code=400, detail="audio_url not found in request")
34
+ url = data["audio_url"]
35
+
36
+ headers = {
37
+ 'Accept': 'application/json, text/plain, */*',
38
+ 'Accept-Language': 'en-US,en;q=0.9',
39
+ 'Cache-Control': 'no-cache',
40
+ 'Connection': 'keep-alive',
41
+ 'Content-Type': 'application/json',
42
+ 'DNT': '1',
43
+ 'Origin': 'https://deepinfra.com',
44
+ 'Pragma': 'no-cache',
45
+ 'Referer': 'https://deepinfra.com/',
46
+ 'Sec-Fetch-Dest': 'empty',
47
+ 'Sec-Fetch-Mode': 'cors',
48
+ 'Sec-Fetch-Site': 'same-site',
49
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0',
50
+ 'sec-ch-ua': '"Chromium";v="124", "Microsoft Edge";v="124", "Not-A.Brand";v="99"',
51
+ 'sec-ch-ua-mobile': '?0',
52
+ 'sec-ch-ua-platform': '"Windows"',
53
+ }
54
+
55
+ # Async HTTP request to get the audio file
56
+ async with aiohttp.ClientSession() as session:
57
+ async with session.get(url) as resp:
58
+ if resp.status != 200:
59
+ return f"Failed to download audio: {resp.status}"
60
+ audio_data = await resp.read()
61
+
62
+ # Encode the audio data to base64
63
+ audio_base64 = base64.b64encode(audio_data).decode("utf-8")
64
+
65
+ json_data = '{"audio": "' + audio_base64 + '"}'
66
+
67
+ # Post request to the API
68
+ async with session.post('https://api.deepinfra.com/v1/inference/openai/whisper-large', headers=headers, data=json_data) as post_resp:
69
+ if post_resp.status != 200:
70
+ return f"API request failed: {post_resp.status}"
71
+ return await post_resp.json()
72
+
73
+ @app.post("/img2location", dependencies=[Depends(verify_user_agent)])
74
+ async def img2location(request: Request):
75
+ request_json = await request.json()
76
+ image_url = request_json.get("image_url", None)
77
+
78
+ if not image_url:
79
+ raise HTTPException(status_code=400, detail="image_url not found in request")
80
+
81
+ def extract_coordinates(text):
82
+ # Split the text into lines
83
+ lines = text.split('\n')
84
+
85
+ # Iterate through each line to find the one that contains the coordinates
86
+ for line in lines:
87
+ try:
88
+ if line.startswith("Coordinates:"):
89
+ # Remove the label and split by comma
90
+ coords = line.replace("Coordinates:", "").strip()
91
+ lat, lon = coords.split(',')
92
+
93
+ # Further split by space to isolate numerical values
94
+ latitude = float(lat.split('°')[0].strip())
95
+ longitude = float(lon.split('°')[0].strip())
96
+
97
+ return latitude, longitude
98
+ except Exception as e:
99
+ print("Error:", e)
100
+ return None
101
+ # Return None if no coordinates are found
102
+ return None
103
+
104
+ headers = {
105
+ 'accept': '*/*',
106
+ 'accept-language': 'en-US,en;q=0.9',
107
+ 'cache-control': 'no-cache',
108
+ 'dnt': '1',
109
+ 'origin': 'https://geospy.ai',
110
+ 'pragma': 'no-cache',
111
+ 'priority': 'u=1, i',
112
+ 'referer': 'https://geospy.ai/',
113
+ 'sec-ch-ua': '"Chromium";v="124", "Microsoft Edge";v="124", "Not-A.Brand";v="99"',
114
+ 'sec-ch-ua-mobile': '?0',
115
+ 'sec-ch-ua-platform': '"Windows"',
116
+ 'sec-fetch-dest': 'empty',
117
+ 'sec-fetch-mode': 'cors',
118
+ 'sec-fetch-site': 'cross-site',
119
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0',
120
+ }
121
+
122
+ async with aiohttp.ClientSession() as session:
123
+ # Fetch the image from the URL
124
+ async with session.get(image_url) as img_response:
125
+ if img_response.status != 200:
126
+ return f"Failed to fetch image: HTTP {img_response.status}"
127
+ image_data = await img_response.read()
128
+
129
+ # Using BytesIO to handle the byte content
130
+ data = aiohttp.FormData()
131
+ data.add_field('image', io.BytesIO(image_data), filename="image.png", content_type='image/png')
132
+
133
+ # Sending the POST request
134
+ async with session.post('https://locate-image-7cs5mab6na-uc.a.run.app/', headers=headers, data=data) as response:
135
+ if response.status != 200:
136
+ return f"Failed to upload image: HTTP {response.status}"
137
+ json_response = await response.json()
138
+
139
+ if 'message' in json_response:
140
+ json_response['message'] = json_response['message'].strip()
141
+ coordinates = extract_coordinates(json_response['message'])
142
+
143
+ if coordinates:
144
+ latitude, longitude = coordinates
145
+ google_maps = f"https://www.google.com/maps/search/?api=1&query={latitude},{longitude}"
146
+
147
+ json_response['message'] += f"\n\nView on Google Maps: {google_maps}"
148
+
149
+ return json_response
150
+
151
+ raise ValueError(f"Unexpected response: {json_response}")
152
+
153
+ @app.post("/pixart-sigma", dependencies=[Depends(verify_user_agent)])
154
+ async def pixart_sigma(request: Request):
155
+ request_json = await request.json()
156
+ prompt = request_json.get("prompt", None)
157
+ negative_prompt = request_json.get("negative_prompt", "")
158
+ style = request_json.get("style", "(No style)")
159
+ use_negative_prompt = request_json.get("use_negative_prompt", True)
160
+ num_imgs = request_json.get("num_imgs", 1)
161
+ seed = request_json.get("seed", 0)
162
+ width = request_json.get("width", 1024)
163
+ height = request_json.get("height", 1024)
164
+ schedule = request_json.get("schedule", "DPM-Solver")
165
+ dpms_guidance_scale = request_json.get("dpms_guidance_scale", 4.5)
166
+ sas_guidance_scale = request_json.get("sas_guidance_scale", 3)
167
+ dpms_inference_steps = request_json.get("dpms_inference_steps", 14)
168
+ sas_inference_steps = request_json.get("sas_inference_steps", 25)
169
+ randomize_seed = request_json.get("randomize_seed", True)
170
+
171
+ hash = generate_hash()
172
+
173
+ headers = {
174
+ 'accept': '*/*'
175
+ }
176
+
177
+ params = {
178
+ '__theme': 'light',
179
+ }
180
+
181
+ json_data = {
182
+ 'data': [
183
+ prompt,
184
+ negative_prompt,
185
+ style,
186
+ use_negative_prompt,
187
+ num_imgs,
188
+ seed,
189
+ width,
190
+ height,
191
+ schedule,
192
+ dpms_guidance_scale,
193
+ sas_guidance_scale,
194
+ dpms_inference_steps,
195
+ sas_inference_steps,
196
+ True,
197
+ ],
198
+ 'event_data': None,
199
+ 'fn_index': 3,
200
+ 'trigger_id': 7,
201
+ 'session_hash': hash,
202
+ }
203
+
204
+ async with aiohttp.ClientSession() as session:
205
+ async with session.post('https://pixart-alpha-pixart-sigma.hf.space/queue/join', params=params, headers=headers, json=json_data) as response:
206
+ print(response.text)
207
+
208
+ params = {
209
+ 'session_hash': hash,
210
+ }
211
+
212
+ async with session.get('https://pixart-alpha-pixart-sigma.hf.space/queue/data', params=params, headers=headers) as response:
213
+ async for line in response.content:
214
+ try:
215
+ if line:
216
+ line = line.decode('utf-8')
217
+ line = line.replace('data: ', '')
218
+ line_json = json.loads(line)
219
+ if line_json["msg"] == "process_completed":
220
+ print(line_json)
221
+ image_url = line_json["output"]["data"][0][0]["image"]["url"]
222
+ return {"image_url": image_url}
223
+ except:
224
+ pass
225
+
226
+ # if __name__ == "__main__":
227
+ # import uvicorn
228
+ # uvicorn.run(app, host="0.0.0.0", port=8000)