tianleliphoebe commited on
Commit
97b8fe0
·
1 Parent(s): 3167788

update NSFW

Browse files
Files changed (1) hide show
  1. model/model_manager.py +6 -6
model/model_manager.py CHANGED
@@ -30,7 +30,7 @@ class ModelManager:
30
  return pipe
31
 
32
  def load_guard(self):
33
- model_id = "meta-llama/Meta-Llama-Guard-2-8B"
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  dtype = torch.bfloat16
36
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD'])
@@ -48,7 +48,7 @@ class ModelManager:
48
 
49
  @spaces.GPU(duration=120)
50
  def generate_image_ig(self, prompt, model_name):
51
- if self.NSFW_filter(prompt) == 'safe':
52
  print('The prompt is safe')
53
  pipe = self.load_model_pipe(model_name)
54
  result = pipe(prompt=prompt)
@@ -57,7 +57,7 @@ class ModelManager:
57
  return result
58
 
59
  def generate_image_ig_api(self, prompt, model_name):
60
- if self.NSFW_filter(prompt) == 'safe':
61
  print('The prompt is safe')
62
  pipe = self.load_model_pipe(model_name)
63
  result = pipe(prompt=prompt)
@@ -125,7 +125,7 @@ class ModelManager:
125
 
126
  @spaces.GPU(duration=200)
127
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
128
- # if self.NSFW_filter(" ".join([textbox_source, textbox_target, textbox_instruct])) == 'safe':
129
  pipe = self.load_model_pipe(model_name)
130
  result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
131
  # else:
@@ -193,7 +193,7 @@ class ModelManager:
193
 
194
  @spaces.GPU(duration=150)
195
  def generate_video_vg(self, prompt, model_name):
196
- # if self.NSFW_filter(prompt) == 'safe':
197
  pipe = self.load_model_pipe(model_name)
198
  result = pipe(prompt=prompt)
199
  # else:
@@ -201,7 +201,7 @@ class ModelManager:
201
  return result
202
 
203
  def generate_video_vg_api(self, prompt, model_name):
204
- # if self.NSFW_filter(prompt) == 'safe':
205
  pipe = self.load_model_pipe(model_name)
206
  result = pipe(prompt=prompt)
207
  # else:
 
30
  return pipe
31
 
32
  def load_guard(self):
33
+ model_id = "meta-llama/Llama-Guard-3-8B"
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  dtype = torch.bfloat16
36
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD'])
 
48
 
49
  @spaces.GPU(duration=120)
50
  def generate_image_ig(self, prompt, model_name):
51
+ if 'unsafe' not in self.NSFW_filter(prompt):
52
  print('The prompt is safe')
53
  pipe = self.load_model_pipe(model_name)
54
  result = pipe(prompt=prompt)
 
57
  return result
58
 
59
  def generate_image_ig_api(self, prompt, model_name):
60
+ if 'unsafe' not in self.NSFW_filter(prompt):
61
  print('The prompt is safe')
62
  pipe = self.load_model_pipe(model_name)
63
  result = pipe(prompt=prompt)
 
125
 
126
  @spaces.GPU(duration=200)
127
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
128
+ # if 'unsafe' not in self.NSFW_filter(" ".join([textbox_source, textbox_target, textbox_instruct])):
129
  pipe = self.load_model_pipe(model_name)
130
  result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
131
  # else:
 
193
 
194
  @spaces.GPU(duration=150)
195
  def generate_video_vg(self, prompt, model_name):
196
+ # if 'unsafe' not in self.NSFW_filter(prompt):
197
  pipe = self.load_model_pipe(model_name)
198
  result = pipe(prompt=prompt)
199
  # else:
 
201
  return result
202
 
203
  def generate_video_vg_api(self, prompt, model_name):
204
+ # if 'unsafe' not in self.NSFW_filter(prompt):
205
  pipe = self.load_model_pipe(model_name)
206
  result = pipe(prompt=prompt)
207
  # else: