geekyrakshit commited on
Commit
5706850
·
1 Parent(s): 780c9f0

update: LlamaGuardFineTuner.train

Browse files
guardrails_genie/train/llama_guard.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  import plotly.graph_objects as go
4
  import streamlit as st
@@ -208,7 +209,15 @@ class LlamaGuardFineTuner:
208
  )
209
  return encodings.input_ids, encodings.attention_mask, labels
210
 
211
- def train(self, batch_size: int = 32, lr: float = 5e-6, num_classes: int = 2):
 
 
 
 
 
 
 
 
212
  wandb.init(
213
  project=self.wandb_project,
214
  entity=self.wandb_entity,
@@ -239,14 +248,16 @@ class LlamaGuardFineTuner:
239
  optimizer.zero_grad()
240
  loss.backward()
241
  optimizer.step()
242
- wandb.log({"loss": loss.item()})
 
243
  if progress_bar:
244
  progress_percentage = (i + 1) * 100 // len(data_loader)
245
  progress_bar.progress(
246
  progress_percentage,
247
  text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
248
  )
249
- save_model(self.model, f"{self.model_name}-{self.dataset_name}.safetensors")
250
- wandb.log_model(f"{self.model_name}-{self.dataset_name}.safetensors")
 
251
  wandb.finish()
252
- os.remove(f"{self.model_name}-{self.dataset_name}.safetensors")
 
1
  import os
2
+ import shutil
3
 
4
  import plotly.graph_objects as go
5
  import streamlit as st
 
209
  )
210
  return encodings.input_ids, encodings.attention_mask, labels
211
 
212
+ def train(
213
+ self,
214
+ batch_size: int = 32,
215
+ lr: float = 5e-6,
216
+ num_classes: int = 2,
217
+ log_interval: int = 20,
218
+ save_interval: int = 1000,
219
+ ):
220
+ os.makedirs("checkpoints", exist_ok=True)
221
  wandb.init(
222
  project=self.wandb_project,
223
  entity=self.wandb_entity,
 
248
  optimizer.zero_grad()
249
  loss.backward()
250
  optimizer.step()
251
+ if (i + 1) % log_interval == 0:
252
+ wandb.log({"loss": loss.item()}, step=i + 1)
253
  if progress_bar:
254
  progress_percentage = (i + 1) * 100 // len(data_loader)
255
  progress_bar.progress(
256
  progress_percentage,
257
  text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
258
  )
259
+ if (i + 1) % save_interval == 0:
260
+ save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
261
+ wandb.log_model(f"checkpoints/model-{i + 1}.safetensors")
262
  wandb.finish()
263
+ shutil.rmtree("checkpoints")