trying to get api working but it is not working yet
Browse files- experimental/clip_api_app.py +33 -33
- experimental/clip_api_app_client.py +35 -29
experimental/clip_api_app.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
-
|
2 |
-
import json
|
3 |
-
import os
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from starlette.requests import Request
|
7 |
-
from PIL import Image
|
8 |
import ray
|
9 |
from ray import serve
|
|
|
10 |
from clip_retrieval.load_clip import load_clip, get_tokenizer
|
11 |
# from clip_retrieval.clip_client import ClipClient, Modality
|
12 |
|
@@ -24,11 +21,14 @@ class CLIPTransform:
|
|
24 |
|
25 |
print ("using device", self.device)
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
with torch.no_grad():
|
30 |
prompt_embededdings = self.model.encode_text(text)
|
31 |
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
|
|
32 |
return(prompt_embededdings)
|
33 |
|
34 |
def image_to_embeddings(self, input_im):
|
@@ -45,31 +45,31 @@ class CLIPTransform:
|
|
45 |
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
46 |
return(image_embeddings)
|
47 |
|
48 |
-
async def __call__(self, http_request: Request) -> str:
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
deployment_graph = CLIPTransform.bind()
|
|
|
1 |
+
from typing import List
|
|
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
|
|
4 |
import ray
|
5 |
from ray import serve
|
6 |
+
from PIL import Image
|
7 |
from clip_retrieval.load_clip import load_clip, get_tokenizer
|
8 |
# from clip_retrieval.clip_client import ClipClient, Modality
|
9 |
|
|
|
21 |
|
22 |
print ("using device", self.device)
|
23 |
|
24 |
+
@serve.batch(max_batch_size=32)
|
25 |
+
# def text_to_embeddings(self, prompts: List[str]) -> torch.Tensor:
|
26 |
+
def text_to_embeddings(self, prompts: List[str]) -> List[np.ndarray]:
|
27 |
+
text = self.tokenizer(prompts).to(self.device)
|
28 |
with torch.no_grad():
|
29 |
prompt_embededdings = self.model.encode_text(text)
|
30 |
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
31 |
+
prompt_embededdings = prompt_embededdings.cpu().numpy().tolist()
|
32 |
return(prompt_embededdings)
|
33 |
|
34 |
def image_to_embeddings(self, input_im):
|
|
|
45 |
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
46 |
return(image_embeddings)
|
47 |
|
48 |
+
# async def __call__(self, http_request: Request) -> str:
|
49 |
+
# request = await http_request.json()
|
50 |
+
# # print(type(request))
|
51 |
+
# # print(str(request))
|
52 |
+
# # switch based if we are using text or image
|
53 |
+
# embeddings = None
|
54 |
+
# if "text" in request:
|
55 |
+
# prompt = request["text"]
|
56 |
+
# embeddings = self.text_to_embeddings(prompt)
|
57 |
+
# elif "image" in request:
|
58 |
+
# image_url = request["image_url"]
|
59 |
+
# # download image from url
|
60 |
+
# import requests
|
61 |
+
# from io import BytesIO
|
62 |
+
# input_image = Image.open(BytesIO(image_url))
|
63 |
+
# input_image = input_image.convert('RGB')
|
64 |
+
# input_image = np.array(input_image)
|
65 |
+
# embeddings = self.image_to_embeddings(input_image)
|
66 |
+
# elif "preprocessed_image" in request:
|
67 |
+
# prepro = request["preprocessed_image"]
|
68 |
+
# # create torch tensor on the device
|
69 |
+
# prepro = torch.tensor(prepro).to(self.device)
|
70 |
+
# embeddings = self.preprocessed_image_to_emdeddings(prepro)
|
71 |
+
# else:
|
72 |
+
# raise Exception("Invalid request")
|
73 |
+
# return embeddings.cpu().numpy().tolist()
|
74 |
|
75 |
deployment_graph = CLIPTransform.bind()
|
experimental/clip_api_app_client.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
|
2 |
-
from
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import requests
|
6 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
7 |
import time
|
|
|
|
|
|
|
|
|
8 |
|
9 |
test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
|
10 |
english_text = (
|
@@ -12,38 +12,44 @@ english_text = (
|
|
12 |
"of wisdom, it was the age of foolishness, it was the epoch of belief"
|
13 |
)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
def
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
n_result, result = future.result()
|
28 |
-
result = json.loads(result)
|
29 |
-
print (f"{n_result} : {len(result[0])}")
|
30 |
-
|
31 |
-
# def process_text(numbers, max_workers=10):
|
32 |
-
# for n in numbers:
|
33 |
-
# n_result, result = send_text_request(n)
|
34 |
-
# result = json.loads(result)
|
35 |
-
# print (f"{n_result} : {len(result[0])}")
|
36 |
|
37 |
if __name__ == "__main__":
|
38 |
# n_calls = 100000
|
39 |
-
n_calls =
|
40 |
numbers = list(range(n_calls))
|
|
|
|
|
41 |
start_time = time.monotonic()
|
42 |
-
|
|
|
|
|
|
|
43 |
end_time = time.monotonic()
|
44 |
total_time = end_time - start_time
|
45 |
avg_time_ms = total_time / n_calls * 1000
|
46 |
calls_per_sec = n_calls / total_time
|
47 |
print(f"Average time taken: {avg_time_ms:.2f} ms")
|
48 |
print(f"Number of calls per second: {calls_per_sec:.2f}")
|
49 |
-
|
|
|
1 |
+
import ray
|
2 |
+
from ray import serve
|
|
|
|
|
|
|
|
|
3 |
import time
|
4 |
+
import asyncio
|
5 |
+
|
6 |
+
# Create a Semaphore object
|
7 |
+
semaphore = asyncio.Semaphore(10)
|
8 |
|
9 |
test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
|
10 |
english_text = (
|
|
|
12 |
"of wisdom, it was the age of foolishness, it was the epoch of belief"
|
13 |
)
|
14 |
|
15 |
+
async def send_text_request(serve_client, number):
|
16 |
+
async with semaphore:
|
17 |
+
# async_handle = serve_client.get_handle("CLIPTransform", sync=False)
|
18 |
+
async_handle = serve.get_deployment("CLIPTransform").get_handle(sync=False)
|
19 |
+
# async_handle = serve.get_deployment("CLIPTransform").get_handle()
|
20 |
+
embeddings = ray.get(await async_handle.text_to_embeddings.remote(english_text))
|
21 |
+
# embeddings = await async_handle.text_to_embeddings.remote(english_text)
|
22 |
+
# embeddings = async_handle.text_to_embeddings.remote(english_text)
|
23 |
+
# embeddings = await ray.get(embeddings)
|
24 |
+
return number, embeddings
|
25 |
|
26 |
+
# def process_text(server_client, numbers, max_workers=10):
|
27 |
+
# with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
28 |
+
# futures = [executor.submit(send_text_request, server_client, number) for number in numbers]
|
29 |
+
# for future in as_completed(futures):
|
30 |
+
# n_result, result = future.result()
|
31 |
+
# print (f"{n_result} : {len(result[0])}")
|
32 |
+
async def process_text(server_client, numbers):
|
33 |
+
tasks = [send_text_request(server_client, number) for number in numbers]
|
34 |
+
for future in asyncio.as_completed(tasks):
|
35 |
+
n_result, result = await future
|
36 |
+
print (f"{n_result} : {len(result[0])}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
if __name__ == "__main__":
|
39 |
# n_calls = 100000
|
40 |
+
n_calls = 1
|
41 |
numbers = list(range(n_calls))
|
42 |
+
ray.init()
|
43 |
+
server_client = serve.start(detached=True)
|
44 |
start_time = time.monotonic()
|
45 |
+
|
46 |
+
# Run the async function
|
47 |
+
asyncio.run(process_text(server_client, numbers))
|
48 |
+
|
49 |
end_time = time.monotonic()
|
50 |
total_time = end_time - start_time
|
51 |
avg_time_ms = total_time / n_calls * 1000
|
52 |
calls_per_sec = n_calls / total_time
|
53 |
print(f"Average time taken: {avg_time_ms:.2f} ms")
|
54 |
print(f"Number of calls per second: {calls_per_sec:.2f}")
|
55 |
+
ray.shutdown()
|