Spaces:
Running
Running
Commit
·
5706850
1
Parent(s):
780c9f0
update: LlamaGuardFineTuner.train
Browse files
guardrails_genie/train/llama_guard.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
import plotly.graph_objects as go
|
4 |
import streamlit as st
|
@@ -208,7 +209,15 @@ class LlamaGuardFineTuner:
|
|
208 |
)
|
209 |
return encodings.input_ids, encodings.attention_mask, labels
|
210 |
|
211 |
-
def train(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
wandb.init(
|
213 |
project=self.wandb_project,
|
214 |
entity=self.wandb_entity,
|
@@ -239,14 +248,16 @@ class LlamaGuardFineTuner:
|
|
239 |
optimizer.zero_grad()
|
240 |
loss.backward()
|
241 |
optimizer.step()
|
242 |
-
|
|
|
243 |
if progress_bar:
|
244 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
245 |
progress_bar.progress(
|
246 |
progress_percentage,
|
247 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
248 |
)
|
249 |
-
|
250 |
-
|
|
|
251 |
wandb.finish()
|
252 |
-
|
|
|
1 |
import os
|
2 |
+
import shutil
|
3 |
|
4 |
import plotly.graph_objects as go
|
5 |
import streamlit as st
|
|
|
209 |
)
|
210 |
return encodings.input_ids, encodings.attention_mask, labels
|
211 |
|
212 |
+
def train(
|
213 |
+
self,
|
214 |
+
batch_size: int = 32,
|
215 |
+
lr: float = 5e-6,
|
216 |
+
num_classes: int = 2,
|
217 |
+
log_interval: int = 20,
|
218 |
+
save_interval: int = 1000,
|
219 |
+
):
|
220 |
+
os.makedirs("checkpoints", exist_ok=True)
|
221 |
wandb.init(
|
222 |
project=self.wandb_project,
|
223 |
entity=self.wandb_entity,
|
|
|
248 |
optimizer.zero_grad()
|
249 |
loss.backward()
|
250 |
optimizer.step()
|
251 |
+
if (i + 1) % log_interval == 0:
|
252 |
+
wandb.log({"loss": loss.item()}, step=i + 1)
|
253 |
if progress_bar:
|
254 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
255 |
progress_bar.progress(
|
256 |
progress_percentage,
|
257 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
258 |
)
|
259 |
+
if (i + 1) % save_interval == 0:
|
260 |
+
save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
|
261 |
+
wandb.log_model(f"checkpoints/model-{i + 1}.safetensors")
|
262 |
wandb.finish()
|
263 |
+
shutil.rmtree("checkpoints")
|