Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import re | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PyPDF2 import PdfReader | |
from peft import get_peft_model, LoraConfig, TaskType | |
# β Force CPU execution | |
device = torch.device("cpu") | |
# πΉ Load IBM Granite Model (CPU-Compatible) | |
MODEL_NAME = "ibm-granite/granite-3.1-2b-instruct" | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="cpu", # Force CPU execution | |
torch_dtype=torch.float32 # Use float32 since Hugging Face runs on CPU | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# πΉ Apply LoRA Fine-Tuning Configuration | |
lora_config = LoraConfig( | |
r=8, | |
lora_alpha=32, | |
target_modules=["q_proj", "v_proj"], | |
lora_dropout=0.1, | |
bias="none", | |
task_type=TaskType.CAUSAL_LM | |
) | |
model = get_peft_model(model, lora_config) | |
model.eval() | |
# π Function to Read & Extract Text from PDFs | |
def read_files(file): | |
file_context = "" | |
reader = PdfReader(file) | |
for page in reader.pages: | |
text = page.extract_text() | |
if text: | |
file_context += text + "\n" | |
return file_context.strip() | |
# π Function to Format AI Prompts | |
def format_prompt(system_msg, user_msg, file_context=""): | |
if file_context: | |
system_msg += f" The user has provided a contract document. Use its context to generate insights, but do not repeat or summarize the document itself." | |
return [ | |
{"role": "system", "content": system_msg}, | |
{"role": "user", "content": user_msg} | |
] | |
# π Function to Generate AI Responses | |
def generate_response(input_text, max_tokens=1000, top_p=0.9, temperature=0.7): | |
model_inputs = tokenizer([input_text], return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = model.generate( | |
**model_inputs, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# π Function to Clean AI Output | |
def post_process(text): | |
cleaned = re.sub(r'ζ₯+', '', text) # Remove unwanted symbols | |
lines = cleaned.splitlines() | |
unique_lines = list(dict.fromkeys([line.strip() for line in lines if line.strip()])) | |
return "\n".join(unique_lines) | |
# π Function to Handle RAG with IBM Granite & Streamlit | |
def granite_simple(prompt, file): | |
file_context = read_files(file) if file else "" | |
system_message = "You are IBM Granite, a legal AI assistant specializing in contract analysis." | |
messages = format_prompt(system_message, prompt, file_context) | |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
response = generate_response(input_text) | |
return post_process(response) | |
# πΉ Streamlit UI | |
def main(): | |
st.set_page_config(page_title="Contract Analysis AI", page_icon="π") | |
st.title("π AI-Powered Contract Analysis Tool") | |
st.write("Upload a contract document (PDF) for a detailed AI-driven legal and technical analysis.") | |
# πΉ Sidebar Settings | |
with st.sidebar: | |
st.header("βοΈ Settings") | |
max_tokens = st.slider("Max Tokens", 50, 1000, 250, 50) | |
top_p = st.slider("Top P (sampling)", 0.1, 1.0, 0.9, 0.1) | |
temperature = st.slider("Temperature (creativity)", 0.1, 1.0, 0.7, 0.1) | |
# πΉ File Upload Section | |
uploaded_file = st.file_uploader("π Upload a contract document (PDF)", type="pdf") | |
# β Ensure file upload message is displayed | |
if uploaded_file is not None: | |
st.session_state["uploaded_file"] = uploaded_file # Persist file in session state | |
st.success("β File uploaded successfully!") | |
st.write("Click the button below to analyze the contract.") | |
# Force button to always render | |
st.markdown('<style>div.stButton > button {display: block; width: 100%;}</style>', unsafe_allow_html=True) | |
if st.button("π Analyze Document"): | |
with st.spinner("Analyzing contract document... β³"): | |
final_answer = granite_simple( | |
"Perform a detailed technical analysis of the attached contract document, highlighting potential risks, legal pitfalls, compliance issues, and areas where contractual terms may lead to future disputes or operational challenges.", | |
uploaded_file | |
) | |
# πΉ Display Analysis Result | |
st.subheader("π Analysis Result") | |
st.write(final_answer) | |
# π₯ Run Streamlit App | |
if __name__ == '__main__': | |
main() | |