# Model Overview
## Description:
The NVIDIA Llama 3.1 8B Medusa FP8 model is the quantized and Medusa-enhanced version of the Meta Llama 3.1 8B Instruct model, which is an auto-regressive language model that uses an optimized transformer architecture. It is an instruction tuned generative model (text in/text out). For more information, please check [here](https://huggingface.co./meta-llama/Meta-Llama-3.1-8B-Instruct).
The NVIDIA Llama 3.1 8B Medusa FP8 model is enhanced with Medusa speculative decoding and quantized with [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer).
This model is ready for commercial and non-commercial use.
## Third-Party Community Consideration:
This model is not owned or developed by NVIDIA. This model has been developed and built to a third-party’s requirements for this application and use case; see link to Non-NVIDIA [(Meta-Llama-3.1-8B-Instruct) Model Card](https://huggingface.co./meta-llama/Meta-Llama-3.1-8B-Instruct).
### License/Terms of Use:
GOVERNING TERMS: Use of this model is governed by the [NVIDIA Open Models License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/). ADDITIONAL INFORMATION: [Llama 3.1 Community License Agreement](https://www.llama.com/llama3_1/license/). Built with Meta Llama 3.1.
## Model Architecture:
**Architecture Type:** Transformer
**Network Architecture:** Llama3.1
## Input:
**Input Type(s):** Text
**Input Format(s):** String
**Input Parameters:** 1D; Sequences
**Other Properties Related to Input:** Context length up to 128K
## Output:
**Output Type(s):** Text
**Output Format:** String
**Output Parameters:** 1D; Sequences
## Software Integration
**Supported Runtime Engine(s):**
* Tensor(RT)-LLM
**Supported Hardware Microarchitecture Compatibility:**
* NVIDIA Blackwell
* NVIDIA Hopper
* NVIDIA Lovelace
**[Preferred/Supported] Operating System(s):**
* Linux
## Model Version(s):
v0.23.0
# Training and Evaluation Datasets:
## Training Dataset:
**Link:** Daring-Anteater, used for data synthesis, which is then used to train the Medusa heads. See [here](https://huggingface.co./datasets/nvidia/Daring-Anteater) for more information regarding the dataset.
** Data Collection Method by dataset
* [Automated]
** Labeling Method by dataset
* Synthetic
**Properties:** Synthetically created dataset, 100K rows.
**Link:** cnn_dailymail, used for calibration. See [here](https://huggingface.co./datasets/abisee/cnn_dailymail) for more information regarding the dataset.
** Data Collection Method by dataset
* Unknown
** Labeling Method by dataset
* Human
## Evaluation Dataset:
**Link:** MMLU, for more details, see [here](https://github.com/hendrycks/test)
** Data Collection Method by dataset
* [Human]
** Labeling Method by dataset
* [Human]
## Medusa Speculative Decoding and Post Training Quantization
Synthesized data was obtained from a FP8 quantized version of Meta-Llama-3.1-8B-Instruct, which is then used to finetune the Medusa heads. This model was then obtained by quantizing the weights and activations of Meta-Llama-3.1-8B-Instruct together with the Medusa heads to FP8 data type, ready for inference with TensorRT-LLM in Medusa speculative decoding mode. Only the weights and activations of the linear operators within transformers blocks and Medusa heads are quantized. This optimization reduces the number of bits per parameter from 16 to 8, reducing the disk size and GPU memory requirements by approximately 50%.
Medusa heads are used to predict candidate tokens beyond the next token. In the generation step, each Medusa head generates a distribution of tokens beyond the previous. Then a tree-based attention mechanism samples some candidate sequences for the original model to validate. The longest accepted candidate sequence is selected so that more than 1 token is returned in the generation step. The number of tokens generated in each step is called acceptance rate.
## Usage
To run inference with [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (supported from [v0.17](https://github.com/NVIDIA/TensorRT-LLM/tree/v0.17.0)), we recommend using LLM APIs as shown in [this example](https://github.com/NVIDIA/TensorRT-LLM/blob/v0.17.0/examples/llm-api/llm_medusa_decoding.py#L34) with ` python llm_medusa_decoding.py --use_modelopt_ckpt` or below. The LLM APIs abstract away steps like checkpoint conversion, engine building, and inference.
```python
### Generate Text Using Medusa Decoding
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (LLM, BuildConfig,
MedusaDecodingConfig, SamplingParams)
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
def main():
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# The end user can customize the sampling configuration with the SamplingParams class
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# The end user can customize the build configuration with the BuildConfig class
build_config = BuildConfig(
max_batch_size=1,
max_seq_len=1024,
max_draft_len=63,
speculative_decoding_mode=SpeculativeDecodingMode.MEDUSA)
# The end user can customize the medusa decoding configuration by specifying the
# medusa heads num and medusa choices with the MedusaDecodingConfig class
speculative_config = MedusaDecodingConfig(num_medusa_heads=3,
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], \
[4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], \
[7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], \
[4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
[0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
)
llm = LLM(model="nvidia/Llama-3.1-8B-Medusa-FP8",
build_config=build_config,
speculative_config=speculative_config)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == '__main__':
main()
```
Alternatively, you can follow the [sample CLIs for Medusa decoding](https://github.com/NVIDIA/TensorRT-LLM/tree/v0.17.0/examples/medusa#usage) in the TensorRT-LLM GitHub repo.
Support in [TensorRT-LLM benchmarking](https://nvidia.github.io/TensorRT-LLM/performance/perf-benchmarking.html) with `trtllm-bench` is coming soon.
## Evaluation
The accuracy (MMLU, 5-shot) and Medusa acceptance rate benchmark results are presented in the table below:
| Precision | [MMLU](https://crfm.stanford.edu/helm/mmlu/latest/) | [MT Bench](https://huggingface.co./spaces/lmsys/mt-bench) Acceptance Rate |
|----------|----------|----------|
| FP8 | 68.3 | 2.07 |
## Inference:
**Engine:** Tensor(RT)-LLM
**Test Hardware:** H100
## Ethical Considerations
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.NVIDIA.com/en-us/support/submit-security-vulnerability/).