import base64 import io import jsonlines import torch import wandb from huggingface_hub import HfApi from PIL import Image def get_wandb_artifact( artifact_name: str, artifact_type: str, get_metadata: bool = False, ) -> str: if wandb.run: artifact = wandb.use_artifact(artifact_name, type=artifact_type) artifact_dir = artifact.download() else: api = wandb.Api() artifact = api.artifact(artifact_name) artifact_dir = artifact.download() if get_metadata: return artifact_dir, artifact.metadata return artifact_dir def get_torch_backend(): if torch.cuda.is_available(): if torch.backends.cuda.is_built(): return "cuda" if torch.backends.mps.is_available(): if torch.backends.mps.is_built(): return "mps" return "cpu" return "cpu" def base64_encode_image(image: Image.Image, mimetype: str) -> str: image.load() if image.mode not in ("RGB", "RGBA"): image = image.convert("RGB") byte_arr = io.BytesIO() image.save(byte_arr, format="PNG") encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8") encoded_string = f"data:{mimetype};base64,{encoded_string}" return str(encoded_string) def read_jsonl_file(file_path: str) -> list[dict[str, any]]: with jsonlines.open(file_path) as reader: for obj in reader: return obj def save_to_huggingface( repo_id: str, local_dir: str, commit_message: str, private: bool = False ): api = HfApi() repo_url = api.create_repo( repo_id=repo_id, token=api.token, private=private, repo_type="model", exist_ok=True, ) repo_id = repo_url.repo_id api.upload_folder( repo_id=repo_id, commit_message=commit_message, token=api.token, folder_path=local_dir, repo_type=repo_url.repo_type, ) def fetch_from_huggingface(repo_id: str, local_dir: str) -> str: api = HfApi() repo_url = api.repo_info(repo_id) if repo_url is None: raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") snapshot = api.snapshot_download(repo_id, revision=None, local_dir=local_dir) if snapshot is None: raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") return snapshot