ssalb commited on
Commit
d2e0b39
·
1 Parent(s): b990ec8

Update space with latest code and dependencies on Thu Jan 2 21:22:45 UTC 2025

Browse files
README.md CHANGED
@@ -9,8 +9,8 @@ sdk_version: 5.9.1
9
  app_file: app.py
10
  pinned: false
11
  preload_from_hub:
12
- # - meta-llama/Llama-3.2-1B-Instruct # Do not preload llama, as the token is not available at build time
13
- - google-bert/bert-base-uncased
14
  - facebook/bart-large-mnli
15
  license: mit
16
  ---
 
9
  app_file: app.py
10
  pinned: false
11
  preload_from_hub:
12
+ - openai-community/gpt2
13
+ - answerdotai/ModernBERT-base
14
  - facebook/bart-large-mnli
15
  license: mit
16
  ---
app.py CHANGED
@@ -16,6 +16,7 @@ genre_choices = [
16
  "horror",
17
  ]
18
 
 
19
  class InputModel(BaseModel):
20
  prompt: str
21
  genre: str
@@ -28,7 +29,7 @@ def create_story_generation_interface() -> gr.Interface:
28
  # Initialize the story generation system
29
  system = StoryGenerationSystem()
30
  system.initialize()
31
-
32
  def generate_stories(
33
  prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
34
  ) -> Tuple[str, List[str]]:
@@ -39,7 +40,11 @@ def create_story_generation_interface() -> gr.Interface:
39
 
40
  # Validate inputs.Gradio seems to validate chioces but not the range of the values
41
  input_values = InputModel(
42
- prompt=prompt, genre=genre, num_stories=num_stories, temperature=temperature, max_length=max_length
 
 
 
 
43
  )
44
 
45
  # Update beam search config with user parameters
@@ -48,7 +53,9 @@ def create_story_generation_interface() -> gr.Interface:
48
 
49
  # Generate and evaluate stories
50
  ranked_stories = system.generate_and_evaluate(
51
- input_values.prompt, input_values.genre, num_stories=input_values.num_stories
 
 
52
  )
53
 
54
  # Format detailed scores
 
16
  "horror",
17
  ]
18
 
19
+
20
  class InputModel(BaseModel):
21
  prompt: str
22
  genre: str
 
29
  # Initialize the story generation system
30
  system = StoryGenerationSystem()
31
  system.initialize()
32
+
33
  def generate_stories(
34
  prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
35
  ) -> Tuple[str, List[str]]:
 
40
 
41
  # Validate inputs.Gradio seems to validate chioces but not the range of the values
42
  input_values = InputModel(
43
+ prompt=prompt,
44
+ genre=genre,
45
+ num_stories=num_stories,
46
+ temperature=temperature,
47
+ max_length=max_length,
48
  )
49
 
50
  # Update beam search config with user parameters
 
53
 
54
  # Generate and evaluate stories
55
  ranked_stories = system.generate_and_evaluate(
56
+ input_values.prompt,
57
+ input_values.genre,
58
+ num_stories=input_values.num_stories,
59
  )
60
 
61
  # Format detailed scores
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  aiofiles==23.2.1 ; python_full_version == "3.10.13"
2
  annotated-types==0.7.0 ; python_full_version == "3.10.13"
3
  anyio==4.7.0 ; python_full_version == "3.10.13"
@@ -30,6 +31,7 @@ packaging==24.2 ; python_full_version == "3.10.13"
30
  pandas==2.2.3 ; python_full_version == "3.10.13"
31
  pillow==11.1.0 ; python_full_version == "3.10.13"
32
  protobuf==5.29.2 ; python_full_version == "3.10.13"
 
33
  pydantic-core==2.27.2 ; python_full_version == "3.10.13"
34
  pydantic==2.10.4 ; python_full_version == "3.10.13"
35
  pydub==0.25.1 ; python_full_version == "3.10.13"
@@ -41,9 +43,9 @@ pyyaml==6.0.2 ; python_full_version == "3.10.13"
41
  regex==2024.11.6 ; python_full_version == "3.10.13"
42
  requests==2.32.3 ; python_full_version == "3.10.13"
43
  rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
44
- ruff==0.8.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
45
  safehttpx==0.1.6 ; python_full_version == "3.10.13"
46
- safetensors==0.4.5 ; python_full_version == "3.10.13"
47
  scikit-learn==1.6.0 ; python_full_version == "3.10.13"
