|
import enum |
|
import heapq |
|
from typing import Optional |
|
|
|
from pydantic import BaseModel |
|
|
|
from litellm import print_verbose |
|
from litellm.caching.caching import DualCache, RedisCache |
|
|
|
|
|
class SchedulerCacheKeys(enum.Enum): |
|
queue = "scheduler:queue" |
|
default_in_memory_ttl = 5 |
|
|
|
|
|
class DefaultPriorities(enum.Enum): |
|
High = 0 |
|
Medium = 128 |
|
Low = 255 |
|
|
|
|
|
class FlowItem(BaseModel): |
|
priority: int |
|
request_id: str |
|
model_name: str |
|
|
|
|
|
class Scheduler: |
|
cache: DualCache |
|
|
|
def __init__( |
|
self, |
|
polling_interval: Optional[float] = None, |
|
redis_cache: Optional[RedisCache] = None, |
|
): |
|
""" |
|
polling_interval: float or null - frequency of polling queue. Default is 3ms. |
|
""" |
|
self.queue: list = [] |
|
default_in_memory_ttl: Optional[float] = None |
|
if redis_cache is not None: |
|
|
|
default_in_memory_ttl = SchedulerCacheKeys.default_in_memory_ttl.value |
|
self.cache = DualCache( |
|
redis_cache=redis_cache, default_in_memory_ttl=default_in_memory_ttl |
|
) |
|
self.polling_interval = polling_interval or 0.03 |
|
|
|
async def add_request(self, request: FlowItem): |
|
|
|
|
|
queue = await self.get_queue(model_name=request.model_name) |
|
|
|
heapq.heappush(queue, (request.priority, request.request_id)) |
|
|
|
|
|
await self.save_queue(queue=queue, model_name=request.model_name) |
|
|
|
async def poll(self, id: str, model_name: str, health_deployments: list) -> bool: |
|
""" |
|
Return if request can be processed. |
|
|
|
Returns: |
|
- True: |
|
* If healthy deployments are available |
|
* OR If request at the top of queue |
|
- False: |
|
* If no healthy deployments available |
|
* AND request not at the top of queue |
|
""" |
|
queue = await self.get_queue(model_name=model_name) |
|
if not queue: |
|
raise Exception( |
|
"Incorrectly setup. Queue is invalid. Queue={}".format(queue) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
print_verbose(f"len(health_deployments): {len(health_deployments)}") |
|
if len(health_deployments) == 0: |
|
print_verbose(f"queue: {queue}, seeking id={id}") |
|
|
|
if queue[0][1] == id: |
|
|
|
heapq.heappop(queue) |
|
print_verbose(f"Popped id: {id}") |
|
return True |
|
else: |
|
return False |
|
|
|
return True |
|
|
|
async def peek(self, id: str, model_name: str, health_deployments: list) -> bool: |
|
"""Return if the id is at the top of the queue. Don't pop the value from heap.""" |
|
queue = await self.get_queue(model_name=model_name) |
|
if not queue: |
|
raise Exception( |
|
"Incorrectly setup. Queue is invalid. Queue={}".format(queue) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if queue[0][1] == id: |
|
return True |
|
|
|
return False |
|
|
|
def get_queue_status(self): |
|
"""Get the status of items in the queue""" |
|
return self.queue |
|
|
|
async def get_queue(self, model_name: str) -> list: |
|
""" |
|
Return a queue for that specific model group |
|
""" |
|
if self.cache is not None: |
|
_cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name) |
|
response = await self.cache.async_get_cache(key=_cache_key) |
|
if response is None or not isinstance(response, list): |
|
return [] |
|
elif isinstance(response, list): |
|
return response |
|
return self.queue |
|
|
|
async def save_queue(self, queue: list, model_name: str) -> None: |
|
""" |
|
Save the updated queue of the model group |
|
""" |
|
if self.cache is not None: |
|
_cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name) |
|
await self.cache.async_set_cache(key=_cache_key, value=queue) |
|
return None |
|
|