sjiang1 commited on
Commit
50bab77
·
1 Parent(s): be0ace7

Add a more detailed version of the inference

Browse files

Add a small code snippet and more information about what is needed

Files changed (1) hide show
  1. README.md +26 -2
README.md CHANGED
@@ -10,6 +10,30 @@ datasets:
10
  ---
11
 
12
  # Model Card for CodeCSE
13
- A simple pre-trained model for code and comment sentence embeddings using contrastive learning.
14
 
15
- Instructions for inference can be found at: https://github.com/emu-se/CodeCSE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  # Model Card for CodeCSE
13
+ A simple pre-trained model for code and comment sentence embeddings using contrastive learning. This model was pretrained using [CodeSearchNet](https://huggingface.co/datasets/code_search_net).
14
 
15
+ Please [**clone the CodeCSE repository**](https://github.com/emu-se/CodeCSE) to get `GraphCodeBERTForCL` and other dependencies to use this pretrained model. https://github.com/emu-se/CodeCSE
16
+
17
+ Detailed instructions are listed in the repository's README.md. Overall, you will need:
18
+
19
+ 1. GraphCodeBERT (CodeCSE uses GraphCodeBERT's input format for code)
20
+ 2. GraphCodeBERTForCL defined in [codecse/codecse](https://github.com/emu-se/CodeCSE/tree/main/codecse/codecse)
21
+
22
+ ## Inference example
23
+ NL input example: example_nl.json
24
+ ```json
25
+ {
26
+ "original_string": "",
27
+ "docstring_tokens": ["Save", "model", "to", "a", "pickle", "located", "at", "path"],
28
+ "url": "https://github.com/openai/baselines/blob/3301089b48c42b87b396e246ea3f56fa4bfc9678/baselines/deepq/deepq.py#L55-L72"
29
+ }
30
+ ```
31
+
32
+ Code snippet to get the embedding of an NL document ([link to complete code](https://github.com/emu-se/CodeCSE/blob/a04a025c7048204bdfd908fe259fafc55e2df169/inference.py#L105)):
33
+ ```
34
+ nl_json = load_example("example_nl.json")
35
+ batch = prepare_inputs(nl_json, tokenizer, args)
36
+ nl_inputs = batch[3]
37
+ with torch.no_grad():
38
+ nl_vec = model(input_ids=nl_inputs, sent_emb="nl")[1]
39
+ ```