# Dataset

In [2]:
import datasets

dataset = datasets.load_dataset("tiiuae/falcon-refinedweb", streaming=True, split="train")

  from .autonotebook import tqdm as notebook_tqdm


In [30]:
dataset

# show the first few examples
for example in dataset.take(5):
    print(example)


{'content': 'these birches can be found in many places in Europe - the photos is from a short trip to Baden-Baden in 2007. the clouds in the background are the messengers of the storm Kyrill. here are some more moments of the trip: Baden-Baden.\n-\n“ast/ray” is a bilingual wordplay: “ast” means “twig” in German. and while “Baden-Baden” sounds like wordplay, too, it is the actual name of a rather well-know spa town that also dates back to Roman times. “Bad” is the German word for “bath”.\nMirror effect turned out nice. I like', 'url': 'http://100parts.wordpress.com/2012/08/04/astray-baden-baden-day-31/', 'timestamp': datetime.datetime(2013, 5, 18, 10, 42), 'dump': 'CC-MAIN-2013-20', 'segment': '1368696382261', 'image_urls': []}
{'content': 'Watch Survivor Redemption Island Season 22 Episode 11: A Mystery Package Online S22e11 Free Stream Megavideo\nArticle by StreamThatSeries\nHorray!time for another dose of very exciting reality series with lots of twists.You must watch survivor redemp

# Model and tokenizer

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "luodian/llama-7b-hf"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config,  
 device_map="auto")

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.07s/it]


In [5]:
for example in dataset.take(10):
    token_count = len(tokenizer.encode(example['content']))
    print(f"Number of tokens in sample: {token_count}") 

Number of tokens in sample: 150
Number of tokens in sample: 2731
Number of tokens in sample: 52
Number of tokens in sample: 162
Number of tokens in sample: 78
Number of tokens in sample: 257
Number of tokens in sample: 1074
Number of tokens in sample: 505
Number of tokens in sample: 592
Number of tokens in sample: 932


# Function to calculate validation loss

In [5]:
def calculate_validation_loss(model, dataloader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch, labels=batch["input_ids"])
            loss = outputs.loss
            total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    return average_loss

# Prepare data

In [9]:
def tokenize_function(examples):
    return tokenizer(examples["content"], padding="max_length", truncation=True, max_length=512) 

In [13]:
sample_size = 1000  # small, for a quick experiment

# Approximate 80/20 split while streaming
train_dataset = []
validation_dataset = []
for i, item in enumerate(dataset):
    if i % 5 == 0:  # Every 5th item goes to validation (approximately 20%)
        validation_dataset.append(item)
    else:
        train_dataset.append(item)
    if len(train_dataset) >= sample_size and len(validation_dataset) >= (sample_size // 4): 
        # Stop once we have enough samples for both training and validation
        break

# Tokenize the training and validation samples
tokenized_train_dataset = []
for item in train_dataset:
    tokenized_item = tokenize_function(item)
    tokenized_train_dataset.append(tokenized_item)

tokenized_validation_dataset = []
for item in validation_dataset:
    tokenized_item = tokenize_function(item)
    tokenized_validation_dataset.append(tokenized_item)

# Convert to Dataset objects if needed
from datasets import Dataset
tokenized_train_dataset = Dataset.from_list(tokenized_train_dataset)
tokenized_validation_dataset = Dataset.from_list(tokenized_validation_dataset)

# Convert the tokenized datasets to PyTorch tensors
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
tokenized_validation_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])


# Before training

## Calculate validation loss

In [14]:
from torch.utils.data import DataLoader

validation_dataloader = DataLoader(tokenized_validation_dataset, batch_size=2) 
loss_before_training = calculate_validation_loss(model, validation_dataloader)
print(f"Validation loss before training: {loss_before_training}")

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co./docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Validation loss before training: 9.958484322547912


## Benchmark

Run this command:

```
accelerate launch -m  lm_eval --model hf \
    --model_args pretrained=luodian/llama-7b-hf,load_in_4bit=True,dtype="bfloat16" \
    --tasks mmlu,hellaswag,truthfulqa \
    --batch_size auto:4 \
    --log_samples \
    --output_path results/before-training
```

Output:

