Nondzu commited on
Commit
485f9fa
1 Parent(s): 5a03b1a

Upload run-humaneval.py

Browse files
Files changed (1) hide show
  1. run-humaneval.py +56 -0
run-humaneval.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.multiprocessing as mp
4
+ from transformers import AutoTokenizer, LlamaForCausalLM
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+ from evalplus.data import get_human_eval_plus, write_jsonl
7
+ import os
8
+ from tqdm import tqdm # import tqdm
9
+
10
+ def setup(rank, world_size):
11
+ os.environ['MASTER_ADDR'] = 'localhost'
12
+ os.environ['MASTER_PORT'] = '12355'
13
+ dist.init_process_group("gloo", rank=rank, world_size=world_size)
14
+
15
+ def cleanup():
16
+ dist.destroy_process_group()
17
+
18
+ def generate_one_completion(ddp_model, tokenizer, prompt: str):
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
21
+
22
+ # Generate
23
+ generate_ids = ddp_model.module.generate(inputs.input_ids.to("cuda"), max_new_tokens=384, do_sample=True, top_p=0.75, top_k=40, temperature=0.1, pad_token_id=tokenizer.eos_token_id)
24
+ completion = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
25
+ completion = completion.replace(prompt, "").split("\n\n\n")[0]
26
+
27
+ print("-------------------")
28
+ print(completion)
29
+ return completion
30
+
31
+ def run(rank, world_size):
32
+ setup(rank, world_size)
33
+
34
+ model_path = "Nondzu/Mistral-7B-codealpaca-lora"
35
+ model = LlamaForCausalLM.from_pretrained(model_path,load_in_8bit=True)
36
+ ddp_model = DDP(model, device_ids=[rank])
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
38
+
39
+ problems = get_human_eval_plus()
40
+ num_samples_per_task = 1
41
+
42
+ samples = [
43
+ dict(task_id=task_id, completion=generate_one_completion(ddp_model, tokenizer, problems[task_id]["prompt"]))
44
+ for task_id in tqdm(problems) # add tqdm here
45
+ for _ in range(num_samples_per_task)
46
+ ]
47
+ write_jsonl(f"samples-Nondzu-Mistral-7B-codealpaca-lora-rank{rank}.jsonl", samples)
48
+
49
+ cleanup()
50
+
51
+ def main():
52
+ world_size = 1
53
+ mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
54
+
55
+ if __name__=="__main__":
56
+ main()