diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/client.py b/ingestion/src/metadata/ingestion/source/database/databricks/client.py index 232c3bc7a332..8577b477d991 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/client.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/client.py @@ -30,8 +30,8 @@ ) from metadata.ingestion.ometa.client import APIError from metadata.ingestion.source.database.databricks.queries import ( - DATABRICKS_GET_COLUMN_LINEAGE_FOR_JOB, - DATABRICKS_GET_TABLE_LINEAGE_FOR_JOB, + DATABRICKS_GET_COLUMN_LINEAGE, + DATABRICKS_GET_TABLE_LINEAGE, ) from metadata.utils.constants import QUERY_WITH_DBT, QUERY_WITH_OM_VERSION from metadata.utils.helpers import datetime_to_ts @@ -73,10 +73,10 @@ def __init__( "Content-Type": "application/json", } self.api_timeout = self.config.connectionTimeout or 120 - self._job_table_lineage_executed: bool = False - self.job_table_lineage: dict[str, list[dict[str, str]]] = defaultdict(list) - self._job_column_lineage_executed: bool = False - self.job_column_lineage: dict[ + self._entity_table_lineage_executed: bool = False + self.entity_table_lineage: dict[str, list[dict[str, str]]] = defaultdict(list) + self._entity_column_lineage_executed: bool = False + self.entity_column_lineage: dict[ str, dict[Tuple[str, str], list[Tuple[str, str]]] ] = defaultdict(lambda: defaultdict(list)) self.engine = engine @@ -101,15 +101,13 @@ def test_lineage_query(self) -> None: with self.engine.connect() as connection: test_table_lineage = connection.execute( text( - DATABRICKS_GET_TABLE_LINEAGE_FOR_JOB.format( - lookback_days=lookback_days - ) + DATABRICKS_GET_TABLE_LINEAGE.format(lookback_days=lookback_days) + " LIMIT 1" ) ) test_column_lineage = connection.execute( text( - DATABRICKS_GET_COLUMN_LINEAGE_FOR_JOB.format( + DATABRICKS_GET_COLUMN_LINEAGE.format( lookback_days=lookback_days ) + " LIMIT 1" @@ -122,8 +120,8 @@ def test_lineage_query(self) -> None: except Exception as exc: logger.debug(f"Error testing lineage queries: {traceback.format_exc()}") raise DatabricksClientException( - f"Failed to test lineage queries Make sure you have access" - "to the tables table_lineage and column_lineage: {exc}" + f"Failed to test lineage queries. Make sure you have access " + f"to the tables table_lineage and column_lineage: {exc}" ) def _run_query_paginator(self, data, result, end_time, response): @@ -280,40 +278,41 @@ def get_job_runs(self, job_id) -> List[dict]: logger.debug(traceback.format_exc()) logger.error(exc) - def get_table_lineage(self, job_id: str) -> List[dict[str, str]]: + def get_table_lineage(self, entity_id: str) -> List[dict[str, str]]: """ - Method returns table lineage for a job by the specified job_id. - On first call, eagerly fetches ALL job lineage in bulk for optimal performance. + Method returns table lineage for a job or pipeline by the specified entity_id. + On first call, eagerly fetches ALL lineage in bulk for optimal performance. """ try: - if not self._job_table_lineage_executed: + if not self._entity_table_lineage_executed: logger.info( - "First lineage request detected - performing bulk lineage fetch for all jobs" + "First lineage request detected - performing bulk lineage fetch for all entities" ) self.cache_lineage() - # Return cached lineage for this specific job - return self.job_table_lineage.get(str(job_id), []) + return self.entity_table_lineage.get(str(entity_id), []) except Exception as exc: logger.debug( - f"Error getting table lineage for job {job_id} due to {traceback.format_exc()}" + f"Error getting table lineage for {entity_id} due to {traceback.format_exc()}" ) logger.error(exc) return [] def get_column_lineage( - self, job_id: str, TableKey: Tuple[str, str] + self, entity_id: str, TableKey: Tuple[str, str] ) -> List[Tuple[str, str]]: """ - Method returns column lineage for a job by the specified job_id and table key + Method returns column lineage for a job or pipeline by the specified entity_id and table key """ try: - if not self._job_column_lineage_executed: - logger.debug("Job column lineage not found. Executing cache_lineage...") + if not self._entity_column_lineage_executed: + logger.debug( + "Entity column lineage not found. Executing cache_lineage..." + ) self.cache_lineage() - return self.job_column_lineage.get(str(job_id), {}).get(TableKey, []) + return self.entity_column_lineage.get(str(entity_id), {}).get(TableKey, []) except Exception as exc: logger.debug( @@ -338,16 +337,16 @@ def run_lineage_query(self, query: str) -> List[dict]: def cache_lineage(self): """ - Method caches table and column lineage for ALL jobs. + Method caches table and column lineage for ALL jobs and pipelines. """ lookback_days = getattr(self.config, "lineageLookBackDays", 90) logger.info(f"Caching table lineage (lookback: {lookback_days} days)") table_lineage = self.run_lineage_query( - DATABRICKS_GET_TABLE_LINEAGE_FOR_JOB.format(lookback_days=lookback_days) + DATABRICKS_GET_TABLE_LINEAGE.format(lookback_days=lookback_days) ) for row in table_lineage or []: try: - self.job_table_lineage[row.job_id].append( + self.entity_table_lineage[row.entity_id].append( { "source_table_full_name": row.source_table_full_name, "target_table_full_name": row.target_table_full_name, @@ -358,13 +357,11 @@ def cache_lineage(self): f"Error parsing row: {row} due to {traceback.format_exc()}" ) continue - self._job_table_lineage_executed = True + self._entity_table_lineage_executed = True - # Not every job has column lineage, so we need to check if the job exists in the column_lineage table - # we will cache the column lineage for jobs that have column lineage logger.info(f"Caching column lineage (lookback: {lookback_days} days)") column_lineage = self.run_lineage_query( - DATABRICKS_GET_COLUMN_LINEAGE_FOR_JOB.format(lookback_days=lookback_days) + DATABRICKS_GET_COLUMN_LINEAGE.format(lookback_days=lookback_days) ) for row in column_lineage or []: try: @@ -377,14 +374,14 @@ def cache_lineage(self): row.target_column_name, ) - self.job_column_lineage[row.job_id][table_key].append(column_pair) + self.entity_column_lineage[row.entity_id][table_key].append(column_pair) except Exception as exc: logger.debug( f"Error parsing row: {row} due to {traceback.format_exc()}" ) continue - self._job_column_lineage_executed = True + self._entity_column_lineage_executed = True logger.debug("Table and column lineage caching completed.") def get_pipeline_details(self, pipeline_id: str) -> Optional[dict]: diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/queries.py b/ingestion/src/metadata/ingestion/source/database/databricks/queries.py index 371c645131c2..d421ea4b9a11 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/queries.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/queries.py @@ -87,28 +87,28 @@ DATABRICKS_DDL = "SHOW CREATE TABLE `{table_name}`" -DATABRICKS_GET_TABLE_LINEAGE_FOR_JOB = """ -SELECT - entity_id AS job_id, +DATABRICKS_GET_TABLE_LINEAGE = """ +SELECT + entity_id, source_table_full_name, target_table_full_name FROM system.access.table_lineage -WHERE entity_type = 'JOB' +WHERE entity_type IN ('JOB', 'PIPELINE') AND event_time >= current_date() - INTERVAL {lookback_days} DAYS AND source_table_full_name IS NOT NULL AND target_table_full_name IS NOT NULL GROUP BY entity_id, source_table_full_name, target_table_full_name """ -DATABRICKS_GET_COLUMN_LINEAGE_FOR_JOB = """ +DATABRICKS_GET_COLUMN_LINEAGE = """ SELECT - entity_id as job_id, + entity_id, source_table_full_name, source_column_name, target_table_full_name, target_column_name FROM system.access.column_lineage -WHERE entity_type = 'JOB' +WHERE entity_type IN ('JOB', 'PIPELINE') AND event_time >= current_date() - INTERVAL {lookback_days} DAYS AND source_table_full_name IS NOT NULL AND target_table_full_name IS NOT NULL diff --git a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py index 85b482110e2d..5439c8aaaa24 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py @@ -1225,15 +1225,12 @@ def yield_pipeline_lineage_details( # Works automatically - no configuration required! yield from self._yield_kafka_lineage(pipeline_details, pipeline_entity) - if not pipeline_details.job_id: + entity_id = pipeline_details.job_id or pipeline_details.pipeline_id + if not entity_id: return - table_lineage_list = self.client.get_table_lineage( - job_id=pipeline_details.job_id - ) - logger.debug( - f"Processing pipeline lineage for job {pipeline_details.job_id}" - ) + table_lineage_list = self.client.get_table_lineage(entity_id=entity_id) + logger.debug(f"Processing pipeline lineage for {entity_id}") if table_lineage_list: for table_lineage in table_lineage_list: source_table_full_name = table_lineage.get("source_table_full_name") @@ -1296,7 +1293,7 @@ def yield_pipeline_lineage_details( processed_column_lineage = ( self._process_and_validate_column_lineage( column_lineage=self.client.get_column_lineage( - job_id=pipeline_details.job_id, + entity_id=entity_id, TableKey=( source_table_full_name, target_table_full_name, @@ -1330,15 +1327,8 @@ def yield_pipeline_lineage_details( ) ) ) - - else: - logger.debug( - f"No source or target table full name found for job {pipeline_details.job_id}" - ) else: - logger.debug( - f"No table lineage found for job {pipeline_details.job_id}" - ) + logger.debug(f"No table lineage found for {entity_id}") except Exception as exc: yield Either( left=StackTraceError( diff --git a/ingestion/tests/unit/topology/pipeline/test_databricks_pipeline.py b/ingestion/tests/unit/topology/pipeline/test_databricks_pipeline.py index 04454b6ab129..9f23a46cc073 100644 --- a/ingestion/tests/unit/topology/pipeline/test_databricks_pipeline.py +++ b/ingestion/tests/unit/topology/pipeline/test_databricks_pipeline.py @@ -444,3 +444,156 @@ def get_by_name_side_effect(entity, fqn): lineage_details.edge.lineageDetails.columnsLineage, [], ) + + def test_databricks_dlt_pipeline_lineage(self): + dlt_pipeline_id = "115f1983-1e70-46a9-b7fb-dd0150179561" + self.databricks.context.get().__dict__["pipeline"] = dlt_pipeline_id + self.databricks.context.get().__dict__[ + "pipeline_service" + ] = "databricks_pipeline_test" + mock_pipeline = Pipeline( + id=uuid.uuid4(), + name=dlt_pipeline_id, + fullyQualifiedName=f"databricks_pipeline_test.{dlt_pipeline_id}", + service=EntityReference(id=uuid.uuid4(), type="pipelineService"), + ) + + # Create source and target tables + mock_source_table = Table( + id="cced5342-12e8-45fb-b50a-918529d43ed1", + name="table_1", + fullyQualifiedName="local_table.dev.table_1", + database=EntityReference(id=uuid.uuid4(), type="database"), + columns=[ + Column( + name="column_1", + fullyQualifiedName="local_table.dev.table_1.column_1", + dataType="VARCHAR", + ) + ], + databaseSchema=EntityReference(id=uuid.uuid4(), type="databaseSchema"), + ) + + mock_target_table = Table( + id="6f5ad342-12e8-45fb-b50a-918529d43ed1", + name="table_2", + fullyQualifiedName="local_table.dev.table_2", + database=EntityReference(id=uuid.uuid4(), type="database"), + columns=[ + Column( + name="column_2", + fullyQualifiedName="local_table.dev.table_2.column_2", + dataType="VARCHAR", + ) + ], + databaseSchema=EntityReference(id=uuid.uuid4(), type="databaseSchema"), + ) + + dlt_pipeline_details = DataBrickPipelineDetails( + pipeline_id=dlt_pipeline_id, + name="test-dlt-pipeline", + ) + + with patch.object(self.databricks.metadata, "get_by_name") as mock_get_by_name: + + def get_by_name_side_effect(entity, fqn): + if entity == Pipeline: + if fqn == f"databricks_pipeline_test.{dlt_pipeline_id}": + return mock_pipeline + elif entity == Table: + if "table_1" in fqn: + return mock_source_table + elif "table_2" in fqn: + return mock_target_table + return None + + mock_get_by_name.side_effect = get_by_name_side_effect + + with patch.object( + self.databricks.client, "get_table_lineage" + ) as mock_get_table_lineage: + mock_get_table_lineage.return_value = [ + { + "source_table_full_name": "local_table.dev.table_1", + "target_table_full_name": "local_table.dev.table_2", + } + ] + with patch.object( + self.databricks.client, "get_column_lineage" + ) as mock_get_column_lineage: + mock_get_column_lineage.return_value = [ + ("column_1", "column_2"), + ("column_3", "column_4"), + ] + with patch.object( + self.databricks.client, "get_pipeline_details" + ) as mock_get_pipeline_details: + mock_get_pipeline_details.return_value = None + + lineage_details = list( + self.databricks.yield_pipeline_lineage_details( + dlt_pipeline_details + ) + )[0].right + self.assertEqual( + lineage_details.edge.fromEntity.id, + EXPECTED_PIPELINE_LINEAGE.edge.fromEntity.id, + ) + self.assertEqual( + lineage_details.edge.toEntity.id, + EXPECTED_PIPELINE_LINEAGE.edge.toEntity.id, + ) + self.assertEqual( + lineage_details.edge.lineageDetails.columnsLineage, + EXPECTED_PIPELINE_LINEAGE.edge.lineageDetails.columnsLineage, + ) + + with patch.object(self.databricks.metadata, "get_by_name") as mock_get_by_name: + + def get_by_name_side_effect(entity, fqn): + if entity == Pipeline: + if fqn == f"databricks_pipeline_test.{dlt_pipeline_id}": + return mock_pipeline + elif entity == Table: + if "table_1" in fqn: + return mock_source_table + elif "table_2" in fqn: + return mock_target_table + return None + + mock_get_by_name.side_effect = get_by_name_side_effect + + with patch.object( + self.databricks.client, "get_table_lineage" + ) as mock_get_table_lineage: + mock_get_table_lineage.return_value = [ + { + "source_table_full_name": "local_table.dev.table_1", + "target_table_full_name": "local_table.dev.table_2", + } + ] + with patch.object( + self.databricks.client, "get_column_lineage" + ) as mock_get_column_lineage: + mock_get_column_lineage.return_value = [] # No column lineage + with patch.object( + self.databricks.client, "get_pipeline_details" + ) as mock_get_pipeline_details: + mock_get_pipeline_details.return_value = None + lineage_details = list( + self.databricks.yield_pipeline_lineage_details( + dlt_pipeline_details + ) + )[0].right + self.assertEqual( + lineage_details.edge.fromEntity.id, + EXPECTED_PIPELINE_LINEAGE.edge.fromEntity.id, + ) + self.assertEqual( + lineage_details.edge.toEntity.id, + EXPECTED_PIPELINE_LINEAGE.edge.toEntity.id, + ) + self.assertEqual( + lineage_details.edge.lineageDetails.columnsLineage, + [], + )