--- language: - en thumbnail: "https://huggingface.co./whut-zhangwx/SimpleGPT" tags: - gpt license: "mit" datasets: - tinyshakespeare base_model: "whut-zhangwx/SimpleGPT" --- ## Intruction This is a pre-trained weight for [SimpleGPT](https://github.com/whut-zhangwx/SimpleGPT). It was trained on [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) and used hyper-parameter as follows ```shell n_layer: 12, n_head: 12, embed_dim: 768, time_step: 256, bias: False, vocab_size: 65, dropout: 0.0 iter_num: 50000 ``` ## File Content ckpt_iter_50000.pt contains 6 items ```python checkpoint = { 'state_dict': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_args, 'iter_num': iter_num, 'best_val_loss': best_val_loss, 'config': config, } ``` Use this little script to display them ```python import os import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt_path = "path/to/ckpt_iter_50000.pt" assert os.path.exists(ckpt_path), f"{ckpt_path} doesn't exit." checkpoint = torch.load(ckpt_path, map_location=device) model_args = checkpoint['model_args'] print(model_args) state_dict = checkpoint['state_dict'] for layer_name, weight_matrix in state_dict.items(): print(f"{layer_name}\t{weight_matrix.shape}") ``` ## Usage git clone my repository [SimpleGPT | whut-zhangwx](https://github.com/whut-zhangwx/SimpleGPT). Follow the script [generate.py](https://github.com/whut-zhangwx/SimpleGPT/blob/master/generate.py) to load checkpoint into GPT model to do generation.