48
  scipy==1.14.1 ; python_full_version == "3.10.13"
49
  semantic-version==2.10.0 ; python_full_version == "3.10.13"
@@ -57,7 +59,7 @@ tokenizers==0.21.0 ; python_full_version == "3.10.13"
57
  tomlkit==0.13.2 ; python_full_version == "3.10.13"
58
  torch==2.4.0 ; python_full_version == "3.10.13"
59
  tqdm==4.67.1 ; python_full_version == "3.10.13"
60
- transformers==4.47.1 ; python_full_version == "3.10.13"
61
  triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
62
  typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
63
  typing-extensions==4.12.2 ; python_full_version == "3.10.13"
 
1
+ accelerate==1.2.1 ; python_full_version == "3.10.13"
2
  aiofiles==23.2.1 ; python_full_version == "3.10.13"
3
  annotated-types==0.7.0 ; python_full_version == "3.10.13"
4
  anyio==4.7.0 ; python_full_version == "3.10.13"
 
31
  pandas==2.2.3 ; python_full_version == "3.10.13"
32
  pillow==11.1.0 ; python_full_version == "3.10.13"
33
  protobuf==5.29.2 ; python_full_version == "3.10.13"
34
+ psutil==6.1.1 ; python_full_version == "3.10.13"
35
  pydantic-core==2.27.2 ; python_full_version == "3.10.13"
36
  pydantic==2.10.4 ; python_full_version == "3.10.13"
37
  pydub==0.25.1 ; python_full_version == "3.10.13"
 
43
  regex==2024.11.6 ; python_full_version == "3.10.13"
44
  requests==2.32.3 ; python_full_version == "3.10.13"
45
  rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
46
+ ruff==0.8.5 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
47
  safehttpx==0.1.6 ; python_full_version == "3.10.13"
48
+ safetensors==0.5.0 ; python_full_version == "3.10.13"
49
  scikit-learn==1.6.0 ; python_full_version == "3.10.13"
50
  scipy==1.14.1 ; python_full_version == "3.10.13"
51
  semantic-version==2.10.0 ; python_full_version == "3.10.13"
 
59
  tomlkit==0.13.2 ; python_full_version == "3.10.13"
60
  torch==2.4.0 ; python_full_version == "3.10.13"
61
  tqdm==4.67.1 ; python_full_version == "3.10.13"
62
+ transformers @ git+https://github.com/huggingface/transformers.git@42865860ec6dc135972d9555753cb7ee17f51fb4 ; python_full_version == "3.10.13"
63
  triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
64
  typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
65
  typing-extensions==4.12.2 ; python_full_version == "3.10.13"
story_beam_search/beam_search.py CHANGED
@@ -51,7 +51,7 @@ class BeamSearchGenerator:
51
  [story[instructions_len:] for story in stories]
52
  )
53
 
54
- stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
55
 
56
  if stories:
57
  for _ in range(self.config.num_iterations):
@@ -64,7 +64,9 @@ class BeamSearchGenerator:
64
  ranked_stories = evaluator.evaluate_multiple(
65
  [story[instructions_len:] for story in all_stories]
66
  )
67
- stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
 
 
68
 
69
  return stories
70
 
 
51
  [story[instructions_len:] for story in stories]
52
  )
53
 
54
+ stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
55
 
56
  if stories:
57
  for _ in range(self.config.num_iterations):
 
64
  ranked_stories = evaluator.evaluate_multiple(
65
  [story[instructions_len:] for story in all_stories]
66
  )
67
+ stories = [
68
+ story for story, _ in ranked_stories[: self.config.num_beams]
69
+ ]
70
 
71
  return stories
72
 
story_beam_search/scoring.py CHANGED
@@ -46,7 +46,9 @@ class CoherenceScorer(StoryScorer):
46
  for sentence in sentences:
47
  inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
48
  with torch.no_grad():
49
- emb = self.model.bert(**inputs).last_hidden_state[:, 0, :]
 
 
50
  embeddings.append(emb.cpu().numpy())
51
 
52
  # Calculate cosine similarity between adjacent embeddings
 
46
  for sentence in sentences:
47
  inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
48
  with torch.no_grad():
49
+ outputs = self.model(**inputs)
50
+ last_hidden_state = outputs.hidden_states[-1]
51
+ emb = last_hidden_state[:, 0, :]
52
  embeddings.append(emb.cpu().numpy())
