Training diverges when used with Llama 2 70B and 4-bit QLoRA
Posted the issue here, but happy to discuss further if anyone can help. The divergence happens after ~20 steps/six hours. Thanks
Hi @alyssavance , have you read this? https://huggingface.co./togethercomputer/LLaMA-2-7B-32K/discussions/2
since you are doing QLoRA, you might need to set trust_remote_code=False to use HF's llama implementation, flash attention only works for float16 or bfloat16.
@gardner I did, I had some type problems but fixed them by removing the JIT decorator from rmsnorm. Right now it runs with no type errors, it does inference fine, it just gradually diverges after the first few dozen steps.
Hi @alyssavance , did you try a smaller learning rate? Instead of 1e-4, it might be worth to try out 2e-5 (same as in the linear interpolation paper).