```
hf (pretrained=luodian/llama-7b-hf,load_in_4bit=True,dtype=bfloat16), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto:4 (16,64,64,64)
|                 Tasks                 |Version|Filter|n-shot|  Metric   |   | Value  |   |Stderr|
|---------------------------------------|------:|------|-----:|-----------|---|-------:|---|-----:|
|hellaswag                              |      1|none  |     0|acc        |↑  |  0.5646|±  |0.0049|
|                                       |       |none  |     0|acc_norm   |↑  |  0.7498|±  |0.0043|
|mmlu                                   |      2|none  |      |acc        |↑  |  0.3126|±  |0.0039|
| - humanities                          |      2|none  |      |acc        |↑  |  0.3101|±  |0.0067|
|  - formal_logic                       |      1|none  |     0|acc        |↑  |  0.2778|±  |0.0401|
|  - high_school_european_history       |      1|none  |     0|acc        |↑  |  0.3939|±  |0.0382|
|  - high_school_us_history             |      1|none  |     0|acc        |↑  |  0.3824|±  |0.0341|
|  - high_school_world_history          |      1|none  |     0|acc        |↑  |  0.3671|±  |0.0314|
|  - international_law                  |      1|none  |     0|acc        |↑  |  0.3884|±  |0.0445|
|  - jurisprudence                      |      1|none  |     0|acc        |↑  |  0.3519|±  |0.0462|
|  - logical_fallacies                  |      1|none  |     0|acc        |↑  |  0.3067|±  |0.0362|
|  - moral_disputes                     |      1|none  |     0|acc        |↑  |  0.3555|±  |0.0258|
|  - moral_scenarios                    |      1|none  |     0|acc        |↑  |  0.2380|±  |0.0142|
|  - philosophy                         |      1|none  |     0|acc        |↑  |  0.3441|±  |0.0270|
|  - prehistory                         |      1|none  |     0|acc        |↑  |  0.3364|±  |0.0263|
|  - professional_law                   |      1|none  |     0|acc        |↑  |  0.2803|±  |0.0115|
|  - world_religions                    |      1|none  |     0|acc        |↑  |  0.4503|±  |0.0382|
| - other                               |      2|none  |      |acc        |↑  |  0.3399|±  |0.0084|
|  - business_ethics                    |      1|none  |     0|acc        |↑  |  0.3300|±  |0.0473|
|  - clinical_knowledge                 |      1|none  |     0|acc        |↑  |  0.3208|±  |0.0287|
|  - college_medicine                   |      1|none  |     0|acc        |↑  |  0.2775|±  |0.0341|
|  - global_facts                       |      1|none  |     0|acc        |↑  |  0.3500|±  |0.0479|
|  - human_aging                        |      1|none  |     0|acc        |↑  |  0.3094|±  |0.0310|
|  - management                         |      1|none  |     0|acc        |↑  |  0.2524|±  |0.0430|
|  - marketing                          |      1|none  |     0|acc        |↑  |  0.3718|±  |0.0317|
|  - medical_genetics                   |      1|none  |     0|acc        |↑  |  0.4200|±  |0.0496|
|  - miscellaneous                      |      1|none  |     0|acc        |↑  |  0.4291|±  |0.0177|
|  - nutrition                          |      1|none  |     0|acc        |↑  |  0.3235|±  |0.0268|
|  - professional_accounting            |      1|none  |     0|acc        |↑  |  0.2801|±  |0.0268|
|  - professional_medicine              |      1|none  |     0|acc        |↑  |  0.2500|±  |0.0263|
|  - virology                           |      1|none  |     0|acc        |↑  |  0.2952|±  |0.0355|
| - social sciences                     |      2|none  |      |acc        |↑  |  0.3052|±  |0.0083|
|  - econometrics                       |      1|none  |     0|acc        |↑  |  0.2544|±  |0.0410|
|  - high_school_geography              |      1|none  |     0|acc        |↑  |  0.2828|±  |0.0321|
|  - high_school_government_and_politics|      1|none  |     0|acc        |↑  |  0.3161|±  |0.0336|
|  - high_school_macroeconomics         |      1|none  |     0|acc        |↑  |  0.2538|±  |0.0221|
|  - high_school_microeconomics         |      1|none  |     0|acc        |↑  |  0.2395|±  |0.0277|
|  - high_school_psychology             |      1|none  |     0|acc        |↑  |  0.3358|±  |0.0202|
|  - human_sexuality                    |      1|none  |     0|acc        |↑  |  0.2901|±  |0.0398|
|  - professional_psychology            |      1|none  |     0|acc        |↑  |  0.3333|±  |0.0191|
|  - public_relations                   |      1|none  |     0|acc        |↑  |  0.3000|±  |0.0439|
|  - security_studies                   |      1|none  |     0|acc        |↑  |  0.2286|±  |0.0269|
|  - sociology                          |      1|none  |     0|acc        |↑  |  0.4179|±  |0.0349|
|  - us_foreign_policy                  |      1|none  |     0|acc        |↑  |  0.3900|±  |0.0490|
| - stem                                |      2|none  |      |acc        |↑  |  0.2965|±  |0.0081|
|  - abstract_algebra                   |      1|none  |     0|acc        |↑  |  0.2600|±  |0.0441|
|  - anatomy                            |      1|none  |     0|acc        |↑  |  0.3407|±  |0.0409|
|  - astronomy                          |      1|none  |     0|acc        |↑  |  0.3487|±  |0.0388|
|  - college_biology                    |      1|none  |     0|acc        |↑  |  0.3403|±  |0.0396|
|  - college_chemistry                  |      1|none  |     0|acc        |↑  |  0.2700|±  |0.0446|
|  - college_computer_science           |      1|none  |     0|acc        |↑  |  0.3300|±  |0.0473|
|  - college_mathematics                |      1|none  |     0|acc        |↑  |  0.3000|±  |0.0461|
|  - college_physics                    |      1|none  |     0|acc        |↑  |  0.1765|±  |0.0379|
|  - computer_security                  |      1|none  |     0|acc        |↑  |  0.3500|±  |0.0479|
|  - conceptual_physics                 |      1|none  |     0|acc        |↑  |  0.3106|±  |0.0303|
|  - electrical_engineering             |      1|none  |     0|acc        |↑  |  0.3310|±  |0.0392|
|  - elementary_mathematics             |      1|none  |     0|acc        |↑  |  0.2619|±  |0.0226|
|  - high_school_biology                |      1|none  |     0|acc        |↑  |  0.3645|±  |0.0274|
|  - high_school_chemistry              |      1|none  |     0|acc        |↑  |  0.2512|±  |0.0305|
|  - high_school_computer_science       |      1|none  |     0|acc        |↑  |  0.3800|±  |0.0488|
|  - high_school_mathematics            |      1|none  |     0|acc        |↑  |  0.2481|±  |0.0263|
|  - high_school_physics                |      1|none  |     0|acc        |↑  |  0.2450|±  |0.0351|
|  - high_school_statistics             |      1|none  |     0|acc        |↑  |  0.2639|±  |0.0301|
|  - machine_learning                   |      1|none  |     0|acc        |↑  |  0.3125|±  |0.0440|
|truthfulqa_gen                         |      3|none  |     0|bleu_acc   |↑  |  0.2766|±  |0.0157|
|                                       |       |none  |     0|bleu_diff  |↑  |-10.2902|±  |0.8441|
|                                       |       |none  |     0|bleu_max   |↑  | 26.5005|±  |0.8063|
|                                       |       |none  |     0|rouge1_acc |↑  |  0.2619|±  |0.0154|
|                                       |       |none  |     0|rouge1_diff|↑  |-13.4103|±  |0.8561|
|                                       |       |none  |     0|rouge1_max |↑  | 51.0861|±  |0.8835|
|                                       |       |none  |     0|rouge2_acc |↑  |  0.2240|±  |0.0146|
|                                       |       |none  |     0|rouge2_diff|↑  |-15.4705|±  |1.0517|
|                                       |       |none  |     0|rouge2_max |↑  | 35.0729|±  |1.0250|
|                                       |       |none  |     0|rougeL_acc |↑  |  0.2619|±  |0.0154|
|                                       |       |none  |     0|rougeL_diff|↑  |-13.6375|±  |0.8721|
|                                       |       |none  |     0|rougeL_max |↑  | 48.4303|±  |0.8983|
|truthfulqa_mc1                         |      2|none  |     0|acc        |↑  |  0.2069|±  |0.0142|
|truthfulqa_mc2                         |      2|none  |     0|acc        |↑  |  0.3252|±  |0.0131|

|      Groups      |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------------------|------:|------|------|------|---|-----:|---|-----:|
|mmlu              |      2|none  |      |acc   |↑  |0.3126|±  |0.0039|
| - humanities     |      2|none  |      |acc   |↑  |0.3101|±  |0.0067|
| - other          |      2|none  |      |acc   |↑  |0.3399|±  |0.0084|
| - social sciences|      2|none  |      |acc   |↑  |0.3052|±  |0.0083|
| - stem           |      2|none  |      |acc   |↑  |0.2965|±  |0.0081|

```



