File size: 4,883 Bytes
1f76ea6
 
 
 
 
 
 
 
 
 
2d8296f
b8db24f
1f76ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8db24f
 
 
 
 
1f76ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8db24f
1f76ea6
 
 
 
b8db24f
1f76ea6
 
 
 
 
b8db24f
1f76ea6
 
 
 
b8db24f
1f76ea6
 
 
b8db24f
1f76ea6
b8db24f
1f76ea6
 
 
 
 
 
 
91291fb
1f76ea6
 
91291fb
 
1f76ea6
 
91291fb
1f76ea6
2d8296f
 
1f76ea6
 
b8db24f
 
 
 
 
 
 
 
 
 
1f76ea6
 
b8db24f
 
1f76ea6
 
 
 
2d8296f
 
1f76ea6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import requests
import tempfile
import shutil

import torch
from pytorch_lightning import LightningModule
from safetensors.torch import save_file
from torch import nn

import gradio as gr
from modelalign import BERTAlignModel

# ===========================
# Utility Functions
# ===========================

def download_checkpoint(url: str, dest_path: str):
    """
    Downloads the checkpoint from the specified URL to the destination path.
    """
    try:
        with requests.get(url, stream=True) as response:
            response.raise_for_status()
            with open(dest_path, 'wb') as f:
                shutil.copyfileobj(response.raw, f)
        return True, "Checkpoint downloaded successfully."
    except Exception as e:
        return False, f"Failed to download checkpoint: {str(e)}"

def initialize_model(model_name: str, device: str = 'cpu'):
    """
    Initializes the BERTAlignModel based on the provided model name.
    """
    try:
        model = BERTAlignModel(base_model_name=model_name)
        model.to(device)
        model.eval()  # Set to evaluation mode
        return True, model
    except Exception as e:
        return False, f"Failed to initialize model: {str(e)}"

def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'):
    """
    Loads the checkpoint into the model.
    """
    try:
        # Load the checkpoint; adjust map_location based on device
        checkpoint = torch.load(checkpoint_path, map_location=device)
        # Assuming the checkpoint has a 'state_dict' key
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            model.load_state_dict(checkpoint, strict=False)
        return True, "Checkpoint loaded successfully."
    except Exception as e:
        return False, f"Failed to load checkpoint: {str(e)}"

def convert_to_safetensors(model: LightningModule, save_path: str):
    """
    Converts the model's state_dict to the safetensors format.
    """
    try:
        state_dict = model.state_dict()
        save_file(state_dict, save_path)
        return True, "Model converted to SafeTensors successfully."
    except Exception as e:
        return False, f"Failed to convert to SafeTensors: {str(e)}"

# ===========================
# Gradio Interface Function
# ===========================

def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
    """
    Orchestrates the download, loading, conversion, and preparation for download.
    Returns the safetensors file or an error message.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        checkpoint_path = os.path.join(tmpdir, "model.ckpt")
        safetensors_path = os.path.join(tmpdir, "model.safetensors")

        # Step 1: Download the checkpoint
        success, message = download_checkpoint(checkpoint_url, checkpoint_path)
        if not success:
            return None, message

        # Step 2: Initialize the model
        success, model_or_msg = initialize_model(model_name)
        if not success:
            return None, model_or_msg
        model = model_or_msg

        # Step 3: Load the checkpoint
        success, message = load_checkpoint(model, checkpoint_path)
        if not success:
            return None, message

        # Step 4: Convert to SafeTensors
        success, message = convert_to_safetensors(model, safetensors_path)
        if not success:
            return None, message

        # Step 5: Read the safetensors file for download
        try:
            return safetensors_path, "Conversion successful! Download your SafeTensors file below."
        except Exception as e:
            return None, f"Failed to prepare download: {str(e)}"

# ===========================
# Gradio Interface Setup
# ===========================

title = "Checkpoint to SafeTensors Converter"
description = """
Convert your PyTorch Lightning .ckpt checkpoints to the secure safetensors format.

**Inputs**:
- **Checkpoint URL**: Direct link to the .ckpt file.
- **Model Name**: Name of the base model (e.g., roberta-base, bert-base-uncased).

**Output**:
- Downloadable safetensors file.
"""

iface = gr.Interface(
    fn=convert_checkpoint_to_safetensors,
    inputs=[
        gr.Textbox(
            lines=2, 
            placeholder="Enter the checkpoint URL here...", 
            label="Checkpoint URL"
        ),
        gr.Textbox(
            lines=1, 
            placeholder="e.g., roberta-base", 
            label="Model Name"
        )
    ],
    outputs=[
        gr.File(label="Download SafeTensors File"),
        gr.Textbox(label="Status")
    ],
    title=title,
    description=description,
    allow_flagging="never"
)

# ===========================
# Launch the Interface
# ===========================

if __name__ == "__main__":
    iface.launch()