|
import os |
|
import base64 |
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
|
|
|
token_base64 = os.getenv("TOKEN") |
|
token = base64.b64decode(token_base64).decode("utf-8") |
|
|
|
|
|
repo_id = os.getenv("REPO_ID") |
|
|
|
|
|
api = HfApi() |
|
archivos = api.list_repo_files(repo_id=repo_id, token=token) |
|
|
|
|
|
archivos_ckpt = [archivo for archivo in archivos if archivo.endswith(".ckpt")] |
|
|
|
if not archivos_ckpt: |
|
print("No se encontraron archivos .ckpt en el repositorio.") |
|
else: |
|
|
|
def get_epoch(archivo: str) -> int: |
|
try: |
|
epoch_str = archivo.split("-")[0].split("=")[1] |
|
return int(epoch_str) |
|
except (IndexError, ValueError): |
|
print(f"El archivo {archivo} no tiene el formato esperado. Ignorando.") |
|
return float('-inf') |
|
|
|
|
|
archivo_con_mayor_epoch = max(archivos_ckpt, key=get_epoch) |
|
|
|
|
|
ruta_archivo_descargado = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=archivo_con_mayor_epoch, |
|
token=token |
|
) |
|
|
|
|
|
os.rename(ruta_archivo_descargado, "model.ckpt") |
|
|
|
print(f"Archivo {archivo_con_mayor_epoch} descargado y renombrado como model.ckpt.") |
|
|