# Training

## Define training loop

In [15]:
from transformers import AdamW, get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(tokenized_train_dataset)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

## Dataloader for training

In [16]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_train_dataset, shuffle=True, batch_size=2) 

## Run training loop

In [20]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    total_loss = 0

    for batch in train_dataloader:
        batch = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**batch, labels=batch["input_ids"])
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {average_loss:.4f}")  # Print the loss

 23%|██▎       | 686/3000 [04:50<16:20,  2.36it/s]


Epoch 1/3 - Average Loss: 2.9037




Epoch 2/3 - Average Loss: 2.0474




Epoch 3/3 - Average Loss: 1.4295


## Loss minimisation

We see clear loss minimisation after just a few training loops (3 epochs over 1000 samples).

Epoch 1/3 - Average Loss: 2.9037

Epoch 2/3 - Average Loss: 2.0474

Epoch 3/3 - Average Loss: 1.4295

# Push model to hub

In [22]:
from huggingface_hub import HfApi

api = HfApi()

trained_model_id = "opyate/llama-7b-hf-redefined-3ep-1k"
api.create_repo(repo_id=trained_model_id)
tokenizer.push_to_hub(trained_model_id)
model.push_to_hub(trained_model_id)

model.safetensors: 100%|██████████| 4.17G/4.17G [05:49<00:00, 11.9MB/s]


