HirCoir's picture
Update download-model.py
7024203 verified
raw
history blame
1.81 kB
import os
import base64
from huggingface_hub import HfApi, hf_hub_download
# Obtener el token de la variable de entorno y descodificarlo
token_base64 = os.getenv("TOKEN")
token = base64.b64decode(token_base64).decode("utf-8")
# Obtener el repo_id de la variable de entorno
repo_id = os.getenv("REPO_ID")
# Crear una instancia de HfApi para listar los archivos del repositorio
api = HfApi()
archivos = api.list_repo_files(repo_id=repo_id, token=token)
# Filtrar archivos por formato ".ckpt"
archivos_ckpt = [archivo for archivo in archivos if archivo.endswith(".ckpt")]
if not archivos_ckpt:
print("No se encontraron archivos .ckpt en el repositorio.")
else:
# Función para extraer la época de un archivo
def get_epoch(archivo: str) -> int:
try:
epoch_str = archivo.split("-")[0].split("=")[1]
return int(epoch_str)
except (IndexError, ValueError):
print(f"El archivo {archivo} no tiene el formato esperado. Ignorando.")
return float('-inf')
# Obtener el archivo con la mayor "epoch"
archivo_con_mayor_epoch = max(archivos_ckpt, key=get_epoch)
# Descargar el archivo con la mayor "epoch" como "model.ckpt"
ruta_archivo_ckpt = hf_hub_download(
repo_id=repo_id,
filename=archivo_con_mayor_epoch,
token=token,
local_dir=".",
)
# Renombrar el archivo descargado a "model.ckpt"
os.rename(ruta_archivo_ckpt, "model.ckpt")
print(f"Archivo {archivo_con_mayor_epoch} descargado y renombrado como model.ckpt.")
# Descargar el archivo config.json desde el repositorio
ruta_archivo_config = hf_hub_download(
repo_id=repo_id,
filename="config.json",
token=token,
local_dir=".",
)
print(f"Archivo config.json descargado a la ruta {ruta_archivo_config}.")