Spaces:
Runtime error
Runtime error
"""Call API providers.""" | |
import os | |
import random | |
import time | |
from fastchat.utils import build_logger | |
from fastchat.constants import WORKER_API_TIMEOUT | |
logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
def openai_api_stream_iter( | |
model_name, | |
messages, | |
temperature, | |
top_p, | |
max_new_tokens, | |
api_base=None, | |
api_key=None, | |
): | |
import openai | |
openai.api_base = api_base or "https://api.openai.com/v1" | |
openai.api_key = api_key or os.environ["OPENAI_API_KEY"] | |
if model_name == "gpt-4-turbo": | |
model_name = "gpt-4-1106-preview" | |
# Make requests | |
gen_params = { | |
"model": model_name, | |
"prompt": messages, | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_new_tokens": max_new_tokens, | |
} | |
logger.info(f"==== request ====\n{gen_params}") | |
res = openai.ChatCompletion.create( | |
model=model_name, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_new_tokens, | |
stream=True, | |
) | |
text = "" | |
for chunk in res: | |
text += chunk["choices"][0]["delta"].get("content", "") | |
data = { | |
"text": text, | |
"error_code": 0, | |
} | |
yield data | |
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): | |
import anthropic | |
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) | |
# Make requests | |
gen_params = { | |
"model": model_name, | |
"prompt": prompt, | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_new_tokens": max_new_tokens, | |
} | |
logger.info(f"==== request ====\n{gen_params}") | |
res = c.completions.create( | |
prompt=prompt, | |
stop_sequences=[anthropic.HUMAN_PROMPT], | |
max_tokens_to_sample=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
model=model_name, | |
stream=True, | |
) | |
text = "" | |
for chunk in res: | |
text += chunk.completion | |
data = { | |
"text": text, | |
"error_code": 0, | |
} | |
yield data | |
def init_palm_chat(model_name): | |
import vertexai # pip3 install google-cloud-aiplatform | |
from vertexai.preview.language_models import ChatModel | |
project_id = os.environ["GCP_PROJECT_ID"] | |
location = "us-central1" | |
vertexai.init(project=project_id, location=location) | |
chat_model = ChatModel.from_pretrained(model_name) | |
chat = chat_model.start_chat(examples=[]) | |
return chat | |
def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): | |
parameters = { | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_output_tokens": max_new_tokens, | |
} | |
gen_params = { | |
"model": "palm-2", | |
"prompt": message, | |
} | |
gen_params.update(parameters) | |
logger.info(f"==== request ====\n{gen_params}") | |
response = chat.send_message(message, **parameters) | |
content = response.text | |
pos = 0 | |
while pos < len(content): | |
# This is a fancy way to simulate token generation latency combined | |
# with a Poisson process. | |
pos += random.randint(10, 20) | |
time.sleep(random.expovariate(50)) | |
data = { | |
"text": content[:pos], | |
"error_code": 0, | |
} | |
yield data | |