CommitInfo(commit_url='https://huggingface.co./opyate/llama-7b-hf-redefined-3ep-1k/commit/ba69710ff29bd8ef8e0eb2460fde7cf0b8d1522c', commit_message='Upload LlamaForCausalLM', commit_description='', oid='ba69710ff29bd8ef8e0eb2460fde7cf0b8d1522c', pr_url=None, pr_revision=None, pr_num=None)

# After training

## Calculate validation loss

In [23]:
loss_after_training = calculate_validation_loss(model, validation_dataloader)
print(f"Validation loss before training: {loss_before_training}")
print(f"Validation loss after training: {loss_after_training}")

Validation loss before training: 9.958484322547912
Validation loss after training: 3.7852714624404906


Observation: validation loss also minimised after training the model.

## Benchmark

Run this command:

```
accelerate launch -m  lm_eval --model hf \
    --model_args pretrained=opyate/llama-7b-hf-redefined-3ep-1k,load_in_4bit=True,dtype="bfloat16" \
    --tasks mmlu,hellaswag,truthfulqa \
    --batch_size auto:4 \
    --log_samples \
    --output_path results/after-training
```

Output:

```
hf (pretrained=opyate/llama-7b-hf-redefined-3ep-1k,load_in_4bit=True,dtype=bfloat16), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto:4 (16,64,64,64)
|                 Tasks                 |Version|Filter|n-shot|  Metric   |   | Value |   |Stderr|
|---------------------------------------|------:|------|-----:|-----------|---|------:|---|-----:|
|hellaswag                              |      1|none  |     0|acc        |↑  | 0.2922|±  |0.0045|
|                                       |       |none  |     0|acc_norm   |↑  | 0.3176|±  |0.0046|
|mmlu                                   |      2|none  |      |acc        |↑  | 0.2400|±  |0.0036|
| - humanities                          |      2|none  |      |acc        |↑  | 0.2436|±  |0.0063|
|  - formal_logic                       |      1|none  |     0|acc        |↑  | 0.2698|±  |0.0397|
|  - high_school_european_history       |      1|none  |     0|acc        |↑  | 0.2182|±  |0.0323|
|  - high_school_us_history             |      1|none  |     0|acc        |↑  | 0.2843|±  |0.0317|
|  - high_school_world_history          |      1|none  |     0|acc        |↑  | 0.2574|±  |0.0285|
|  - international_law                  |      1|none  |     0|acc        |↑  | 0.2314|±  |0.0385|
|  - jurisprudence                      |      1|none  |     0|acc        |↑  | 0.2778|±  |0.0433|
|  - logical_fallacies                  |      1|none  |     0|acc        |↑  | 0.2638|±  |0.0346|
|  - moral_disputes                     |      1|none  |     0|acc        |↑  | 0.2601|±  |0.0236|
|  - moral_scenarios                    |      1|none  |     0|acc        |↑  | 0.2380|±  |0.0142|
|  - philosophy                         |      1|none  |     0|acc        |↑  | 0.2283|±  |0.0238|
|  - prehistory                         |      1|none  |     0|acc        |↑  | 0.2623|±  |0.0245|
|  - professional_law                   |      1|none  |     0|acc        |↑  | 0.2327|±  |0.0108|
|  - world_religions                    |      1|none  |     0|acc        |↑  | 0.2339|±  |0.0325|
| - other                               |      2|none  |      |acc        |↑  | 0.2594|±  |0.0078|
|  - business_ethics                    |      1|none  |     0|acc        |↑  | 0.3200|±  |0.0469|
|  - clinical_knowledge                 |      1|none  |     0|acc        |↑  | 0.2528|±  |0.0267|
|  - college_medicine                   |      1|none  |     0|acc        |↑  | 0.2486|±  |0.0330|
|  - global_facts                       |      1|none  |     0|acc        |↑  | 0.2900|±  |0.0456|
|  - human_aging                        |      1|none  |     0|acc        |↑  | 0.3857|±  |0.0327|
|  - management                         |      1|none  |     0|acc        |↑  | 0.2524|±  |0.0430|
|  - marketing                          |      1|none  |     0|acc        |↑  | 0.2735|±  |0.0292|
|  - medical_genetics                   |      1|none  |     0|acc        |↑  | 0.2600|±  |0.0441|
|  - miscellaneous                      |      1|none  |     0|acc        |↑  | 0.2414|±  |0.0153|
|  - nutrition                          |      1|none  |     0|acc        |↑  | 0.2451|±  |0.0246|
|  - professional_accounting            |      1|none  |     0|acc        |↑  | 0.2447|±  |0.0256|
|  - professional_medicine              |      1|none  |     0|acc        |↑  | 0.1912|±  |0.0239|
|  - virology                           |      1|none  |     0|acc        |↑  | 0.2892|±  |0.0353|
| - social sciences                     |      2|none  |      |acc        |↑  | 0.2327|±  |0.0076|
|  - econometrics                       |      1|none  |     0|acc        |↑  | 0.2193|±  |0.0389|
|  - high_school_geography              |      1|none  |     0|acc        |↑  | 0.1818|±  |0.0275|
|  - high_school_government_and_politics|      1|none  |     0|acc        |↑  | 0.2021|±  |0.0290|
|  - high_school_macroeconomics         |      1|none  |     0|acc        |↑  | 0.2077|±  |0.0206|
|  - high_school_microeconomics         |      1|none  |     0|acc        |↑  | 0.2353|±  |0.0276|
|  - high_school_psychology             |      1|none  |     0|acc        |↑  | 0.2349|±  |0.0182|
|  - human_sexuality                    |      1|none  |     0|acc        |↑  | 0.2443|±  |0.0377|
|  - professional_psychology            |      1|none  |     0|acc        |↑  | 0.2598|±  |0.0177|
|  - public_relations                   |      1|none  |     0|acc        |↑  | 0.2909|±  |0.0435|
|  - security_studies                   |      1|none  |     0|acc        |↑  | 0.2041|±  |0.0258|
|  - sociology                          |      1|none  |     0|acc        |↑  | 0.2587|±  |0.0310|
|  - us_foreign_policy                  |      1|none  |     0|acc        |↑  | 0.2600|±  |0.0441|
| - stem                                |      2|none  |      |acc        |↑  | 0.2226|±  |0.0074|
|  - abstract_algebra                   |      1|none  |     0|acc        |↑  | 0.3100|±  |0.0465|
|  - anatomy                            |      1|none  |     0|acc        |↑  | 0.2519|±  |0.0375|
|  - astronomy                          |      1|none  |     0|acc        |↑  | 0.1711|±  |0.0306|
|  - college_biology                    |      1|none  |     0|acc        |↑  | 0.2292|±  |0.0351|
|  - college_chemistry                  |      1|none  |     0|acc        |↑  | 0.2200|±  |0.0416|
|  - college_computer_science           |      1|none  |     0|acc        |↑  | 0.1300|±  |0.0338|
|  - college_mathematics                |      1|none  |     0|acc        |↑  | 0.2400|±  |0.0429|
|  - college_physics                    |      1|none  |     0|acc        |↑  | 0.1667|±  |0.0371|
|  - computer_security                  |      1|none  |     0|acc        |↑  | 0.1600|±  |0.0368|
|  - conceptual_physics                 |      1|none  |     0|acc        |↑  | 0.2766|±  |0.0292|
|  - electrical_engineering             |      1|none  |     0|acc        |↑  | 0.2483|±  |0.0360|
|  - elementary_mathematics             |      1|none  |     0|acc        |↑  | 0.2196|±  |0.0213|
|  - high_school_biology                |      1|none  |     0|acc        |↑  | 0.2097|±  |0.0232|
|  - high_school_chemistry              |      1|none  |     0|acc        |↑  | 0.2217|±  |0.0292|
|  - high_school_computer_science       |      1|none  |     0|acc        |↑  | 0.2200|±  |0.0416|
|  - high_school_mathematics            |      1|none  |     0|acc        |↑  | 0.3000|±  |0.0279|
|  - high_school_physics                |      1|none  |     0|acc        |↑  | 0.1457|±  |0.0288|
|  - high_school_statistics             |      1|none  |     0|acc        |↑  | 0.1667|±  |0.0254|
|  - machine_learning                   |      1|none  |     0|acc        |↑  | 0.2768|±  |0.0425|
|truthfulqa_gen                         |      3|none  |     0|bleu_acc   |↑  | 0.0073|±  |0.0030|
|                                       |       |none  |     0|bleu_diff  |↑  | 0.0010|±  |0.0008|
|                                       |       |none  |     0|bleu_max   |↑  | 0.0187|±  |0.0088|
|                                       |       |none  |     0|rouge1_acc |↑  | 0.0220|±  |0.0051|
|                                       |       |none  |     0|rouge1_diff|↑  |-0.0077|±  |0.0277|
|                                       |       |none  |     0|rouge1_max |↑  | 0.0726|±  |0.0213|
|                                       |       |none  |     0|rouge2_acc |↑  | 0.0024|±  |0.0017|
|                                       |       |none  |     0|rouge2_diff|↑  | 0.0031|±  |0.0043|
|                                       |       |none  |     0|rouge2_max |↑  | 0.0048|±  |0.0039|
|                                       |       |none  |     0|rougeL_acc |↑  | 0.0220|±  |0.0051|
|                                       |       |none  |     0|rougeL_diff|↑  |-0.0093|±  |0.0278|
|                                       |       |none  |     0|rougeL_max |↑  | 0.0709|±  |0.0208|
|truthfulqa_mc1                         |      2|none  |     0|acc        |↑  | 0.2277|±  |0.0147|
|truthfulqa_mc2                         |      2|none  |     0|acc        |↑  | 0.4685|±  |0.0171|

|      Groups      |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------------------|------:|------|------|------|---|-----:|---|-----:|
|mmlu              |      2|none  |      |acc   |↑  |0.2400|±  |0.0036|
| - humanities     |      2|none  |      |acc   |↑  |0.2436|±  |0.0063|
| - other          |      2|none  |      |acc   |↑  |0.2594|±  |0.0078|
| - social sciences|      2|none  |      |acc   |↑  |0.2327|±  |0.0076|
| - stem           |      2|none  |      |acc   |↑  |0.2226|±  |0.0074|
```

# Conclusion

The benchmarks generally performed worse after training. E.g. before training, `hellaswag` on `acc_norm` showed reasonable performance (0.7498), then more than halved after training (0.3176). `truthfulqa_mc1` and `truthfulqa_mc2` show slight improvements.

There are various reasons this could be:
- We used a quantised (4bit) model (quantisation reduces model weight precision, which introduces approximation error)
- and a very small number of training samples (not enough to learn from)
- might be issues with the training setup e.g., 
  - overfitting: The model might have memorized the training data too well and is now performing poorly on unseen data (the validation set used for `lm-eval`).
  - data preparation: There could be problems with the training data, such as noise, inconsistencies, or biases that are negatively impacting the model's learning
  - hyperparameters: The learning rate, batch size, weight decay, etc might not be optimal for this specific task and model, and can be remedied with hyper parameter search techniques.

I generally used quantisation, small batch sizes, and a small sample size to complete this task in good time. I would have also liked to show a [Gradio leaderboard](https://huggingface.co./spaces/freddyaboulton/gradio_leaderboard/tree/main) to compare the bench results, but enough can be gleaned from just looking at the numbers.