med-copilot / llm_calls.py
presidio's picture
initial commit
8bae60c verified
raw
history blame
4.23 kB
from typing import Dict, List
from openai import OpenAI
import requests
import json
import simplejson
from pydantic import BaseModel
class AnswerFormat(BaseModel):
dataset: List[Dict]
explanations: str
references: str
def query_perplexity(
system_prompt: str,
user_prompt: str,
json_data: str,
api_key: str,
url="https://api.perplexity.ai/chat/completions",
model="sonar-pro",
):
"""Query Perplexity AI API for a response.
Args:
system_prompt (str): System message providing AI context.
user_prompt (str): User's query.
json_data (str): JSON data representing the current dataset.
api_key (str): Perplexity AI API key.
url (str): API endpoint.
model (str): Perplexity AI model to use.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for randomness.
top_p (float): Nucleus sampling parameter.
top_k (int): Top-k filtering.
presence_penalty (float): Encourages new token diversity.
frequency_penalty (float): Penalizes frequent tokens.
return_images (bool): Whether to include images in response.
return_related_questions (bool): Whether to include related questions.
search_domain_filter (str or None): Domain filter for web search.
search_recency_filter (str or None): Recency filter for web search.
stream (bool): Whether to stream response.
Returns:
str: Parsed JSON response from Perplexity AI API.
"""
payload = {
"model": model,
"messages": [
{"role": "system", "content": f"{system_prompt}\n"
f"Make sure you add the citations found to the references key"},
{"role": "user", "content": f"Here is the dataset: {json_data}\n\n"
f"User query:\n"
f"{user_prompt}"},
],
"response_format": {
"type": "json_schema",
"json_schema": {"schema": AnswerFormat.model_json_schema()},
},
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code == 200:
response_json = response.json()
return response_json["choices"][0]["message"]["content"]
else:
return f"API request failed with status code {response.status_code}, details: {response.text}"
def query_openai(system_prompt: str, user_prompt: str, json_data: str, openai_client: OpenAI) -> str:
"""Query OpenAI API for a response.
Args:
system_prompt (str): System prompt providing context to the AI.
user_prompt (str): User's query.
json_data (str): JSON data representing the current dataset.
openai_client (OpenAI): OpenAI client instance with API key set.
Returns:
str: JSON response from the API.
"""
response = openai_client.chat.completions.create(
model="gpt-4-turbo",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Here is the dataset: {json_data}"},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
if len(response.choices) > 0:
content = response.choices[0].message.content
return content
else:
return "Bad response from OpenAI"
def validate_llm_response(response: str) -> dict:
# extract dict from json
try:
return json.loads(response)
except json.JSONDecodeError:
try:
return simplejson.loads(response) # More forgiving JSON parser
except simplejson.JSONDecodeError:
return None # JSON is too broken to fix
# Validate expected keys
required_keys = {"dataset", "explanation", "references"}
if not required_keys.issubset(response.keys()):
raise ValueError(f"Missing required keys: {required_keys - response.keys()}")
return response # Return as a structured dictionary