File size: 3,662 Bytes
26e58a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import gc
import gradio as gr
import torch
from huggingface_hub import snapshot_download, HfApi, notebook_login, create_repo, whoami, login

api = HfApi()


def info_fn(text):
    gr.Info(text)


def warning_fn(text):
    gr.Warning(text)


def upload(hf_token, base_model_name_or_path, peft_model_path, output_dir):
    try:
      login(hf_token)
      repo_name = output_dir

      device_arg = {'device_map': "cpu"}

      info_fn(f"Loading base model: {base_model_name_or_path}")

      base_model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=torch.float16, **device_arg)

      info_fn(f"Loading PEFT: {peft_model_path}")

      model = PeftModel.from_pretrained(base_model, peft_model_path, **device_arg)

      info_fn(f"Running merge_and_unload")

      model = model.merge_and_unload()
      tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)

      info_fn("Saving model..")
      model.save_pretrained(output_dir, safe_serialization=True)

      info_fn("Saving tokenizer...")
      tokenizer.save_pretrained(output_dir)

      info_fn(f"Model saved to {output_dir}")

      del model
      gc.collect()

      try:
        info_fn("Creating Repo...")
        info_fn(api.create_repo(repo_id=repo_name).__dict__['url'])
      except Exception as e:
        warning_fn(f"Model already exists: {e}")

      info_fn("Uploading to hub...")
      uploading = api.upload_folder(
          folder_path=output_dir,
          repo_id=output_dir,
          repo_type="model")
      
      return uploading

    except Exception as e:
      gc.collect()
      gr.Error(e)

      return e


INTRODUCTION_TEXT = f"""
🎯 The Leaderboard allows you to merge your Lora adapters.

## ❓ What is Lora?

LoRA: Low-Rank Adaptation of Large Language Models allows you to train LLM's with a low cost. Lora freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks.
You can learn more about LoRa here:

[📝 LoRA: Low-Rank Adaptation of Large Language Models Arxiv](https://arxiv.org/abs/2106.09685)

## 🛠️ How does this space work?

🛠️ The leaderboard's backend mainly runs the transformers and PEFT library.

🤖 The code first loads your original model and then your adapter models.

📚 The code merges your adapter weights using the `merge_and_unload` function from the PEFT library.

📤 The code saves your resulting model temporarily and then pushes the resulting model to the hub.

## 🧮 Required RAM

This space is loading the model to RAM without performing any quantization, so the required RAM is high.

You can merge models up to 13B. (If your adapter weights are too large, it might not work.)
"""


with gr.Blocks() as demo:
    gr.Markdown("""<h1 align="center" id="space-title">🚀 Lora Merge</h1>""")
    gr.Markdown(INTRODUCTION_TEXT)

    with gr.Row():
      with gr.Column(scale=1):
        hf_token = gr.Textbox(label="Huggingface Write Access Token")
        base_model_name_or_path = gr.Textbox(label="Base Model")
        peft_model_path = gr.Textbox(label="Adapter Model")
        output_dir = gr.Textbox(label="Output Model Name")

      with gr.Column(scale=1):
        text  = gr.Textbox(label="Output Model Name", lines=14)


    submit = gr.Button("Merge lora with adapters")
    submit.click(fn=upload, inputs=[hf_token, base_model_name_or_path, peft_model_path, output_dir], outputs=text)


demo.queue()
demo.launch(show_error=True)