|
import ast |
|
import base64 |
|
import binascii |
|
import os |
|
import traceback |
|
from typing import Any, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm._logging import print_verbose, verbose_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler |
|
from litellm.proxy._types import KeyManagementSystem |
|
|
|
oidc_cache = DualCache() |
|
|
|
|
|
|
|
|
|
|
|
def _is_base64(s): |
|
try: |
|
return base64.b64encode(base64.b64decode(s)).decode() == s |
|
except binascii.Error: |
|
return False |
|
|
|
|
|
def str_to_bool(value: Optional[str]) -> Optional[bool]: |
|
""" |
|
Converts a string to a boolean if it's a recognized boolean string. |
|
Returns None if the string is not a recognized boolean value. |
|
|
|
:param value: The string to be checked. |
|
:return: True or False if the string is a recognized boolean, otherwise None. |
|
""" |
|
if value is None: |
|
return None |
|
|
|
true_values = {"true"} |
|
false_values = {"false"} |
|
|
|
value_lower = value.strip().lower() |
|
|
|
if value_lower in true_values: |
|
return True |
|
elif value_lower in false_values: |
|
return False |
|
else: |
|
return None |
|
|
|
|
|
def get_secret_str( |
|
secret_name: str, |
|
default_value: Optional[Union[str, bool]] = None, |
|
) -> Optional[str]: |
|
""" |
|
Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors. |
|
""" |
|
value = get_secret(secret_name=secret_name, default_value=default_value) |
|
if value is not None and not isinstance(value, str): |
|
return None |
|
|
|
return value |
|
|
|
|
|
def get_secret_bool( |
|
secret_name: str, |
|
default_value: Optional[bool] = None, |
|
) -> Optional[bool]: |
|
""" |
|
Guarantees response from 'get_secret' is either boolean or none. Used for fixing linting errors. |
|
|
|
Args: |
|
secret_name: The name of the secret to get. |
|
default_value: The default value to return if the secret is not found. |
|
|
|
Returns: |
|
The secret value as a boolean or None if the secret is not found. |
|
""" |
|
_secret_value = get_secret(secret_name, default_value) |
|
if _secret_value is None: |
|
return None |
|
elif isinstance(_secret_value, bool): |
|
return _secret_value |
|
else: |
|
return str_to_bool(_secret_value) |
|
|
|
|
|
def get_secret( |
|
secret_name: str, |
|
default_value: Optional[Union[str, bool]] = None, |
|
): |
|
key_management_system = litellm._key_management_system |
|
key_management_settings = litellm._key_management_settings |
|
secret = None |
|
|
|
if secret_name.startswith("os.environ/"): |
|
secret_name = secret_name.replace("os.environ/", "") |
|
|
|
|
|
if secret_name.startswith("oidc/"): |
|
secret_name_split = secret_name.replace("oidc/", "") |
|
oidc_provider, oidc_aud = secret_name_split.split("/", 1) |
|
|
|
if oidc_provider == "google": |
|
oidc_token = oidc_cache.get_cache(key=secret_name) |
|
if oidc_token is not None: |
|
return oidc_token |
|
|
|
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) |
|
|
|
response = oidc_client.get( |
|
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", |
|
params={"audience": oidc_aud}, |
|
headers={"Metadata-Flavor": "Google"}, |
|
) |
|
if response.status_code == 200: |
|
oidc_token = response.text |
|
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) |
|
return oidc_token |
|
else: |
|
raise ValueError("Google OIDC provider failed") |
|
elif oidc_provider == "circleci": |
|
|
|
env_secret = os.getenv("CIRCLE_OIDC_TOKEN") |
|
if env_secret is None: |
|
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment") |
|
return env_secret |
|
elif oidc_provider == "circleci_v2": |
|
|
|
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2") |
|
if env_secret is None: |
|
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment") |
|
return env_secret |
|
elif oidc_provider == "github": |
|
|
|
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") |
|
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") |
|
if ( |
|
actions_id_token_request_url is None |
|
or actions_id_token_request_token is None |
|
): |
|
raise ValueError( |
|
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" |
|
) |
|
|
|
oidc_token = oidc_cache.get_cache(key=secret_name) |
|
if oidc_token is not None: |
|
return oidc_token |
|
|
|
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) |
|
response = oidc_client.get( |
|
actions_id_token_request_url, |
|
params={"audience": oidc_aud}, |
|
headers={ |
|
"Authorization": f"Bearer {actions_id_token_request_token}", |
|
"Accept": "application/json; api-version=2.0", |
|
}, |
|
) |
|
if response.status_code == 200: |
|
oidc_token = response.json().get("value", None) |
|
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) |
|
return oidc_token |
|
else: |
|
raise ValueError("Github OIDC provider failed") |
|
elif oidc_provider == "azure": |
|
|
|
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") |
|
if azure_federated_token_file is None: |
|
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") |
|
with open(azure_federated_token_file, "r") as f: |
|
oidc_token = f.read() |
|
return oidc_token |
|
elif oidc_provider == "file": |
|
|
|
with open(oidc_aud, "r") as f: |
|
oidc_token = f.read() |
|
return oidc_token |
|
elif oidc_provider == "env": |
|
|
|
oidc_token = os.getenv(oidc_aud) |
|
if oidc_token is None: |
|
raise ValueError(f"Environment variable {oidc_aud} not found") |
|
return oidc_token |
|
elif oidc_provider == "env_path": |
|
|
|
token_file_path = os.getenv(oidc_aud) |
|
if token_file_path is None: |
|
raise ValueError(f"Environment variable {oidc_aud} not found") |
|
with open(token_file_path, "r") as f: |
|
oidc_token = f.read() |
|
return oidc_token |
|
else: |
|
raise ValueError("Unsupported OIDC provider") |
|
|
|
try: |
|
if ( |
|
_should_read_secret_from_secret_manager() |
|
and litellm.secret_manager_client is not None |
|
): |
|
try: |
|
client = litellm.secret_manager_client |
|
key_manager = "local" |
|
if key_management_system is not None: |
|
key_manager = key_management_system.value |
|
|
|
if key_management_settings is not None: |
|
if ( |
|
key_management_settings.hosted_keys is not None |
|
and secret_name not in key_management_settings.hosted_keys |
|
): |
|
key_manager = "local" |
|
|
|
if ( |
|
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value |
|
or type(client).__module__ + "." + type(client).__name__ |
|
== "azure.keyvault.secrets._client.SecretClient" |
|
): |
|
secret = client.get_secret(secret_name).value |
|
elif ( |
|
key_manager == KeyManagementSystem.GOOGLE_KMS.value |
|
or client.__class__.__name__ == "KeyManagementServiceClient" |
|
): |
|
encrypted_secret: Any = os.getenv(secret_name) |
|
if encrypted_secret is None: |
|
raise ValueError( |
|
"Google KMS requires the encrypted secret to be in the environment!" |
|
) |
|
b64_flag = _is_base64(encrypted_secret) |
|
if b64_flag is True: |
|
encrypted_secret = base64.b64decode(encrypted_secret) |
|
ciphertext = encrypted_secret |
|
else: |
|
raise ValueError( |
|
"Google KMS requires the encrypted secret to be encoded in base64" |
|
) |
|
response = client.decrypt( |
|
request={ |
|
"name": litellm._google_kms_resource_name, |
|
"ciphertext": ciphertext, |
|
} |
|
) |
|
secret = response.plaintext.decode( |
|
"utf-8" |
|
) |
|
elif key_manager == KeyManagementSystem.AWS_KMS.value: |
|
""" |
|
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. |
|
""" |
|
encrypted_value = os.getenv(secret_name, None) |
|
if encrypted_value is None: |
|
raise Exception( |
|
"AWS KMS - Encrypted Value of Key={} is None".format( |
|
secret_name |
|
) |
|
) |
|
|
|
ciphertext_blob = base64.b64decode(encrypted_value) |
|
|
|
|
|
params = {"CiphertextBlob": ciphertext_blob} |
|
|
|
response = client.decrypt(**params) |
|
|
|
|
|
plaintext = response["Plaintext"] |
|
secret = plaintext.decode("utf-8") |
|
if isinstance(secret, str): |
|
secret = secret.strip() |
|
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: |
|
from litellm.secret_managers.aws_secret_manager_v2 import ( |
|
AWSSecretsManagerV2, |
|
) |
|
|
|
if isinstance(client, AWSSecretsManagerV2): |
|
secret = client.sync_read_secret(secret_name=secret_name) |
|
print_verbose(f"get_secret_value_response: {secret}") |
|
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: |
|
try: |
|
secret = client.get_secret_from_google_secret_manager( |
|
secret_name |
|
) |
|
print_verbose(f"secret from google secret manager: {secret}") |
|
if secret is None: |
|
raise ValueError( |
|
f"No secret found in Google Secret Manager for {secret_name}" |
|
) |
|
except Exception as e: |
|
print_verbose(f"An error occurred - {str(e)}") |
|
raise e |
|
elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value: |
|
try: |
|
secret = client.sync_read_secret(secret_name=secret_name) |
|
if secret is None: |
|
raise ValueError( |
|
f"No secret found in Hashicorp Secret Manager for {secret_name}" |
|
) |
|
except Exception as e: |
|
print_verbose(f"An error occurred - {str(e)}") |
|
raise e |
|
elif key_manager == "local": |
|
secret = os.getenv(secret_name) |
|
else: |
|
secret = client.get_secret(secret_name).secret_value |
|
except Exception as e: |
|
verbose_logger.error( |
|
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}" |
|
) |
|
secret = os.getenv(secret_name) |
|
try: |
|
if isinstance(secret, str): |
|
secret_value_as_bool = ast.literal_eval(secret) |
|
if isinstance(secret_value_as_bool, bool): |
|
return secret_value_as_bool |
|
else: |
|
return secret |
|
except Exception: |
|
return secret |
|
else: |
|
secret = os.environ.get(secret_name) |
|
secret_value_as_bool = str_to_bool(secret) if secret is not None else None |
|
if secret_value_as_bool is not None and isinstance( |
|
secret_value_as_bool, bool |
|
): |
|
return secret_value_as_bool |
|
else: |
|
return secret |
|
except Exception as e: |
|
if default_value is not None: |
|
return default_value |
|
else: |
|
raise e |
|
|
|
|
|
def _should_read_secret_from_secret_manager() -> bool: |
|
""" |
|
Returns True if the secret manager should be used to read the secret, False otherwise |
|
|
|
- If the secret manager client is not set, return False |
|
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True |
|
- Otherwise, return False |
|
""" |
|
if litellm.secret_manager_client is not None: |
|
if litellm._key_management_settings is not None: |
|
if ( |
|
litellm._key_management_settings.access_mode == "read_only" |
|
or litellm._key_management_settings.access_mode == "read_and_write" |
|
): |
|
return True |
|
return False |
|
|