diff --git a/docs/ahds_integration.md b/docs/ahds_integration.md index f6390a8f29..7e2eea3965 100644 --- a/docs/ahds_integration.md +++ b/docs/ahds_integration.md @@ -96,15 +96,72 @@ print(f"Anonymized: {result.text}") ## Authentication -The AHDS de-identification service integration uses Azure's `DefaultAzureCredential`, which supports multiple authentication methods: +The AHDS de-identification service integration uses Azure authentication with a secure-by-default approach: + +### Production Mode (Default) + +By default, the integration uses a restricted credential chain that only includes: + +1. **EnvironmentCredential**: Service Principal via environment variables +2. **WorkloadIdentityCredential**: Workload Identity (Kubernetes) +3. **ManagedIdentityCredential**: Managed Identity (Azure services) + +This credential chain is more secure as it excludes interactive browser logins and developer credentials, making it suitable for production deployments. + +### Development Mode + +For local development, you can enable `DefaultAzureCredential` by setting the environment variable: + +```bash +export ENV=development +``` + +In development mode, the integration uses `DefaultAzureCredential`, which supports additional authentication methods: 1. Environment variables (Service Principal) -2. Managed Identity (when running on Azure) -3. Azure CLI (`az login`) -4. Visual Studio/VS Code credentials -5. Interactive browser login +2. Workload Identity +3. Managed Identity +4. Azure CLI (`az login`) +5. Azure PowerShell +6. Visual Studio/VS Code credentials +7. Interactive browser login + +### Local Development Setup + +For local development with Azure CLI authentication: + +```bash +# Set development mode +export ENV=development + +# Login with Azure CLI +az login + +# Set your AHDS endpoint +export AHDS_ENDPOINT="https://your-ahds-endpoint.api.eus001.deid.azure.com" + +# Now you can run Presidio with AHDS integration +python your_script.py +``` + +### Production Deployment Recommendations + +For production deployments, we recommend: + +1. **Do not set ENV=development** (use default production mode) +2. **Use Managed Identity** when running on Azure services (AKS, App Service, Functions, etc.) +3. **Use Workload Identity** for Kubernetes deployments +4. **Use Service Principal** with environment variables for other scenarios: + ```bash + export AZURE_CLIENT_ID="" + export AZURE_CLIENT_SECRET="" + export AZURE_TENANT_ID="" + ``` + +### Environment Variables Summary -For production deployments, we recommend using Service Principal or Managed Identity. +- `AHDS_ENDPOINT`: Your AHDS de-identification service endpoint (required) +- `ENV`: Set to `development` for local dev, omit for production (optional, default: production mode) ## Troubleshooting diff --git a/docs/samples/python/ahds/index.md b/docs/samples/python/ahds/index.md index e8ee3a7caf..ebadde0932 100644 --- a/docs/samples/python/ahds/index.md +++ b/docs/samples/python/ahds/index.md @@ -19,11 +19,25 @@ and health (PHI) information. A list of all supported entities can be found in t [official documentation](https://learn.microsoft.com/en-us/azure/healthcare-apis/deidentification/overview). ## Prerequisites -To use AHDS De-Identification with Preisido, an Azure De-Identification Service resource should +To use AHDS De-Identification with Presidio, an Azure De-Identification Service resource should first be created under an Azure subscription. Follow the [official documentation](https://learn.microsoft.com/en-us/azure/healthcare-apis/deidentification/quickstart) for instructions. The endpoint, generated once the resource is created, will be used when integrating with AHDS De-Identification, using a Presidio remote recognizer. +### Authentication Setup + +The integration uses a secure-by-default authentication approach: + +**Production Mode (Default)**: Uses a restricted credential chain (EnvironmentCredential, WorkloadIdentityCredential, ManagedIdentityCredential) + +**Development Mode**: Set `ENV=development` to use DefaultAzureCredential for local development with Azure CLI: +```bash +export ENV=development +az login +``` + +For more details, see the [AHDS Integration Authentication documentation](../../../ahds_integration.md#authentication). + ## Azure Health Data Services de-identification Recognizer [The implementation of a `AzureHealthDeid` recognizer can be found here](https://github.com/microsoft/presidio/blob/main/presidio-analyzer/presidio_analyzer/predefined_recognizers/ahds_recognizer.py). diff --git a/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ahds_recognizer.py b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ahds_recognizer.py index 19961445b2..277036aceb 100644 --- a/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ahds_recognizer.py +++ b/presidio-analyzer/presidio_analyzer/predefined_recognizers/third_party/ahds_recognizer.py @@ -8,14 +8,23 @@ DeidentificationOperationType, PhiCategory, ) - from azure.identity import DefaultAzureCredential, ManagedIdentityCredential + from azure.identity import ( + ChainedTokenCredential, + DefaultAzureCredential, + EnvironmentCredential, + ManagedIdentityCredential, + WorkloadIdentityCredential, + ) except ImportError: DeidentificationClient = None DeidentificationContent = None DeidentificationOperationType = None PhiCategory = None + ChainedTokenCredential = None DefaultAzureCredential = None + EnvironmentCredential = None ManagedIdentityCredential = None + WorkloadIdentityCredential = None from presidio_analyzer import AnalysisExplanation, RecognizerResult, RemoteRecognizer from presidio_analyzer.nlp_engine import NlpArtifacts @@ -64,11 +73,16 @@ def __init__( "Please install azure-health-deidentification and azure-identity." ) - credential = None - if os.getenv('ENV') == 'production': - credential = ManagedIdentityCredential() + # Use ChainedTokenCredential for production (secure by default) + # Only use DefaultAzureCredential in development mode + if os.getenv('ENV') == 'development': + credential = DefaultAzureCredential() # CodeQL [SM05139] OK for dev else: - credential = DefaultAzureCredential() + credential = ChainedTokenCredential( + EnvironmentCredential(), + WorkloadIdentityCredential(), + ManagedIdentityCredential() + ) client = DeidentificationClient(endpoint, credential) self.deid_client = client diff --git a/presidio-analyzer/tests/test_ahds_recognizer_credential_selection.py b/presidio-analyzer/tests/test_ahds_recognizer_credential_selection.py index 4bfc655b1d..26bce9c0f9 100644 --- a/presidio-analyzer/tests/test_ahds_recognizer_credential_selection.py +++ b/presidio-analyzer/tests/test_ahds_recognizer_credential_selection.py @@ -17,6 +17,9 @@ def mock_azure_modules(): # Mock the classes and enums we need mock_deid_client = MagicMock() mock_default_cred = MagicMock() + mock_chained_cred = MagicMock() + mock_env_cred = MagicMock() + mock_workload_cred = MagicMock() mock_managed_cred = MagicMock() # Mock PhiCategory enum for _get_supported_entities @@ -25,11 +28,17 @@ def mock_azure_modules(): with patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.DeidentificationClient', mock_deid_client), \ patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.DefaultAzureCredential', mock_default_cred), \ + patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.ChainedTokenCredential', mock_chained_cred), \ + patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.EnvironmentCredential', mock_env_cred), \ + patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.WorkloadIdentityCredential', mock_workload_cred), \ patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.ManagedIdentityCredential', mock_managed_cred), \ patch('presidio_analyzer.predefined_recognizers.third_party.ahds_recognizer.PhiCategory', mock_phi_category): yield { 'DeidentificationClient': mock_deid_client, 'DefaultAzureCredential': mock_default_cred, + 'ChainedTokenCredential': mock_chained_cred, + 'EnvironmentCredential': mock_env_cred, + 'WorkloadIdentityCredential': mock_workload_cred, 'ManagedIdentityCredential': mock_managed_cred, 'PhiCategory': mock_phi_category, } @@ -39,7 +48,7 @@ class TestAHDSRecognizerCredentialSelection: """Test credential selection based on environment variables.""" def test_uses_default_credential_in_development_environment(self, mock_azure_modules): - """Test that DefaultAzureCredential is used when ENV is not production.""" + """Test that DefaultAzureCredential is used when ENV=development.""" with patch.dict(os.environ, {'ENV': 'development', 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): mock_client_instance = MagicMock() mock_azure_modules['DeidentificationClient'].return_value = mock_client_instance @@ -48,32 +57,36 @@ def test_uses_default_credential_in_development_environment(self, mock_azure_mod # Verify DefaultAzureCredential was called mock_azure_modules['DefaultAzureCredential'].assert_called_once() - mock_azure_modules['ManagedIdentityCredential'].assert_not_called() + mock_azure_modules['ChainedTokenCredential'].assert_not_called() # Verify DeidentificationClient was initialized with DefaultAzureCredential instance mock_azure_modules['DeidentificationClient'].assert_called_once() call_args = mock_azure_modules['DeidentificationClient'].call_args assert call_args[0][1] == mock_azure_modules['DefaultAzureCredential'].return_value - def test_uses_managed_identity_in_production_environment(self, mock_azure_modules): - """Test that ManagedIdentityCredential is used when ENV=production.""" + def test_uses_chained_credential_in_production_environment(self, mock_azure_modules): + """Test that ChainedTokenCredential is used when ENV=production.""" with patch.dict(os.environ, {'ENV': 'production', 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): mock_client_instance = MagicMock() mock_azure_modules['DeidentificationClient'].return_value = mock_client_instance recognizer = AzureHealthDeidRecognizer() - # Verify ManagedIdentityCredential was called + # Verify ChainedTokenCredential was called with the right credentials + mock_azure_modules['EnvironmentCredential'].assert_called_once() + mock_azure_modules['WorkloadIdentityCredential'].assert_called_once() mock_azure_modules['ManagedIdentityCredential'].assert_called_once() + mock_azure_modules['ChainedTokenCredential'].assert_called_once() mock_azure_modules['DefaultAzureCredential'].assert_not_called() - # Verify DeidentificationClient was initialized with ManagedIdentityCredential instance - mock_azure_modules['DeidentificationClient'].assert_called_once() - call_args = mock_azure_modules['DeidentificationClient'].call_args - assert call_args[0][1] == mock_azure_modules['ManagedIdentityCredential'].return_value + # Verify ChainedTokenCredential was called with correct order + call_args = mock_azure_modules['ChainedTokenCredential'].call_args[0] + assert call_args[0] == mock_azure_modules['EnvironmentCredential'].return_value + assert call_args[1] == mock_azure_modules['WorkloadIdentityCredential'].return_value + assert call_args[2] == mock_azure_modules['ManagedIdentityCredential'].return_value - def test_uses_default_credential_when_env_var_not_set(self, mock_azure_modules): - """Test that DefaultAzureCredential is used when ENV is not set.""" + def test_uses_chained_credential_when_env_var_not_set(self, mock_azure_modules): + """Test that ChainedTokenCredential is used when ENV is not set (default).""" # Ensure ENV is not set env_without_presidio = {k: v for k, v in os.environ.items() if k != 'ENV'} env_without_presidio['AHDS_ENDPOINT'] = 'https://test.endpoint.com' @@ -84,25 +97,13 @@ def test_uses_default_credential_when_env_var_not_set(self, mock_azure_modules): recognizer = AzureHealthDeidRecognizer() - # Verify DefaultAzureCredential was called - mock_azure_modules['DefaultAzureCredential'].assert_called_once() - mock_azure_modules['ManagedIdentityCredential'].assert_not_called() - - def test_uses_managed_identity_only_for_production_value(self, mock_azure_modules): - """Test that ManagedIdentityCredential is used only when ENV='production'.""" - with patch.dict(os.environ, {'ENV': 'production', 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): - mock_client_instance = MagicMock() - mock_azure_modules['DeidentificationClient'].return_value = mock_client_instance - - recognizer = AzureHealthDeidRecognizer() - - # Verify ManagedIdentityCredential was called - mock_azure_modules['ManagedIdentityCredential'].assert_called_once() + # Verify ChainedTokenCredential was called (secure by default) + mock_azure_modules['ChainedTokenCredential'].assert_called_once() mock_azure_modules['DefaultAzureCredential'].assert_not_called() - def test_uses_default_credential_for_non_production_environment_values(self, mock_azure_modules): - """Test that DefaultAzureCredential is used for any ENV value other than 'production'.""" - test_environments = ['dev', 'development', 'staging', 'test', 'local', 'PRODUCTION'] + def test_uses_chained_credential_for_non_development_environment_values(self, mock_azure_modules): + """Test that ChainedTokenCredential is used for any ENV value other than 'development'.""" + test_environments = ['prod', 'production', 'staging', 'test', 'local', 'DEVELOPMENT'] for env_value in test_environments: with patch.dict(os.environ, {'ENV': env_value, 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): @@ -111,14 +112,17 @@ def test_uses_default_credential_for_non_production_environment_values(self, moc # Reset mocks mock_azure_modules['DefaultAzureCredential'].reset_mock() + mock_azure_modules['ChainedTokenCredential'].reset_mock() + mock_azure_modules['EnvironmentCredential'].reset_mock() + mock_azure_modules['WorkloadIdentityCredential'].reset_mock() mock_azure_modules['ManagedIdentityCredential'].reset_mock() mock_azure_modules['DeidentificationClient'].reset_mock() recognizer = AzureHealthDeidRecognizer() - # Verify DefaultAzureCredential was called for this environment - mock_azure_modules['DefaultAzureCredential'].assert_called_once(), f"Failed for environment: {env_value}" - mock_azure_modules['ManagedIdentityCredential'].assert_not_called(), f"ManagedIdentityCredential should not be called for environment: {env_value}" + # Verify ChainedTokenCredential was called for this environment + mock_azure_modules['ChainedTokenCredential'].assert_called_once(), f"Failed for environment: {env_value}" + mock_azure_modules['DefaultAzureCredential'].assert_not_called(), f"DefaultAzureCredential should not be called for environment: {env_value}" def test_respects_provided_client_parameter(self, mock_azure_modules): """Test that when a client is provided, no credential creation occurs.""" diff --git a/presidio-anonymizer/presidio_anonymizer/operators/ahds_surrogate.py b/presidio-anonymizer/presidio_anonymizer/operators/ahds_surrogate.py index f21588fe44..2cee625d6a 100644 --- a/presidio-anonymizer/presidio_anonymizer/operators/ahds_surrogate.py +++ b/presidio-anonymizer/presidio_anonymizer/operators/ahds_surrogate.py @@ -16,14 +16,23 @@ TaggedPhiEntities, TextEncodingType, ) - from azure.identity import DefaultAzureCredential, ManagedIdentityCredential + from azure.identity import ( + ChainedTokenCredential, + DefaultAzureCredential, + EnvironmentCredential, + ManagedIdentityCredential, + WorkloadIdentityCredential, + ) except ImportError: DeidentificationClient = None DeidentificationContent = None DeidentificationCustomizationOptions = None DeidentificationResult = None + ChainedTokenCredential = None DefaultAzureCredential = None + EnvironmentCredential = None ManagedIdentityCredential = None + WorkloadIdentityCredential = None SimplePhiEntity = None TaggedPhiEntities = None PhiCategory = None @@ -240,11 +249,16 @@ def operate(self, text: str = None, params: Dict = None) -> str: # Convert analyzer results to AHDS tagged entities tagged_entities = self._convert_to_tagged_entities(entities) - credential = None - if os.getenv('ENV') == 'production': - credential = ManagedIdentityCredential() + # Use ChainedTokenCredential for production + # Only use DefaultAzureCredential in development mode + if os.getenv('ENV') == 'development': + credential = DefaultAzureCredential() # CodeQL [SM05139] OK for dev else: - credential = DefaultAzureCredential() + credential = ChainedTokenCredential( + EnvironmentCredential(), + WorkloadIdentityCredential(), + ManagedIdentityCredential() + ) client = DeidentificationClient(endpoint, credential, api_version="2025-07-15-preview") diff --git a/presidio-anonymizer/tests/test_ahds_surrogate.py b/presidio-anonymizer/tests/test_ahds_surrogate.py index 483d49a909..eab8c5c3fd 100644 --- a/presidio-anonymizer/tests/test_ahds_surrogate.py +++ b/presidio-anonymizer/tests/test_ahds_surrogate.py @@ -2,7 +2,7 @@ import os import importlib -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import dotenv @@ -310,8 +310,14 @@ def test_service_error_handling(import_modules): {'entity_type': 'PERSON', 'start': 0, 'end': 8, 'text': 'John Doe', 'score': 0.9} ] } - with pytest.raises(InvalidParamError): - operator.operate("John Doe is a patient", params) + # Mock Azure env vars so credentials can initialize + with patch.dict(os.environ, { + 'AZURE_TENANT_ID': 'test-tenant-id', + 'AZURE_CLIENT_ID': 'test-client-id', + 'AZURE_FEDERATED_TOKEN_FILE': '/tmp/test-token-file' + }): + with pytest.raises(InvalidParamError): + operator.operate("John Doe is a patient", params) def test_operator_type(import_modules): operator = AHDSSurrogate() diff --git a/presidio-anonymizer/tests/test_ahds_surrogate_credential_selection.py b/presidio-anonymizer/tests/test_ahds_surrogate_credential_selection.py index 9b7b36bc15..b92e1e01b1 100644 --- a/presidio-anonymizer/tests/test_ahds_surrogate_credential_selection.py +++ b/presidio-anonymizer/tests/test_ahds_surrogate_credential_selection.py @@ -18,6 +18,9 @@ def mock_azure_modules(): # Mock the classes and enums we need mock_deid_client = MagicMock() mock_default_cred = MagicMock() + mock_chained_cred = MagicMock() + mock_env_cred = MagicMock() + mock_workload_cred = MagicMock() mock_managed_cred = MagicMock() # Mock TextEncodingType enum @@ -35,6 +38,9 @@ def mock_azure_modules(): with patch('presidio_anonymizer.operators.ahds_surrogate.DeidentificationClient', mock_deid_client), \ patch('presidio_anonymizer.operators.ahds_surrogate.DefaultAzureCredential', mock_default_cred), \ + patch('presidio_anonymizer.operators.ahds_surrogate.ChainedTokenCredential', mock_chained_cred), \ + patch('presidio_anonymizer.operators.ahds_surrogate.EnvironmentCredential', mock_env_cred), \ + patch('presidio_anonymizer.operators.ahds_surrogate.WorkloadIdentityCredential', mock_workload_cred), \ patch('presidio_anonymizer.operators.ahds_surrogate.ManagedIdentityCredential', mock_managed_cred), \ patch('presidio_anonymizer.operators.ahds_surrogate.TextEncodingType', mock_text_encoding_type), \ patch('presidio_anonymizer.operators.ahds_surrogate.DeidentificationOperationType', mock_operation_type), \ @@ -44,6 +50,9 @@ def mock_azure_modules(): yield { 'DeidentificationClient': mock_deid_client, 'DefaultAzureCredential': mock_default_cred, + 'ChainedTokenCredential': mock_chained_cred, + 'EnvironmentCredential': mock_env_cred, + 'WorkloadIdentityCredential': mock_workload_cred, 'ManagedIdentityCredential': mock_managed_cred, 'TextEncodingType': mock_text_encoding_type, 'DeidentificationOperationType': mock_operation_type, @@ -57,7 +66,7 @@ class TestAHDSCredentialSelection: """Test credential selection based on environment variables.""" def test_uses_default_credential_in_development_environment(self, mock_azure_modules): - """Test that DefaultAzureCredential is used when ENV is not production.""" + """Test that DefaultAzureCredential is used when ENV=development.""" operator = AHDSSurrogate() with patch.dict(os.environ, {'ENV': 'development', 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): @@ -72,15 +81,15 @@ def test_uses_default_credential_in_development_environment(self, mock_azure_mod # Verify DefaultAzureCredential was called mock_azure_modules['DefaultAzureCredential'].assert_called_once() - mock_azure_modules['ManagedIdentityCredential'].assert_not_called() + mock_azure_modules['ChainedTokenCredential'].assert_not_called() # Verify DeidentificationClient was initialized with DefaultAzureCredential instance mock_azure_modules['DeidentificationClient'].assert_called_once() call_args = mock_azure_modules['DeidentificationClient'].call_args assert call_args[0][1] == mock_azure_modules['DefaultAzureCredential'].return_value - def test_uses_managed_identity_in_production_environment(self, mock_azure_modules): - """Test that ManagedIdentityCredential is used when ENV=production.""" + def test_uses_chained_credential_in_production_environment(self, mock_azure_modules): + """Test that ChainedTokenCredential is used when ENV=production.""" operator = AHDSSurrogate() with patch.dict(os.environ, {'ENV': 'production', 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): @@ -93,17 +102,21 @@ def test_uses_managed_identity_in_production_environment(self, mock_azure_module result = operator.operate("test text", {"entities": []}) - # Verify ManagedIdentityCredential was called + # Verify ChainedTokenCredential was called with the right credentials + mock_azure_modules['EnvironmentCredential'].assert_called_once() + mock_azure_modules['WorkloadIdentityCredential'].assert_called_once() mock_azure_modules['ManagedIdentityCredential'].assert_called_once() + mock_azure_modules['ChainedTokenCredential'].assert_called_once() mock_azure_modules['DefaultAzureCredential'].assert_not_called() - # Verify DeidentificationClient was initialized with ManagedIdentityCredential instance - mock_azure_modules['DeidentificationClient'].assert_called_once() - call_args = mock_azure_modules['DeidentificationClient'].call_args - assert call_args[0][1] == mock_azure_modules['ManagedIdentityCredential'].return_value + # Verify ChainedTokenCredential was called with correct order + call_args = mock_azure_modules['ChainedTokenCredential'].call_args[0] + assert call_args[0] == mock_azure_modules['EnvironmentCredential'].return_value + assert call_args[1] == mock_azure_modules['WorkloadIdentityCredential'].return_value + assert call_args[2] == mock_azure_modules['ManagedIdentityCredential'].return_value - def test_uses_default_credential_when_env_var_not_set(self, mock_azure_modules): - """Test that DefaultAzureCredential is used when ENV is not set.""" + def test_uses_chained_credential_when_env_var_not_set(self, mock_azure_modules): + """Test that ChainedTokenCredential is used when ENV is not set (default).""" operator = AHDSSurrogate() # Ensure ENV is not set @@ -120,33 +133,15 @@ def test_uses_default_credential_when_env_var_not_set(self, mock_azure_modules): result = operator.operate("test text", {"entities": []}) - # Verify DefaultAzureCredential was called - mock_azure_modules['DefaultAzureCredential'].assert_called_once() - mock_azure_modules['ManagedIdentityCredential'].assert_not_called() - - def test_uses_managed_identity_only_for_production_value(self, mock_azure_modules): - """Test that ManagedIdentityCredential is used only when ENV='production'.""" - operator = AHDSSurrogate() - - with patch.dict(os.environ, {'ENV': 'production', 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): - with patch.object(operator, '_convert_to_tagged_entities', return_value=[]): - mock_client_instance = MagicMock() - mock_result = MagicMock() - mock_result.output_text = "anonymized text" - mock_client_instance.deidentify_text.return_value = mock_client_instance - mock_azure_modules['DeidentificationClient'].return_value = mock_client_instance - - result = operator.operate("test text", {"entities": []}) - - # Verify ManagedIdentityCredential was called - mock_azure_modules['ManagedIdentityCredential'].assert_called_once() + # Verify ChainedTokenCredential was called (secure by default) + mock_azure_modules['ChainedTokenCredential'].assert_called_once() mock_azure_modules['DefaultAzureCredential'].assert_not_called() - def test_uses_default_credential_for_non_production_environment_values(self, mock_azure_modules): - """Test that DefaultAzureCredential is used for any ENV value other than 'production'.""" + def test_uses_chained_credential_for_non_development_environment_values(self, mock_azure_modules): + """Test that ChainedTokenCredential is used for any ENV value other than 'development'.""" operator = AHDSSurrogate() - test_environments = ['dev', 'development', 'staging', 'test', 'local', 'PRODUCTION'] + test_environments = ['prod', 'production', 'staging', 'test', 'local', 'DEVELOPMENT'] for env_value in test_environments: with patch.dict(os.environ, {'ENV': env_value, 'AHDS_ENDPOINT': 'https://test.endpoint.com'}): @@ -159,11 +154,14 @@ def test_uses_default_credential_for_non_production_environment_values(self, moc # Reset mocks mock_azure_modules['DefaultAzureCredential'].reset_mock() + mock_azure_modules['ChainedTokenCredential'].reset_mock() + mock_azure_modules['EnvironmentCredential'].reset_mock() + mock_azure_modules['WorkloadIdentityCredential'].reset_mock() mock_azure_modules['ManagedIdentityCredential'].reset_mock() mock_azure_modules['DeidentificationClient'].reset_mock() result = operator.operate("test text", {"entities": []}) - # Verify DefaultAzureCredential was called for this environment - mock_azure_modules['DefaultAzureCredential'].assert_called_once(), f"Failed for environment: {env_value}" - mock_azure_modules['ManagedIdentityCredential'].assert_not_called(), f"ManagedIdentityCredential should not be called for environment: {env_value}" \ No newline at end of file + # Verify ChainedTokenCredential was called for this environment + mock_azure_modules['ChainedTokenCredential'].assert_called_once(), f"Failed for environment: {env_value}" + mock_azure_modules['DefaultAzureCredential'].assert_not_called(), f"DefaultAzureCredential should not be called for environment: {env_value}" \ No newline at end of file