File size: 3,673 Bytes
1076fcf
 
9ed598f
 
1076fcf
9ed598f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
---
license: cc-by-sa-4.0
datasets:
- bigcode/the-stack-dedup
---


# replit-code-v1-3b 

`replit-code-v1-3b` is a 2.7B model. It is trained on the Stack Dedup v1.2 dataset.



## Model


```python
from transformers import AutoModelForCausalLM

# load model
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
```

To use the optimized Triton implementation of FlashAttention on GPUs with BF16 precision, move the model to `bfloat16` and use it as follows:

```python
from transformers import AutoModelForCausalLM

# load model
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True, attn_impl='triton')
model.to(device='cuda:0', dtype=torch.bfloat16)

# forward pass
x = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
x = x.to(device='cuda:0', dtype=torch.bfloat16)
y = model(x)

```

Note that `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the
[Transformers](https://huggingface.co./docs/transformers/index) library. 

## Tokenizer

We have trained a custom SentencePiece Unigram tokenizer optimized with a vocabulary specifically for code of 32768 tokens.

Note that using this requires the `sentencepiece` library to be installed. 

The tokenizer can be used as follows:

```python
from transformers import AutoTokenizer

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)

# single input encoding + generation
x = tokenizer.encode('def hello():\n  print("hello world")\n', return_tensors='pt')
y = model.generate(x)

# decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(generated_code)
```

Note that: 
- `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the [Transformers](https://huggingface.co./docs/transformers/index) library. 
- `clean_up_tokenization_spaces=False` is meant to avoid removing spaces in the output, because that would affect the syntactical correctness of the generated code. 


## Generation

You can generate code using the `transformers` library as follows:

```python
tokenizer = transformers.AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
model = transformers.AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)

x = tokenizer.encode('def fibonacci(n): ', return_tensors='pt')
y = model.generate(x, max_length=100, do_sample=True, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)

# decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(generated_code)
```

Experiment with different decoding methods and parameters to get the best results for your use case.

## Post Processing

Note that as with all code generation models, post-processing of the generated code is important. In particular, the following post-processing steps are recommended:
- stop generation when the EOS token is encountered
- remove trailing whitespaces
- set `max_tokens` to a reasonable value based on your completion use case
- truncate generation to stop words such as `return`, `def`, "```", "`\n\n\n`" to avoid generating incomplete code when `max_tokens` is larger than the length of the expected generated code.

## Inference
Coming soon.

## Evaluation
Coming soon.

## Model Hash
5bc28ce32c6f9aec935ead7b60ea1c46