togepi55 commited on
Commit
ccbb8f0
·
verified ·
1 Parent(s): 03db82b

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +75 -0
README.md CHANGED
@@ -106,3 +106,78 @@ RLHF,DPOを実施していないため不適切な表現が出力される可
106
  * weight_decay=0.01
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  * weight_decay=0.01
107
 
108
 
109
+
110
+
111
+
112
+ # tasks-100-tv.jsonlでの出力方法
113
+
114
+ ~~~
115
+ import torch
116
+ from transformers import (
117
+ AutoTokenizer,
118
+ AutoModelForCausalLM,
119
+ BitsAndBytesConfig,
120
+ )
121
+ from peft import LoraConfig, PeftModel
122
+ from datasets import load_dataset
123
+
124
+
125
+ BASE_MODEL = "llm-jp/llm-jp-3-13b"
126
+ PEFT_MODEL = "togepi55/llm-jp-3-13b-it"
127
+
128
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
129
+ bnb_config = BitsAndBytesConfig(
130
+ load_in_4bit=True,
131
+ bnb_4bit_compute_dtype=torch.float16,
132
+ bnb_4bit_quant_type="nf4",
133
+ bnb_4bit_use_double_quant=False,
134
+ )
135
+
136
+ base_model = AutoModelForCausalLM.from_pretrained(
137
+ BASE_MODEL,
138
+ device_map="auto",
139
+ quantization_config=bnb_config,
140
+ torch_dtype="auto",
141
+ trust_remote_code=True,
142
+ )
143
+
144
+ model = PeftModel.from_pretrained(base_model, PEFT_MODEL)
145
+
146
+ # elyza-tasks-100-TV_0.jsonl データの読み込み
147
+ from datasets import load_dataset
148
+
149
+ dataset = load_dataset("json", data_files="./elyza-tasks-100-TV_0.jsonl", split="train")
150
+
151
+
152
+ results = []
153
+
154
+ for num in tqdm(range(100)):
155
+ instruction = dataset["input"][num]
156
+
157
+ prompt = f"<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい\n\n### 指示:\n{instruction}\n\n### 応答:\n"
158
+
159
+ model_input = tokenizer(prompt, return_tensors="pt").to(model.device)
160
+ input_ids = model_input["input_ids"]
161
+
162
+ with torch.no_grad():
163
+ outputs = model.generate(
164
+ input_ids,
165
+ max_new_tokens=300,
166
+ attention_mask = model_input.attention_mask,
167
+ pad_token_id=tokenizer.pad_token_id,
168
+ eos_token_id=tokenizer.eos_token_id,
169
+ do_sample=False,
170
+ repetition_penalty=1.02,
171
+ )[0]
172
+ output = tokenizer.decode(outputs[input_ids.size(1):], skip_special_tokens=True)
173
+ results.append({"task_id": num, "input": instruction, "output": output})
174
+
175
+
176
+
177
+ # 保存する場合
178
+ import json
179
+ with open("output.jsonl", "wt", encoding='utf-8') as f:
180
+ for result in results:
181
+ json.dump(result, f, ensure_ascii=False)
182
+ f.write('\n')
183
+ ~~~