Spaces:
Paused
Paused
ssalb
commited on
Commit
·
d2e0b39
1
Parent(s):
b990ec8
Update space with latest code and dependencies on Thu Jan 2 21:22:45 UTC 2025
Browse files- README.md +2 -2
- app.py +10 -3
- requirements.txt +5 -3
- story_beam_search/beam_search.py +4 -2
- story_beam_search/scoring.py +3 -1
- story_beam_search/stories_generator.py +10 -9
README.md
CHANGED
@@ -9,8 +9,8 @@ sdk_version: 5.9.1
|
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
preload_from_hub:
|
12 |
-
|
13 |
-
-
|
14 |
- facebook/bart-large-mnli
|
15 |
license: mit
|
16 |
---
|
|
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
preload_from_hub:
|
12 |
+
- openai-community/gpt2
|
13 |
+
- answerdotai/ModernBERT-base
|
14 |
- facebook/bart-large-mnli
|
15 |
license: mit
|
16 |
---
|
app.py
CHANGED
@@ -16,6 +16,7 @@ genre_choices = [
|
|
16 |
"horror",
|
17 |
]
|
18 |
|
|
|
19 |
class InputModel(BaseModel):
|
20 |
prompt: str
|
21 |
genre: str
|
@@ -28,7 +29,7 @@ def create_story_generation_interface() -> gr.Interface:
|
|
28 |
# Initialize the story generation system
|
29 |
system = StoryGenerationSystem()
|
30 |
system.initialize()
|
31 |
-
|
32 |
def generate_stories(
|
33 |
prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
|
34 |
) -> Tuple[str, List[str]]:
|
@@ -39,7 +40,11 @@ def create_story_generation_interface() -> gr.Interface:
|
|
39 |
|
40 |
# Validate inputs.Gradio seems to validate chioces but not the range of the values
|
41 |
input_values = InputModel(
|
42 |
-
prompt=prompt,
|
|
|
|
|
|
|
|
|
43 |
)
|
44 |
|
45 |
# Update beam search config with user parameters
|
@@ -48,7 +53,9 @@ def create_story_generation_interface() -> gr.Interface:
|
|
48 |
|
49 |
# Generate and evaluate stories
|
50 |
ranked_stories = system.generate_and_evaluate(
|
51 |
-
input_values.prompt,
|
|
|
|
|
52 |
)
|
53 |
|
54 |
# Format detailed scores
|
|
|
16 |
"horror",
|
17 |
]
|
18 |
|
19 |
+
|
20 |
class InputModel(BaseModel):
|
21 |
prompt: str
|
22 |
genre: str
|
|
|
29 |
# Initialize the story generation system
|
30 |
system = StoryGenerationSystem()
|
31 |
system.initialize()
|
32 |
+
|
33 |
def generate_stories(
|
34 |
prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
|
35 |
) -> Tuple[str, List[str]]:
|
|
|
40 |
|
41 |
# Validate inputs.Gradio seems to validate chioces but not the range of the values
|
42 |
input_values = InputModel(
|
43 |
+
prompt=prompt,
|
44 |
+
genre=genre,
|
45 |
+
num_stories=num_stories,
|
46 |
+
temperature=temperature,
|
47 |
+
max_length=max_length,
|
48 |
)
|
49 |
|
50 |
# Update beam search config with user parameters
|
|
|
53 |
|
54 |
# Generate and evaluate stories
|
55 |
ranked_stories = system.generate_and_evaluate(
|
56 |
+
input_values.prompt,
|
57 |
+
input_values.genre,
|
58 |
+
num_stories=input_values.num_stories,
|
59 |
)
|
60 |
|
61 |
# Format detailed scores
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
aiofiles==23.2.1 ; python_full_version == "3.10.13"
|
2 |
annotated-types==0.7.0 ; python_full_version == "3.10.13"
|
3 |
anyio==4.7.0 ; python_full_version == "3.10.13"
|
@@ -30,6 +31,7 @@ packaging==24.2 ; python_full_version == "3.10.13"
|
|
30 |
pandas==2.2.3 ; python_full_version == "3.10.13"
|
31 |
pillow==11.1.0 ; python_full_version == "3.10.13"
|
32 |
protobuf==5.29.2 ; python_full_version == "3.10.13"
|
|
|
33 |
pydantic-core==2.27.2 ; python_full_version == "3.10.13"
|
34 |
pydantic==2.10.4 ; python_full_version == "3.10.13"
|
35 |
pydub==0.25.1 ; python_full_version == "3.10.13"
|
@@ -41,9 +43,9 @@ pyyaml==6.0.2 ; python_full_version == "3.10.13"
|
|
41 |
regex==2024.11.6 ; python_full_version == "3.10.13"
|
42 |
requests==2.32.3 ; python_full_version == "3.10.13"
|
43 |
rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
44 |
-
ruff==0.8.
|
45 |
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
46 |
-
safetensors==0.
|
47 |
scikit-learn==1.6.0 ; python_full_version == "3.10.13"
|
48 |
scipy==1.14.1 ; python_full_version == "3.10.13"
|
49 |
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
@@ -57,7 +59,7 @@ tokenizers==0.21.0 ; python_full_version == "3.10.13"
|
|
57 |
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
58 |
torch==2.4.0 ; python_full_version == "3.10.13"
|
59 |
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
60 |
-
transformers
|
61 |
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
62 |
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
63 |
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
|
|
1 |
+
accelerate==1.2.1 ; python_full_version == "3.10.13"
|
2 |
aiofiles==23.2.1 ; python_full_version == "3.10.13"
|
3 |
annotated-types==0.7.0 ; python_full_version == "3.10.13"
|
4 |
anyio==4.7.0 ; python_full_version == "3.10.13"
|
|
|
31 |
pandas==2.2.3 ; python_full_version == "3.10.13"
|
32 |
pillow==11.1.0 ; python_full_version == "3.10.13"
|
33 |
protobuf==5.29.2 ; python_full_version == "3.10.13"
|
34 |
+
psutil==6.1.1 ; python_full_version == "3.10.13"
|
35 |
pydantic-core==2.27.2 ; python_full_version == "3.10.13"
|
36 |
pydantic==2.10.4 ; python_full_version == "3.10.13"
|
37 |
pydub==0.25.1 ; python_full_version == "3.10.13"
|
|
|
43 |
regex==2024.11.6 ; python_full_version == "3.10.13"
|
44 |
requests==2.32.3 ; python_full_version == "3.10.13"
|
45 |
rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
46 |
+
ruff==0.8.5 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
47 |
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
48 |
+
safetensors==0.5.0 ; python_full_version == "3.10.13"
|
49 |
scikit-learn==1.6.0 ; python_full_version == "3.10.13"
|
50 |
scipy==1.14.1 ; python_full_version == "3.10.13"
|
51 |
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
|
|
59 |
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
60 |
torch==2.4.0 ; python_full_version == "3.10.13"
|
61 |
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
62 |
+
transformers @ git+https://github.com/huggingface/transformers.git@42865860ec6dc135972d9555753cb7ee17f51fb4 ; python_full_version == "3.10.13"
|
63 |
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
64 |
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
65 |
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
story_beam_search/beam_search.py
CHANGED
@@ -51,7 +51,7 @@ class BeamSearchGenerator:
|
|
51 |
[story[instructions_len:] for story in stories]
|
52 |
)
|
53 |
|
54 |
-
stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
|
55 |
|
56 |
if stories:
|
57 |
for _ in range(self.config.num_iterations):
|
@@ -64,7 +64,9 @@ class BeamSearchGenerator:
|
|
64 |
ranked_stories = evaluator.evaluate_multiple(
|
65 |
[story[instructions_len:] for story in all_stories]
|
66 |
)
|
67 |
-
stories = [
|
|
|
|
|
68 |
|
69 |
return stories
|
70 |
|
|
|
51 |
[story[instructions_len:] for story in stories]
|
52 |
)
|
53 |
|
54 |
+
stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
|
55 |
|
56 |
if stories:
|
57 |
for _ in range(self.config.num_iterations):
|
|
|
64 |
ranked_stories = evaluator.evaluate_multiple(
|
65 |
[story[instructions_len:] for story in all_stories]
|
66 |
)
|
67 |
+
stories = [
|
68 |
+
story for story, _ in ranked_stories[: self.config.num_beams]
|
69 |
+
]
|
70 |
|
71 |
return stories
|
72 |
|
story_beam_search/scoring.py
CHANGED
@@ -46,7 +46,9 @@ class CoherenceScorer(StoryScorer):
|
|
46 |
for sentence in sentences:
|
47 |
inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
|
48 |
with torch.no_grad():
|
49 |
-
|
|
|
|
|
50 |
embeddings.append(emb.cpu().numpy())
|
51 |
|
52 |
# Calculate cosine similarity between adjacent embeddings
|
|
|
46 |
for sentence in sentences:
|
47 |
inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
|
48 |
with torch.no_grad():
|
49 |
+
outputs = self.model(**inputs)
|
50 |
+
last_hidden_state = outputs.hidden_states[-1]
|
51 |
+
emb = last_hidden_state[:, 0, :]
|
52 |
embeddings.append(emb.cpu().numpy())
|
53 |
|
54 |
# Calculate cosine similarity between adjacent embeddings
|
story_beam_search/stories_generator.py
CHANGED
@@ -16,8 +16,8 @@ auth_token = os.getenv("HF_TOKEN", None)
|
|
16 |
|
17 |
@dataclass
|
18 |
class ModelConfig:
|
19 |
-
text_model_name: str = "
|
20 |
-
bert_name: str = "
|
21 |
zero_shot_name: str = "facebook/bart-large-mnli"
|
22 |
device: str = (
|
23 |
"mps"
|
@@ -54,14 +54,14 @@ class ModelLoader:
|
|
54 |
self.config.text_model_name, token=auth_token
|
55 |
)
|
56 |
text_model = AutoModelForCausalLM.from_pretrained(
|
57 |
-
self.config.text_model_name
|
58 |
).to(self.device)
|
59 |
text_model.eval()
|
60 |
|
61 |
# Load BERT model for coherence and fluency scoring
|
62 |
print(f"Loading BERT model ({self.config.bert_name})...")
|
63 |
bert_tokenizer = AutoTokenizer.from_pretrained(self.config.bert_name)
|
64 |
-
bert_model = AutoModelForMaskedLM.from_pretrained(self.config.bert_name).to(
|
65 |
self.device
|
66 |
)
|
67 |
bert_model.eval()
|
@@ -155,22 +155,23 @@ class StoryGenerationSystem:
|
|
155 |
"""Generate stories and evaluate them."""
|
156 |
if not self.models:
|
157 |
raise RuntimeError("System not initialized. Call initialize() first.")
|
158 |
-
|
159 |
# Low effort attempt to detect prompt injections using the zero-shot classifier
|
160 |
-
prompt_segments = re.split(r
|
161 |
prompt_segments = list(set(prompt_segments))
|
162 |
|
163 |
storyness_score = self.storyness.score(prompt)
|
164 |
for segment in prompt_segments:
|
165 |
-
if segment.strip():
|
166 |
injection_score = self.injection_guard.score(segment)
|
167 |
if storyness_score < 0.2 or injection_score > 0.2:
|
168 |
print("Potential prompt injection detected.")
|
169 |
print(f"storyness_score: {storyness_score}")
|
170 |
print(f"injection_score: {injection_score}")
|
171 |
print("Prompt:", segment)
|
172 |
-
raise ValueError(
|
173 |
-
|
|
|
174 |
|
175 |
# Create evaluator with specified genre
|
176 |
evaluator = self.create_evaluator(genre)
|
|
|
16 |
|
17 |
@dataclass
|
18 |
class ModelConfig:
|
19 |
+
text_model_name: str = "openai-community/gpt2"
|
20 |
+
bert_name: str = "answerdotai/ModernBERT-base"
|
21 |
zero_shot_name: str = "facebook/bart-large-mnli"
|
22 |
device: str = (
|
23 |
"mps"
|
|
|
54 |
self.config.text_model_name, token=auth_token
|
55 |
)
|
56 |
text_model = AutoModelForCausalLM.from_pretrained(
|
57 |
+
self.config.text_model_name, device_map="auto", torch_dtype=torch.float16
|
58 |
).to(self.device)
|
59 |
text_model.eval()
|
60 |
|
61 |
# Load BERT model for coherence and fluency scoring
|
62 |
print(f"Loading BERT model ({self.config.bert_name})...")
|
63 |
bert_tokenizer = AutoTokenizer.from_pretrained(self.config.bert_name)
|
64 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(self.config.bert_name, output_hidden_states=True).to(
|
65 |
self.device
|
66 |
)
|
67 |
bert_model.eval()
|
|
|
155 |
"""Generate stories and evaluate them."""
|
156 |
if not self.models:
|
157 |
raise RuntimeError("System not initialized. Call initialize() first.")
|
158 |
+
|
159 |
# Low effort attempt to detect prompt injections using the zero-shot classifier
|
160 |
+
prompt_segments = re.split(r"[^a-zA-Z0-9 ]+", prompt)
|
161 |
prompt_segments = list(set(prompt_segments))
|
162 |
|
163 |
storyness_score = self.storyness.score(prompt)
|
164 |
for segment in prompt_segments:
|
165 |
+
if segment.strip():
|
166 |
injection_score = self.injection_guard.score(segment)
|
167 |
if storyness_score < 0.2 or injection_score > 0.2:
|
168 |
print("Potential prompt injection detected.")
|
169 |
print(f"storyness_score: {storyness_score}")
|
170 |
print(f"injection_score: {injection_score}")
|
171 |
print("Prompt:", segment)
|
172 |
+
raise ValueError(
|
173 |
+
"Prompt does not seem like a story. Please try again."
|
174 |
+
)
|
175 |
|
176 |
# Create evaluator with specified genre
|
177 |
evaluator = self.create_evaluator(genre)
|