PANH's picture
Update app.py
91291fb verified
raw
history blame
No virus
4.88 kB
import os
import requests
import tempfile
import shutil
import torch
from pytorch_lightning import LightningModule
from safetensors.torch import save_file
from torch import nn
import gradio as gr
from modelalign import BERTAlignModel
# ===========================
# Utility Functions
# ===========================
def download_checkpoint(url: str, dest_path: str):
"""
Downloads the checkpoint from the specified URL to the destination path.
"""
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(dest_path, 'wb') as f:
shutil.copyfileobj(response.raw, f)
return True, "Checkpoint downloaded successfully."
except Exception as e:
return False, f"Failed to download checkpoint: {str(e)}"
def initialize_model(model_name: str, device: str = 'cpu'):
"""
Initializes the BERTAlignModel based on the provided model name.
"""
try:
model = BERTAlignModel(base_model_name=model_name)
model.to(device)
model.eval() # Set to evaluation mode
return True, model
except Exception as e:
return False, f"Failed to initialize model: {str(e)}"
def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'):
"""
Loads the checkpoint into the model.
"""
try:
# Load the checkpoint; adjust map_location based on device
checkpoint = torch.load(checkpoint_path, map_location=device)
# Assuming the checkpoint has a 'state_dict' key
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
return True, "Checkpoint loaded successfully."
except Exception as e:
return False, f"Failed to load checkpoint: {str(e)}"
def convert_to_safetensors(model: LightningModule, save_path: str):
"""
Converts the model's state_dict to the safetensors format.
"""
try:
state_dict = model.state_dict()
save_file(state_dict, save_path)
return True, "Model converted to SafeTensors successfully."
except Exception as e:
return False, f"Failed to convert to SafeTensors: {str(e)}"
# ===========================
# Gradio Interface Function
# ===========================
def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
"""
Orchestrates the download, loading, conversion, and preparation for download.
Returns the safetensors file or an error message.
"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = os.path.join(tmpdir, "model.ckpt")
safetensors_path = os.path.join(tmpdir, "model.safetensors")
# Step 1: Download the checkpoint
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
if not success:
return None, message
# Step 2: Initialize the model
success, model_or_msg = initialize_model(model_name)
if not success:
return None, model_or_msg
model = model_or_msg
# Step 3: Load the checkpoint
success, message = load_checkpoint(model, checkpoint_path)
if not success:
return None, message
# Step 4: Convert to SafeTensors
success, message = convert_to_safetensors(model, safetensors_path)
if not success:
return None, message
# Step 5: Read the safetensors file for download
try:
return safetensors_path, "Conversion successful! Download your SafeTensors file below."
except Exception as e:
return None, f"Failed to prepare download: {str(e)}"
# ===========================
# Gradio Interface Setup
# ===========================
title = "Checkpoint to SafeTensors Converter"
description = """
Convert your PyTorch Lightning .ckpt checkpoints to the secure safetensors format.
**Inputs**:
- **Checkpoint URL**: Direct link to the .ckpt file.
- **Model Name**: Name of the base model (e.g., roberta-base, bert-base-uncased).
**Output**:
- Downloadable safetensors file.
"""
iface = gr.Interface(
fn=convert_checkpoint_to_safetensors,
inputs=[
gr.Textbox(
lines=2,
placeholder="Enter the checkpoint URL here...",
label="Checkpoint URL"
),
gr.Textbox(
lines=1,
placeholder="e.g., roberta-base",
label="Model Name"
)
],
outputs=[
gr.File(label="Download SafeTensors File"),
gr.Textbox(label="Status")
],
title=title,
description=description,
allow_flagging="never"
)
# ===========================
# Launch the Interface
# ===========================
if __name__ == "__main__":
iface.launch()