Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ model_name = "ibm-granite/granite-3.1-2b-instruct"
|
|
27 |
model = AutoModelForCausalLM.from_pretrained(
|
28 |
model_name,
|
29 |
device_map="balanced", # Using balanced CPU mapping.
|
30 |
-
torch_dtype=torch.float16 # Use float16 if supported
|
31 |
)
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
|
@@ -64,7 +64,6 @@ def read_file(file_obj):
|
|
64 |
"""
|
65 |
Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
|
66 |
"""
|
67 |
-
# If file_obj is a string path:
|
68 |
if isinstance(file_obj, str):
|
69 |
if file_obj in FILE_CACHE:
|
70 |
return FILE_CACHE[file_obj]
|
@@ -139,11 +138,17 @@ def read_files(file_objs, max_length=3000):
|
|
139 |
SUMMARY_CACHE[cache_key] = summarized
|
140 |
return summarized
|
141 |
|
142 |
-
def
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
|
149 |
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
|
@@ -168,18 +173,23 @@ def post_process(text):
|
|
168 |
unique_lines.append(clean_line)
|
169 |
return "\n".join(unique_lines)
|
170 |
|
171 |
-
def granite_analysis(
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
"
|
178 |
-
"
|
|
|
|
|
179 |
)
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
183 |
final_response = post_process(response)
|
184 |
return final_response
|
185 |
|
@@ -207,10 +217,4 @@ if st.button("Analyze Contract"):
|
|
207 |
result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
|
208 |
st.success("Analysis complete!")
|
209 |
st.markdown("### Analysis Output")
|
210 |
-
|
211 |
-
keyword = "assistant"
|
212 |
-
text_after_keyword = result.rsplit(keyword, 1)[-1].strip()
|
213 |
-
|
214 |
-
st.text_area("Output", text_after_keyword, height=400)
|
215 |
-
|
216 |
-
|
|
|
27 |
model = AutoModelForCausalLM.from_pretrained(
|
28 |
model_name,
|
29 |
device_map="balanced", # Using balanced CPU mapping.
|
30 |
+
torch_dtype=torch.float16 # Use float16 if supported.
|
31 |
)
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
|
|
|
64 |
"""
|
65 |
Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
|
66 |
"""
|
|
|
67 |
if isinstance(file_obj, str):
|
68 |
if file_obj in FILE_CACHE:
|
69 |
return FILE_CACHE[file_obj]
|
|
|
138 |
SUMMARY_CACHE[cache_key] = summarized
|
139 |
return summarized
|
140 |
|
141 |
+
def build_prompt(system_msg, document_content, user_prompt):
|
142 |
+
"""
|
143 |
+
Build a unified prompt that explicitly delineates the system instructions,
|
144 |
+
document content, and user prompt.
|
145 |
+
"""
|
146 |
+
prompt_parts = []
|
147 |
+
prompt_parts.append("SYSTEM PROMPT:\n" + system_msg.strip())
|
148 |
+
if document_content:
|
149 |
+
prompt_parts.append("\nDOCUMENT CONTENT:\n" + document_content.strip())
|
150 |
+
prompt_parts.append("\nUSER PROMPT:\n" + user_prompt.strip())
|
151 |
+
return "\n\n".join(prompt_parts)
|
152 |
|
153 |
def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
|
154 |
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
|
|
|
173 |
unique_lines.append(clean_line)
|
174 |
return "\n".join(unique_lines)
|
175 |
|
176 |
+
def granite_analysis(user_prompt, file_objs=None, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
|
177 |
+
# Read and summarize document content.
|
178 |
+
document_content = read_files(file_objs) if file_objs else ""
|
179 |
+
|
180 |
+
# Define a clear system prompt.
|
181 |
+
system_prompt = (
|
182 |
+
"You are IBM Granite, an enterprise legal and technical analysis assistant. "
|
183 |
+
"Your task is to critically analyze the contract document provided below. "
|
184 |
+
"Pay special attention to identifying dangerous provisions, legal pitfalls, and potential liabilities. "
|
185 |
+
"Make sure to address both the overall contract structure and specific clauses where applicable."
|
186 |
)
|
187 |
+
|
188 |
+
# Build a unified prompt with explicit sections.
|
189 |
+
unified_prompt = build_prompt(system_prompt, document_content, user_prompt)
|
190 |
+
|
191 |
+
# Generate the analysis.
|
192 |
+
response = speculative_decode(unified_prompt, max_tokens=max_tokens, top_p=top_p, temperature=temperature)
|
193 |
final_response = post_process(response)
|
194 |
return final_response
|
195 |
|
|
|
217 |
result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
|
218 |
st.success("Analysis complete!")
|
219 |
st.markdown("### Analysis Output")
|
220 |
+
st.text_area("Output", result, height=400)
|
|
|
|
|
|
|
|
|
|
|
|