53
 
54
  # Calculate cosine similarity between adjacent embeddings
story_beam_search/stories_generator.py CHANGED
@@ -16,8 +16,8 @@ auth_token = os.getenv("HF_TOKEN", None)
16
 
17
  @dataclass
18
  class ModelConfig:
19
- text_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
20
- bert_name: str = "bert-base-uncased" # "answerdotai/ModernBERT-base"
21
  zero_shot_name: str = "facebook/bart-large-mnli"
22
  device: str = (
23
  "mps"
@@ -54,14 +54,14 @@ class ModelLoader:
54
  self.config.text_model_name, token=auth_token
55
  )
56
  text_model = AutoModelForCausalLM.from_pretrained(
57
- self.config.text_model_name
58
  ).to(self.device)
59
  text_model.eval()
60
 
61
  # Load BERT model for coherence and fluency scoring
62
  print(f"Loading BERT model ({self.config.bert_name})...")
63
  bert_tokenizer = AutoTokenizer.from_pretrained(self.config.bert_name)
64
- bert_model = AutoModelForMaskedLM.from_pretrained(self.config.bert_name).to(
65
  self.device
66
  )
67
  bert_model.eval()
@@ -155,22 +155,23 @@ class StoryGenerationSystem:
155
  """Generate stories and evaluate them."""
156
  if not self.models:
157
  raise RuntimeError("System not initialized. Call initialize() first.")
158
-
159
  # Low effort attempt to detect prompt injections using the zero-shot classifier
160
- prompt_segments = re.split(r'[^a-zA-Z0-9 ]+', prompt)
161
  prompt_segments = list(set(prompt_segments))
162
 
163
  storyness_score = self.storyness.score(prompt)
164
  for segment in prompt_segments:
165
- if segment.strip():
166
  injection_score = self.injection_guard.score(segment)
167
  if storyness_score < 0.2 or injection_score > 0.2:
168
  print("Potential prompt injection detected.")
169
  print(f"storyness_score: {storyness_score}")
170
  print(f"injection_score: {injection_score}")
171
  print("Prompt:", segment)
172
- raise ValueError("Prompt does not seem like a story. Please try again.")
173
-
 
174
 
175
  # Create evaluator with specified genre
176
  evaluator = self.create_evaluator(genre)
 
16
 
17
  @dataclass
18
  class ModelConfig:
19
+ text_model_name: str = "openai-community/gpt2"
20
+ bert_name: str = "answerdotai/ModernBERT-base"
21
  zero_shot_name: str = "facebook/bart-large-mnli"
22
  device: str = (
23
  "mps"
 
54
  self.config.text_model_name, token=auth_token
55
  )
56
  text_model = AutoModelForCausalLM.from_pretrained(
57
+ self.config.text_model_name, device_map="auto", torch_dtype=torch.float16
58
  ).to(self.device)
59
  text_model.eval()
60
 
61
  # Load BERT model for coherence and fluency scoring
62
  print(f"Loading BERT model ({self.config.bert_name})...")
63
  bert_tokenizer = AutoTokenizer.from_pretrained(self.config.bert_name)
64
+ bert_model = AutoModelForMaskedLM.from_pretrained(self.config.bert_name, output_hidden_states=True).to(
65
  self.device
66
  )
67
  bert_model.eval()
 
155
  """Generate stories and evaluate them."""
156
  if not self.models:
157
  raise RuntimeError("System not initialized. Call initialize() first.")
158
+
159
  # Low effort attempt to detect prompt injections using the zero-shot classifier
160
+ prompt_segments = re.split(r"[^a-zA-Z0-9 ]+", prompt)
161
  prompt_segments = list(set(prompt_segments))
162
 
163
  storyness_score = self.storyness.score(prompt)
164
  for segment in prompt_segments:
165
+ if segment.strip():
166
  injection_score = self.injection_guard.score(segment)
167
  if storyness_score < 0.2 or injection_score > 0.2:
168
  print("Potential prompt injection detected.")
169
  print(f"storyness_score: {storyness_score}")
170
  print(f"injection_score: {injection_score}")
171
  print("Prompt:", segment)
172
+ raise ValueError(
173
+ "Prompt does not seem like a story. Please try again."
174
+ )
175
 
176
  # Create evaluator with specified genre
177
  evaluator = self.create_evaluator(genre)