presidio-de-identify / presidio_helpers.py
awacke1's picture
Update presidio_helpers.py
ee7ba2f verified
"""
Helper methods for the Presidio Streamlit app
"""
from typing import List, Optional, Tuple
import logging
import streamlit as st
from presidio_analyzer import (
AnalyzerEngine,
RecognizerResult,
RecognizerRegistry,
PatternRecognizer,
)
from presidio_anonymizer import AnonymizerEngine
from presidio_anonymizer.entities import OperatorConfig
logger = logging.getLogger("presidio-streamlit")
@st.cache_resource
def nlp_engine_and_registry(
model_family: str,
model_path: str,
) -> Tuple[object, RecognizerRegistry]:
"""Create the NLP Engine instance based on the requested model."""
registry = RecognizerRegistry()
try:
if model_family.lower() == "flair":
from flair.models import SequenceTagger
tagger = SequenceTagger.load(model_path)
registry.load_predefined_recognizers()
registry.add_recognizer_from_dict({
"name": "flair_recognizer",
"supported_language": "en",
"supported_entities": ["PERSON", "LOCATION", "ORGANIZATION"],
"model": model_path,
"package": "flair",
})
return tagger, registry
elif model_family.lower() == "huggingface":
from transformers import pipeline
nlp = pipeline("ner", model=model_path, tokenizer=model_path)
registry.load_predefined_recognizers()
registry.add_recognizer_from_dict({
"name": "huggingface_recognizer",
"supported_language": "en",
"supported_entities": ["PERSON", "LOCATION", "ORGANIZATION", "DATE_TIME"],
"model": model_path,
"package": "transformers",
})
return nlp, registry
else:
raise ValueError(f"Model family {model_family} not supported")
except Exception as e:
logger.error(f"Error loading model {model_path} for {model_family}: {str(e)}")
raise RuntimeError(f"Failed to load model: {str(e)}. Ensure model is downloaded and accessible.")
@st.cache_resource
def analyzer_engine(
model_family: str,
model_path: str,
) -> AnalyzerEngine:
"""Create the Analyzer Engine instance based on the requested model."""
nlp_engine, registry = nlp_engine_and_registry(model_family, model_path)
analyzer = AnalyzerEngine(registry=registry)
return analyzer
@st.cache_data
def get_supported_entities(model_family: str, model_path: str) -> List[str]:
"""Return supported entities for the selected model."""
if model_family.lower() == "huggingface":
return ["PERSON", "LOCATION", "ORGANIZATION", "DATE_TIME"]
elif model_family.lower() == "flair":
return ["PERSON", "LOCATION", "ORGANIZATION"]
return ["PERSON", "LOCATION", "ORGANIZATION"]
def analyze(
analyzer: AnalyzerEngine,
text: str,
entities: List[str],
language: str,
score_threshold: float,
return_decision_process: bool,
allow_list: List[str],
deny_list: List[str],
) -> List[RecognizerResult]:
"""Analyze text for PHI entities."""
try:
results = analyzer.analyze(
text=text,
entities=entities,
language=language,
score_threshold=score_threshold,
return_decision_process=return_decision_process,
)
# Apply allow and deny lists
filtered_results = []
for result in results:
text_snippet = text[result.start:result.end].lower()
if any(word.lower() in text_snippet for word in allow_list):
continue
if any(word.lower() in text_snippet for word in deny_list):
filtered_results.append(result)
elif not deny_list:
filtered_results.append(result)
return filtered_results
except Exception as e:
logger.error(f"Analysis error: {str(e)}")
raise
def anonymize(
text: str,
operator: str,
analyze_results: List[RecognizerResult],
mask_char: str = "*",
number_of_chars: int = 15,
) -> dict:
"""Anonymize detected PHI entities in the text."""
try:
anonymizer = AnonymizerEngine()
operator_config = {
"DEFAULT": OperatorConfig(operator, {})
}
if operator == "mask":
operator_config["DEFAULT"] = OperatorConfig(operator, {
"masking_char": mask_char,
"chars_to_mask": number_of_chars,
})
return anonymizer.anonymize(
text=text,
analyzer_results=analyze_results,
operators=operator_config,
)
except Exception as e:
logger.error(f"Anonymization error: {str(e)}")
raise
def create_ad_hoc_deny_list_recognizer(
deny_list: Optional[List[str]] = None,
) -> Optional[PatternRecognizer]:
"""Create a recognizer for deny list items."""
if not deny_list:
return None
try:
deny_list_recognizer = PatternRecognizer(
supported_entity="GENERIC_PII", deny_list=deny_list
)
return deny_list_recognizer
except Exception as e:
logger.error(f"Error creating deny list recognizer: {str(e)}")
raise