diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86c6f847d9..fc91c197f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -348,6 +348,11 @@ jobs: - name: Wait for services to be ready run: sleep 60 + - name: Download Ollama model for Ollama LangExtract recognizer tests + run: | + docker exec $(docker ps -qf "name=ollama") ollama pull qwen2.5:1.5b + docker exec $(docker ps -qf "name=ollama") ollama run qwen2.5:1.5b + - name: Run E2E tests working-directory: e2e-tests run: | @@ -431,6 +436,11 @@ jobs: - name: Wait for services to be ready run: sleep 60 + - name: Download Ollama model for Ollama LangExtract recognizer tests + run: | + docker exec $(docker ps -qf "name=ollama") ollama pull qwen2.5:1.5b + docker exec $(docker ps -qf "name=ollama") ollama run qwen2.5:1.5b + - name: Run E2E tests working-directory: e2e-tests run: | diff --git a/docker-compose.yml b/docker-compose.yml index a1d97a1dbd..abce615a5a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,18 @@ services: + ollama: + image: ollama/ollama:latest + ports: + - "127.0.0.1:11434:11434" # or "127.0.0.1:11435:11434" + volumes: + - ollama-data:/root/.ollama + environment: + - OLLAMA_HOST=0.0.0.0 + healthcheck: + test: ["CMD", "ollama", "list"] + interval: 10s + timeout: 10s + retries: 30 # ~5 minutes total + start_period: 60s presidio-anonymizer: image: ${REGISTRY_NAME}/${IMAGE_PREFIX}presidio-anonymizer${TAG} build: @@ -9,6 +23,7 @@ services: - PORT=5001 ports: - "5001:5001" + presidio-analyzer: image: ${REGISTRY_NAME}/${IMAGE_PREFIX}presidio-analyzer${TAG} build: @@ -17,8 +32,13 @@ services: - type=registry,ref=${REGISTRY_NAME}/${IMAGE_PREFIX}presidio-analyzer:latest environment: - PORT=5001 + - OLLAMA_HOST=http://ollama:11434 ports: - "5002:5001" + depends_on: + ollama: + condition: service_healthy + presidio-image-redactor: image: ${REGISTRY_NAME}/${IMAGE_PREFIX}presidio-image-redactor${TAG} build: @@ -29,3 +49,6 @@ services: - PORT=5001 ports: - "5003:5001" + +volumes: + ollama-data: diff --git a/docs/analyzer/adding_recognizers.md b/docs/analyzer/adding_recognizers.md index c259fe8274..d7a91a41f1 100644 --- a/docs/analyzer/adding_recognizers.md +++ b/docs/analyzer/adding_recognizers.md @@ -165,6 +165,13 @@ On how to integrate Presidio with AHDS De-Identification Protected Health Inform and a sample for a ADHS Remote Recognizer, refer to the [AHDS de-Identification Integration document](../samples/python/ahds/index.md). +### Language Model-based PII/PHI detection recognizer + +Presidio supports language model-based entity detection using LLMs and SLMs for flexible PII/PHI recognition. + +The current implementation uses LangExtract with Ollama (local models). For full setup instructions and usage examples, +see the [Language Model-based PII/PHI Detection guide](../samples/python/langextract/index.md). + ### Creating ad-hoc recognizers In addition to recognizers in code, it is possible to create ad-hoc recognizers via the Presidio Analyzer API for regex and deny-list based logic. diff --git a/docs/samples/index.md b/docs/samples/index.md index ff0b2b9142..422b66bc75 100644 --- a/docs/samples/index.md +++ b/docs/samples/index.md @@ -19,6 +19,7 @@ | Usage | Text | Python file | [Azure AI Language as a Remote Recognizer](python/text_analytics/index.md) | | Usage | Text | Python file | [Azure Health Data Services de-identification Service as a Remote Recognizer](python/ahds/index.md) | | Usage | Text | Python file | [AHDS Surrogate Example](python/ahds/example_ahds_surrogate.py) | +| Usage | Text | Python file | [Language Model-based PII/PHI Detection using LangExtract](python/langextract/index.md) | | Usage | CSV | Python file | [Analyze and Anonymize CSV file](https://github.com/microsoft/presidio/blob/main/docs/samples/python/process_csv_file.py) | | Usage | Text | Python | [Using Flair as an external PII model](https://github.com/microsoft/presidio/blob/main/docs/samples/python/flair_recognizer.py)| | Usage | Text | Python file | [Using Span Marker as an external PII model](https://github.com/microsoft/presidio/blob/main/docs/samples/python/span_marker_recognizer.py)| diff --git a/docs/samples/python/langextract/index.md b/docs/samples/python/langextract/index.md new file mode 100644 index 0000000000..c53d39de37 --- /dev/null +++ b/docs/samples/python/langextract/index.md @@ -0,0 +1,181 @@ +# Language Model-based PII/PHI Detection (Experimental Feature) + +## Introduction + +Presidio supports language model-based PII/PHI detection for flexible entity recognition using language models (LLMs, SLMs, etc.). This approach enables detection of both: +- **PII (Personally Identifiable Information)**: Names, emails, phone numbers, SSN, credit cards, etc. +- **PHI (Protected Health Information)**: Medical records, health identifiers, etc. + +(The default approach uses [LangExtract](https://github.com/google/langextract) under the hood to integrate with language model providers.) + +## Entity Detection Capabilities + +Unlike pattern-based recognizers, language model-based detection is flexible and depends on: + +- The language model being used +- The prompt description provided +- The few-shot examples configured + +The default configuration includes examples for common PII/PHI entities such as PERSON, EMAIL_ADDRESS, PHONE_NUMBER, US_SSN, CREDIT_CARD, MEDICAL_LICENSE, and more. +**You can customize the prompts and examples to detect any entity types relevant to your use case**. + +For the default entity mappings and examples, see the [default configuration](https://github.com/microsoft/presidio/blob/main/presidio-analyzer/presidio_analyzer/conf/langextract_config_ollama.yaml). + +## Supported Language Model Providers + +Presidio supports the following language model providers through LangExtract: + +1. **Ollama** - Local language model deployment (open-source models like Gemma, Llama, etc.) +2. **Azure OpenAI** - _Documentation coming soon_ + +## Language Model-based Recognizer Implementation + +Presidio provides a hierarchy of recognizers for language model-based PII/PHI detection: + +- **`LMRecognizer`**: Abstract base class for all language model recognizers (LLMs, SLMs, etc.) +- **`LangExtractRecognizer`**: Abstract base class for LangExtract library integration (model-agnostic) +- **`OllamaLangExtractRecognizer`**: Concrete implementation for Ollama local language models +- **`AzureOpenAILangExtractRecognizer`**: _Documentation coming soon_ + +[OllamaLangExtractRecognizer implementation](https://github.com/microsoft/presidio/blob/main/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ollama_langextract_recognizer.py) + +--- + +## Using Ollama (Local Models) + +### Prerequisites + +1. **Install Presidio with LangExtract support**: + ```sh + pip install presidio-analyzer[langextract] + ``` + +2. **Set up Ollama** + +You have two options to set up Ollama: + + **Option 1: Docker Compose** (recommended for CPU) + + This option requires Docker to be installed on your system. + + **Where to run:** From the root presidio directory (where `docker-compose.yml` is located) + + ```bash + docker compose up -d ollama + docker exec presidio-ollama-1 ollama pull qwen2.5:1.5b + docker exec presidio-ollama-1 ollama list + ``` + + **Platform differences:** + - **Linux/Mac**: Commands above work as-is + - **Windows**: Use PowerShell or CMD, commands are the same + + If you don't have Docker installed: + - Linux: Follow [Docker installation guide](https://docs.docker.com/engine/install/) + - Mac: Install [Docker Desktop for Mac](https://docs.docker.com/desktop/install/mac-install/) + - Windows: Install [Docker Desktop for Windows](https://docs.docker.com/desktop/install/windows-install/) + + **Option 2: Native installation** (recommended for GPU acceleration) + + Follow the [official LangExtract Ollama guide](https://github.com/google/langextract?tab=readme-ov-file#using-local-llms-with-ollama). + + After installation, pull and run the model: + ```bash + ollama pull qwen2.5:1.5b + ollama run qwen2.5:1.5b + ``` + + > This option provides better performance with GPU acceleration (e.g., on Mac with Metal Performance Shaders or systems with NVIDIA GPUs). + > The model must be pulled and run before using the recognizer. The default model is `qwen2.5:1.5b`. + +3. **Configuration** (optional): Create your own `ollama_config.yaml` or use the [default configuration](https://github.com/microsoft/presidio/blob/main//presidio-analyzer/presidio_analyzer/conf/langextract_config_ollama.yaml) + +### Usage + +**Option 1: Enable in configuration file** + +Enable the recognizer in [`default_recognizers.yaml`](https://github.com/microsoft/presidio/blob/main/presidio-analyzer/presidio_analyzer/conf/default_recognizers.yaml): +```yaml +- name: OllamaLangExtractRecognizer + enabled: true # Change from false to true +``` + +Then load the analyzer using this modified configuration file: + +```python +from presidio_analyzer import AnalyzerEngine +from presidio_analyzer.recognizer_registry import RecognizerRegistryProvider + +# Point to your modified default_recognizers.yaml with Ollama enabled +provider = RecognizerRegistryProvider( + conf_file="/path/to/your/modified/default_recognizers.yaml" +) +registry = provider.create_recognizer_registry() + +# Create analyzer with the registry that includes Ollama recognizer +analyzer = AnalyzerEngine(registry=registry, supported_languages=["en"]) + +# Analyze text - Ollama recognizer will participate in detection +results = analyzer.analyze(text="My email is john.doe@example.com", language="en") +``` + +**Option 2: Add programmatically** + +```python +from presidio_analyzer import AnalyzerEngine +from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + +analyzer = AnalyzerEngine() +analyzer.registry.add_recognizer(OllamaLangExtractRecognizer()) + +results = analyzer.analyze(text="My email is john.doe@example.com", language="en") +``` + +!!! note "Note" + The recognizer is disabled by default in `default_recognizers.yaml` to avoid requiring Ollama for basic Presidio usage. Enable it when you have Ollama set up and running. + +### Custom Configuration + +To use a custom configuration file: + +```python +analyzer.registry.add_recognizer( + OllamaLangExtractRecognizer(config_path="/path/to/custom_config.yaml") +) +``` + +### Configuration Options + +The `langextract_config_ollama.yaml` file supports the following options: + +- **`model_id`**: The Ollama model to use (default: `"qwen2.5:1.5b"`) +- **`model_url`**: Ollama server URL (default: `"http://localhost:11434"`) +- **`temperature`**: Model temperature for generation (default: `null` for model default) +- **`supported_entities`**: PII/PHI entity types to detect +- **`entity_mappings`**: Map LangExtract entity classes to Presidio entity names +- **`min_score`**: Minimum confidence score (default: `0.5`) + +See the [configuration file](https://github.com/microsoft/presidio/blob/main/presidio-analyzer/presidio_analyzer/conf/ollama_config.yaml) for all options. + +## Troubleshooting + +**ConnectionError: "Ollama server not reachable"** +- Ensure Ollama is running: `docker ps` or check `http://localhost:11434` +- Verify the `model_url` in your configuration matches your Ollama server address + +**RuntimeError: "Model 'qwen2.5:1.5b' not found"** +- Pull the model: `docker exec -it presidio-ollama-1 ollama pull qwen2.5:1.5b` +- Or for manual setup: `ollama pull qwen2.5:1.5b` +- Verify the model name matches the `model_id` in your configuration + +--- + +## Using Azure OpenAI (Cloud Models) + +_Documentation coming soon_ + +--- + +## Choosing Between Ollama and Azure OpenAI + +_Comparison documentation coming soon_ diff --git a/e2e-tests/requirements.txt b/e2e-tests/requirements.txt index f0a761fc23..6ce13cfc7f 100644 --- a/e2e-tests/requirements.txt +++ b/e2e-tests/requirements.txt @@ -1,4 +1,4 @@ requests>=2.32.4 pytest -file:../presidio-analyzer -file:../presidio-anonymizer \ No newline at end of file +-e ../presidio-analyzer[langextract] +-e ../presidio-anonymizer diff --git a/e2e-tests/resources/ollama_test_config.yaml b/e2e-tests/resources/ollama_test_config.yaml new file mode 100644 index 0000000000..5ec1910e08 --- /dev/null +++ b/e2e-tests/resources/ollama_test_config.yaml @@ -0,0 +1,51 @@ +# LMRecognizer base configuration +lm_recognizer: + supported_entities: + - PERSON + - LOCATION + - ORGANIZATION + - PHONE_NUMBER + - EMAIL_ADDRESS + - DATE_TIME + - US_SSN + - CREDIT_CARD + - MEDICAL_LICENSE + - IP_ADDRESS + - URL + - IBAN_CODE + + labels_to_ignore: + - payment_status + + enable_generic_consolidation: true + min_score: 0.5 + +langextract: + prompt_file: presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2 + examples_file: presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml + + entity_mappings: + person: PERSON + full_name: PERSON + name_first: PERSON + name_last: PERSON + name_middle: PERSON + location: LOCATION + address: LOCATION + organization: ORGANIZATION + phone: PHONE_NUMBER + phone_number: PHONE_NUMBER + email: EMAIL_ADDRESS + date: DATE_TIME + ssn: US_SSN + identification_number: US_SSN + credit_card: CREDIT_CARD + medical_record: MEDICAL_LICENSE + ip_address: IP_ADDRESS + url: URL + iban: IBAN_CODE + + model: + model_id: qwen2.5:1.5b + model_url: http://localhost:11434 + temperature: 0.0 diff --git a/e2e-tests/resources/test_ollama_enabled_recognizers.yaml b/e2e-tests/resources/test_ollama_enabled_recognizers.yaml new file mode 100644 index 0000000000..390f322533 --- /dev/null +++ b/e2e-tests/resources/test_ollama_enabled_recognizers.yaml @@ -0,0 +1,208 @@ +supported_languages: + - en +global_regex_flags: 26 + +recognizers: + # Recognizers listed here can either be loaded from the recognizers defined in code (type: predefined), + # or created based on the provided configuration (type: custom). + # For predefined: + # - If only a recognizer name is provided, a predefined recognizer with this name and default parameters will be loaded. + # - If a parameter isn't provided, the default one would be loaded. + # For custom: + # - See an example configuration here: https://github.com/microsoft/presidio/blob/main/presidio-analyzer/presidio_analyzer/conf/example_recognizers.yaml + # - Custom pattern recognizers with this configuration can be added to this file, with type: custom + # For recognizers supporting more than one language, an instance of the recognizer for each language will be created. + # For example, see the CreditCardRecognizer definition below: + - name: CreditCardRecognizer + supported_languages: + - language: en + context: [credit, card, visa, mastercard, cc, amex, discover, jcb, diners, maestro, instapayment] + - language: es + context: [tarjeta, credito, visa, mastercard, cc, amex, discover, jcb, diners, maestro, instapayment] + - language: it + - language: pl + type: predefined + + - name: UsBankRecognizer + supported_languages: + - en + type: predefined + + - name: UsLicenseRecognizer + supported_languages: + - en + type: predefined + + - name: UsItinRecognizer + supported_languages: + - en + type: predefined + + - name: UsPassportRecognizer + supported_languages: + - en + type: predefined + + - name: UsSsnRecognizer + supported_languages: + - en + type: predefined + + - name: NhsRecognizer + supported_languages: + - en + type: predefined + + - name: UkNinoRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: SgFinRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: AuAbnRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: AuAcnRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: AuTfnRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: AuMedicareRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: InPanRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: InAadhaarRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: InVehicleRegistrationRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: InPassportRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: EsNifRecognizer + supported_languages: + - es + type: predefined + + - name: EsNieRecognizer + supported_languages: + - es + type: predefined + + - name: ItDriverLicenseRecognizer + supported_languages: + - it + type: predefined + + - name: ItFiscalCodeRecognizer + supported_languages: + - it + type: predefined + + - name: ItVatCodeRecognizer + supported_languages: + - it + type: predefined + + - name: ItIdentityCardRecognizer + supported_languages: + - it + type: predefined + + - name: ItPassportRecognizer + supported_languages: + - it + type: predefined + + - name: PlPeselRecognizer + supported_languages: + - pl + type: predefined + + - name: KrRrnRecognizer + supported_languages: + - ko + - kr + type: predefined + enabled: false + + - name: ThTninRecognizer + supported_languages: + - th + type: predefined + enabled: false + + - name: CryptoRecognizer + type: predefined + + - name: DateRecognizer + type: predefined + + - name: EmailRecognizer + type: predefined + + - name: IbanRecognizer + type: predefined + + - name: IpRecognizer + type: predefined + + - name: MedicalLicenseRecognizer + type: predefined + + - name: PhoneRecognizer + type: predefined + + - name: UrlRecognizer + type: predefined + + - name: InVoterRecognizer + type: predefined + enabled: false + + - name: InGstinRecognizer + supported_languages: + - en + type: predefined + enabled: false + + - name: OllamaLangExtractRecognizer + supported_languages: + - en + type: predefined + enabled: true + config_path: e2e-tests/resources/ollama_test_config.yaml diff --git a/e2e-tests/tests/test_analyzer.py b/e2e-tests/tests/test_api_analyzer.py similarity index 100% rename from e2e-tests/tests/test_analyzer.py rename to e2e-tests/tests/test_api_analyzer.py diff --git a/e2e-tests/tests/test_anonymizer.py b/e2e-tests/tests/test_api_anonymizer.py similarity index 100% rename from e2e-tests/tests/test_anonymizer.py rename to e2e-tests/tests/test_api_anonymizer.py diff --git a/e2e-tests/tests/test_e2e_integration_flows.py b/e2e-tests/tests/test_api_e2e_integration_flows.py similarity index 78% rename from e2e-tests/tests/test_e2e_integration_flows.py rename to e2e-tests/tests/test_api_e2e_integration_flows.py index 1a2ff7ae4a..995039c457 100644 --- a/e2e-tests/tests/test_e2e_integration_flows.py +++ b/e2e-tests/tests/test_api_e2e_integration_flows.py @@ -2,13 +2,9 @@ from pathlib import Path import pytest -from presidio_analyzer import AnalyzerEngine, RecognizerResult -from presidio_analyzer.nlp_engine import NlpEngineProvider from common.assertions import equal_json_strings from common.methods import analyze, anonymize, analyzer_supported_entities -from presidio_anonymizer import AnonymizerEngine -from presidio_anonymizer.entities import EngineResult, OperatorResult def analyze_and_assert(analyzer_request, expected_response): @@ -34,10 +30,10 @@ def test_given_text_with_pii_then_analyze_and_anonymize_successfully(): expected_response = """ [ - {"entity_type": "PERSON", "start": 0, "end": 10, "score": 0.85, + {"entity_type": "PERSON", "start": 0, "end": 10, "score": 0.85, "analysis_explanation": null }, - {"entity_type": "US_DRIVER_LICENSE", "start": 30, "end": 38, "score": 0.6499999999999999, + {"entity_type": "US_DRIVER_LICENSE", "start": 30, "end": 38, "score": 0.6499999999999999, "analysis_explanation": null } ] @@ -164,10 +160,10 @@ def test_given_an_unknown_entity_then_anonymize_uses_defaults(): expected_response = """ [ - {"entity_type": "PERSON", "start": 0, "end": 10, "score": 0.85, + {"entity_type": "PERSON", "start": 0, "end": 10, "score": 0.85, "analysis_explanation": null }, - {"entity_type": "US_DRIVER_LICENSE", "start": 30, "end": 38, "score": 0.6499999999999999, + {"entity_type": "US_DRIVER_LICENSE", "start": 30, "end": 38, "score": 0.6499999999999999, "analysis_explanation": null } ] @@ -189,7 +185,6 @@ def test_given_an_unknown_entity_then_anonymize_uses_defaults(): @pytest.mark.integration def test_demo_website_text_returns_correct_anonymized_version(): # Analyzer request info - dir_path = Path(__file__).resolve().parent.parent with open(Path(dir_path, "resources", "demo.txt"), encoding="utf-8") as f: text_into_rows = f.read().split("\n") @@ -206,20 +201,16 @@ def test_demo_website_text_returns_correct_anonymized_version(): } # Call analyzer - analyzer_status_code, analyzer_content = analyze(json.dumps(analyzer_request)) - analyzer_data = json.loads(analyzer_content) # Anonymizer request info - anonymizer_request = { "text": analyzer_request["text"], "analyzer_results": analyzer_data, } # Call anonymizer - anonymizer_status_code, anonymizer_response = anonymize( json.dumps(anonymizer_request) ) @@ -228,7 +219,6 @@ def test_demo_website_text_returns_correct_anonymized_version(): actual_anonymized_text = anonymizer_response_dict["text"] # Expected output: - with open( Path(dir_path, "resources", "demo_anonymized.txt"), encoding="utf-8" ) as f_exp: @@ -238,56 +228,4 @@ def test_demo_website_text_returns_correct_anonymized_version(): expected_anonymized_text = " ".join(text_into_rows) # Assert equal - assert expected_anonymized_text == actual_anonymized_text - - -@pytest.mark.package -def test_given_text_with_pii_using_package_then_analyze_and_anonymize_complete_successfully(): - text_to_test = "John Smith drivers license is AC432223" - - expected_response = [ - RecognizerResult("PERSON", 0, 10, 0.85), - RecognizerResult("US_DRIVER_LICENSE", 30, 38, 0.6499999999999999), - ] - # Create configuration containing engine name and models - configuration = { - "nlp_engine_name": "spacy", - "models": [{"lang_code": "en", "model_name": "en_core_web_lg"}], - } - - # Create NLP engine based on configuration - provider = NlpEngineProvider(nlp_configuration=configuration) - nlp_engine = provider.create_engine() - - # Pass the created NLP engine and supported_languages to the AnalyzerEngine - analyzer = AnalyzerEngine(nlp_engine=nlp_engine, supported_languages=["en"]) - analyzer_results = analyzer.analyze(text_to_test, "en") - for i in range(len(analyzer_results)): - assert analyzer_results[i] == expected_response[i] - - expected_response = EngineResult( - text=" drivers license is " - ) - expected_response.add_item( - OperatorResult( - operator="replace", - entity_type="US_DRIVER_LICENSE", - start=28, - end=47, - text="", - ) - ) - expected_response.add_item( - OperatorResult( - operator="replace", - entity_type="PERSON", - start=0, - end=8, - text="", - ) - ) - - anonymizer = AnonymizerEngine() - anonymizer_results = anonymizer.anonymize(text_to_test, analyzer_results) - assert anonymizer_results == expected_response diff --git a/e2e-tests/tests/test_image_redactor.py b/e2e-tests/tests/test_api_image_redactor.py similarity index 100% rename from e2e-tests/tests/test_image_redactor.py rename to e2e-tests/tests/test_api_image_redactor.py diff --git a/e2e-tests/tests/test_package_e2e_integration_flows.py b/e2e-tests/tests/test_package_e2e_integration_flows.py new file mode 100644 index 0000000000..0bdcc12fdd --- /dev/null +++ b/e2e-tests/tests/test_package_e2e_integration_flows.py @@ -0,0 +1,194 @@ +import pytest +from presidio_analyzer import AnalyzerEngine, RecognizerResult +from presidio_analyzer.nlp_engine import NlpEngineProvider +from presidio_anonymizer import AnonymizerEngine +from presidio_anonymizer.entities import EngineResult, OperatorResult + +try: + from presidio_analyzer.predefined_recognizers.third_party.\ + ollama_langextract_recognizer import OllamaLangExtractRecognizer + OLLAMA_RECOGNIZER_AVAILABLE = True +except ImportError: + OLLAMA_RECOGNIZER_AVAILABLE = False + OllamaLangExtractRecognizer = None + + +@pytest.mark.package +def test_given_text_with_pii_using_package_then_analyze_and_anonymize_successfully(): + text_to_test = "John Smith drivers license is AC432223" + + expected_response = [ + RecognizerResult("PERSON", 0, 10, 0.85), + RecognizerResult("US_DRIVER_LICENSE", 30, 38, 0.6499999999999999), + ] + # Create configuration containing engine name and models + configuration = { + "nlp_engine_name": "spacy", + "models": [{"lang_code": "en", "model_name": "en_core_web_lg"}], + } + + # Create NLP engine based on configuration + provider = NlpEngineProvider(nlp_configuration=configuration) + nlp_engine = provider.create_engine() + + # Pass the created NLP engine and supported_languages to the AnalyzerEngine + analyzer = AnalyzerEngine(nlp_engine=nlp_engine, supported_languages=["en"]) + analyzer_results = analyzer.analyze(text_to_test, "en") + for i in range(len(analyzer_results)): + assert analyzer_results[i] == expected_response[i] + + expected_response = EngineResult( + text=" drivers license is " + ) + expected_response.add_item( + OperatorResult( + operator="replace", + entity_type="US_DRIVER_LICENSE", + start=28, + end=47, + text="", + ) + ) + expected_response.add_item( + OperatorResult( + operator="replace", + entity_type="PERSON", + start=0, + end=8, + text="", + ) + ) + + anonymizer = AnonymizerEngine() + anonymizer_results = anonymizer.anonymize(text_to_test, analyzer_results) + assert anonymizer_results == expected_response + + +@pytest.mark.package +def test_given_text_with_pii_using_ollama_recognizer_then_detects_entities(tmp_path): + """Test Ollama LangExtract recognizer detects entities when explicitly added to analyzer.""" + assert OLLAMA_RECOGNIZER_AVAILABLE, "LangExtract must be installed for e2e tests" + + text_to_test = "Patient John Smith, SSN 123-45-6789, email john@example.com, phone 555-123-4567, lives at 123 Main St, works at Acme Corp" + + # Use pre-configured config file with small model (qwen2.5:1.5b) + import os + config_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "ollama_test_config.yaml" + ) + + # Create Ollama recognizer with custom config + ollama_recognizer = OllamaLangExtractRecognizer(config_path=config_path) + + # Create analyzer with ONLY Ollama recognizer (no NLP engine, no default recognizers) + from presidio_analyzer.recognizer_registry import RecognizerRegistry + registry = RecognizerRegistry() + registry.add_recognizer(ollama_recognizer) + + analyzer = AnalyzerEngine( + registry=registry, + supported_languages=["en"] + ) + + # Analyze text + results = analyzer.analyze(text_to_test, language="en") + + # Verify at least some entities were detected + assert len(results) > 0, "Expected to detect at least one PII entity" + + # Check which recognizers participated in detection + recognizers_used = set() + langextract_detected_at_least_one = False + + for result in results: + if result.recognition_metadata: + recognizer_name = result.recognition_metadata.get( + RecognizerResult.RECOGNIZER_NAME_KEY, "" + ) + recognizers_used.add(recognizer_name) + + langextract_detected_at_least_one |= ( + recognizer_name == "Ollama LangExtract PII" + ) + + # Verify that Ollama LangExtract recognizer participated in detection + assert langextract_detected_at_least_one, \ + f"Expected 'Ollama LangExtract PII' recognizer to detect at least one entity. Recognizers used: {recognizers_used}" + + +@pytest.mark.package +def test_ollama_recognizer_loads_from_yaml_configuration_when_enabled(): + """ + E2E test to verify Ollama recognizer can be enabled via YAML configuration. + + The test ensures that when enabled=true in the YAML config: + 1. The recognizer loads successfully (handles supported_language and context kwargs) + 2. The config_path is resolved correctly (handles relative paths) + 3. The recognizer can detect PII entities + + Prerequisites: + - Ollama service running with qwen2.5:1.5b model + - LangExtract library installed + """ + if not OLLAMA_RECOGNIZER_AVAILABLE: + pytest.skip("LangExtract not installed") + + # Check if Ollama is available + import os + try: + import requests + ollama_url = os.environ.get("OLLAMA_HOST", "http://localhost:11434") + response = requests.get(f"{ollama_url}/api/tags", timeout=2) + if response.status_code != 200: + pytest.skip("Ollama service not available") + except Exception: + pytest.skip("Ollama service not available") + + # Load recognizer registry from YAML config with Ollama enabled + from presidio_analyzer.recognizer_registry import RecognizerRegistryProvider + + config_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "test_ollama_enabled_recognizers.yaml" + ) + + + provider = RecognizerRegistryProvider(conf_file=config_path) + registry = provider.create_recognizer_registry() + + # Verify Ollama recognizer was loaded + ollama_recognizers = [r for r in registry.recognizers if "Ollama" in r.name] + assert len(ollama_recognizers) == 1, \ + f"Expected exactly 1 Ollama recognizer, found {len(ollama_recognizers)}" + + ollama_rec = ollama_recognizers[0] + assert ollama_rec.name == "Ollama LangExtract PII" + assert ollama_rec.supported_language == "en" + assert len(ollama_rec.supported_entities) > 0 + + # Test functionality: analyze text with the loaded recognizer + analyzer = AnalyzerEngine(registry=registry, supported_languages=["en"]) + + text_to_test = "Patient John Smith, SSN 123-45-6789, email john@example.com, phone 555-123-4567, lives at 123 Main St, works at Acme Corp" + results = analyzer.analyze(text_to_test, language="en") + + # Should detect entities + assert len(results) > 0, "Expected to detect at least one PII entity" + + # Check if Ollama recognizer detected anything + ollama_detected = any( + r.recognition_metadata and + "Ollama" in r.recognition_metadata.get(RecognizerResult.RECOGNIZER_NAME_KEY, "") + for r in results + ) + + # At minimum, other recognizers should detect common entities + entity_types = {r.entity_type for r in results} + expected_entities = {"EMAIL_ADDRESS", "PERSON", "PHONE_NUMBER", "US_SSN"} + detected_expected = entity_types & expected_entities + + assert len(detected_expected) >= 2, \ + f"Expected at least 2 entities from {expected_entities}, detected: {entity_types}" + + print(f"\nāœ“ Ollama recognizer loaded successfully from YAML config") + print(f" Detected entities: {entity_types}") + print(f" Ollama participated: {ollama_detected}") diff --git a/presidio-analyzer/README.md b/presidio-analyzer/README.md index f60d899571..461850e3c1 100644 --- a/presidio-analyzer/README.md +++ b/presidio-analyzer/README.md @@ -12,6 +12,18 @@ but can easily be extended with other types of custom recognizers. Predefined and custom recognizers leverage regex, Named Entity Recognition and other types of logic to detect PII in unstructured text. +### Language Model-based PII/PHI Detection + +Presidio analyzer supports language model-based PII/PHI detection (LLMs, SLMs) for flexible entity recognition. The current implementation uses [LangExtract](https://github.com/google/langextract) with Ollama for local model deployment. + +```bash +pip install presidio-analyzer[langextract] +``` + +**Note:** The Ollama recognizer does not validate server connectivity or model availability during initialization. Connection errors or missing models will be reported when `analyze()` is first called. Ensure Ollama is running and the required model is installed before analysis. + +See the [Language Model-based PII/PHI Detection guide](https://microsoft.github.io/presidio/samples/python/langextract/) for setup and usage. + ## Deploy Presidio analyzer to Azure Use the following button to deploy presidio analyzer to your Azure subscription. diff --git a/presidio-analyzer/presidio_analyzer/__init__.py b/presidio-analyzer/presidio_analyzer/__init__.py index a6ba0270a9..61de87790e 100644 --- a/presidio-analyzer/presidio_analyzer/__init__.py +++ b/presidio-analyzer/presidio_analyzer/__init__.py @@ -11,6 +11,7 @@ from presidio_analyzer.pattern import Pattern from presidio_analyzer.pattern_recognizer import PatternRecognizer from presidio_analyzer.remote_recognizer import RemoteRecognizer +from presidio_analyzer.lm_recognizer import LMRecognizer from presidio_analyzer.recognizer_registry import RecognizerRegistry from presidio_analyzer.analyzer_engine import AnalyzerEngine from presidio_analyzer.batch_analyzer_engine import BatchAnalyzerEngine @@ -22,13 +23,11 @@ # Define default loggers behavior # 1. presidio_analyzer logger - logging.getLogger("presidio-analyzer").addHandler(logging.NullHandler()) # 2. decision_process logger. -# Setting the decision process trace here as we would want it +# Setting the decision_process trace here as we would want it # to be activated using a parameter to AnalyzeEngine and not by default. - decision_process_logger = logging.getLogger("decision_process") ch = logging.StreamHandler() formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]%(message)s") @@ -44,6 +43,7 @@ "LocalRecognizer", "PatternRecognizer", "RemoteRecognizer", + "LMRecognizer", "RecognizerRegistry", "AnalyzerEngine", "AnalyzerRequest", diff --git a/presidio-analyzer/presidio_analyzer/conf/default_recognizers.yaml b/presidio-analyzer/presidio_analyzer/conf/default_recognizers.yaml index 4565f70770..294a4130fa 100644 --- a/presidio-analyzer/presidio_analyzer/conf/default_recognizers.yaml +++ b/presidio-analyzer/presidio_analyzer/conf/default_recognizers.yaml @@ -199,3 +199,10 @@ recognizers: - en type: predefined enabled: false + + - name: OllamaLangExtractRecognizer + supported_languages: + - en + type: predefined + enabled: false + config_path: presidio-analyzer/presidio_analyzer/conf/langextract_config_ollama.yaml diff --git a/presidio-analyzer/presidio_analyzer/conf/langextract_config_ollama.yaml b/presidio-analyzer/presidio_analyzer/conf/langextract_config_ollama.yaml new file mode 100644 index 0000000000..4c13be8634 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/conf/langextract_config_ollama.yaml @@ -0,0 +1,46 @@ +# Ollama Configuration +# https://github.com/google/langextract#using-local-llms-with-ollama + +lm_recognizer: + supported_entities: + - PERSON + - EMAIL_ADDRESS + - PHONE_NUMBER + - US_SSN + - LOCATION + - ORGANIZATION + - DATE_TIME + - CREDIT_CARD + - IP_ADDRESS + - URL + + labels_to_ignore: + - payment_status + - metadata + - annotation + + enable_generic_consolidation: true + min_score: 0.5 + +langextract: + prompt_file: "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2" + examples_file: "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml" + + model: + model_id: "qwen2.5:1.5b" + model_url: "http://localhost:11434" + temperature: null + + entity_mappings: + person: PERSON + name: PERSON + email: EMAIL_ADDRESS + phone: PHONE_NUMBER + ssn: US_SSN + location: LOCATION + address: LOCATION + organization: ORGANIZATION + date: DATE_TIME + credit_card: CREDIT_CARD + ip_address: IP_ADDRESS + url: URL diff --git a/presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml b/presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml new file mode 100644 index 0000000000..4565261f31 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml @@ -0,0 +1,160 @@ +# LangExtract PII/PHI Extraction Examples for Presidio +# Using uppercase Presidio entity names directly + +examples: + # 1) Person + Email + Phone + - text: "My name is John Doe and my email is john.doe@example.com. Call me at (555) 123-4567." + extractions: + - extraction_class: "PERSON" + extraction_text: "John Doe" + attributes: + type: "full_name" + - extraction_class: "EMAIL_ADDRESS" + extraction_text: "john.doe@example.com" + attributes: + type: "email_address" + - extraction_class: "PHONE_NUMBER" + extraction_text: "(555) 123-4567" + attributes: + type: "phone_number" + + # 2) Street address + - text: "Ship to 742 Evergreen Terrace, Springfield, IL 62704." + extractions: + - extraction_class: "LOCATION" + extraction_text: "742 Evergreen Terrace, Springfield, IL 62704" + attributes: + type: "street_address" + + # 3) SSN + Date of Birth + - text: "SSN: 123-45-6789; Date of Birth: 01/15/1980." + extractions: + - extraction_class: "US_SSN" + extraction_text: "123-45-6789" + attributes: + type: "us_ssn" + - extraction_class: "DATE_TIME" + extraction_text: "01/15/1980" + attributes: + type: "date_of_birth" + format: "MM/DD/YYYY" + + # 4) Credit card variations + - text: "Use card 4532123412341234 or 4532-1234-1234-1234." + extractions: + - extraction_class: "CREDIT_CARD" + extraction_text: "4532123412341234" + attributes: + type: "credit_card_number" + - extraction_class: "CREDIT_CARD" + extraction_text: "4532-1234-1234-1234" + attributes: + type: "credit_card_number" + + # 5) Email with plus-tag + Organization + - text: "Contact jane.smith+billing@sub.acme-corp.com at Acme Corp." + extractions: + - extraction_class: "EMAIL_ADDRESS" + extraction_text: "jane.smith+billing@sub.acme-corp.com" + attributes: + type: "email_address" + - extraction_class: "ORGANIZATION" + extraction_text: "Acme Corp" + attributes: + type: "company" + + # 6) Phone international formats + - text: "EU office: +44 20 7946 0958; US office: +1-415-555-2671." + extractions: + - extraction_class: "PHONE_NUMBER" + extraction_text: "+44 20 7946 0958" + attributes: + type: "phone_number" + region: "GB" + - extraction_class: "PHONE_NUMBER" + extraction_text: "+1-415-555-2671" + attributes: + type: "phone_number" + region: "US" + + # 7) URL + IP addresses + - text: "Visit https://example.com/login or example.org. Server IPs: 192.168.1.100 and fe80::1ff:fe23:4567:890a." + extractions: + - extraction_class: "URL" + extraction_text: "https://example.com/login" + attributes: + type: "website_url" + - extraction_class: "URL" + extraction_text: "example.org" + attributes: + type: "domain" + - extraction_class: "IP_ADDRESS" + extraction_text: "192.168.1.100" + attributes: + type: "ipv4" + - extraction_class: "IP_ADDRESS" + extraction_text: "fe80::1ff:fe23:4567:890a" + attributes: + type: "ipv6" + + # 8) Email without TLD + Phone + - text: "Reach me at alex.roe@mailserver or at the office line 555.987.6543." + extractions: + - extraction_class: "EMAIL_ADDRESS" + extraction_text: "alex.roe@mailserver" + attributes: + type: "email_address" + - extraction_class: "PHONE_NUMBER" + extraction_text: "555.987.6543" + attributes: + type: "phone_number" + + # 9) Organization + Person (healthcare) + - text: "Seen at City General Hospital by Dr. Alice Carter." + extractions: + - extraction_class: "ORGANIZATION" + extraction_text: "City General Hospital" + attributes: + type: "hospital" + - extraction_class: "PERSON" + extraction_text: "Dr. Alice Carter" + attributes: + type: "healthcare_provider" + + # 10) Location (city + address) + - text: "Patient resides at 55 King Street in Boston." + extractions: + - extraction_class: "LOCATION" + extraction_text: "55 King Street" + attributes: + type: "street_address" + - extraction_class: "LOCATION" + extraction_text: "Boston" + attributes: + type: "city" + + # 11) Person + Location (city, state) + - text: "Sarah Johnson lives in Seattle, Washington." + extractions: + - extraction_class: "PERSON" + extraction_text: "Sarah Johnson" + attributes: + type: "full_name" + - extraction_class: "LOCATION" + extraction_text: "Seattle, Washington" + attributes: + type: "city_state" + + # 12) Multiple dates + - text: "Admitted on 2024-03-15 and discharged on 2024-03-20." + extractions: + - extraction_class: "DATE_TIME" + extraction_text: "2024-03-15" + attributes: + type: "admission_date" + format: "YYYY-MM-DD" + - extraction_class: "DATE_TIME" + extraction_text: "2024-03-20" + attributes: + type: "discharge_date" + format: "YYYY-MM-DD" diff --git a/presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2 b/presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2 new file mode 100644 index 0000000000..f5edbffd55 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2 @@ -0,0 +1,27 @@ +Extract personally identifiable information (PII) and protected health information (PHI) from the input text. + +CRITICAL EXTRACTION RULES: +1. **Source Truth:** ONLY extract from the current input text. Never hallucinate data or infer details not explicitly present. +2. **Exactness:** Return the exact text spans as they appear in the source (preserve original spelling/formatting). +3. **Consolidation:** Do not create overlapping spans. If entities are nested, select the longest, most complete span. +4. **Context:** When populating attributes, infer the entity's role (e.g., 'patient' vs 'provider') based on the surrounding sentence. +5. **Fallback:** Only use GENERIC_PII_ENTITY if the text is clearly sensitive but definitely does not fit any of the specific types listed. + +ENTITY TYPES TO EXTRACT: +{% for entity in supported_entities %} +- {{ entity }} +{%- endfor %} + +{% if enable_generic_consolidation %} +GENERIC_PII_ENTITY (UNKNOWN ENTITIES): +- Label as GENERIC_PII_ENTITY only if the data is sensitive PII/PHI but does not match any specific type above. +- This acts as a catch-all for safety; do not overuse it for known types. +{% endif %} + +{% if labels_to_ignore %} +DO NOT EXTRACT: +The following types should be ignored and not extracted: +{% for label in labels_to_ignore %} +- {{ label }} +{%- endfor %} +{% endif %} diff --git a/presidio-analyzer/presidio_analyzer/conf/ollama_config_e2e.yaml b/presidio-analyzer/presidio_analyzer/conf/ollama_config_e2e.yaml new file mode 100644 index 0000000000..1567d7ed50 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/conf/ollama_config_e2e.yaml @@ -0,0 +1,50 @@ +# Ollama LangExtract E2E Test Configuration +# https://github.com/google/langextract + +lm_recognizer: + supported_entities: + - PERSON + - LOCATION + - ORGANIZATION + - PHONE_NUMBER + - EMAIL_ADDRESS + - DATE_TIME + - US_SSN + - CREDIT_CARD + - MEDICAL_LICENSE + - IP_ADDRESS + - URL + - IBAN_CODE + + labels_to_ignore: + - payment_status + - metadata + + enable_generic_consolidation: true # Consolidates unknown entity types to GENERIC_PII_ENTITY + min_score: 0.5 + +langextract: + prompt_file: "langextract_prompts/default_pii_phi_prompt.j2" + examples_file: "langextract_prompts/default_pii_phi_examples.yaml" + + model: + model_id: "qwen2.5:1.5b" + model_url: "http://ollama:11434" + temperature: null + + entity_mappings: + person: PERSON + full_name: PERSON + location: LOCATION + address: LOCATION + organization: ORGANIZATION + phone: PHONE_NUMBER + email: EMAIL_ADDRESS + date: DATE_TIME + ssn: US_SSN + credit_card: CREDIT_CARD + medical_record: MEDICAL_LICENSE + ip_address: IP_ADDRESS + ip-address: IP_ADDRESS + url: URL + iban: IBAN_CODE \ No newline at end of file diff --git a/presidio-analyzer/presidio_analyzer/llm_utils/__init__.py b/presidio-analyzer/presidio_analyzer/llm_utils/__init__.py new file mode 100644 index 0000000000..3ab3d2b4c6 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/llm_utils/__init__.py @@ -0,0 +1,59 @@ +"""Utilities for LLM-based recognizers.""" + +from .config_loader import ( + get_conf_path, + get_model_config, + load_yaml_file, + resolve_config_path, + validate_config_fields, +) +from .entity_mapper import ( + GENERIC_PII_ENTITY, + consolidate_generic_entities, + ensure_generic_entity_support, + filter_results_by_entities, + filter_results_by_labels, + filter_results_by_score, + skip_unmapped_entities, + validate_result_positions, +) +from .examples_loader import convert_to_langextract_format, load_yaml_examples +from .langextract_helper import ( + calculate_extraction_confidence, + check_langextract_available, + convert_langextract_to_presidio_results, + create_reverse_entity_mapping, + extract_lm_config, + get_supported_entities, + lx, +) +from .prompt_loader import load_file_from_conf, load_prompt_file, render_jinja_template + +__all__ = [ + "get_conf_path", + "get_model_config", + "load_yaml_file", + "resolve_config_path", + "validate_config_fields", + "GENERIC_PII_ENTITY", + "consolidate_generic_entities", + "ensure_generic_entity_support", + "filter_results_by_entities", + "filter_results_by_labels", + "filter_results_by_score", + "skip_unmapped_entities", + "validate_result_positions", + "convert_to_langextract_format", + "load_yaml_examples", + "calculate_extraction_confidence", + "check_langextract_available", + "convert_langextract_to_presidio_results", + "create_reverse_entity_mapping", + "extract_lm_config", + "get_supported_entities", + "lx", + "load_file_from_conf", + "load_prompt_file", + "render_jinja_template", +] + diff --git a/presidio-analyzer/presidio_analyzer/llm_utils/config_loader.py b/presidio-analyzer/presidio_analyzer/llm_utils/config_loader.py new file mode 100644 index 0000000000..d8ad2306ed --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/llm_utils/config_loader.py @@ -0,0 +1,122 @@ +"""Configuration loading utilities for LLM recognizers.""" +import logging +from pathlib import Path +from typing import Dict, List, Union + +import yaml + +logger = logging.getLogger("presidio-analyzer") + +__all__ = [ + "get_conf_path", + "load_yaml_file", + "resolve_config_path", + "get_model_config", + "validate_config_fields", +] + + +def get_conf_path(filename: str, conf_subdir: str = "conf") -> Path: + """Get absolute path to file in configuration directory. + + :param filename: Name of the file to locate. + :param conf_subdir: Subdirectory name within package (default: "conf"). + :return: Absolute path to the configuration file. + """ + return Path(__file__).parent.parent / conf_subdir / filename + + +def resolve_config_path(config_path: Union[str, Path]) -> Path: + """Resolve configuration file path to absolute path. + + Handles paths in multiple formats (checked in order): + 1. Absolute paths: returned as-is + 2. Relative paths that exist from CWD: returned as-is + 3. Relative paths resolved from repository root + + :param config_path: Configuration file path (string or Path object). + :return: Resolved absolute path. + """ + config_path_obj = Path(config_path) + + if config_path_obj.is_absolute(): + return config_path_obj + + if config_path_obj.exists(): + return config_path_obj + + presidio_analyzer_root = Path(__file__).parent.parent + repo_root = presidio_analyzer_root.parent.parent + repo_resolved = repo_root / config_path + + return repo_resolved + + +def load_yaml_file(filepath: Union[str, Path]) -> Dict: + """Load and parse YAML configuration file. + + Automatically resolves relative paths from presidio_analyzer package root. + + :param filepath: Path to YAML file (string or Path object). + :return: Parsed YAML content as dictionary. + :raises FileNotFoundError: If file doesn't exist. + :raises ValueError: If YAML parsing fails. + """ + resolved_path = resolve_config_path(filepath) + + if not resolved_path.exists(): + raise FileNotFoundError(f"File not found: {resolved_path}") + + try: + with open(resolved_path) as f: + return yaml.safe_load(f) + except yaml.YAMLError as e: + raise ValueError(f"Failed to parse YAML: {e}") + + +def get_model_config(config: Dict, provider_key: str) -> Dict: + """Extract and validate model configuration from provider config. + + :param config: Full configuration dictionary. + :param provider_key: Provider key (e.g., "openai", "ollama"). + :return: Model configuration dictionary. + :raises ValueError: If required model fields are missing. + """ + validate_config_fields( + config, + [ + (provider_key,), + (provider_key, "model"), + (provider_key, "model", "model_id"), + ] + ) + + return config[provider_key]["model"] + + +def validate_config_fields( + config: Dict, + required_fields: List[Union[str, tuple]], + config_name: str = "Configuration" +) -> None: + """Validate that required fields exist in configuration. + + :param config: Configuration dictionary to validate. + :param required_fields: List of required field names (str) or nested paths (tuple). + :param config_name: Name of config for error messages (default: "Configuration"). + :raises ValueError: If any required field is missing or empty. + """ + for field in required_fields: + if isinstance(field, str): + if not config.get(field): + raise ValueError(f"{config_name} must contain '{field}'") + elif isinstance(field, tuple): + current = config + for i, key in enumerate(field): + if not isinstance(current, dict) or key not in current: + path = ".".join(field[:i+1]) + raise ValueError(f"{config_name} must contain '{path}'") + current = current[key] + if i == len(field) - 1 and not current: + path = ".".join(field) + raise ValueError(f"{config_name} must contain '{path}'") diff --git a/presidio-analyzer/presidio_analyzer/llm_utils/entity_mapper.py b/presidio-analyzer/presidio_analyzer/llm_utils/entity_mapper.py new file mode 100644 index 0000000000..b599073e81 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/llm_utils/entity_mapper.py @@ -0,0 +1,212 @@ +"""Entity mapping and filtering utilities for LLM recognizers.""" +import logging +from typing import List, Set + +from presidio_analyzer import RecognizerResult + +logger = logging.getLogger("presidio-analyzer") + +GENERIC_PII_ENTITY = "GENERIC_PII_ENTITY" + +__all__ = [ + "GENERIC_PII_ENTITY", + "filter_results_by_labels", + "filter_results_by_score", + "filter_results_by_entities", + "validate_result_positions", + "consolidate_generic_entities", + "skip_unmapped_entities", + "ensure_generic_entity_support", +] + + +def filter_results_by_labels( + results: List[RecognizerResult], + labels_to_ignore: List[str] +) -> List[RecognizerResult]: + """Filter out results with ignored entity labels. + + :param results: List of recognizer results to filter. + :param labels_to_ignore: Entity type labels to exclude (case-insensitive). + :return: Filtered list of results without ignored labels. + """ + labels_to_ignore_lower = [label.lower() for label in labels_to_ignore] + filtered = [] + + for result in results: + if not result.entity_type: + logger.warning("LLM returned result without entity_type, skipping") + continue + + if result.entity_type.lower() in labels_to_ignore_lower: + logger.debug( + "Entity %s at [%d:%d] is in labels_to_ignore, skipping", + result.entity_type, result.start, result.end + ) + continue + + filtered.append(result) + + return filtered + + +def filter_results_by_score( + results: List[RecognizerResult], + min_score: float +) -> List[RecognizerResult]: + """Filter out results below minimum confidence score. + + :param results: List of recognizer results to filter. + :param min_score: Minimum confidence score threshold (0.0-1.0). + :return: Filtered list of results meeting minimum score. + """ + filtered = [] + + for result in results: + if result.score < min_score: + logger.debug( + "Entity %s at [%d:%d] below min_score (%.2f < %.2f), skipping", + result.entity_type, result.start, result.end, + result.score, min_score + ) + continue + + filtered.append(result) + + return filtered + + +def filter_results_by_entities( + results: List[RecognizerResult], + requested_entities: List[str] +) -> List[RecognizerResult]: + """Filter results to only include requested entity types. + + :param results: List of recognizer results to filter. + :param requested_entities: Entity types to include (empty list = include all). + :return: Filtered list containing only requested entity types. + """ + if not requested_entities: + return results + + filtered = [] + + for result in results: + if result.entity_type not in requested_entities: + logger.debug( + "Entity %s at [%d:%d] not in requested entities %s, skipping", + result.entity_type, result.start, result.end, requested_entities + ) + continue + + filtered.append(result) + + return filtered + + +def validate_result_positions( + results: List[RecognizerResult] +) -> List[RecognizerResult]: + """Filter out results with invalid or missing start/end positions. + + :param results: List of recognizer results to validate. + :return: Filtered list with only valid position ranges. + """ + filtered = [] + + for result in results: + if result.start is None or result.end is None: + logger.warning( + "LLM returned result without start/end positions, skipping: %s", result + ) + continue + + filtered.append(result) + + return filtered + + +def consolidate_generic_entities( + results: List[RecognizerResult], + supported_entities: List[str], + generic_entities_logged: Set[str] +) -> List[RecognizerResult]: + """Consolidate unmapped entity types to GENERIC_PII_ENTITY. + + :param results: List of recognizer results to process. + :param supported_entities: List of supported entity type names. + :param generic_entities_logged: Set tracking logged generic entities + (modified in-place). + :return: Results with unmapped entities consolidated to + GENERIC_PII_ENTITY. + """ + processed = [] + + for result in results: + if result.entity_type not in supported_entities: + original_entity_type = result.entity_type + result.entity_type = GENERIC_PII_ENTITY + + if original_entity_type not in generic_entities_logged: + logger.warning( + "Detected unmapped entity '%s', " + "consolidated to GENERIC_PII_ENTITY. " + "To map or exclude, update " + "'entity_mappings' or 'labels_to_ignore'.", + original_entity_type, + ) + generic_entities_logged.add(original_entity_type) + + if result.recognition_metadata is None: + result.recognition_metadata = {} + result.recognition_metadata["original_entity_type"] = original_entity_type + + processed.append(result) + + return processed + + +def skip_unmapped_entities( + results: List[RecognizerResult], + supported_entities: List[str] +) -> List[RecognizerResult]: + """Skip unmapped entities instead of consolidating them. + + :param results: List of recognizer results to filter. + :param supported_entities: List of supported entity type names. + :return: Filtered results excluding unmapped entity types. + """ + filtered = [] + + for result in results: + if result.entity_type not in supported_entities: + logger.warning( + "Detected unmapped entity '%s', skipped " + "(enable_generic_consolidation=False). " + "To map or exclude, update " + "'entity_mappings' or 'labels_to_ignore'.", + result.entity_type, + ) + continue + + filtered.append(result) + + return filtered + + +def ensure_generic_entity_support( + supported_entities: List[str], + enable_generic_consolidation: bool +) -> List[str]: + """Ensure GENERIC_PII_ENTITY is in supported entities list if consolidation enabled. + + :param supported_entities: Current list of supported entity types. + :param enable_generic_consolidation: Whether generic consolidation is enabled. + :return: Updated list including GENERIC_PII_ENTITY if needed. + """ + entities = supported_entities.copy() + + if enable_generic_consolidation and GENERIC_PII_ENTITY not in entities: + entities.append(GENERIC_PII_ENTITY) + + return entities diff --git a/presidio-analyzer/presidio_analyzer/llm_utils/examples_loader.py b/presidio-analyzer/presidio_analyzer/llm_utils/examples_loader.py new file mode 100644 index 0000000000..72aa318c31 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/llm_utils/examples_loader.py @@ -0,0 +1,56 @@ +"""Examples loading utilities for LLM recognizers.""" +import logging +from typing import Dict, List + +from .config_loader import load_yaml_file, resolve_config_path, validate_config_fields +from .langextract_helper import check_langextract_available, lx + +logger = logging.getLogger("presidio-analyzer") + +__all__ = [ + "load_yaml_examples", + "convert_to_langextract_format", +] + + +def load_yaml_examples( + examples_file: str, conf_subdir: str = "conf" +) -> List[Dict]: + """Load and validate examples from YAML configuration file. + + :param examples_file: Path to YAML file with examples (repo-root-relative). + :param conf_subdir: Configuration subdirectory (deprecated, kept for compatibility). + :return: List of example dictionaries. + :raises ValueError: If 'examples' field is missing. + """ + filepath = resolve_config_path(examples_file) + data = load_yaml_file(filepath) + validate_config_fields(data, ["examples"], "Examples file") + return data["examples"] + + +def convert_to_langextract_format(examples_data: List[Dict]) -> List: + """Convert example dictionaries to LangExtract Example objects. + + :param examples_data: List of example dictionaries with text and extractions. + :return: List of LangExtract Example objects. + :raises ImportError: If langextract is not installed. + """ + check_langextract_available() + + langextract_examples = [] + for example in examples_data: + extractions = [ + lx.data.Extraction( + extraction_class=ext["extraction_class"], + extraction_text=ext["extraction_text"], + attributes=ext.get("attributes", {}), + ) + for ext in example.get("extractions", []) + ] + + langextract_examples.append( + lx.data.ExampleData(text=example["text"], extractions=extractions) + ) + + return langextract_examples diff --git a/presidio-analyzer/presidio_analyzer/llm_utils/langextract_helper.py b/presidio-analyzer/presidio_analyzer/llm_utils/langextract_helper.py new file mode 100644 index 0000000000..13abcb6dd9 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/llm_utils/langextract_helper.py @@ -0,0 +1,199 @@ +"""LangExtract helper utilities.""" +import logging +from typing import Dict, List, Optional + +from presidio_analyzer import AnalysisExplanation, RecognizerResult + +logger = logging.getLogger("presidio-analyzer") + +try: + import langextract as lx +except ImportError: + lx = None + +__all__ = [ + "lx", + "check_langextract_available", + "extract_lm_config", + "get_supported_entities", + "create_reverse_entity_mapping", + "calculate_extraction_confidence", + "convert_langextract_to_presidio_results", +] + + +def check_langextract_available(): + """Check if langextract is available and raise error if not.""" + if not lx: + raise ImportError( + "LangExtract is not installed. " + "Install it with: poetry install --extras langextract" + ) + + +# Default alignment score mappings for LangExtract extractions +DEFAULT_ALIGNMENT_SCORES = { + "MATCH_EXACT": 0.95, + "MATCH_FUZZY": 0.80, + "MATCH_LESSER": 0.70, + "NOT_ALIGNED": 0.60, +} + + +def extract_lm_config(config: Dict) -> Dict: + """Extract LM recognizer configuration section with default values. + + :param config: Full configuration dictionary. + :return: LM recognizer config with keys: supported_entities, min_score, + labels_to_ignore, enable_generic_consolidation. + """ + lm_config_section = config.get("lm_recognizer", {}) + + return { + "supported_entities": lm_config_section.get("supported_entities"), + "min_score": lm_config_section.get("min_score", 0.5), + "labels_to_ignore": lm_config_section.get("labels_to_ignore", []), + "enable_generic_consolidation": lm_config_section.get( + "enable_generic_consolidation", True + ), + } + + +def get_supported_entities( + lm_config: Dict, + langextract_config: Dict +) -> Optional[List[str]]: + """Get supported entities list, checking LM config first then LangExtract config. + + :param lm_config: LM recognizer configuration dictionary. + :param langextract_config: LangExtract configuration dictionary. + :return: List of supported entity types, or None if not specified. + """ + return ( + lm_config.get("supported_entities") + or langextract_config.get("supported_entities") + ) + + +def create_reverse_entity_mapping(entity_mappings: Dict) -> Dict: + """Create reverse mapping from values to keys. + + :param entity_mappings: Original entity mapping dictionary. + :return: Reversed dictionary mapping values to keys. + """ + return {v: k for k, v in entity_mappings.items()} + + +def calculate_extraction_confidence( + extraction, + alignment_scores: Optional[Dict[str, float]] = None +) -> float: + """Calculate confidence score based on extraction alignment status. + + :param extraction: LangExtract extraction object with optional alignment_status. + :param alignment_scores: Custom score mapping for alignment statuses (optional). + :return: Confidence score between 0.0 and 1.0. + """ + default_score = 0.85 + + if alignment_scores is None: + alignment_scores = DEFAULT_ALIGNMENT_SCORES + + if not hasattr(extraction, "alignment_status") or not ( + extraction.alignment_status + ): + return default_score + + status = str(extraction.alignment_status).upper() + for status_key, score in alignment_scores.items(): + if status_key in status: + return score + + return default_score + + +def convert_langextract_to_presidio_results( + langextract_result, + entity_mappings: Dict, + supported_entities: List[str], + enable_generic_consolidation: bool, + recognizer_name: str, + alignment_scores: Optional[Dict[str, float]] = None +) -> List[RecognizerResult]: + """Convert LangExtract extraction results to Presidio RecognizerResult objects. + + :param langextract_result: LangExtract result object with extractions. + :param entity_mappings: Mapping of extraction classes to Presidio entity types. + :param supported_entities: List of supported Presidio entity types. + :param enable_generic_consolidation: Whether to consolidate unknown entities. + :param recognizer_name: Name of recognizer for result metadata. + :param alignment_scores: Custom alignment score mappings (optional). + :return: List of Presidio RecognizerResult objects. + """ + results = [] + if not langextract_result or not langextract_result.extractions: + return results + + supported_entities_set = set(supported_entities) + + for extraction in langextract_result.extractions: + extraction_class = extraction.extraction_class + + if extraction_class in supported_entities_set: + entity_type = extraction_class + else: + extraction_class_lower = extraction_class.lower() + entity_type = entity_mappings.get(extraction_class_lower) + + if not entity_type: + if enable_generic_consolidation: + entity_type = extraction_class.upper() + logger.debug( + "Unknown extraction class '%s' will be consolidated to " + "GENERIC_PII_ENTITY", + extraction_class, + ) + else: + logger.warning( + "Unknown extraction class '%s' not found in entity " + "mappings, skipping", + extraction_class, + ) + continue + + if not extraction.char_interval: + logger.warning("Extraction missing char_interval, skipping") + continue + + confidence = calculate_extraction_confidence(extraction, alignment_scores) + + metadata = {} + if hasattr(extraction, 'attributes') and extraction.attributes: + metadata['attributes'] = extraction.attributes + if hasattr(extraction, 'alignment_status') and extraction.alignment_status: + metadata['alignment'] = str(extraction.alignment_status) + + explanation = AnalysisExplanation( + recognizer=recognizer_name, + original_score=confidence, + textual_explanation=( + f"LangExtract extraction with " + f"{extraction.alignment_status} alignment" + if hasattr(extraction, "alignment_status") + and extraction.alignment_status + else "LangExtract extraction" + ), + ) + + result = RecognizerResult( + entity_type=entity_type, + start=extraction.char_interval.start_pos, + end=extraction.char_interval.end_pos, + score=confidence, + analysis_explanation=explanation, + recognition_metadata=metadata if metadata else None + ) + + results.append(result) + + return results diff --git a/presidio-analyzer/presidio_analyzer/llm_utils/prompt_loader.py b/presidio-analyzer/presidio_analyzer/llm_utils/prompt_loader.py new file mode 100644 index 0000000000..83fb3e0741 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/llm_utils/prompt_loader.py @@ -0,0 +1,60 @@ +"""Prompt loading and rendering utilities for LLM recognizers.""" +import logging + +from .config_loader import resolve_config_path + +logger = logging.getLogger("presidio-analyzer") + +__all__ = [ + "load_file_from_conf", + "load_prompt_file", + "render_jinja_template", +] + + +def load_file_from_conf(filename: str, conf_subdir: str = "conf") -> str: + """Load text file from configuration directory. + + :param filename: Path to file to load (can be repo-root-relative). + :param conf_subdir: Configuration subdirectory (deprecated, kept for compatibility). + :return: File contents as string. + :raises FileNotFoundError: If file doesn't exist. + """ + file_path = resolve_config_path(filename) + + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "r") as f: + return f.read() + + +def load_prompt_file(prompt_file: str, conf_subdir: str = "conf") -> str: + """Load prompt template file from configuration directory. + + :param prompt_file: Path to prompt template file (can be repo-root-relative). + :param conf_subdir: Configuration subdirectory (deprecated, kept for compatibility). + :return: Prompt template contents as string. + :raises FileNotFoundError: If file doesn't exist. + """ + return load_file_from_conf(prompt_file, conf_subdir) + + +def render_jinja_template(template_str: str, **kwargs) -> str: + """Render Jinja2 template with provided variables. + + :param template_str: Jinja2 template string. + :param kwargs: Variables to pass to template rendering. + :return: Rendered template as string. + :raises ImportError: If Jinja2 is not installed. + """ + try: + from jinja2 import Template + except ImportError: + raise ImportError( + "Jinja2 is not installed. " + "Install it with: poetry install --extras langextract" + ) + + template = Template(template_str) + return template.render(**kwargs) diff --git a/presidio-analyzer/presidio_analyzer/lm_recognizer.py b/presidio-analyzer/presidio_analyzer/lm_recognizer.py new file mode 100644 index 0000000000..553c881795 --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/lm_recognizer.py @@ -0,0 +1,161 @@ +import logging +from abc import ABC, abstractmethod +from typing import List, Optional + +from presidio_analyzer import RecognizerResult, RemoteRecognizer +from presidio_analyzer.llm_utils import ( + consolidate_generic_entities, + ensure_generic_entity_support, + filter_results_by_entities, + filter_results_by_labels, + filter_results_by_score, + skip_unmapped_entities, + validate_result_positions, +) +from presidio_analyzer.nlp_engine import NlpArtifacts + +logger = logging.getLogger("presidio-analyzer") + + +class LMRecognizer(RemoteRecognizer, ABC): + """ + Base class for language model-based PII recognizers. + + Provides common functionality for LLM-based entity detection. + Subclasses implement _call_llm() for specific LLM providers. + """ + + def __init__( + self, + supported_entities: Optional[List[str]] = None, + supported_language: str = "en", + name: str = "Language Model PII Recognizer", + version: str = "1.0.0", + model_id: Optional[str] = None, + temperature: Optional[float] = None, + min_score: float = 0.5, + labels_to_ignore: Optional[List[str]] = None, + enable_generic_consolidation: bool = True, + **kwargs + ): + """Initialize LM recognizer. + + :param supported_entities: Entity types to detect. + :param labels_to_ignore: Entity labels to skip. + :param enable_generic_consolidation: Consolidate unknown + entities to GENERIC_PII_ENTITY. + """ + if not supported_entities: + raise ValueError( + "LMRecognizer requires at least one entity in 'supported_entities'" + ) + + super().__init__( + supported_entities=supported_entities, + supported_language=supported_language, + name=name, + version=version, + **kwargs + ) + + self.model_id = model_id + self.temperature = temperature + self.min_score = min_score + self.labels_to_ignore = [label.lower() for label in (labels_to_ignore or [])] + self.enable_generic_consolidation = enable_generic_consolidation + + self._generic_entities_logged = set() + + self.supported_entities = ensure_generic_entity_support( + self.supported_entities, enable_generic_consolidation + ) + + @abstractmethod + def _call_llm( + self, + text: str, + entities: List[str], + **kwargs + ) -> List[RecognizerResult]: + """ + Call LLM service and return RecognizerResult objects. + + Subclasses implement this to integrate with specific LLM providers. + + :param text: Text to analyze for PII. + :param entities: Entity types to detect. + :return: List of RecognizerResult objects. + """ + ... + + def analyze( + self, + text: str, + entities: Optional[List[str]] = None, + nlp_artifacts: Optional[NlpArtifacts] = None + ) -> List[RecognizerResult]: + """Analyze text for PII/PHI using LLM.""" + if not text or not text.strip(): + logger.debug("Empty text provided, returning empty results") + return [] + + if entities is None: + requested_entities = self.supported_entities + else: + requested_entities = [e for e in entities if e in self.supported_entities] + + if not requested_entities: + logger.debug( + "No requested entities (%s) match supported entities (%s)", + entities, self.supported_entities + ) + return [] + + results = self._call_llm(text, requested_entities) + + filtered_results = self._filter_and_process_results( + results, requested_entities + ) + + if filtered_results: + logger.debug( + "LLM recognizer found %d entities", + len(filtered_results), + ) + + return filtered_results + + def _filter_and_process_results( + self, + results: List[RecognizerResult], + requested_entities: Optional[List[str]] = None + ) -> List[RecognizerResult]: + """Filter and process results.""" + filtered_results = filter_results_by_labels(results, self.labels_to_ignore) + + if self.enable_generic_consolidation: + filtered_results = consolidate_generic_entities( + filtered_results, + self.supported_entities, + self._generic_entities_logged + ) + else: + filtered_results = skip_unmapped_entities( + filtered_results, + self.supported_entities + ) + + if requested_entities: + filtered_results = filter_results_by_entities( + filtered_results, + requested_entities + ) + + filtered_results = validate_result_positions(filtered_results) + filtered_results = filter_results_by_score(filtered_results, self.min_score) + + return filtered_results + + def get_supported_entities(self) -> List[str]: + """Return list of supported PII entity types.""" + return self.supported_entities diff --git a/presidio-analyzer/presidio_analyzer/predefined_recognizers/__init__.py b/presidio-analyzer/presidio_analyzer/predefined_recognizers/__init__.py index ef9f2d2272..9333bb8211 100644 --- a/presidio-analyzer/presidio_analyzer/predefined_recognizers/__init__.py +++ b/presidio-analyzer/presidio_analyzer/predefined_recognizers/__init__.py @@ -85,6 +85,8 @@ # Third-party recognizers from .third_party.azure_ai_language import AzureAILanguageRecognizer +from .third_party.langextract_recognizer import LangExtractRecognizer +from .third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer PREDEFINED_RECOGNIZERS = [ "PhoneRecognizer", @@ -152,4 +154,6 @@ "AzureHealthDeidRecognizer", "KrRrnRecognizer", "ThTninRecognizer", + "LangExtractRecognizer", + "OllamaLangExtractRecognizer", ] diff --git a/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/__init__.py b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/__init__.py index 9714fbaa04..8afd3b63ea 100644 --- a/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/__init__.py +++ b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/__init__.py @@ -2,8 +2,10 @@ from .ahds_recognizer import AzureHealthDeidRecognizer from .azure_ai_language import AzureAILanguageRecognizer +from .langextract_recognizer import LangExtractRecognizer __all__ = [ "AzureAILanguageRecognizer", - "AzureHealthDeidRecognizer" + "AzureHealthDeidRecognizer", + "LangExtractRecognizer", ] diff --git a/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/langextract_recognizer.py b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/langextract_recognizer.py new file mode 100644 index 0000000000..c744ea895b --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/langextract_recognizer.py @@ -0,0 +1,136 @@ +import logging +from abc import ABC, abstractmethod +from typing import List + +from presidio_analyzer.llm_utils import ( + check_langextract_available, + convert_langextract_to_presidio_results, + convert_to_langextract_format, + extract_lm_config, + get_model_config, + get_supported_entities, + load_prompt_file, + load_yaml_examples, + load_yaml_file, + render_jinja_template, + validate_config_fields, +) +from presidio_analyzer.lm_recognizer import LMRecognizer + +logger = logging.getLogger("presidio-analyzer") + + +class LangExtractRecognizer(LMRecognizer, ABC): + """ + Base class for LangExtract-based PII recognizers. + + Subclasses implement _call_langextract() for specific LLM providers. + """ + + def __init__( + self, + config_path: str, + name: str = "LangExtract LLM PII", + supported_language: str = "en" + ): + """Initialize LangExtract recognizer. + + :param config_path: Path to configuration file. + :param name: Name of the recognizer (provided by subclass). + :param supported_language: Language this recognizer supports (default: "en"). + """ + check_langextract_available() + + full_config = load_yaml_file(config_path) + + lm_config = extract_lm_config(full_config) + langextract_config = full_config.get("langextract", {}) + + supported_entities = get_supported_entities(lm_config, langextract_config) + + if not supported_entities: + raise ValueError( + "Configuration must contain 'supported_entities' in " + "'lm_recognizer' or 'langextract'" + ) + + validate_config_fields( + full_config, + [ + ("langextract",), + ("langextract", "model"), + ("langextract", "model", "model_id"), + ("langextract", "entity_mappings"), + ("langextract", "prompt_file"), + ("langextract", "examples_file"), + ] + ) + + self.config = langextract_config + model_config = get_model_config( + full_config, provider_key="langextract" + ) + + super().__init__( + supported_entities=supported_entities, + supported_language=supported_language, + name=name, + version="1.0.0", + model_id=model_config["model_id"], + temperature=model_config.get("temperature"), + min_score=lm_config.get("min_score"), + labels_to_ignore=lm_config.get("labels_to_ignore"), + enable_generic_consolidation=lm_config.get( + "enable_generic_consolidation" + ), + ) + + examples_data = load_yaml_examples( + langextract_config["examples_file"] + ) + self.examples = convert_to_langextract_format(examples_data) + + prompt_template = load_prompt_file( + langextract_config["prompt_file"] + ) + self.prompt_description = render_jinja_template( + prompt_template, + supported_entities=self.supported_entities, + enable_generic_consolidation=self.enable_generic_consolidation, + labels_to_ignore=self.labels_to_ignore, + ) + + self.entity_mappings = langextract_config["entity_mappings"] + self.debug = langextract_config.get("debug", False) + + def _call_llm(self, text: str, entities: List[str], **kwargs): + """Call LangExtract LLM.""" + # Build extract params + extract_params = { + "text": text, + "prompt": self.prompt_description, + "examples": self.examples, + "debug": self.debug, + } + + # Add temperature if configured + if self.temperature is not None: + extract_params["temperature"] = self.temperature + + # Add any additional kwargs + extract_params.update(kwargs) + + langextract_result = self._call_langextract(**extract_params) + + return convert_langextract_to_presidio_results( + langextract_result=langextract_result, + entity_mappings=self.entity_mappings, + supported_entities=self.supported_entities, + enable_generic_consolidation=self.enable_generic_consolidation, + recognizer_name=self.__class__.__name__ + ) + + @abstractmethod + def _call_langextract(self, **kwargs): + """Call provider-specific LangExtract implementation.""" + ... diff --git a/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ollama_langextract_recognizer.py b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ollama_langextract_recognizer.py new file mode 100644 index 0000000000..ad976a655d --- /dev/null +++ b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ollama_langextract_recognizer.py @@ -0,0 +1,73 @@ +import logging +from pathlib import Path +from typing import Optional + +from presidio_analyzer.llm_utils import lx +from presidio_analyzer.predefined_recognizers.third_party.\ + langextract_recognizer import LangExtractRecognizer + +logger = logging.getLogger("presidio-analyzer") + +class OllamaLangExtractRecognizer(LangExtractRecognizer): + """LangExtract recognizer using Ollama backend.""" + + DEFAULT_CONFIG_PATH = ( + Path(__file__).parent.parent.parent / "conf" / "langextract_config_ollama.yaml" + ) + + def __init__( + self, + config_path: Optional[str] = None, + supported_language: str = "en", + context: Optional[list] = None, + **kwargs, + ): + """Initialize Ollama LangExtract recognizer. + + Note: Ollama server availability and model availability are not validated + during initialization. Any connectivity or model issues will be reported + when analyze() is first called. + + :param config_path: Path to configuration file (optional). + :param supported_language: Language this recognizer supports + (optional, default: "en"). + :param context: List of context words + (optional, currently not used by LLM recognizers). + :param kwargs: Additional keyword arguments (unused, allows flexibility + when instantiated from YAML configuration). + """ + actual_config_path = ( + config_path if config_path else str(self.DEFAULT_CONFIG_PATH) + ) + + super().__init__( + config_path=actual_config_path, + name="Ollama LangExtract PII", + supported_language=supported_language + ) + + model_config = self.config.get("model", {}) + self.model_url = model_config.get("model_url") + if not self.model_url: + raise ValueError("Ollama model configuration must contain 'model_url'") + + def _call_langextract(self, **kwargs): + """Call Ollama through LangExtract.""" + try: + extract_params = { + "text_or_documents": kwargs.pop("text"), + "prompt_description": kwargs.pop("prompt"), + "examples": kwargs.pop("examples"), + "model_id": self.model_id, + "model_url": self.model_url, + } + + extract_params.update(kwargs) + + return lx.extract(**extract_params) + except Exception: + logger.exception( + "LangExtract extraction failed (Ollama at %s, model '%s')", + self.model_url, self.model_id + ) + raise diff --git a/presidio-analyzer/pyproject.toml b/presidio-analyzer/pyproject.toml index 27b30eabba..32758e0b2c 100644 --- a/presidio-analyzer/pyproject.toml +++ b/presidio-analyzer/pyproject.toml @@ -59,6 +59,11 @@ gliner = [ "gliner (>=0.2.13,<1.0.0) ; python_version >= '3.10'", "onnxruntime (>=1.19) ; python_version >= '3.10'" ] +langextract = [ + "langextract (>=1.0.0)", + "more-itertools (>=10.0.0)", + "jinja2 (>=3.0.0)", +] [tool.poetry.group.dev.dependencies] pip = "*" diff --git a/presidio-analyzer/tests/test_config_loader.py b/presidio-analyzer/tests/test_config_loader.py new file mode 100644 index 0000000000..055ae51dae --- /dev/null +++ b/presidio-analyzer/tests/test_config_loader.py @@ -0,0 +1,274 @@ +"""Tests for llm_utils.config_loader module.""" +import pytest +import tempfile +from pathlib import Path +from presidio_analyzer.llm_utils.config_loader import ( + load_yaml_file, + get_model_config, +) +from presidio_analyzer.llm_utils.langextract_helper import extract_lm_config + + +class TestLoadYamlFile: + """Tests for load_yaml_file function.""" + + def test_when_config_file_exists_then_loads_yaml(self): + """Test loading a valid YAML configuration file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +lm_recognizer: + supported_entities: ["PERSON", "EMAIL"] + min_score: 0.7 + labels_to_ignore: ["metadata"] +""") + config_path = f.name + + try: + config = load_yaml_file(config_path) + assert config is not None + assert "lm_recognizer" in config + assert config["lm_recognizer"]["supported_entities"] == ["PERSON", "EMAIL"] + assert config["lm_recognizer"]["min_score"] == 0.7 + finally: + Path(config_path).unlink() + + def test_when_config_file_missing_then_raises_file_not_found_error(self): + """Test that missing config file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + load_yaml_file("/nonexistent/path/config.yaml") + + def test_when_config_has_invalid_yaml_then_raises_value_error(self): + """Test that invalid YAML raises ValueError.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write("invalid: yaml: content:\n - [unclosed") + config_path = f.name + + try: + with pytest.raises(ValueError, match="Failed to parse YAML"): + load_yaml_file(config_path) + finally: + Path(config_path).unlink() + + def test_when_config_is_empty_then_returns_none(self): + """Test loading an empty YAML file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write("") + config_path = f.name + + try: + config = load_yaml_file(config_path) + # yaml.safe_load returns None for empty file + assert config is None + finally: + Path(config_path).unlink() + + def test_when_config_has_multiple_sections_then_loads_all(self): + """Test loading config with multiple sections.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +lm_recognizer: + supported_entities: ["PERSON"] + +langextract: + model: + model_id: "qwen2.5:1.5b" + +other_section: + some_key: "some_value" +""") + config_path = f.name + + try: + config = load_yaml_file(config_path) + assert "lm_recognizer" in config + assert "langextract" in config + assert "other_section" in config + finally: + Path(config_path).unlink() + + +class TestExtractLmConfig: + """Tests for extract_lm_config function.""" + + def test_when_lm_recognizer_section_exists_then_extracts_with_defaults(self): + """Test extracting lm_recognizer config applies defaults for missing fields.""" + full_config = { + "lm_recognizer": { + "supported_entities": ["PERSON", "EMAIL"] + } + } + + lm_config = extract_lm_config(full_config) + + # Should have the provided entities + assert lm_config["supported_entities"] == ["PERSON", "EMAIL"] + # Should apply defaults for missing fields + assert lm_config["min_score"] == 0.5 + assert lm_config["labels_to_ignore"] == [] + assert lm_config["enable_generic_consolidation"] is True + + def test_when_all_fields_present_then_uses_provided_values(self): + """Test that provided values override defaults.""" + full_config = { + "lm_recognizer": { + "supported_entities": ["PERSON"], + "min_score": 0.8, + "labels_to_ignore": ["system", "metadata"], + "enable_generic_consolidation": False + } + } + + lm_config = extract_lm_config(full_config) + + assert lm_config["supported_entities"] == ["PERSON"] + assert lm_config["min_score"] == 0.8 + assert lm_config["labels_to_ignore"] == ["system", "metadata"] + assert lm_config["enable_generic_consolidation"] is False + + def test_when_lm_recognizer_missing_then_returns_none_for_entities(self): + """Test that missing lm_recognizer section returns None for supported_entities.""" + full_config = {"other_section": {}} + + lm_config = extract_lm_config(full_config) + + # Should return None for supported_entities when lm_recognizer missing + assert lm_config["supported_entities"] is None + assert lm_config["min_score"] == 0.5 + assert lm_config["labels_to_ignore"] == [] + assert lm_config["enable_generic_consolidation"] is True + + def test_when_partial_config_then_merges_with_defaults(self): + """Test partial config merges with defaults.""" + full_config = { + "lm_recognizer": { + "supported_entities": ["PHONE"], + "min_score": 0.6 + # labels_to_ignore and enable_generic_consolidation missing + } + } + + lm_config = extract_lm_config(full_config) + + assert lm_config["supported_entities"] == ["PHONE"] + assert lm_config["min_score"] == 0.6 + assert lm_config["labels_to_ignore"] == [] # Default + assert lm_config["enable_generic_consolidation"] is True # Default + + +class TestGetModelConfig: + """Tests for get_model_config function.""" + + def test_when_provider_key_exists_then_extracts_model_config(self): + """Test extracting model config for a specific provider.""" + config = { + "langextract": { + "model": { + "model_id": "qwen2.5:1.5b", + "temperature": 0.0, + "model_url": "http://localhost:11434" + } + } + } + + model_config = get_model_config(config, "langextract") + + assert model_config is not None + assert "model_id" in model_config + assert model_config["model_id"] == "qwen2.5:1.5b" + assert model_config["temperature"] == 0.0 + + def test_when_provider_key_missing_then_raises_value_error(self): + """Test that missing provider key raises ValueError.""" + config = { + "other_provider": { + "model": {"model_id": "some-model"} + } + } + + with pytest.raises(ValueError, match="Configuration must contain 'langextract'"): + get_model_config(config, "langextract") + + def test_when_model_section_missing_then_raises_value_error(self): + """Test that missing model section raises ValueError.""" + config = { + "langextract": { + "other_section": {} + } + } + + with pytest.raises(ValueError, match="Configuration must contain 'langextract.model'"): + get_model_config(config, "langextract") + + def test_when_model_id_missing_then_raises_value_error(self): + """Test that missing model_id raises ValueError.""" + config = { + "langextract": { + "model": { + "temperature": 0.0 + } + } + } + + with pytest.raises(ValueError, match="Configuration must contain 'langextract.model.model_id'"): + get_model_config(config, "langextract") + + def test_when_model_config_has_extra_params_then_includes_all(self): + """Test that extra model parameters are included.""" + config = { + "langextract": { + "model": { + "model_id": "qwen2.5:1.5b", + "temperature": 0.1, + "model_url": "http://localhost:11434", + "custom_param": "custom_value" + } + } + } + + model_config = get_model_config(config, "langextract") + + assert model_config["model_id"] == "qwen2.5:1.5b" + assert model_config["temperature"] == 0.1 + assert model_config["model_url"] == "http://localhost:11434" + assert model_config["custom_param"] == "custom_value" + + +class TestIntegration: + """Integration tests for config_loader functions.""" + + def test_when_loading_full_config_workflow_then_extracts_correctly(self): + """Test complete workflow: load YAML → extract lm config → get model config.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +lm_recognizer: + supported_entities: ["PERSON", "EMAIL", "PHONE"] + min_score: 0.75 + labels_to_ignore: ["system"] + enable_generic_consolidation: false + +langextract: + model: + model_id: "qwen2.5:1.5b" + temperature: 0.0 + model_url: "http://localhost:11434" +""") + config_path = f.name + + try: + # Step 1: Load YAML + full_config = load_yaml_file(config_path) + + # Step 2: Extract lm_recognizer config + lm_config = extract_lm_config(full_config) + assert lm_config["supported_entities"] == ["PERSON", "EMAIL", "PHONE"] + assert lm_config["min_score"] == 0.75 + assert lm_config["labels_to_ignore"] == ["system"] + assert lm_config["enable_generic_consolidation"] is False + + # Step 3: Get model config + model_config = get_model_config(full_config, "langextract") + assert model_config["model_id"] == "qwen2.5:1.5b" + assert model_config["temperature"] == 0.0 + + finally: + Path(config_path).unlink() diff --git a/presidio-analyzer/tests/test_entity_mapper.py b/presidio-analyzer/tests/test_entity_mapper.py new file mode 100644 index 0000000000..60f1875d44 --- /dev/null +++ b/presidio-analyzer/tests/test_entity_mapper.py @@ -0,0 +1,457 @@ +"""Tests for llm_utils.entity_mapper module.""" +import pytest +from presidio_analyzer import RecognizerResult, AnalysisExplanation +from presidio_analyzer.llm_utils.entity_mapper import ( + GENERIC_PII_ENTITY, + filter_results_by_labels, + filter_results_by_score, + filter_results_by_entities, + validate_result_positions, + consolidate_generic_entities, + skip_unmapped_entities, + ensure_generic_entity_support, +) + + +def create_test_result( + entity_type="PERSON", + start=0, + end=10, + score=0.9, + recognition_metadata=None +): + """Helper to create RecognizerResult for testing.""" + return RecognizerResult( + entity_type=entity_type, + start=start, + end=end, + score=score, + analysis_explanation=AnalysisExplanation( + recognizer="TestRecognizer", + original_score=score, + textual_explanation="Test" + ), + recognition_metadata=recognition_metadata + ) + + +class TestFilterResultsByLabels: + """Tests for filter_results_by_labels function.""" + + def test_when_no_ignored_labels_then_returns_all_results(self): + """Test that all results are returned when no labels are ignored.""" + results = [ + create_test_result("PERSON"), + create_test_result("EMAIL_ADDRESS"), + ] + + filtered = filter_results_by_labels(results, []) + + assert len(filtered) == 2 + + def test_when_result_matches_ignored_label_then_filters_out(self): + """Test that results with ignored labels are filtered out.""" + results = [ + create_test_result("PERSON"), + create_test_result("SYSTEM"), + create_test_result("EMAIL_ADDRESS"), + ] + + filtered = filter_results_by_labels(results, ["system"]) + + assert len(filtered) == 2 + assert all(r.entity_type != "SYSTEM" for r in filtered) + + def test_when_multiple_ignored_labels_then_filters_all(self): + """Test filtering multiple ignored labels.""" + results = [ + create_test_result("PERSON"), + create_test_result("SYSTEM"), + create_test_result("METADATA"), + create_test_result("EMAIL_ADDRESS"), + ] + + filtered = filter_results_by_labels(results, ["system", "metadata"]) + + assert len(filtered) == 2 + entity_types = [r.entity_type for r in filtered] + assert "PERSON" in entity_types + assert "EMAIL_ADDRESS" in entity_types + + def test_when_label_matching_is_case_insensitive(self): + """Test that label matching is case-insensitive.""" + results = [ + create_test_result("PERSON"), + create_test_result("System"), + ] + + filtered = filter_results_by_labels(results, ["SYSTEM"]) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + def test_when_result_has_no_entity_type_then_skips_with_warning(self): + """Test that results without entity_type are skipped.""" + result_without_type = RecognizerResult( + entity_type=None, + start=0, + end=10, + score=0.9 + ) + results = [ + result_without_type, + create_test_result("PERSON"), + ] + + filtered = filter_results_by_labels(results, []) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + +class TestFilterResultsByScore: + """Tests for filter_results_by_score function.""" + + def test_when_all_results_above_threshold_then_returns_all(self): + """Test that all results above threshold are returned.""" + results = [ + create_test_result("PERSON", score=0.9), + create_test_result("EMAIL_ADDRESS", score=0.8), + ] + + filtered = filter_results_by_score(results, min_score=0.5) + + assert len(filtered) == 2 + + def test_when_result_below_threshold_then_filters_out(self): + """Test that results below threshold are filtered out.""" + results = [ + create_test_result("PERSON", score=0.9), + create_test_result("EMAIL_ADDRESS", score=0.4), + create_test_result("PHONE_NUMBER", score=0.7), + ] + + filtered = filter_results_by_score(results, min_score=0.5) + + assert len(filtered) == 2 + scores = [r.score for r in filtered] + assert all(s >= 0.5 for s in scores) + + def test_when_result_equals_threshold_then_includes(self): + """Test that results exactly at threshold are included.""" + results = [ + create_test_result("PERSON", score=0.5), + ] + + filtered = filter_results_by_score(results, min_score=0.5) + + assert len(filtered) == 1 + + def test_when_min_score_is_zero_then_returns_all(self): + """Test that min_score=0 includes all results.""" + results = [ + create_test_result("PERSON", score=0.1), + create_test_result("EMAIL_ADDRESS", score=0.9), + ] + + filtered = filter_results_by_score(results, min_score=0.0) + + assert len(filtered) == 2 + + def test_when_min_score_is_one_then_only_perfect_scores(self): + """Test that min_score=1.0 only includes perfect scores.""" + results = [ + create_test_result("PERSON", score=1.0), + create_test_result("EMAIL_ADDRESS", score=0.99), + ] + + filtered = filter_results_by_score(results, min_score=1.0) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + +class TestFilterResultsByEntities: + """Tests for filter_results_by_entities function.""" + + def test_when_no_requested_entities_then_returns_all(self): + """Test that all results are returned when requested_entities is empty.""" + results = [ + create_test_result("PERSON"), + create_test_result("EMAIL_ADDRESS"), + ] + + filtered = filter_results_by_entities(results, []) + + assert len(filtered) == 2 + + def test_when_result_matches_requested_entity_then_includes(self): + """Test that only requested entities are included.""" + results = [ + create_test_result("PERSON"), + create_test_result("EMAIL_ADDRESS"), + create_test_result("PHONE_NUMBER"), + ] + + filtered = filter_results_by_entities(results, ["PERSON", "PHONE_NUMBER"]) + + assert len(filtered) == 2 + entity_types = [r.entity_type for r in filtered] + assert "PERSON" in entity_types + assert "PHONE_NUMBER" in entity_types + assert "EMAIL_ADDRESS" not in entity_types + + def test_when_result_not_in_requested_entities_then_filters_out(self): + """Test that non-requested entities are filtered out.""" + results = [ + create_test_result("PERSON"), + create_test_result("UNKNOWN_TYPE"), + ] + + filtered = filter_results_by_entities(results, ["PERSON"]) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + +class TestValidateResultPositions: + """Tests for validate_result_positions function.""" + + def test_when_all_results_have_positions_then_returns_all(self): + """Test that results with valid positions are returned.""" + results = [ + create_test_result("PERSON", start=0, end=10), + create_test_result("EMAIL_ADDRESS", start=20, end=40), + ] + + filtered = validate_result_positions(results) + + assert len(filtered) == 2 + + def test_when_result_missing_start_then_filters_out(self): + """Test that results with missing start are filtered out.""" + valid_result = create_test_result("PERSON", start=0, end=10) + invalid_result = RecognizerResult( + entity_type="EMAIL_ADDRESS", + start=None, + end=40, + score=0.9 + ) + + filtered = validate_result_positions([valid_result, invalid_result]) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + def test_when_result_missing_end_then_filters_out(self): + """Test that results with missing end are filtered out.""" + valid_result = create_test_result("PERSON", start=0, end=10) + invalid_result = RecognizerResult( + entity_type="EMAIL_ADDRESS", + start=20, + end=None, + score=0.9 + ) + + filtered = validate_result_positions([valid_result, invalid_result]) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + def test_when_result_missing_both_positions_then_filters_out(self): + """Test that results with both positions missing are filtered out.""" + valid_result = create_test_result("PERSON", start=0, end=10) + invalid_result = RecognizerResult( + entity_type="EMAIL_ADDRESS", + start=None, + end=None, + score=0.9 + ) + + filtered = validate_result_positions([valid_result, invalid_result]) + + assert len(filtered) == 1 + + +class TestConsolidateGenericEntities: + """Tests for consolidate_generic_entities function.""" + + def test_when_entity_supported_then_keeps_original(self): + """Test that supported entities are not modified.""" + results = [ + create_test_result("PERSON"), + ] + supported_entities = ["PERSON", "EMAIL_ADDRESS"] + logged_entities = set() + + processed = consolidate_generic_entities(results, supported_entities, logged_entities) + + assert len(processed) == 1 + assert processed[0].entity_type == "PERSON" + + def test_when_entity_unsupported_then_consolidates_to_generic(self): + """Test that unsupported entities are consolidated to GENERIC_PII_ENTITY.""" + results = [ + create_test_result("UNKNOWN_TYPE"), + ] + supported_entities = ["PERSON"] + logged_entities = set() + + processed = consolidate_generic_entities(results, supported_entities, logged_entities) + + assert len(processed) == 1 + assert processed[0].entity_type == GENERIC_PII_ENTITY + + def test_when_entity_consolidated_then_stores_original_in_metadata(self): + """Test that original entity type is stored in metadata.""" + results = [ + create_test_result("UNKNOWN_TYPE"), + ] + supported_entities = ["PERSON"] + logged_entities = set() + + processed = consolidate_generic_entities(results, supported_entities, logged_entities) + + assert processed[0].recognition_metadata is not None + assert processed[0].recognition_metadata["original_entity_type"] == "UNKNOWN_TYPE" + + def test_when_entity_already_has_metadata_then_adds_original_type(self): + """Test that original_entity_type is added to existing metadata.""" + results = [ + create_test_result("UNKNOWN_TYPE", recognition_metadata={"key": "value"}), + ] + supported_entities = ["PERSON"] + logged_entities = set() + + processed = consolidate_generic_entities(results, supported_entities, logged_entities) + + assert processed[0].recognition_metadata["key"] == "value" + assert processed[0].recognition_metadata["original_entity_type"] == "UNKNOWN_TYPE" + + def test_when_unknown_entity_first_seen_then_logs_to_set(self): + """Test that first occurrence of unknown entity is added to logged set.""" + results = [ + create_test_result("UNKNOWN_TYPE"), + ] + supported_entities = ["PERSON"] + logged_entities = set() + + consolidate_generic_entities(results, supported_entities, logged_entities) + + assert "UNKNOWN_TYPE" in logged_entities + + def test_when_unknown_entity_seen_again_then_not_logged_twice(self): + """Test that same unknown entity is only logged once.""" + results = [ + create_test_result("UNKNOWN_TYPE"), + create_test_result("UNKNOWN_TYPE"), + ] + supported_entities = ["PERSON"] + logged_entities = set() + + consolidate_generic_entities(results, supported_entities, logged_entities) + + assert len([e for e in logged_entities if e == "UNKNOWN_TYPE"]) == 1 + + +class TestSkipUnmappedEntities: + """Tests for skip_unmapped_entities function.""" + + def test_when_entity_supported_then_includes(self): + """Test that supported entities are included.""" + results = [ + create_test_result("PERSON"), + create_test_result("EMAIL_ADDRESS"), + ] + supported_entities = ["PERSON", "EMAIL_ADDRESS"] + + filtered = skip_unmapped_entities(results, supported_entities) + + assert len(filtered) == 2 + + def test_when_entity_unsupported_then_skips(self): + """Test that unsupported entities are skipped.""" + results = [ + create_test_result("PERSON"), + create_test_result("UNKNOWN_TYPE"), + ] + supported_entities = ["PERSON"] + + filtered = skip_unmapped_entities(results, supported_entities) + + assert len(filtered) == 1 + assert filtered[0].entity_type == "PERSON" + + def test_when_multiple_unsupported_then_skips_all(self): + """Test that multiple unsupported entities are all skipped.""" + results = [ + create_test_result("PERSON"), + create_test_result("UNKNOWN1"), + create_test_result("UNKNOWN2"), + ] + supported_entities = ["PERSON"] + + filtered = skip_unmapped_entities(results, supported_entities) + + assert len(filtered) == 1 + + +class TestEnsureGenericEntitySupport: + """Tests for ensure_generic_entity_support function.""" + + def test_when_consolidation_enabled_and_generic_missing_then_adds_generic(self): + """Test that GENERIC_PII_ENTITY is added when consolidation is enabled.""" + supported_entities = ["PERSON", "EMAIL_ADDRESS"] + + result = ensure_generic_entity_support(supported_entities, enable_generic_consolidation=True) + + assert GENERIC_PII_ENTITY in result + assert "PERSON" in result + assert "EMAIL_ADDRESS" in result + + def test_when_consolidation_disabled_then_does_not_add_generic(self): + """Test that GENERIC_PII_ENTITY is not added when consolidation is disabled.""" + supported_entities = ["PERSON", "EMAIL_ADDRESS"] + + result = ensure_generic_entity_support(supported_entities, enable_generic_consolidation=False) + + assert GENERIC_PII_ENTITY not in result + assert len(result) == 2 + + def test_when_generic_already_present_and_consolidation_enabled_then_no_duplicate(self): + """Test that GENERIC_PII_ENTITY is not duplicated.""" + supported_entities = ["PERSON", GENERIC_PII_ENTITY] + + result = ensure_generic_entity_support(supported_entities, enable_generic_consolidation=True) + + # Count occurrences of GENERIC_PII_ENTITY + generic_count = result.count(GENERIC_PII_ENTITY) + assert generic_count == 1 + + def test_when_original_list_not_modified(self): + """Test that original list is not modified (returns a copy).""" + original = ["PERSON", "EMAIL_ADDRESS"] + + result = ensure_generic_entity_support(original, enable_generic_consolidation=True) + + assert GENERIC_PII_ENTITY not in original # Original unchanged + assert GENERIC_PII_ENTITY in result # Result has it + + def test_when_empty_list_and_consolidation_enabled_then_adds_generic(self): + """Test that GENERIC_PII_ENTITY is added to empty list.""" + result = ensure_generic_entity_support([], enable_generic_consolidation=True) + + assert len(result) == 1 + assert result[0] == GENERIC_PII_ENTITY + + +class TestGenericPiiEntityConstant: + """Tests for GENERIC_PII_ENTITY constant.""" + + def test_generic_pii_entity_is_string(self): + """Test that GENERIC_PII_ENTITY is a string.""" + assert isinstance(GENERIC_PII_ENTITY, str) + + def test_generic_pii_entity_value(self): + """Test the expected value of GENERIC_PII_ENTITY.""" + assert GENERIC_PII_ENTITY == "GENERIC_PII_ENTITY" diff --git a/presidio-analyzer/tests/test_examples_loader.py b/presidio-analyzer/tests/test_examples_loader.py new file mode 100644 index 0000000000..2f21889163 --- /dev/null +++ b/presidio-analyzer/tests/test_examples_loader.py @@ -0,0 +1,280 @@ +"""Tests for llm_utils.examples_loader module.""" +import pytest +import tempfile +from pathlib import Path +from presidio_analyzer.llm_utils.examples_loader import ( + load_yaml_examples, + convert_to_langextract_format, +) + + +class TestLoadYamlExamples: + """Tests for load_yaml_examples function.""" + + def test_when_examples_file_exists_then_loads_list(self): + """Test loading examples from YAML file returns a list.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +examples: + - text: "John Doe works at Acme Corp" + extractions: + - extraction_class: "PERSON" + extraction_text: "John Doe" + attributes: {} + + - text: "Contact us at info@example.com" + extractions: + - extraction_class: "EMAIL_ADDRESS" + extraction_text: "info@example.com" + attributes: {} +""") + examples_path = Path(f.name) + + try: + # Use absolute path to load temp file + examples = load_yaml_examples(str(examples_path)) + + assert isinstance(examples, list) + assert len(examples) == 2 + assert examples[0]["text"] == "John Doe works at Acme Corp" + assert examples[1]["text"] == "Contact us at info@example.com" + finally: + examples_path.unlink() + + def test_when_examples_file_has_multiple_extractions_per_example_then_loads_all(self): + """Test loading example with multiple extractions.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +examples: + - text: "John at john@example.com" + extractions: + - extraction_class: "PERSON" + extraction_text: "John" + attributes: {} + - extraction_class: "EMAIL_ADDRESS" + extraction_text: "john@example.com" + attributes: {} +""") + examples_path = Path(f.name) + + try: + examples = load_yaml_examples(str(examples_path)) + + assert len(examples) == 1 + assert len(examples[0]["extractions"]) == 2 + finally: + examples_path.unlink() + + def test_when_examples_file_missing_then_raises_file_not_found_error(self): + """Test that missing examples file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + load_yaml_examples("nonexistent_examples.yaml") + + def test_when_examples_missing_from_yaml_then_raises_value_error(self): + """Test that YAML without 'examples' key raises ValueError.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +other_section: + - text: "some text" +""") + examples_path = Path(f.name) + + try: + with pytest.raises(ValueError, match="Examples file must contain 'examples'"): + load_yaml_examples(str(examples_path)) + finally: + examples_path.unlink() + + def test_when_loading_actual_langextract_examples_file_then_works(self): + """Test loading the actual langextract examples file.""" + # Load the real examples file from conf directory using repo-root-relative path + examples = load_yaml_examples( + "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml" + ) + + assert isinstance(examples, list) + assert len(examples) > 0 + # Each example should have text and extractions + for example in examples: + assert "text" in example + assert "extractions" in example + + +class TestConvertToLangextractFormat: + """Tests for convert_to_langextract_format function.""" + + def test_when_examples_data_valid_then_converts_to_langextract_format(self): + """Test converting examples to LangExtract ExampleData format.""" + examples_data = [ + { + "text": "John Doe works here", + "extractions": [ + { + "extraction_class": "PERSON", + "extraction_text": "John Doe", + "attributes": {} + } + ] + } + ] + + result = convert_to_langextract_format(examples_data) + + assert isinstance(result, list) + assert len(result) == 1 + + # Check it's a LangExtract ExampleData object + example = result[0] + assert hasattr(example, 'text') + assert hasattr(example, 'extractions') + assert example.text == "John Doe works here" + assert len(example.extractions) == 1 + + def test_when_multiple_examples_then_converts_all(self): + """Test converting multiple examples.""" + examples_data = [ + { + "text": "John Doe", + "extractions": [ + {"extraction_class": "PERSON", "extraction_text": "John Doe", "attributes": {}} + ] + }, + { + "text": "info@example.com", + "extractions": [ + {"extraction_class": "EMAIL_ADDRESS", "extraction_text": "info@example.com", "attributes": {}} + ] + } + ] + + result = convert_to_langextract_format(examples_data) + + assert len(result) == 2 + assert result[0].text == "John Doe" + assert result[1].text == "info@example.com" + + def test_when_example_has_multiple_extractions_then_converts_all(self): + """Test converting example with multiple extractions.""" + examples_data = [ + { + "text": "John at john@example.com", + "extractions": [ + {"extraction_class": "PERSON", "extraction_text": "John", "attributes": {}}, + {"extraction_class": "EMAIL_ADDRESS", "extraction_text": "john@example.com", "attributes": {}} + ] + } + ] + + result = convert_to_langextract_format(examples_data) + + assert len(result) == 1 + assert len(result[0].extractions) == 2 + + def test_when_extraction_has_all_fields_then_preserves_data(self): + """Test that all extraction fields are preserved.""" + examples_data = [ + { + "text": "Sample text", + "extractions": [ + { + "extraction_class": "PERSON", + "extraction_text": "Sample", + "attributes": {"type": "name"} + } + ] + } + ] + + result = convert_to_langextract_format(examples_data) + extraction = result[0].extractions[0] + + # Check the Extraction object has correct attributes + assert hasattr(extraction, 'extraction_class') + assert hasattr(extraction, 'extraction_text') + assert hasattr(extraction, 'attributes') + assert extraction.extraction_class == "PERSON" + assert extraction.extraction_text == "Sample" + assert extraction.attributes == {"type": "name"} + + def test_when_empty_examples_list_then_returns_empty_list(self): + """Test converting empty examples list.""" + result = convert_to_langextract_format([]) + assert result == [] + + def test_when_example_has_no_extractions_then_creates_empty_extractions(self): + """Test example with empty extractions list.""" + examples_data = [ + { + "text": "No entities here", + "extractions": [] + } + ] + + result = convert_to_langextract_format(examples_data) + + assert len(result) == 1 + assert result[0].text == "No entities here" + assert len(result[0].extractions) == 0 + + +class TestIntegration: + """Integration tests for examples_loader functions.""" + + def test_when_loading_and_converting_workflow_then_works(self): + """Test complete workflow: load YAML → convert to LangExtract format.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(""" +examples: + - text: "John Doe at john@example.com and Jane Smith" + extractions: + - extraction_class: "PERSON" + extraction_text: "John Doe" + attributes: {} + - extraction_class: "EMAIL_ADDRESS" + extraction_text: "john@example.com" + attributes: {} + - extraction_class: "PERSON" + extraction_text: "Jane Smith" + attributes: {} +""") + examples_path = Path(f.name) + + try: + # Step 1: Load YAML + examples_data = load_yaml_examples(str(examples_path)) + assert len(examples_data) == 1 + assert len(examples_data[0]["extractions"]) == 3 + + # Step 2: Convert to LangExtract format + langextract_examples = convert_to_langextract_format(examples_data) + assert len(langextract_examples) == 1 + assert len(langextract_examples[0].extractions) == 3 + + # Verify entity types + entity_types = [e.extraction_class for e in langextract_examples[0].extractions] + assert "PERSON" in entity_types + assert "EMAIL_ADDRESS" in entity_types + assert entity_types.count("PERSON") == 2 + + finally: + examples_path.unlink() + + def test_when_using_actual_langextract_examples_then_converts_correctly(self): + """Test loading and converting the actual langextract examples file.""" + # Load the real examples file using repo-root-relative path + examples_data = load_yaml_examples( + "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml" + ) + + # Convert to LangExtract format + langextract_examples = convert_to_langextract_format(examples_data) + + assert isinstance(langextract_examples, list) + assert len(langextract_examples) > 0 + + # Verify structure + for example in langextract_examples: + assert hasattr(example, 'text') + assert hasattr(example, 'extractions') + assert isinstance(example.text, str) + assert isinstance(example.extractions, list) diff --git a/presidio-analyzer/tests/test_langextract_helper.py b/presidio-analyzer/tests/test_langextract_helper.py new file mode 100644 index 0000000000..c04901f1c7 --- /dev/null +++ b/presidio-analyzer/tests/test_langextract_helper.py @@ -0,0 +1,285 @@ +"""Tests for llm_utils.langextract_helper module.""" +import pytest +from unittest.mock import Mock, MagicMock, patch +from presidio_analyzer.llm_utils.langextract_helper import ( + extract_lm_config, + get_supported_entities, + create_reverse_entity_mapping, + calculate_extraction_confidence, + DEFAULT_ALIGNMENT_SCORES, +) + + +class TestExtractLmConfig: + """Tests for extract_lm_config function.""" + + def test_when_lm_recognizer_section_exists_then_extracts_all_fields(self): + """Test extracting all fields from lm_recognizer section.""" + config = { + "lm_recognizer": { + "supported_entities": ["PERSON", "EMAIL"], + "min_score": 0.7, + "labels_to_ignore": ["system"], + "enable_generic_consolidation": False + } + } + + result = extract_lm_config(config) + + assert result["supported_entities"] == ["PERSON", "EMAIL"] + assert result["min_score"] == 0.7 + assert result["labels_to_ignore"] == ["system"] + assert result["enable_generic_consolidation"] is False + + def test_when_lm_recognizer_missing_then_returns_defaults(self): + """Test that defaults are returned when lm_recognizer section is missing.""" + config = {} + + result = extract_lm_config(config) + + assert result["supported_entities"] is None + assert result["min_score"] == 0.5 + assert result["labels_to_ignore"] == [] + assert result["enable_generic_consolidation"] is True + + def test_when_partial_config_then_uses_defaults_for_missing_fields(self): + """Test that defaults are used for missing fields.""" + config = { + "lm_recognizer": { + "supported_entities": ["PERSON"] + } + } + + result = extract_lm_config(config) + + assert result["supported_entities"] == ["PERSON"] + assert result["min_score"] == 0.5 + assert result["labels_to_ignore"] == [] + assert result["enable_generic_consolidation"] is True + + def test_when_only_min_score_provided_then_uses_defaults_for_others(self): + """Test partial config with only min_score.""" + config = { + "lm_recognizer": { + "min_score": 0.8 + } + } + + result = extract_lm_config(config) + + assert result["supported_entities"] is None + assert result["min_score"] == 0.8 + assert result["labels_to_ignore"] == [] + assert result["enable_generic_consolidation"] is True + + +class TestGetSupportedEntities: + """Tests for get_supported_entities function.""" + + def test_when_lm_config_has_entities_then_returns_from_lm_config(self): + """Test that entities from lm_config are prioritized.""" + lm_config = {"supported_entities": ["PERSON", "EMAIL"]} + langextract_config = {"supported_entities": ["PHONE"]} + + result = get_supported_entities(lm_config, langextract_config) + + assert result == ["PERSON", "EMAIL"] + + def test_when_lm_config_missing_entities_then_returns_from_langextract_config(self): + """Test fallback to langextract_config when lm_config has no entities.""" + lm_config = {"supported_entities": None} + langextract_config = {"supported_entities": ["PHONE", "ADDRESS"]} + + result = get_supported_entities(lm_config, langextract_config) + + assert result == ["PHONE", "ADDRESS"] + + def test_when_both_missing_entities_then_returns_none(self): + """Test that None is returned when both configs lack entities.""" + lm_config = {} + langextract_config = {} + + result = get_supported_entities(lm_config, langextract_config) + + assert result is None + + def test_when_lm_config_has_empty_list_then_falls_back_to_langextract(self): + """Test that empty list in lm_config causes fallback.""" + lm_config = {"supported_entities": []} + langextract_config = {"supported_entities": ["PERSON"]} + + result = get_supported_entities(lm_config, langextract_config) + + # Empty list is falsy, so should fallback + assert result == ["PERSON"] + + +class TestCreateReverseEntityMapping: + """Tests for create_reverse_entity_mapping function.""" + + def test_when_mapping_provided_then_creates_reverse(self): + """Test creating reverse mapping from entity mappings.""" + entity_mappings = { + "person": "PERSON", + "email": "EMAIL_ADDRESS", + "phone": "PHONE_NUMBER" + } + + result = create_reverse_entity_mapping(entity_mappings) + + assert result["PERSON"] == "person" + assert result["EMAIL_ADDRESS"] == "email" + assert result["PHONE_NUMBER"] == "phone" + + def test_when_empty_mapping_then_returns_empty_dict(self): + """Test that empty mapping returns empty dict.""" + result = create_reverse_entity_mapping({}) + + assert result == {} + + def test_when_duplicate_values_then_last_wins(self): + """Test behavior with duplicate values in original mapping.""" + entity_mappings = { + "person1": "PERSON", + "person2": "PERSON" # Duplicate value + } + + result = create_reverse_entity_mapping(entity_mappings) + + # Last key-value pair should win + assert result["PERSON"] == "person2" + + +class TestCalculateExtractionConfidence: + """Tests for calculate_extraction_confidence function.""" + + def test_when_extraction_has_match_exact_then_returns_095(self): + """Test confidence for MATCH_EXACT alignment.""" + extraction = Mock() + extraction.alignment_status = "MATCH_EXACT" + + result = calculate_extraction_confidence(extraction) + + assert result == 0.95 + + def test_when_extraction_has_match_fuzzy_then_returns_080(self): + """Test confidence for MATCH_FUZZY alignment.""" + extraction = Mock() + extraction.alignment_status = "MATCH_FUZZY" + + result = calculate_extraction_confidence(extraction) + + assert result == 0.80 + + def test_when_extraction_has_match_lesser_then_returns_070(self): + """Test confidence for MATCH_LESSER alignment.""" + extraction = Mock() + extraction.alignment_status = "MATCH_LESSER" + + result = calculate_extraction_confidence(extraction) + + assert result == 0.70 + + def test_when_extraction_has_not_aligned_then_returns_060(self): + """Test confidence for NOT_ALIGNED alignment.""" + extraction = Mock() + extraction.alignment_status = "NOT_ALIGNED" + + result = calculate_extraction_confidence(extraction) + + assert result == 0.60 + + def test_when_extraction_missing_alignment_status_then_returns_default(self): + """Test that default score is returned when alignment_status is missing.""" + extraction = Mock(spec=[]) # Mock without alignment_status attribute + + result = calculate_extraction_confidence(extraction) + + assert result == 0.85 + + def test_when_alignment_status_is_none_then_returns_default(self): + """Test that default score is returned when alignment_status is None.""" + extraction = Mock() + extraction.alignment_status = None + + result = calculate_extraction_confidence(extraction) + + assert result == 0.85 + + def test_when_alignment_status_is_empty_string_then_returns_default(self): + """Test that default score is returned for empty alignment_status.""" + extraction = Mock() + extraction.alignment_status = "" + + result = calculate_extraction_confidence(extraction) + + assert result == 0.85 + + def test_when_unknown_alignment_status_then_returns_default(self): + """Test that default score is returned for unknown alignment status.""" + extraction = Mock() + extraction.alignment_status = "UNKNOWN_STATUS" + + result = calculate_extraction_confidence(extraction) + + assert result == 0.85 + + def test_when_alignment_status_lowercase_then_matches_correctly(self): + """Test that alignment status matching is case-insensitive.""" + extraction = Mock() + extraction.alignment_status = "match_exact" # lowercase + + result = calculate_extraction_confidence(extraction) + + # Should still match because we convert to uppercase + assert result == 0.95 + + def test_when_custom_alignment_scores_provided_then_uses_custom(self): + """Test using custom alignment scores.""" + extraction = Mock() + extraction.alignment_status = "MATCH_EXACT" + + custom_scores = { + "MATCH_EXACT": 0.99, + "MATCH_FUZZY": 0.75 + } + + result = calculate_extraction_confidence(extraction, alignment_scores=custom_scores) + + assert result == 0.99 + + def test_when_custom_scores_missing_status_then_returns_default(self): + """Test that default is returned when custom scores don't have the status.""" + extraction = Mock() + extraction.alignment_status = "MATCH_FUZZY" + + custom_scores = { + "MATCH_EXACT": 0.99 + # MATCH_FUZZY missing + } + + result = calculate_extraction_confidence(extraction, alignment_scores=custom_scores) + + assert result == 0.85 # Default score + + +class TestDefaultAlignmentScores: + """Tests for DEFAULT_ALIGNMENT_SCORES constant.""" + + def test_default_alignment_scores_has_all_statuses(self): + """Test that DEFAULT_ALIGNMENT_SCORES contains expected statuses.""" + assert "MATCH_EXACT" in DEFAULT_ALIGNMENT_SCORES + assert "MATCH_FUZZY" in DEFAULT_ALIGNMENT_SCORES + assert "MATCH_LESSER" in DEFAULT_ALIGNMENT_SCORES + assert "NOT_ALIGNED" in DEFAULT_ALIGNMENT_SCORES + + def test_default_alignment_scores_values_are_valid(self): + """Test that all scores are between 0 and 1.""" + for status, score in DEFAULT_ALIGNMENT_SCORES.items(): + assert 0.0 <= score <= 1.0, f"{status} score {score} is not between 0 and 1" + + def test_default_alignment_scores_are_ordered_correctly(self): + """Test that scores are in descending order of confidence.""" + assert DEFAULT_ALIGNMENT_SCORES["MATCH_EXACT"] > DEFAULT_ALIGNMENT_SCORES["MATCH_FUZZY"] + assert DEFAULT_ALIGNMENT_SCORES["MATCH_FUZZY"] > DEFAULT_ALIGNMENT_SCORES["MATCH_LESSER"] + assert DEFAULT_ALIGNMENT_SCORES["MATCH_LESSER"] > DEFAULT_ALIGNMENT_SCORES["NOT_ALIGNED"] diff --git a/presidio-analyzer/tests/test_lm_recognizer.py b/presidio-analyzer/tests/test_lm_recognizer.py new file mode 100644 index 0000000000..e6d4c39823 --- /dev/null +++ b/presidio-analyzer/tests/test_lm_recognizer.py @@ -0,0 +1,243 @@ +"""Tests for LMRecognizer base class.""" +import pytest +from unittest.mock import Mock + +from presidio_analyzer.lm_recognizer import LMRecognizer +from presidio_analyzer import RecognizerResult + + +class ConcreteLMRecognizer(LMRecognizer): + """Concrete implementation for testing.""" + + def __init__(self, **kwargs): + super().__init__( + supported_entities=["PERSON", "EMAIL_ADDRESS"], + name="Test LLM Recognizer", + **kwargs + ) + + def _call_llm(self, text, entities, **kwargs): + """Mock implementation that returns RecognizerResult objects.""" + return self.mock_entities if hasattr(self, 'mock_entities') else [] + + +class TestLMRecognizerAnalyze: + """Test LMRecognizer.analyze() method.""" + + def test_when_text_is_empty_then_returns_empty_list(self): + """Test that empty text returns empty results.""" + recognizer = ConcreteLMRecognizer() + + results = recognizer.analyze("") + assert results == [] + + results = recognizer.analyze(" ") + assert results == [] + + def test_when_entities_is_none_then_uses_all_supported_entities(self): + """Test that None entities parameter uses all supported entities.""" + recognizer = ConcreteLMRecognizer() + recognizer.mock_entities = [ + RecognizerResult(entity_type="PERSON", start=0, end=4, score=0.85) + ] + + results = recognizer.analyze("John", entities=None) + assert len(results) == 1 + assert results[0].entity_type == "PERSON" + + def test_when_requested_entity_not_supported_then_returns_empty(self): + """Test that unsupported entities return empty results.""" + recognizer = ConcreteLMRecognizer() + recognizer.mock_entities = [] + + results = recognizer.analyze("test", entities=["CREDIT_CARD"]) + assert results == [] + + def test_when_llm_returns_entities_then_returns_recognizer_results(self): + """Test that LLM returns RecognizerResult objects.""" + recognizer = ConcreteLMRecognizer() + recognizer.mock_entities = [ + RecognizerResult( + entity_type="PERSON", + start=11, + end=19, + score=0.9, + recognition_metadata={"source": "test"} + ) + ] + + results = recognizer.analyze("My name is John Doe", entities=["PERSON"]) + + assert len(results) == 1 + assert isinstance(results[0], RecognizerResult) + assert results[0].entity_type == "PERSON" + assert results[0].start == 11 + assert results[0].end == 19 + assert results[0].score == 0.9 + + def test_when_entity_below_min_score_then_filters_out(self): + """Test that entities below min_score are filtered.""" + recognizer = ConcreteLMRecognizer(min_score=0.8) + recognizer.mock_entities = [ + RecognizerResult(entity_type="PERSON", start=0, end=4, score=0.9), + RecognizerResult(entity_type="PERSON", start=5, end=9, score=0.7), + ] + + results = recognizer.analyze("John Jane") + + assert len(results) == 1 + assert results[0].score == 0.9 + + def test_when_llm_raises_exception_then_returns_empty_list(self): + """Test that exceptions are caught and empty list returned.""" + class ErrorLMRecognizer(LMRecognizer): + """Recognizer that raises error in _call_llm.""" + def __init__(self): + super().__init__( + supported_entities=["PERSON"], + name="Error Test Recognizer" + ) + + def _call_llm(self, text, entities, **kwargs): + raise RuntimeError("LLM API error") + + recognizer = ErrorLMRecognizer() + + with pytest.raises(RuntimeError, match="LLM API error"): + recognizer.analyze("test text", entities=["PERSON"]) + + def test_when_entity_missing_required_fields_then_skips(self): + """Test that entities with missing required fields are skipped.""" + recognizer = ConcreteLMRecognizer() + recognizer.mock_entities = [ + RecognizerResult(entity_type="", start=0, end=4, score=0.85), # Missing type + RecognizerResult(entity_type="PERSON", start=None, end=4, score=0.85), # Missing start + ] + + results = recognizer.analyze("test") + assert results == [] + + +class TestLMRecognizerMetadata: + """Test LMRecognizer metadata handling.""" + + def test_when_entity_has_metadata_then_preserves_it(self): + """Test that metadata is preserved in RecognizerResult.""" + recognizer = ConcreteLMRecognizer() + recognizer.mock_entities = [ + RecognizerResult( + entity_type="PERSON", + start=0, + end=4, + score=0.9, + recognition_metadata={"alignment": "MATCH_EXACT", "source": "test"} + ) + ] + + results = recognizer.analyze("John") + + assert len(results) == 1 + assert results[0].recognition_metadata == {"alignment": "MATCH_EXACT", "source": "test"} + assert results[0].score == 0.9 + + def test_when_entity_has_no_metadata_then_works_correctly(self): + """Test that entities without metadata work correctly.""" + recognizer = ConcreteLMRecognizer() + recognizer.mock_entities = [ + RecognizerResult(entity_type="PERSON", start=0, end=4, score=0.85) + ] + + results = recognizer.analyze("John") + + assert len(results) == 1 + assert results[0].entity_type == "PERSON" + assert results[0].score == 0.85 + + +class TestLMRecognizerFiltering: + """Test LMRecognizer filtering functionality.""" + + def test_when_entity_in_labels_to_ignore_then_filters_out(self): + """Test that entities in labels_to_ignore are filtered out.""" + recognizer = ConcreteLMRecognizer(labels_to_ignore=["PERSON", "location"]) + recognizer.mock_entities = [ + RecognizerResult(entity_type="PERSON", start=0, end=4, score=0.9), + RecognizerResult(entity_type="EMAIL_ADDRESS", start=5, end=20, score=0.9), + RecognizerResult(entity_type="LOCATION", start=21, end=30, score=0.9), + ] + + results = recognizer.analyze("test text") + + # Only EMAIL_ADDRESS should remain (PERSON and LOCATION ignored, case-insensitive) + assert len(results) == 1 + assert results[0].entity_type == "EMAIL_ADDRESS" + + def test_when_unknown_entity_and_consolidation_enabled_then_creates_generic(self): + """Test unknown entity types get consolidated to GENERIC_PII_ENTITY.""" + recognizer = ConcreteLMRecognizer(enable_generic_consolidation=True) + recognizer.mock_entities = [ + RecognizerResult(entity_type="UNKNOWN_TYPE", start=0, end=10, score=0.9), + RecognizerResult(entity_type="PERSON", start=11, end=15, score=0.9), + ] + + results = recognizer.analyze("test text") + + assert len(results) == 2 + # Unknown type should be converted to GENERIC_PII_ENTITY + assert results[0].entity_type == "GENERIC_PII_ENTITY" + assert results[0].recognition_metadata["original_entity_type"] == "UNKNOWN_TYPE" + # Known type should remain unchanged + assert results[1].entity_type == "PERSON" + + def test_when_unknown_entity_and_consolidation_disabled_then_skips(self): + """Test unknown entities are skipped when consolidation is disabled.""" + recognizer = ConcreteLMRecognizer(enable_generic_consolidation=False) + recognizer.mock_entities = [ + RecognizerResult(entity_type="UNKNOWN_TYPE", start=0, end=10, score=0.9), + RecognizerResult(entity_type="PERSON", start=11, end=15, score=0.9), + ] + + results = recognizer.analyze("test text") + + # Only known entity should be returned + assert len(results) == 1 + assert results[0].entity_type == "PERSON" + + def test_when_unknown_entity_with_existing_metadata_then_preserves_metadata(self): + """Test that existing metadata is preserved when adding original_entity_type.""" + recognizer = ConcreteLMRecognizer(enable_generic_consolidation=True) + recognizer.mock_entities = [ + RecognizerResult( + entity_type="CUSTOM_TYPE", + start=0, + end=10, + score=0.9, + recognition_metadata={"source": "custom", "confidence": "high"} + ), + ] + + results = recognizer.analyze("test text") + + assert len(results) == 1 + assert results[0].entity_type == "GENERIC_PII_ENTITY" + assert results[0].recognition_metadata["original_entity_type"] == "CUSTOM_TYPE" + assert results[0].recognition_metadata["source"] == "custom" + assert results[0].recognition_metadata["confidence"] == "high" + + def test_when_get_supported_entities_called_then_returns_list(self): + """Test get_supported_entities returns the correct list.""" + recognizer = ConcreteLMRecognizer() + entities = recognizer.get_supported_entities() + + assert "PERSON" in entities + assert "EMAIL_ADDRESS" in entities + assert "GENERIC_PII_ENTITY" in entities # Added by default + + def test_when_generic_consolidation_disabled_then_no_generic_in_supported(self): + """Test GENERIC_PII_ENTITY not added when consolidation disabled.""" + recognizer = ConcreteLMRecognizer(enable_generic_consolidation=False) + entities = recognizer.get_supported_entities() + + assert "PERSON" in entities + assert "EMAIL_ADDRESS" in entities + assert "GENERIC_PII_ENTITY" not in entities diff --git a/presidio-analyzer/tests/test_ollama_recognizer.py b/presidio-analyzer/tests/test_ollama_recognizer.py new file mode 100644 index 0000000000..b49d4fb67e --- /dev/null +++ b/presidio-analyzer/tests/test_ollama_recognizer.py @@ -0,0 +1,432 @@ +"""Tests for LangExtract recognizer hierarchy using mocks.""" +import pytest +import urllib.error +from unittest.mock import Mock, patch, MagicMock + + +def create_test_config( + supported_entities=None, + entity_mappings=None, + model_id="qwen2.5:1.5b", + model_url="http://localhost:11434", + temperature=0.0, + min_score=0.5, + labels_to_ignore=None, + enable_generic_consolidation=True +): + """Create test config.""" + if supported_entities is None: + supported_entities = ["PERSON", "EMAIL_ADDRESS"] + if entity_mappings is None: + entity_mappings = {"person": "PERSON", "email": "EMAIL_ADDRESS"} + if labels_to_ignore is None: + labels_to_ignore = [] + + return { + "lm_recognizer": { + "supported_entities": supported_entities, + "labels_to_ignore": labels_to_ignore, + "enable_generic_consolidation": enable_generic_consolidation, + "min_score": min_score, + }, + "langextract": { + "prompt_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2", + "examples_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml", + "model": { + "model_id": model_id, + "model_url": model_url, + "temperature": temperature, + }, + "entity_mappings": entity_mappings, + } + } + + +class TestOllamaLangExtractRecognizerInitialization: + """Test OllamaLangExtractRecognizer initialization and configuration loading.""" + + def test_when_langextract_not_installed_then_raises_import_error(self): + """Test that ImportError is raised when langextract is not installed.""" + with patch( + 'presidio_analyzer.llm_utils.langextract_helper.lx', + None + ): + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + with pytest.raises(ImportError, match="LangExtract is not installed"): + OllamaLangExtractRecognizer() + + def test_when_initialized_with_mocked_ollama_then_succeeds(self, tmp_path): + """Test OllamaLangExtractRecognizer initialization.""" + import yaml + + config = create_test_config( + supported_entities=["PERSON", "EMAIL_ADDRESS"], + entity_mappings={"person": "PERSON", "email": "EMAIL_ADDRESS"}, + model_id="qwen2.5:1.5b", + model_url="http://localhost:11434", + temperature=0.0, + min_score=0.5 + ) + + config_file = tmp_path / "test_config.yaml" + with open(config_file, 'w') as f: + yaml.dump(config, f) + + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + recognizer = OllamaLangExtractRecognizer(config_path=str(config_file)) + + # Verify initialization + assert recognizer.name == "Ollama LangExtract PII" + assert recognizer.model_id == "qwen2.5:1.5b" + assert recognizer.model_url == "http://localhost:11434" + assert len(recognizer.supported_entities) == 3 # PERSON, EMAIL_ADDRESS, GENERIC_PII_ENTITY + assert "PERSON" in recognizer.supported_entities + assert "EMAIL_ADDRESS" in recognizer.supported_entities + assert "GENERIC_PII_ENTITY" in recognizer.supported_entities + + # Verify inheritance hierarchy + from presidio_analyzer.lm_recognizer import LMRecognizer + from presidio_analyzer.predefined_recognizers.third_party.langextract_recognizer import LangExtractRecognizer + assert isinstance(recognizer, OllamaLangExtractRecognizer) + assert isinstance(recognizer, LangExtractRecognizer) + assert isinstance(recognizer, LMRecognizer) + + def test_when_config_file_missing_then_raises_file_not_found_error(self): + """Test FileNotFoundError when config file doesn't exist.""" + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + with pytest.raises(FileNotFoundError, match="File not found"): + OllamaLangExtractRecognizer(config_path="/nonexistent/path.yaml") + + def test_when_model_section_missing_then_raises_value_error(self, tmp_path): + """Test ValueError when config missing 'langextract.model' section.""" + import yaml + + config = { + "lm_recognizer": { + "supported_entities": ["PERSON"], + "labels_to_ignore": [], + "enable_generic_consolidation": True, + "min_score": 0.5, + }, + "langextract": { + "prompt_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2", + "examples_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml", + "entity_mappings": {"person": "PERSON"}, + # Missing 'model' section + } + } + + config_file = tmp_path / "bad_config.yaml" + with open(config_file, 'w') as f: + yaml.dump(config, f) + + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + with pytest.raises(ValueError, match="Configuration must contain 'langextract.model'"): + OllamaLangExtractRecognizer(config_path=str(config_file)) + + def test_when_model_id_missing_then_raises_value_error(self, tmp_path): + """Test ValueError when model_id is missing.""" + import yaml + + config = { + "lm_recognizer": { + "supported_entities": ["PERSON"], + "labels_to_ignore": [], + "enable_generic_consolidation": True, + "min_score": 0.5, + }, + "langextract": { + "prompt_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2", + "examples_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml", + "entity_mappings": {"person": "PERSON"}, + "model": { + "model_url": "http://localhost:11434" + # Missing model_id + } + } + } + + config_file = tmp_path / "bad_config.yaml" + with open(config_file, 'w') as f: + yaml.dump(config, f) + + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + with pytest.raises(ValueError, match="Configuration must contain 'langextract.model.model_id'"): + OllamaLangExtractRecognizer(config_path=str(config_file)) + + def test_when_model_url_missing_then_raises_value_error(self, tmp_path): + """Test ValueError when model_url is missing.""" + import yaml + + config = { + "lm_recognizer": { + "supported_entities": ["PERSON"], + "labels_to_ignore": [], + "enable_generic_consolidation": True, + "min_score": 0.5, + }, + "langextract": { + "prompt_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2", + "examples_file": "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_examples.yaml", + "entity_mappings": {"person": "PERSON"}, + "model": { + "model_id": "qwen2.5:1.5b" + # Missing model_url + } + } + } + + config_file = tmp_path / "bad_config.yaml" + with open(config_file, 'w') as f: + yaml.dump(config, f) + + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + with pytest.raises(ValueError, match="Ollama model configuration must contain 'model_url'"): + OllamaLangExtractRecognizer(config_path=str(config_file)) + + +class TestOllamaLangExtractRecognizerAnalyze: + """Test the analyze method with mocked LangExtract.""" + + @pytest.fixture + def mock_recognizer(self, tmp_path): + """Fixture to create a mocked OllamaLangExtractRecognizer.""" + import yaml + + config = create_test_config( + supported_entities=["PERSON", "EMAIL_ADDRESS", "PHONE_NUMBER"], + entity_mappings={ + "person": "PERSON", + "email": "EMAIL_ADDRESS", + "phone": "PHONE_NUMBER" + }, + model_id="qwen2.5:1.5b", + model_url="http://localhost:11434", + temperature=0.0, + min_score=0.5 + ) + + config_file = tmp_path / "test_config.yaml" + with open(config_file, 'w') as f: + yaml.dump(config, f) + + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + return OllamaLangExtractRecognizer(config_path=str(config_file)) + + def test_when_text_contains_person_then_detects_entity(self, mock_recognizer): + """Test analysis detecting a person entity with mocked LangExtract.""" + text = "My name is John Doe" + + # Create mock extraction + mock_extraction = Mock() + mock_extraction.extraction_class = "person" + mock_extraction.extraction_text = "John Doe" + mock_extraction.char_interval = Mock() + mock_extraction.char_interval.start_pos = 11 + mock_extraction.char_interval.end_pos = 19 + mock_extraction.alignment_status = "MATCH_EXACT" + mock_extraction.attributes = {"type": "full_name"} + + mock_result = Mock() + mock_result.extractions = [mock_extraction] + + with patch('langextract.extract', return_value=mock_result): + results = mock_recognizer.analyze(text, entities=["PERSON"]) + + assert len(results) == 1 + assert results[0].entity_type == "PERSON" + assert results[0].start == 11 + assert results[0].end == 19 + assert results[0].score == 0.95 # MATCH_EXACT score + + def test_when_text_contains_multiple_entities_then_detects_all(self, mock_recognizer): + """Test analysis detecting multiple entity types.""" + text = "Contact John Doe at john@example.com or 555-1234" + + # Create mock extractions + person_extraction = Mock() + person_extraction.extraction_class = "person" + person_extraction.extraction_text = "John Doe" + person_extraction.char_interval = Mock(start_pos=8, end_pos=16) + person_extraction.alignment_status = "MATCH_EXACT" + person_extraction.attributes = {} + + email_extraction = Mock() + email_extraction.extraction_class = "email" + email_extraction.extraction_text = "john@example.com" + email_extraction.char_interval = Mock(start_pos=20, end_pos=36) + email_extraction.alignment_status = "MATCH_EXACT" + email_extraction.attributes = {} + + phone_extraction = Mock() + phone_extraction.extraction_class = "phone" + phone_extraction.extraction_text = "555-1234" + phone_extraction.char_interval = Mock(start_pos=40, end_pos=48) + phone_extraction.alignment_status = "MATCH_FUZZY" + phone_extraction.attributes = {} + + mock_result = Mock() + mock_result.extractions = [person_extraction, email_extraction, phone_extraction] + + with patch('langextract.extract', return_value=mock_result): + results = mock_recognizer.analyze(text) + + assert len(results) == 3 + assert results[0].entity_type == "PERSON" + assert results[1].entity_type == "EMAIL_ADDRESS" + assert results[2].entity_type == "PHONE_NUMBER" + assert results[2].score == 0.80 # MATCH_FUZZY score + + def test_when_text_is_empty_then_returns_no_results(self, mock_recognizer): + """Test analysis with empty text returns no results.""" + results = mock_recognizer.analyze("") + assert len(results) == 0 + + results = mock_recognizer.analyze(" ") + assert len(results) == 0 + + def test_when_no_entities_match_then_returns_empty_list(self, mock_recognizer): + """Test analysis when requested entities don't match supported entities.""" + text = "Some text here" + + # Request unsupported entity type + results = mock_recognizer.analyze(text, entities=["CREDIT_CARD"]) + assert len(results) == 0 + + def test_when_entities_requested_then_filters_results(self, mock_recognizer): + """Test that analyze filters results based on requested entities.""" + from presidio_analyzer import RecognizerResult + + text = "Contact John Doe at john@example.com" + + # Create RecognizerResult objects (what _call_llm returns) + person_result = RecognizerResult( + entity_type="PERSON", + start=8, + end=16, + score=0.95 + ) + + email_result = RecognizerResult( + entity_type="EMAIL_ADDRESS", + start=20, + end=36, + score=0.95 + ) + + with patch.object(mock_recognizer, '_call_llm', return_value=[person_result, email_result]): + # Request only PERSON entities - EMAIL_ADDRESS should be filtered out by analyze() + results = mock_recognizer.analyze(text, entities=["PERSON"]) + + # Should only return PERSON, EMAIL_ADDRESS filtered by analyze() method + assert len(results) == 1 + assert results[0].entity_type == "PERSON" + assert results[0].start == 8 + assert results[0].end == 16 + + def test_when_min_score_set_then_filters_low_confidence_results(self, mock_recognizer): + """Test that results below min_score are filtered out.""" + # Set min_score to 0.5 (default in config) + text = "Some text" + + mock_extraction = Mock() + mock_extraction.extraction_class = "person" + mock_extraction.extraction_text = "Some text" + mock_extraction.char_interval = Mock(start_pos=0, end_pos=9) + mock_extraction.alignment_status = "NOT_ALIGNED" # Score 0.60 + mock_extraction.attributes = {} + + mock_result = Mock() + mock_result.extractions = [mock_extraction] + + with patch('langextract.extract', return_value=mock_result): + results = mock_recognizer.analyze(text) + + # NOT_ALIGNED has score 0.60, which is above min_score 0.5 + assert len(results) == 1 + + def test_when_langextract_raises_exception_then_exception_propagates(self, mock_recognizer): + """Test that exceptions from LangExtract propagate to caller.""" + text = "Some text" + + with patch('langextract.extract', side_effect=Exception("LangExtract error")): + with pytest.raises(Exception, match="LangExtract error"): + mock_recognizer.analyze(text) + + def test_when_entity_has_no_mapping_and_consolidation_enabled_then_creates_generic( + self, mock_recognizer + ): + """Test that extractions with unknown entity classes become GENERIC_PII_ENTITY.""" + text = "Some text" + + mock_extraction = Mock() + mock_extraction.extraction_class = "unknown_type" # Not in entity_mappings + mock_extraction.extraction_text = "Some text" + mock_extraction.char_interval = Mock(start_pos=0, end_pos=9) + mock_extraction.alignment_status = "MATCH_EXACT" + mock_extraction.attributes = {} + + mock_result = Mock() + mock_result.extractions = [mock_extraction] + + with patch('langextract.extract', return_value=mock_result): + results = mock_recognizer.analyze(text) + + # Unknown entity type should be consolidated to GENERIC_PII_ENTITY + assert len(results) == 1 + assert results[0].entity_type == "GENERIC_PII_ENTITY" + assert results[0].recognition_metadata["original_entity_type"] == "UNKNOWN_TYPE" + assert results[0].start == 0 + assert results[0].end == 9 + + def test_when_entity_has_no_mapping_and_consolidation_disabled_then_skips(self, tmp_path): + """Test that unknown entities are skipped when consolidation is disabled.""" + import yaml + + # Create config with consolidation disabled + config = create_test_config( + supported_entities=["PERSON", "EMAIL_ADDRESS"], + entity_mappings={"person": "PERSON", "email": "EMAIL_ADDRESS"}, + enable_generic_consolidation=False + ) + + config_file = tmp_path / "test_config.yaml" + with open(config_file, 'w') as f: + yaml.dump(config, f) + + with patch('presidio_analyzer.llm_utils.langextract_helper.lx', + return_value=Mock()): + + from presidio_analyzer.predefined_recognizers.third_party.ollama_langextract_recognizer import OllamaLangExtractRecognizer + recognizer = OllamaLangExtractRecognizer(config_path=str(config_file)) + + text = "Some text" + + mock_extraction = Mock() + mock_extraction.extraction_class = "unknown_type" # Not in entity_mappings + mock_extraction.extraction_text = "Some text" + mock_extraction.char_interval = Mock(start_pos=0, end_pos=9) + mock_extraction.alignment_status = "MATCH_EXACT" + mock_extraction.attributes = {} + + mock_result = Mock() + mock_result.extractions = [mock_extraction] + + with patch('langextract.extract', return_value=mock_result): + results = recognizer.analyze(text) + + # Unknown entity type should be skipped when consolidation is disabled + assert len(results) == 0 diff --git a/presidio-analyzer/tests/test_prompt_loader.py b/presidio-analyzer/tests/test_prompt_loader.py new file mode 100644 index 0000000000..7ff2132418 --- /dev/null +++ b/presidio-analyzer/tests/test_prompt_loader.py @@ -0,0 +1,134 @@ +"""Tests for llm_utils.prompt_loader module.""" +import pytest +from pathlib import Path +from presidio_analyzer.llm_utils.prompt_loader import ( + load_prompt_file, + render_jinja_template, +) + + +class TestLoadPromptFile: + """Tests for load_prompt_file function.""" + + def test_when_prompt_file_exists_then_loads_content(self): + """Test loading an existing prompt template file from conf directory.""" + # Load the actual langextract prompt file using repo-root-relative path + result = load_prompt_file( + "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2" + ) + + assert result is not None + assert len(result) > 0 + assert "ENTITY TYPES TO EXTRACT" in result + assert "Extract personally identifiable information" in result + + def test_when_prompt_file_missing_then_raises_file_not_found_error(self): + """Test that missing prompt file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + load_prompt_file("nonexistent_prompt.jinja2") + + +class TestRenderJinjaTemplate: + """Tests for render_jinja_template function.""" + + def test_when_template_has_variables_then_renders_correctly(self): + """Test rendering a template with variables.""" + template = "Hello {{ name }}, your score is {{ score }}" + result = render_jinja_template(template, name="Alice", score=95) + + assert result == "Hello Alice, your score is 95" + + def test_when_template_has_list_iteration_then_renders_correctly(self): + """Test rendering a template with for loop.""" + template = """Entities: +{% for entity in entities %} +- {{ entity }} +{% endfor %}""" + result = render_jinja_template(template, entities=["PERSON", "EMAIL", "PHONE"]) + + assert "- PERSON" in result + assert "- EMAIL" in result + assert "- PHONE" in result + + def test_when_template_has_conditionals_then_renders_correctly(self): + """Test rendering a template with if/else conditions.""" + template = """{% if enabled %}Feature is ON{% else %}Feature is OFF{% endif %}""" + + result_enabled = render_jinja_template(template, enabled=True) + assert result_enabled.strip() == "Feature is ON" + + result_disabled = render_jinja_template(template, enabled=False) + assert result_disabled.strip() == "Feature is OFF" + + def test_when_template_has_no_variables_then_returns_original(self): + """Test rendering a static template without variables.""" + template = "This is a static prompt with no variables" + result = render_jinja_template(template) + + assert result == template + + def test_when_template_has_complex_expressions_then_renders_correctly(self): + """Test rendering with complex Jinja2 expressions.""" + template = """Total: {{ items | length }} +{% for item in items %} +{{ loop.index }}. {{ item.name }}: {{ item.value }} +{% endfor %}""" + items = [ + {"name": "Item1", "value": 10}, + {"name": "Item2", "value": 20} + ] + result = render_jinja_template(template, items=items) + + assert "Total: 2" in result + assert "1. Item1: 10" in result + assert "2. Item2: 20" in result + + def test_when_template_has_filters_then_applies_correctly(self): + """Test rendering with Jinja2 filters.""" + template = "{{ text | upper }} and {{ number | int }}" + result = render_jinja_template(template, text="hello", number="42") + + assert "HELLO" in result + assert "42" in result + + def test_when_missing_required_variable_then_renders_empty(self): + """Test that missing variables render as empty strings (Jinja2 default).""" + template = "Hello {{ undefined_var }}" + result = render_jinja_template(template) + + # Jinja2 renders undefined variables as empty strings by default + assert result == "Hello " + + def test_when_template_has_whitespace_control_then_handles_correctly(self): + """Test Jinja2 whitespace control with {%- and -%}.""" + template = """ENTITIES: +{%- for entity in entities %} +- {{ entity }} +{%- endfor %} +END""" + result = render_jinja_template(template, entities=["PERSON", "EMAIL"]) + + assert "ENTITIES:" in result + assert "- PERSON" in result + assert "- EMAIL" in result + assert "END" in result + + def test_when_rendering_actual_langextract_template_then_works(self): + """Test rendering the actual LangExtract prompt template.""" + # Load the actual template using repo-root-relative path + template = load_prompt_file( + "presidio-analyzer/presidio_analyzer/conf/langextract_prompts/default_pii_phi_prompt.j2" + ) + + # Render with typical parameters + result = render_jinja_template( + template, + supported_entities=["PERSON", "EMAIL_ADDRESS"], + enable_generic_consolidation=True, + labels_to_ignore=["metadata"] + ) + + assert "PERSON" in result + assert "EMAIL_ADDRESS" in result + # Check for consolidation entity if enabled + assert len(result) > 0