This model is continually pre-trained from meta-llama/Meta-Llama-3-8B with the structure proposed in M+.
We equip Llama-3 with 10240 memory tokens in each layer, leading to a memory pool of 1.34B parameters. Meanwhile, we have a long-term memory pool in each layer, which is set to 153600 tokens in our paper, although it can be even longer.

To use the model, please use the following code:

git clone [email protected]:wangyu-ustc/MemoryLLM.git
cd MemoryLLM

Then simply use the following code to load the model:

import torch
from transformers import AutoTokenizer
from modeling_mplus import MPlus

# load the model mplus-8b (currently we only have the pretrained version)
model = MPlus.from_pretrained("YuWangX/mplus-8b", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("YuWangX/mplus-8b")
model = model.to(torch.bfloat16) # need to call it again to cast the `inv_freq` in rotary_emb to bfloat16 as well
model.put_ltm_to_numpy() # We include ltm as modules so that it can be uploaded to huggingface, but for inference we need to put ltm on CPU and cast ltm_ags to numpy. 

How to use the model

Inject a piece of context into the model using the following script:

model = model.cuda()

# Self-Update with the new context
ctx = "Last week, John had a wonderful picnic with David. During their conversation, David mentioned multiple times that he likes eating apples. Though he didn't mention any other fruits, John says he can infer that David also like bananas."

# please make sure the context to inject into the memory is larger than 16 tokens, this is the hard minimum when training the model. The memory will be disturbed when less than 16 tokens are injected into the memory. 
model.inject_memory(tokenizer(ctx, return_tensors='pt', add_special_tokens=False).input_ids.cuda(), update_memory=True)

# Generation
inputs = tokenizer("Question: What fruits does David like? Answer:", return_tensors='pt', add_special_tokens=False).input_ids.cuda()
outputs = model.generate(input_ids=inputs, max_new_tokens=20)
response = tokenizer.decode(outputs[0][inputs.shape[1]:])
print(response)
Downloads last month
13
Safetensors
Model size
9.95B params
Tensor type
F32
F64
FP16
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.