YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co./docs/hub/model-cards#model-card-metadata)

Song Finder: Fine-Tuning and Using the Gemma-2B Model with LoRA on Hugging Face


Project Overview

This project demonstrates how to fine-tune and use the Gemma-2B model for Question-Answering (QA) tasks. The fine-tuning process integrates LoRA (Low-Rank Adaptation) to optimize memory usage. The dataset used is a custom-generated dataset consisting of question-answer pairs related to song lyrics, where the original CSV format was converted into JSON format for fine-tuning. The fine-tuned model is then capable of generating answers based on new questions provided in a similar format.


Dataset Details

1. Data Collection:

  • The dataset was collected by web crawling song lyrics websites, resulting in approximately 100,000 song entries. Each entry contains metadata such as song lyrics, song ID, and genre ID.

2. Preprocessing:

  • After data collection, we performed missing value handling to clean the dataset. Only the information available as answers (A) was retained during this stage.

3. Q-A Pair Generation:

  • Since only answer (A) information was available, corresponding question (Q) information was generated through the following process:
    1. Data Transformation to Tags: Song data was transformed into tags that represent metadata, genre, and other characteristics.
    2. Extracting Tags for Q Generation: Relevant tags were extracted to create questions (Q) corresponding to the answers (A).
    3. Creating Varied Wording Lists: For each tag, lists of possible wording variations were created to generate a variety of question forms. This ensures that the model is trained with diverse question formats for each answer.
    4. Phonological Variations: For song lyrics, different phonological transformations (e.g., consonant/vowel shifts, casual speech) were applied to simulate real-world variations in the questions generated.
    5. OpenAI API for Q-A Generation: After these transformations, the OpenAI API was used to generate question-answer pairs. Song metadata and lyrics were provided as context for the API to generate the question (Q). These Q-A pairs were then stored in a JSON file for model fine-tuning.

4. Data Format:

  • The final set of {'question', 'answer'} pairs was stored in a JSON file, with the following structure:

    ...
    {
      "q_script": ...,
      "a_script": ...
    }
    {
      "q_script": "What is the name of the song that was released in 2022 with the lyrics 'You got me for days'?",
      "a_script": "The song is 'You Got Me' by Alan Walker, released in 2022."
    }
    {
      "q_script": ...,
      "a_script": ...
    }
    ...
    
  • This dataset was originally collected in CSV format but was converted into JSON to facilitate easier fine-tuning with the Gemma-2B model.


Key Features:

  • Gemma-2B Model: A pre-trained causal language model that generates answers based on input questions.
  • LoRA Integration: LoRA reduces the number of trainable parameters, optimizing memory usage for fine-tuning large models like Gemma-2B.
  • Custom Dataset: The dataset consists of question-answer pairs created from song metadata, lyrics, and other related information.
  • Multimodal Tasks: This project demonstrates how to use the model to handle complex queries such as song identification and troubleshooting issues related to specific lyrics or metadata.

Files Included

  • Training Script: A Python script for fine-tuning the Gemma-2B model using Keras and LoRA.
  • Dataset File: JSON file of question-answer pairs used for training.
  • Inference Example: Example prompts to query the fine-tuned model and generate answers.

How to Run

1. Install the Required Libraries:

Ensure you have the required libraries installed:

!pip install transformers keras-nlp pandas datasets huggingface_hub

2. Load the Dataset

The original dataset was in CSV format but was converted into JSON for fine-tuning. The following code demonstrates how to load the CSV file and convert it to a JSON file:

import pandas as pd
import json
import random

# Load the CSV dataset
csv_file = "qa_dataset.csv"
df = pd.read_csv(csv_file)

# Convert CSV data to JSON format
json_file = "qa_dataset.json"
data = df.to_dict(orient="records")

# Save the data as a JSON file
with open(json_file, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=4)

print(f"JSON file saved: {json_file}")

3. Fine-Tuning the Model

Use the following code to fine-tune the Gemma-2B model with LoRA:

import keras
from transformers import AutoTokenizer, AutoModelForCausalLM
from keras_nlp.models import GemmaCausalLM

