import os import requests import tempfile import shutil import torch from pytorch_lightning import LightningModule from safetensors.torch import save_file from torch import nn import gradio as gr from modelalign import BERTAlignModel # =========================== # Utility Functions # =========================== def download_checkpoint(url: str, dest_path: str): """ Downloads the checkpoint from the specified URL to the destination path. """ try: with requests.get(url, stream=True) as response: response.raise_for_status() with open(dest_path, 'wb') as f: shutil.copyfileobj(response.raw, f) return True, "Checkpoint downloaded successfully." except Exception as e: return False, f"Failed to download checkpoint: {str(e)}" def initialize_model(model_name: str, device: str = 'cpu'): """ Initializes the BERTAlignModel based on the provided model name. """ try: model = BERTAlignModel(base_model_name=model_name) model.to(device) model.eval() # Set to evaluation mode return True, model except Exception as e: return False, f"Failed to initialize model: {str(e)}" def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'): """ Loads the checkpoint into the model. """ try: # Load the checkpoint; adjust map_location based on device checkpoint = torch.load(checkpoint_path, map_location=device) # Assuming the checkpoint has a 'state_dict' key if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict'], strict=False) else: model.load_state_dict(checkpoint, strict=False) return True, "Checkpoint loaded successfully." except Exception as e: return False, f"Failed to load checkpoint: {str(e)}" def convert_to_safetensors(model: LightningModule, save_path: str): """ Converts the model's state_dict to the safetensors format. """ try: state_dict = model.state_dict() save_file(state_dict, save_path) return True, "Model converted to SafeTensors successfully." except Exception as e: return False, f"Failed to convert to SafeTensors: {str(e)}" # =========================== # Gradio Interface Function # =========================== def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str): """ Orchestrates the download, loading, conversion, and preparation for download. Returns the safetensors file or an error message. """ with tempfile.TemporaryDirectory() as tmpdir: checkpoint_path = os.path.join(tmpdir, "model.ckpt") safetensors_path = os.path.join(tmpdir, "model.safetensors") # Step 1: Download the checkpoint success, message = download_checkpoint(checkpoint_url, checkpoint_path) if not success: return None, message # Step 2: Initialize the model success, model_or_msg = initialize_model(model_name) if not success: return None, model_or_msg model = model_or_msg # Step 3: Load the checkpoint success, message = load_checkpoint(model, checkpoint_path) if not success: return None, message # Step 4: Convert to SafeTensors success, message = convert_to_safetensors(model, safetensors_path) if not success: return None, message # Step 5: Read the safetensors file for download try: return safetensors_path, "Conversion successful! Download your SafeTensors file below." except Exception as e: return None, f"Failed to prepare download: {str(e)}" # =========================== # Gradio Interface Setup # =========================== title = "Checkpoint to SafeTensors Converter" description = """ Convert your PyTorch Lightning .ckpt checkpoints to the secure safetensors format. **Inputs**: - **Checkpoint URL**: Direct link to the .ckpt file. - **Model Name**: Name of the base model (e.g., roberta-base, bert-base-uncased). **Output**: - Downloadable safetensors file. """ iface = gr.Interface( fn=convert_checkpoint_to_safetensors, inputs=[ gr.Textbox( lines=2, placeholder="Enter the checkpoint URL here...", label="Checkpoint URL" ), gr.Textbox( lines=1, placeholder="e.g., roberta-base", label="Model Name" ) ], outputs=[ gr.File(label="Download SafeTensors File"), gr.Textbox(label="Status") ], title=title, description=description, allow_flagging="never" ) # =========================== # Launch the Interface # =========================== if __name__ == "__main__": iface.launch()