Spaces:
Sleeping
Sleeping
import json | |
import uuid | |
from typing import Iterator, Union, List, Dict | |
from dotenv import load_dotenv; load_dotenv() | |
import os | |
import requests | |
AVAILABLE_MODELS = [ | |
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", | |
"o1-mini", | |
"claude-3-sonnet-20240229", | |
"gemini-1.5-pro", | |
"gemini-1.5-flash", | |
"o1-preview", | |
"gpt-4o" | |
] | |
def API_Inference( | |
messages: List[Dict[str, str]], | |
model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", | |
stream: bool = False, | |
max_tokens: int = 4000, | |
temperature: float = 0.7, | |
top_p: float = 0.95, | |
) -> Union[str, Iterator[str], None]: | |
if model not in AVAILABLE_MODELS: | |
raise ValueError( | |
f"Model {model} not available. Available models: {', '.join(AVAILABLE_MODELS)}" | |
) | |
if model == "claude-3-sonnet-20240229": | |
messages = [{"role": "system", "content": "."}] + [msg for msg in messages if msg["role"] != "system"] | |
api_endpoint = os.environ.get("AMIGO_BASE_URL") | |
headers = { | |
"Accept": "*/*", | |
"Accept-Encoding": "gzip, deflate, br, zstd", | |
"Authorization": "Bearer ", | |
"Content-Type": "application/json", | |
"User-Agent": ( | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
"AppleWebKit/537.36 (KHTML, like Gecko) " | |
"Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0" | |
), | |
"X-Device-UUID": str(uuid.uuid4()), | |
} | |
payload = { | |
"messages": messages, | |
"model": model, | |
"max_tokens": max_tokens, | |
"stream": stream, | |
"presence_penalty": 0, | |
"temperature": temperature, | |
"top_p": top_p, | |
} | |
try: | |
response = requests.post(api_endpoint, headers=headers, json=payload, stream=stream) | |
response.raise_for_status() | |
except requests.exceptions.RequestException as e: | |
print("An error occurred while making the request:", e) | |
return None | |
def process_response() -> Iterator[str]: | |
for line in response.iter_lines(): | |
if line: | |
# Decode the line from bytes to string | |
decoded_line = line.decode('utf-8').strip() | |
if decoded_line.startswith("data: "): | |
data_str = decoded_line[6:] | |
if data_str == "[DONE]": | |
break | |
try: | |
# Load the JSON data | |
data_json = json.loads(data_str) | |
# Extract the content from the response | |
choices = data_json.get("choices", []) | |
if choices: | |
delta = choices[0].get("delta", {}) | |
content = delta.get("content", "") | |
if content: | |
yield content | |
except json.JSONDecodeError: | |
print(f"Received non-JSON data: {data_str}") | |
if stream: | |
return process_response() | |
else: | |
return "".join(process_response()) | |
if __name__ == "__main__": | |
# Example usage with the new format | |
conversation = [ | |
{"role": "system", "content": "You are a helpful and friendly AI assistant."}, | |
{"role": "user", "content": "What is the capital of France?"}, | |
{"role": "assistant", "content": "Paris"}, | |
{"role": "user", "content": "Who are you. Are you GPT-4o or gpt-3.5?"} | |
] | |
# For non-streaming response | |
response = API_Inference(conversation, stream=False, model="claude-3-sonnet-20240229") | |
print(response) | |
print("--" * 50) | |
# # For streaming response | |
for chunk in API_Inference(conversation, stream=True, model="gpt-4o"): | |
print(chunk, end="", flush=True) |