Note on adding new elements to the vocabulary
Hi Gemma team!
Thanks for releasing the weights and code for this model.
When adding new tokens to the tokenizer (and thus expanding the softmax matrix), gemma-2b
shows a degenerate behavior wherein it assigns (almost) all probability mass to the new token. So, when finetuning with new tokens, the loss starts out huge; this can have an overall detrimental effect on finetuning.
This is jarring for users; see, e.g., this tweet: https://twitter.com/Teknium1/status/1760977919512649890.
This happened a few years ago for the OpenAI GPT-2 models, and I pinpointed the cause. For some models, the logits are all highly negative. I'm not sure why this sometimes happens; the softmax function is invariant to any additive constant so it doesn't affect the probabilities... but maybe is the result of certain activations? certainly it gives different properties in terms of additive noise. When the logits are all negative and you randomly initialize a new word embedding, its dot product with the hidden states of the model is (very likely to be) close to 0. Because of this, its logit is 0 everywhere. Because exp(0) is much larger than exp(negative numbers), the new word dominates in the softmax. Cool!
Anyway, you can provably avoid this by initializing new embeddings to be the average of all existing embeddings. I show this here; it's a fun little exploration:
https://nlp.stanford.edu//~johnhew//vocab-expansion.html
It might be nice to add a note in the README about this, and/or override (and document) the default vocabulary expansion behavior, so that people get a smoother experience doing this.
I'm happy to implement either/both of these if there's interest.
Hey John, thanks for your tweet and for linking your work; it's a really lovely insight and I've shared it with the team as well :)
I wonder if this result is exacerbated for us because the vocabulary size is 256K and exp(0) would squash all other logits, and even more so if we use something like top k=40 sampling, or top p=0.9 sampling. If you have time, pease send a PR for adding this note to the README and I can take a look.
@johnhew thank you for this explanation!
Anyway, you can provably avoid this by initializing new embeddings to be the average of all existing embeddings. I show this here; it's a fun little exploration:
Like this? https://github.com/artidoro/qlora/blob/7f4e95a68dc076bea9b3a413d2b512eca6d004e5/qlora.py#L425