Spaces:
Runtime error
Runtime error
import spacy | |
import streamlit as st | |
import re | |
import logging | |
from presidio_anonymizer import AnonymizerEngine | |
from presidio_analyzer import AnalyzerEngine, PatternRecognizer, RecognizerResult, EntityRecognizer | |
from annotated_text import annotated_text | |
from flair_recognizer import FlairRecognizer | |
from detoxify import Detoxify | |
############################### | |
#### Render Streamlit page #### | |
############################### | |
st.title("Anonymise your text!") | |
st.markdown( | |
"This mini-app anonymises text using Flair and Presidio. You can find the code in the Files and Versions tabs in the [HuggingFace page](https://huggingface.co./spaces/arogeriogel/anonymise_this)" | |
) | |
# Configure logger | |
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True) | |
############################## | |
###### Define functions ###### | |
############################## | |
# @st.cache_resource(show_spinner="Fetching model from cache...") | |
def analyzer_engine(): | |
"""Return AnalyzerEngine.""" | |
analyzer = AnalyzerEngine() | |
flair_recognizer = FlairRecognizer() | |
analyzer.registry.add_recognizer(flair_recognizer) | |
return analyzer | |
def analyze(**kwargs): | |
"""Analyze input using Analyzer engine and input arguments (kwargs).""" | |
if "entities" not in kwargs or "All" in kwargs["entities"]: | |
kwargs["entities"] = None | |
results = analyzer_engine().analyze(**kwargs) | |
st.session_state.analyze_results = results | |
def annotate(): | |
text = st.session_state.text | |
analyze_results = st.session_state.analyze_results | |
tokens = [] | |
starts=[] | |
# sort by start index | |
results = sorted(analyze_results, key=lambda x: x.start) | |
for i, res in enumerate(results): | |
# if we already have an entity for this token don't add another | |
if res.start not in starts: | |
if i == 0: | |
tokens.append(text[:res.start]) | |
# append entity text and entity type | |
tokens.append((text[res.start: res.end], res.entity_type)) | |
# if another entity coming i.e. we're not at the last results element, add text up to next entity | |
if i != len(results) - 1: | |
tokens.append(text[res.end:results[i+1].start]) | |
# if no more entities coming, add all remaining text | |
else: | |
tokens.append(text[res.end:]) | |
# append this token to the list so we don't repeat results per token | |
starts.append(res.start) | |
return tokens | |
def get_supported_entities(): | |
"""Return supported entities from the Analyzer Engine.""" | |
return analyzer_engine().get_supported_entities() | |
def analyze_text(): | |
if not st.session_state.text: | |
st.session_state.text_error = "Please enter your text" | |
return | |
toxicity_results = Detoxify('original').predict(st.session_state.text) | |
is_toxic=False | |
for k in toxicity_results.keys(): | |
for k in toxicity_results.keys(): | |
if k!='toxicity': | |
if toxicity_results[k]>0.5: | |
is_toxic=True | |
else: | |
if toxicity_results[k]>0.65: | |
is_toxic=True | |
if is_toxic: | |
st.session_state.text_error = "Your text entry was detected as toxic, please re-write it." | |
return | |
else: | |
with text_spinner_placeholder: | |
with st.spinner("Please wait while your text is being analysed..."): | |
logging.info(f"This is the text being analysed: {st.session_state.text}") | |
st.session_state.text_error = "" | |
st.session_state.n_requests += 1 | |
analyze( | |
text=st.session_state.text, | |
entities=st_entities, | |
language="en", | |
return_decision_process=False, | |
) | |
if st.session_state.excluded_words: | |
exclude_manual_input() | |
if st.session_state.allowed_words: | |
allow_manual_input() | |
logging.info( | |
f"analyse results: {st.session_state.analyze_results}\n" | |
) | |
def exclude_manual_input(): | |
deny_list = [i.strip() for i in st.session_state.excluded_words.split(',')] | |
def _deny_list_to_regex(deny_list): | |
""" | |
Convert a list of words to a matching regex. | |
To be analyzed by the analyze method as any other regex patterns. | |
:param deny_list: the list of words to detect | |
:return:the regex of the words for detection | |
""" | |
# Escape deny list elements as preparation for regex | |
escaped_deny_list = [re.escape(element) for element in deny_list] | |
regex = r"(?:^|(?<=\W))(" + "|".join(escaped_deny_list) + r")(?:(?=\W)|$)" | |
return regex | |
deny_list_pattern = _deny_list_to_regex(deny_list) | |
matches = re.finditer(deny_list_pattern, st.session_state.text) | |
results = [] | |
for match in matches: | |
start, end = match.span() | |
current_match = st.session_state.text[start:end] | |
# Skip empty results | |
if current_match == "": | |
continue | |
pattern_result = RecognizerResult( | |
entity_type='MANUALLY ADDED', | |
start=start, | |
end=end, | |
score=1.0, | |
) | |
# check if already in detected strings | |
found=False | |
for token in st.session_state.analyze_results: | |
if token.start==start and token.end==end: | |
found=True | |
if found==False: | |
results.append(pattern_result) | |
results = EntityRecognizer.remove_duplicates(results) | |
st.session_state.analyze_results.extend(results) | |
logging.info( | |
f"analyse results after adding excluded words: {st.session_state.analyze_results}\n" | |
) | |
def allow_manual_input(): | |
analyze_results_fltered=[] | |
for token in st.session_state.analyze_results: | |
if st.session_state.text[token.start:token.end] not in st.session_state.allowed_words: | |
analyze_results_fltered.append(token) | |
logging.info( | |
f"analyse results after removing allowed words: {analyze_results_fltered}\n" | |
) | |
st.session_state.analyze_results = analyze_results_fltered | |
# @st.cache_resource(show_spinner="Fetching model from cache...") | |
def anonymizer_engine(): | |
"""Return AnonymizerEngine.""" | |
return AnonymizerEngine() | |
def anonymise_text(): | |
if st.session_state.n_requests >= 50: | |
st.session_state.text_error = "Too many requests. Please wait a few seconds before anonymising more text." | |
logging.info(f"Session request limit reached: {st.session_state.n_requests}") | |
st.session_state.n_requests = 1 | |
st.session_state.text_error = "" | |
if not st.session_state.text: | |
st.session_state.text_error = "Please enter your text" | |
return | |
if not st.session_state.analyze_results: | |
analyze_text() | |
with text_spinner_placeholder: | |
with st.spinner("Please wait while your text is being anonymised..."): | |
anon_results = anonymizer_engine().anonymize(st.session_state.text, st.session_state.analyze_results) | |
st.session_state.text_error = "" | |
st.session_state.n_requests += 1 | |
st.session_state.anon_results = anon_results | |
logging.info( | |
f"text anonymised: {st.session_state.anon_results}" | |
) | |
def clear_results(): | |
st.session_state.anon_results="" | |
st.session_state.analyze_results="" | |
####################################### | |
#### Initialize "global" variables #### | |
####################################### | |
if "text_error" not in st.session_state: | |
st.session_state.text_error = "" | |
if "analyze_results" not in st.session_state: | |
st.session_state.analyze_results = "" | |
if "anon_results" not in st.session_state: | |
st.session_state.anon_results = "" | |
if "n_requests" not in st.session_state: | |
st.session_state.n_requests = 0 | |
############################## | |
####### Page arguments ####### | |
############################## | |
# Every widget with a key is automatically added to Session State as a global variable. | |
# In Streamlit, interacting with a widget triggers a rerun and variables defined | |
# in the code get reinitialized after each rerun. | |
# If a callback function is associated with a widget then a change in the widget | |
# triggers the following sequence: First the callback function is executed and then | |
# the app executes from top to bottom. | |
st.text_input( | |
label="Text", | |
placeholder="Write your text here", | |
key='text', | |
on_change=clear_results | |
) | |
st.text_input( | |
label="Data to be redacted (optional)", | |
placeholder="John, Mary, London", | |
key='excluded_words', | |
on_change=clear_results | |
) | |
st.text_input( | |
label="Data to be ignored (optional)", | |
placeholder="NHS, GEL, Lab", | |
key='allowed_words', | |
on_change=clear_results | |
) | |
st_entities = st.sidebar.multiselect( | |
label="Which entities to look for?", | |
options=get_supported_entities(), | |
default=list(get_supported_entities()), | |
) | |
############################## | |
######## Page buttons ######## | |
############################## | |
# button return true when clicked | |
col1, col2 = st.columns(2) | |
analyze_now=False | |
with col1: | |
analyze_now = st.button( | |
label="Analyse text", | |
type="primary", | |
on_click=analyze_text, | |
) | |
anonymise_now=False | |
with col2: | |
anonymise_now = st.button( | |
label="Anonymise text", | |
type="primary", | |
on_click=anonymise_text, | |
) | |
############################## | |
######## Page actions ######## | |
############################## | |
text_spinner_placeholder = st.empty() | |
if st.session_state.text_error: | |
st.error(st.session_state.text_error) | |
with col1: | |
if st.session_state.analyze_results: | |
annotated_tokens=annotate() | |
annotated_text(*annotated_tokens) | |
st.write(st.session_state.analyze_results) | |
if not st.session_state.analyze_results and analyze_now and not st.session_state.text_error: | |
st.write("### No PII was found. ###") | |
with col2: | |
if st.session_state.anon_results: | |
st.write(st.session_state.anon_results.text) | |
if not st.session_state.analyze_results and anonymise_now and not st.session_state.text_error: | |
st.write("### No PII was found. ###") |