import ast import base64 import binascii import os import traceback from typing import Any, Optional, Union import httpx import litellm from litellm._logging import print_verbose, verbose_logger from litellm.caching.caching import DualCache from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.proxy._types import KeyManagementSystem oidc_cache = DualCache() ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there def _is_base64(s): try: return base64.b64encode(base64.b64decode(s)).decode() == s except binascii.Error: return False def str_to_bool(value: Optional[str]) -> Optional[bool]: """ Converts a string to a boolean if it's a recognized boolean string. Returns None if the string is not a recognized boolean value. :param value: The string to be checked. :return: True or False if the string is a recognized boolean, otherwise None. """ if value is None: return None true_values = {"true"} false_values = {"false"} value_lower = value.strip().lower() if value_lower in true_values: return True elif value_lower in false_values: return False else: return None def get_secret_str( secret_name: str, default_value: Optional[Union[str, bool]] = None, ) -> Optional[str]: """ Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors. """ value = get_secret(secret_name=secret_name, default_value=default_value) if value is not None and not isinstance(value, str): return None return value def get_secret_bool( secret_name: str, default_value: Optional[bool] = None, ) -> Optional[bool]: """ Guarantees response from 'get_secret' is either boolean or none. Used for fixing linting errors. Args: secret_name: The name of the secret to get. default_value: The default value to return if the secret is not found. Returns: The secret value as a boolean or None if the secret is not found. """ _secret_value = get_secret(secret_name, default_value) if _secret_value is None: return None elif isinstance(_secret_value, bool): return _secret_value else: return str_to_bool(_secret_value) def get_secret( # noqa: PLR0915 secret_name: str, default_value: Optional[Union[str, bool]] = None, ): key_management_system = litellm._key_management_system key_management_settings = litellm._key_management_settings secret = None if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") # Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke if secret_name.startswith("oidc/"): secret_name_split = secret_name.replace("oidc/", "") oidc_provider, oidc_aud = secret_name_split.split("/", 1) # TODO: Add caching for HTTP requests if oidc_provider == "google": oidc_token = oidc_cache.get_cache(key=secret_name) if oidc_token is not None: return oidc_token oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature response = oidc_client.get( "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", params={"audience": oidc_aud}, headers={"Metadata-Flavor": "Google"}, ) if response.status_code == 200: oidc_token = response.text oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) return oidc_token else: raise ValueError("Google OIDC provider failed") elif oidc_provider == "circleci": # https://circleci.com/docs/openid-connect-tokens/ env_secret = os.getenv("CIRCLE_OIDC_TOKEN") if env_secret is None: raise ValueError("CIRCLE_OIDC_TOKEN not found in environment") return env_secret elif oidc_provider == "circleci_v2": # https://circleci.com/docs/openid-connect-tokens/ env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2") if env_secret is None: raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment") return env_secret elif oidc_provider == "github": # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") if ( actions_id_token_request_url is None or actions_id_token_request_token is None ): raise ValueError( "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" ) oidc_token = oidc_cache.get_cache(key=secret_name) if oidc_token is not None: return oidc_token oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) response = oidc_client.get( actions_id_token_request_url, params={"audience": oidc_aud}, headers={ "Authorization": f"Bearer {actions_id_token_request_token}", "Accept": "application/json; api-version=2.0", }, ) if response.status_code == 200: oidc_token = response.json().get("value", None) oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) return oidc_token else: raise ValueError("Github OIDC provider failed") elif oidc_provider == "azure": # https://azure.github.io/azure-workload-identity/docs/quick-start.html azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") if azure_federated_token_file is None: raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") with open(azure_federated_token_file, "r") as f: oidc_token = f.read() return oidc_token elif oidc_provider == "file": # Load token from a file with open(oidc_aud, "r") as f: oidc_token = f.read() return oidc_token elif oidc_provider == "env": # Load token directly from an environment variable oidc_token = os.getenv(oidc_aud) if oidc_token is None: raise ValueError(f"Environment variable {oidc_aud} not found") return oidc_token elif oidc_provider == "env_path": # Load token from a file path specified in an environment variable token_file_path = os.getenv(oidc_aud) if token_file_path is None: raise ValueError(f"Environment variable {oidc_aud} not found") with open(token_file_path, "r") as f: oidc_token = f.read() return oidc_token else: raise ValueError("Unsupported OIDC provider") try: if ( _should_read_secret_from_secret_manager() and litellm.secret_manager_client is not None ): try: client = litellm.secret_manager_client key_manager = "local" if key_management_system is not None: key_manager = key_management_system.value if key_management_settings is not None: if ( key_management_settings.hosted_keys is not None and secret_name not in key_management_settings.hosted_keys ): # allow user to specify which keys to check in hosted key manager key_manager = "local" if ( key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value or type(client).__module__ + "." + type(client).__name__ == "azure.keyvault.secrets._client.SecretClient" ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient secret = client.get_secret(secret_name).value elif ( key_manager == KeyManagementSystem.GOOGLE_KMS.value or client.__class__.__name__ == "KeyManagementServiceClient" ): encrypted_secret: Any = os.getenv(secret_name) if encrypted_secret is None: raise ValueError( "Google KMS requires the encrypted secret to be in the environment!" ) b64_flag = _is_base64(encrypted_secret) if b64_flag is True: # if passed in as encoded b64 string encrypted_secret = base64.b64decode(encrypted_secret) ciphertext = encrypted_secret else: raise ValueError( "Google KMS requires the encrypted secret to be encoded in base64" ) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce response = client.decrypt( request={ "name": litellm._google_kms_resource_name, "ciphertext": ciphertext, } ) secret = response.plaintext.decode( "utf-8" ) # assumes the original value was encoded with utf-8 elif key_manager == KeyManagementSystem.AWS_KMS.value: """ Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. """ encrypted_value = os.getenv(secret_name, None) if encrypted_value is None: raise Exception( "AWS KMS - Encrypted Value of Key={} is None".format( secret_name ) ) # Decode the base64 encoded ciphertext ciphertext_blob = base64.b64decode(encrypted_value) # Set up the parameters for the decrypt call params = {"CiphertextBlob": ciphertext_blob} # Perform the decryption response = client.decrypt(**params) # Extract and decode the plaintext plaintext = response["Plaintext"] secret = plaintext.decode("utf-8") if isinstance(secret, str): secret = secret.strip() elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: from litellm.secret_managers.aws_secret_manager_v2 import ( AWSSecretsManagerV2, ) if isinstance(client, AWSSecretsManagerV2): secret = client.sync_read_secret(secret_name=secret_name) print_verbose(f"get_secret_value_response: {secret}") elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: try: secret = client.get_secret_from_google_secret_manager( secret_name ) print_verbose(f"secret from google secret manager: {secret}") if secret is None: raise ValueError( f"No secret found in Google Secret Manager for {secret_name}" ) except Exception as e: print_verbose(f"An error occurred - {str(e)}") raise e elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value: try: secret = client.sync_read_secret(secret_name=secret_name) if secret is None: raise ValueError( f"No secret found in Hashicorp Secret Manager for {secret_name}" ) except Exception as e: print_verbose(f"An error occurred - {str(e)}") raise e elif key_manager == "local": secret = os.getenv(secret_name) else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value except Exception as e: # check if it's in os.environ verbose_logger.error( f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}" ) secret = os.getenv(secret_name) try: if isinstance(secret, str): secret_value_as_bool = ast.literal_eval(secret) if isinstance(secret_value_as_bool, bool): return secret_value_as_bool else: return secret except Exception: return secret else: secret = os.environ.get(secret_name) secret_value_as_bool = str_to_bool(secret) if secret is not None else None if secret_value_as_bool is not None and isinstance( secret_value_as_bool, bool ): return secret_value_as_bool else: return secret except Exception as e: if default_value is not None: return default_value else: raise e def _should_read_secret_from_secret_manager() -> bool: """ Returns True if the secret manager should be used to read the secret, False otherwise - If the secret manager client is not set, return False - If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True - Otherwise, return False """ if litellm.secret_manager_client is not None: if litellm._key_management_settings is not None: if ( litellm._key_management_settings.access_mode == "read_only" or litellm._key_management_settings.access_mode == "read_and_write" ): return True return False