# Load the fine-tuning dataset from the JSON file
with open(json_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

# Format the data for fine-tuning
formatted_data = [f"Question:
{item['q_script']}

Answer:
{item['a_script']}" for item in data]

# Load the Gemma-2B model
gemma_model_id = "gemma2_instruct_2b_en"
gemma_lm = GemmaCausalLM.from_preset(gemma_model_id)

# Enable LoRA for the model
gemma_lm.backbone.enable_lora(rank=4)

# Set sequence length and compile the model
gemma_lm.preprocessor.sequence_length = 128
optimizer = keras.optimizers.AdamW(learning_rate=5e-5, weight_decay=0.01)
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Fine-tune the model with the formatted data
gemma_lm.fit(formatted_data, epochs=1, batch_size=1)

4. Querying the Model

Once the model is fine-tuned, you can query it with new questions:

# Define a function to format a question into a prompt template
def ask_question(query: str) -> str:
    template = "Question:
{question}

Answer:
{answer}"
    prompt = template.format(question=query, answer="")
    return prompt

# Ask a question and generate an answer
prompt = ask_question("How can I install Python 3 on an AWS EC2 instance?")
print(gemma_lm.generate(prompt, max_length=512))

LoRA Integration

LoRA (Low-Rank Adaptation) is used in this project to optimize the fine-tuning process. By reducing the number of trainable parameters, LoRA allows for efficient fine-tuning of large models like Gemma-2B, even on devices with limited memory.


Uploading the Fine-Tuned Model to Hugging Face

Once the model is fine-tuned, you can upload it to the Hugging Face Hub using the following steps:

1. Install Hugging Face CLI

!pip install huggingface_hub
!huggingface-cli login

2. Save and Push the Model to Hugging Face

from huggingface_hub import HfApi

# Save and push model to Hugging Face
model_name = "my-fine-tuned-gemma"
gemma_lm.save_pretrained(model_name)
tokenizer.save_pretrained(model_name)

# Push to Hugging Face Hub
api = HfApi()
api.upload_folder(
    folder_path=model_name,
    repo_id="username/my-fine-tuned-gemma",
    repo_type="model"
)

Example Inference

Hereโ€™s an example of querying the fine-tuned model:

# Define a prompt with a specific query
prompt = ask_question("๋…ธ๋ž˜ ์ œ๋ชฉ์ด ๋„๋Œ€์ฒด ๋ญ์˜€๋Š”์ง€ ๊ธฐ์–ต ์•ˆ ๋‚˜์„œ ๋‹ต๋‹ตํ•ด์š”.
2016๋…„๊ฒฝ์— ๋ฐœํ‘œ๋œ ๊ณก ๊ฐ™์•„์š”.
๋Œ€์ค‘ ์Œ์•… ์Šคํƒ€์ผ ๋งž๋Š” ๊ฒƒ ๊ฐ™์•„.
์•„๋งˆ๋„ The Chainsmokers์˜ ๊ณก์ผ ๊ฑฐ์•ผ.")

# Generate the response
print(gemma_lm.generate(prompt, max_length=1012))

Output:

์งˆ๋ฌธํ•˜์‹  ๋…ธ๋ž˜๋Š” โ€˜The Chainsmokersโ€™์˜ โ€˜Closer (Feat. Halsey)โ€™์ž…๋‹ˆ๋‹ค.
ํ•ด๋‹น ๋…ธ๋ž˜๋Š” โ€˜์ผ๋ ‰ํŠธ๋กœ๋‹ˆ์นดโ€™ ์žฅ๋ฅด์˜ ๋…ธ๋ž˜์ž…๋‹ˆ๋‹ค.
โ€˜2016.11.05โ€™ ๋ฐœ๋งค๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
โ€˜Collage EPโ€™ ์•จ๋ฒ”์— ์ˆ˜๋ก๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

Acknowledgments

We would like to thank:

  • Google ML Bootcamp team for providing such opportunity.
  • Hugging Face for providing excellent tools and models.
  • The Keras team for Keras NLP support and integration.
  • All contributors to the Gemma and LoRA projects.

Contributors

This project was developed by:

  • Seohyun Kang
  • Sujin Kim
  • Mingyu Jo
Downloads last month
9
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.