DeathReaper0965's picture
Update README.md
5881bc9
|
raw
history blame
2.33 kB
metadata
license: mit
datasets:
  - codeparrot/codeparrot-clean
tags:
  - text-generation
  - code-generation
  - gpt2-large
widget:
  - text: 'def hello_world():'
    example_title: Code Generation Example 1
  - text: 'def get_files_size(filename):'
    example_title: Code Generation Example 2
inference:
  parameters:
    max_new_tokens: 30
    temperature: 0.5
    num_return_sequences: 1

Code Generation using GPT2-Large

This is a GPT2-large model that's further fine-tuned on the Codeparrot clean dataset with a custom metric focused on code generation.
I've further trained the tokenizer initialized from the GPT2-large on the same dataset to better align the tokenization for generating code.

Model description

This Model has the same architecture and Parameters as the GPT2-large model. Please refer to this link to know more about the model details.

Intended Use & Limitations

This model is intended to generate code for the required function based on a small description of the output required.

Note: The model is primarily trained with an objective of code generation.

Usage

You can use this model directly to get the summaries:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load Code Generator LLM and tokenizer from checkpoint
tokenizer = AutoTokenizer.from_pretrained("DeathReaper0965/gpt2_large_code_generator")
model = AutoModelForCausalLM.from_pretrained("DeathReaper0965/gpt2_large_code_generator")
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

inputs = tokenizer("def hello_world():", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

outputs = model.generate(**inputs,
                         max_new_tokens= 30,
                         temperature= 0.5,
                         num_return_sequences= 1)

print(tokenizer.batch_decode(outputs)[0])

###########OUTPUT###########
def hello_world():
    return "Hello World!"

@app.route("/hello_world")
def hello_world():
    return "Hello World!"

Designed and Developed with by Praneet | LinkedIn | GitHub