JAX weights converted from Torch checkpoint at facebook/galactica-30b.

(env) ubuntu@vm:~$ JAX_PLATFORM_NAME=cpu python3
>>> import jax
>>> print(jax.devices())   
[CpuDevice(id=0)]  # Ensure that model weights are loaded into CPU RAM, not accelerator memory.
>>> from transformers import FlaxOPTForCausalLM
>>> model = FlaxOPTForCausalLM.from_pretrained("facebook/galactica-30b", from_pt=True)
>>> model.push_to_hub(hf_model_repo)

Citation and Attribution

Citation from the original repo is reproduced below as per the cc-by-nc-4.0 licsense.

@inproceedings{GALACTICA,
    title={GALACTICA: A Large Language Model for Science},
    author={Ross Taylor and Marcin Kardas and Guillem Cucurull and Thomas Scialom and Anthony Hartshorn and Elvis Saravia and Andrew Poulton and Viktor Kerkez and Robert Stojnic},
    year={2022}
}

Research supported with Cloud TPUs from Google's TPU Research Cloud (TRC)

Downloads last month
14
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.