Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer | |
from transformers import AutoModelForSeq2SeqLM | |
import streamlit as st | |
import fitz # PyMuPDF | |
from docx import Document | |
import re | |
import nltk | |
from presidio_analyzer import AnalyzerEngine, PatternRecognizer, RecognizerResult, Pattern | |
nltk.download('punkt') | |
def sentence_tokenize(text): | |
sentences = nltk.sent_tokenize(text) | |
return sentences | |
model_dir_large = 'edithram23/Redaction_Personal_info_v1' | |
tokenizer_large = AutoTokenizer.from_pretrained(model_dir_large) | |
model_large = AutoModelForSeq2SeqLM.from_pretrained(model_dir_large) | |
# model_dir_small = 'edithram23/Redaction' | |
# tokenizer_small = AutoTokenizer.from_pretrained(model_dir_small) | |
# model_small = AutoModelForSeq2SeqLM.from_pretrained(model_dir_small) | |
# def small(text, model=model_small, tokenizer=tokenizer_small): | |
# inputs = ["Mask Generation: " + text.lower() + '.'] | |
# inputs = tokenizer(inputs, max_length=256, truncation=True, return_tensors="pt") | |
# output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text)) | |
# decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
# predicted_title = decoded_output.strip() | |
# pattern = r'\[.*?\]' | |
# redacted_text = re.sub(pattern, '[redacted]', predicted_title) | |
# return redacted_text | |
# Initialize the analyzer engine | |
analyzer = AnalyzerEngine() | |
# Define a custom address recognizer using a regex pattern | |
address_pattern = Pattern(name="address", regex=r"\d+\s\w+\s(?:street|st|road|rd|avenue|ave|lane|ln|drive|dr|blvd|boulevard)\s*\w*", score=0.5) | |
address_recognizer = PatternRecognizer(supported_entity="ADDRESS", patterns=[address_pattern]) | |
# Add the custom address recognizer to the analyzer | |
analyzer.registry.add_recognizer(address_recognizer) | |
analyzer.get_recognizers | |
# Define a function to extract entities | |
def extract_entities(text): | |
entities = { | |
"NAME": [], | |
"PHONE_NUMBER": [], | |
"EMAIL": [], | |
"ADDRESS": [], | |
"LOCATION": [], | |
"IN_AADHAAR": [], | |
} | |
output = [] | |
# Analyze the text for PII | |
results = analyzer.analyze(text=text, language='en') | |
for result in results: | |
if result.entity_type == "PERSON": | |
entities["NAME"].append(text[result.start:result.end]) | |
output+=[text[result.start:result.end]] | |
elif result.entity_type == "PHONE_NUMBER": | |
entities["PHONE_NUMBER"].append(text[result.start:result.end]) | |
output+=[text[result.start:result.end]] | |
elif result.entity_type == "EMAIL_ADDRESS": | |
entities["EMAIL"].append(text[result.start:result.end]) | |
output+=[text[result.start:result.end]] | |
elif result.entity_type == "ADDRESS": | |
entities["ADDRESS"].append(text[result.start:result.end]) | |
output+=[text[result.start:result.end]] | |
elif result.entity_type == 'LOCATION': | |
entities['LOCATION'].append(text[result.start:result.end]) | |
output+=[text[result.start:result.end]] | |
elif result.entity_type == 'IN_AADHAAR': | |
entities['IN_PAN'].append(text[result.start:result.end]) | |
output+=[text[result.start:result.end]] | |
return entities,output | |
def mask_generation(text, model=model_large, tokenizer=tokenizer_large): | |
if len(text) < 90: | |
text = text + '.' | |
# return small(text) | |
inputs = ["Mask Generation: " + text.lower() + '.'] | |
inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt") | |
output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text)) | |
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
predicted_title = decoded_output.strip() | |
pattern = r'\[.*?\]' | |
redacted_text = re.sub(pattern, '[redacted]', predicted_title) | |
return redacted_text | |
def redact_text(page, text): | |
text_instances = page.search_for(text) | |
for inst in text_instances: | |
page.add_redact_annot(inst, fill=(0, 0, 0)) | |
page.apply_redactions() | |
def read_pdf(file): | |
pdf_document = fitz.open(stream=file.read(), filetype="pdf") | |
text = "" | |
for page_num in range(len(pdf_document)): | |
page = pdf_document.load_page(page_num) | |
text += page.get_text() | |
return text, pdf_document | |
def read_docx(file): | |
doc = Document(file) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
return text | |
def read_txt(file): | |
text = file.read().decode("utf-8") | |
return text | |
def process_file(file): | |
if file.type == "application/pdf": | |
return read_pdf(file) | |
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
return read_docx(file), None | |
elif file.type == "text/plain": | |
return read_txt(file), None | |
else: | |
return "Unsupported file type.", None | |
st.title("Redaction") | |
uploaded_file = st.file_uploader("Upload a file", type=["pdf", "docx", "txt"]) | |
if uploaded_file is not None: | |
file_contents, pdf_document = process_file(uploaded_file) | |
if pdf_document: | |
redacted_text = [] | |
for pg in pdf_document: | |
text = pg.get_text('text') | |
sentences = sentence_tokenize(text) | |
for sent in sentences: | |
entities,words_out = extract_entities(sent) | |
avai_red = pg.search_for(sent) | |
new=[] | |
for w in words_out: | |
new+=w.split('\n') | |
words_out = [i for i in new if len(i)>2] | |
print(words_out) | |
for i in avai_red: | |
b = pg.get_text("text", clip=i) | |
# result = [item for item in output if item in b] # Get elements of 'a' that are in 'b' | |
for j in words_out: | |
new_n = pg.search_for(j, clip=i) | |
for all in new_n: | |
pg.add_redact_annot(all,fill=(0, 0, 0)) | |
pg.apply_redactions() | |
output_pdf = "output_redacted.pdf" | |
pdf_document.save(output_pdf) | |
with open(output_pdf, "rb") as file: | |
st.download_button( | |
label="Download Processed PDF", | |
data=file, | |
file_name="processed_file.pdf", | |
mime="application/pdf", | |
) | |
else: | |
token = sentence_tokenize(file_contents) | |
final = '' | |
for i in range(0, len(token)): | |
final += mask_generation(token[i]) + '\n' | |
processed_text = final | |
st.text_area("OUTPUT", processed_text, height=400) | |
st.download_button( | |
label="Download Processed File", | |
data=processed_text, | |
file_name="processed_file.txt", | |
mime="text/plain", | |
) | |