Arkay92 commited on
Commit
094178c
·
verified ·
1 Parent(s): 6883dbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -25
app.py CHANGED
@@ -27,7 +27,7 @@ model_name = "ibm-granite/granite-3.1-2b-instruct"
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
  device_map="balanced", # Using balanced CPU mapping.
30
- torch_dtype=torch.float16 # Use float16 if supported, otherwise float32.
31
  )
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
 
@@ -64,7 +64,6 @@ def read_file(file_obj):
64
  """
65
  Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
66
  """
67
- # If file_obj is a string path:
68
  if isinstance(file_obj, str):
69
  if file_obj in FILE_CACHE:
70
  return FILE_CACHE[file_obj]
@@ -139,11 +138,17 @@ def read_files(file_objs, max_length=3000):
139
  SUMMARY_CACHE[cache_key] = summarized
140
  return summarized
141
 
142
- def format_prompt(system_msg, user_msg):
143
- return [
144
- {"role": "system", "content": system_msg},
145
- {"role": "user", "content": user_msg}
146
- ]
 
 
 
 
 
 
147
 
148
  def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
149
  model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
@@ -168,18 +173,23 @@ def post_process(text):
168
  unique_lines.append(clean_line)
169
  return "\n".join(unique_lines)
170
 
171
- def granite_analysis(prompt, file_objs=None, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
172
- file_context = read_files(file_objs) if file_objs else ""
173
- internal_context = f"\n[Internal Context]: {file_context.strip()}" if file_context else ""
174
- refined_prompt = prompt + internal_context
175
- system_message = (
176
- "You are IBM Granite, an enterprise legal and technical analysis assistant. Your task is to critically analyze "
177
- "contract documents with a special focus on identifying dangerous provisions, significant legal pitfalls, "
178
- "and areas that could expose a party to high risks or liabilities."
 
 
179
  )
180
- messages = format_prompt(system_message, refined_prompt)
181
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
182
- response = speculative_decode(input_text, max_tokens=max_tokens, top_p=top_p, temperature=temperature)
 
 
 
183
  final_response = post_process(response)
184
  return final_response
185
 
@@ -207,10 +217,4 @@ if st.button("Analyze Contract"):
207
  result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
208
  st.success("Analysis complete!")
209
  st.markdown("### Analysis Output")
210
-
211
- keyword = "assistant"
212
- text_after_keyword = result.rsplit(keyword, 1)[-1].strip()
213
-
214
- st.text_area("Output", text_after_keyword, height=400)
215
-
216
-
 
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
  device_map="balanced", # Using balanced CPU mapping.
30
+ torch_dtype=torch.float16 # Use float16 if supported.
31
  )
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
 
 
64
  """
65
  Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
66
  """
 
67
  if isinstance(file_obj, str):
68
  if file_obj in FILE_CACHE:
69
  return FILE_CACHE[file_obj]
 
138
  SUMMARY_CACHE[cache_key] = summarized
139
  return summarized
140
 
141
+ def build_prompt(system_msg, document_content, user_prompt):
142
+ """
143
+ Build a unified prompt that explicitly delineates the system instructions,
144
+ document content, and user prompt.
145
+ """
146
+ prompt_parts = []
147
+ prompt_parts.append("SYSTEM PROMPT:\n" + system_msg.strip())
148
+ if document_content:
149
+ prompt_parts.append("\nDOCUMENT CONTENT:\n" + document_content.strip())
150
+ prompt_parts.append("\nUSER PROMPT:\n" + user_prompt.strip())
151
+ return "\n\n".join(prompt_parts)
152
 
153
  def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
154
  model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
 
173
  unique_lines.append(clean_line)
174
  return "\n".join(unique_lines)
175
 
176
+ def granite_analysis(user_prompt, file_objs=None, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
177
+ # Read and summarize document content.
178
+ document_content = read_files(file_objs) if file_objs else ""
179
+
180
+ # Define a clear system prompt.
181
+ system_prompt = (
182
+ "You are IBM Granite, an enterprise legal and technical analysis assistant. "
183
+ "Your task is to critically analyze the contract document provided below. "
184
+ "Pay special attention to identifying dangerous provisions, legal pitfalls, and potential liabilities. "
185
+ "Make sure to address both the overall contract structure and specific clauses where applicable."
186
  )
187
+
188
+ # Build a unified prompt with explicit sections.
189
+ unified_prompt = build_prompt(system_prompt, document_content, user_prompt)
190
+
191
+ # Generate the analysis.
192
+ response = speculative_decode(unified_prompt, max_tokens=max_tokens, top_p=top_p, temperature=temperature)
193
  final_response = post_process(response)
194
  return final_response
195
 
 
217
  result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
218
  st.success("Analysis complete!")
219
  st.markdown("### Analysis Output")
220
+ st.text_area("Output", result, height=400)