|
""" |
|
This is a file for the AWS Secret Manager Integration |
|
|
|
Relevant issue: https://github.com/BerriAI/litellm/issues/1883 |
|
|
|
Requires: |
|
* `os.environ["AWS_REGION_NAME"], |
|
* `pip install boto3>=1.28.57` |
|
""" |
|
|
|
import ast |
|
import base64 |
|
import os |
|
import re |
|
from typing import Any, Dict, Optional |
|
|
|
import litellm |
|
from litellm.proxy._types import KeyManagementSystem |
|
|
|
|
|
def validate_environment(): |
|
if "AWS_REGION_NAME" not in os.environ: |
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME") |
|
|
|
|
|
def load_aws_kms(use_aws_kms: Optional[bool]): |
|
if use_aws_kms is None or use_aws_kms is False: |
|
return |
|
try: |
|
import boto3 |
|
|
|
validate_environment() |
|
|
|
|
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) |
|
|
|
litellm.secret_manager_client = kms_client |
|
litellm._key_management_system = KeyManagementSystem.AWS_KMS |
|
|
|
except Exception as e: |
|
raise e |
|
|
|
|
|
class AWSKeyManagementService_V2: |
|
""" |
|
V2 Clean Class for decrypting keys from AWS KeyManagementService |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.validate_environment() |
|
self.kms_client = self.load_aws_kms(use_aws_kms=True) |
|
|
|
def validate_environment( |
|
self, |
|
): |
|
if "AWS_REGION_NAME" not in os.environ: |
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME") |
|
|
|
|
|
is_litellm_license_in_env: bool = False |
|
|
|
if os.getenv("LITELLM_LICENSE", None) is not None: |
|
is_litellm_license_in_env = True |
|
elif os.getenv("LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE", None) is not None: |
|
is_litellm_license_in_env = True |
|
if is_litellm_license_in_env is False: |
|
raise ValueError( |
|
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment." |
|
) |
|
|
|
def load_aws_kms(self, use_aws_kms: Optional[bool]): |
|
if use_aws_kms is None or use_aws_kms is False: |
|
return |
|
try: |
|
import boto3 |
|
|
|
validate_environment() |
|
|
|
|
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) |
|
|
|
return kms_client |
|
except Exception as e: |
|
raise e |
|
|
|
def decrypt_value(self, secret_name: str) -> Any: |
|
if self.kms_client is None: |
|
raise ValueError("kms_client is None") |
|
encrypted_value = os.getenv(secret_name, None) |
|
if encrypted_value is None: |
|
raise Exception( |
|
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name) |
|
) |
|
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"): |
|
encrypted_value = encrypted_value.replace("aws_kms/", "") |
|
|
|
|
|
ciphertext_blob = base64.b64decode(encrypted_value) |
|
|
|
|
|
params = {"CiphertextBlob": ciphertext_blob} |
|
|
|
response = self.kms_client.decrypt(**params) |
|
|
|
|
|
plaintext = response["Plaintext"] |
|
secret = plaintext.decode("utf-8") |
|
if isinstance(secret, str): |
|
secret = secret.strip() |
|
try: |
|
secret_value_as_bool = ast.literal_eval(secret) |
|
if isinstance(secret_value_as_bool, bool): |
|
return secret_value_as_bool |
|
except Exception: |
|
pass |
|
|
|
return secret |
|
|
|
|
|
""" |
|
- look for all values in the env with `aws_kms/<hashed_key>` |
|
- decrypt keys |
|
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist. |
|
""" |
|
|
|
|
|
def decrypt_env_var() -> Dict[str, Any]: |
|
|
|
aws_kms = AWSKeyManagementService_V2() |
|
|
|
new_values = {} |
|
for k, v in os.environ.items(): |
|
if ( |
|
k is not None |
|
and isinstance(k, str) |
|
and k.lower().startswith("litellm_secret_aws_kms") |
|
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")): |
|
decrypted_value = aws_kms.decrypt_value(secret_name=k) |
|
|
|
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE) |
|
new_values[k] = decrypted_value |
|
|
|
return new_values |
|
|