diff --git a/src/datarepo/core/tables/deltalake_table.py b/src/datarepo/core/tables/deltalake_table.py index 7c9a57a..215eb42 100644 --- a/src/datarepo/core/tables/deltalake_table.py +++ b/src/datarepo/core/tables/deltalake_table.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass import os -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Dict, Tuple import warnings import boto3 @@ -18,7 +18,7 @@ from datarepo.core.tables.filters import InputFilters, normalize_filters from datarepo.core.tables.metadata import ( TableMetadata, - TableProtocol, + VersionedTableProtocol, TableSchema, TableColumn, TablePartition, @@ -65,7 +65,7 @@ def to_storage_options(self) -> dict[str, Any]: return opts -class DeltalakeTable(TableProtocol): +class DeltalakeTable(VersionedTableProtocol): """A table that is backed by a Delta Lake table.""" def __init__( @@ -112,6 +112,47 @@ def __init__( **(table_metadata_args or {}), ) + # for delta table version is formed as combination of uri and schema + self._versions: Dict[str, Tuple[str, pa.Schema]] = { + "v1": (uri, schema), + } + + def add_version( + self, version: str, uri: str, schema: pa.Schema, **kwargs: Any + ) -> None: + """Add a new version of the table. + + Args: + version (str): The version number of the table. + uri (str): The URI of the table. + schema (pa.Schema): The schema of the table. + """ + if version in self._versions: + raise ValueError(f"Version {version} already exists") + self._versions[version] = (uri, schema) + self.uri = uri + self.schema = schema + + def change_version(self, version: str): + """Change the current version of the table. + + Args: + version (str): The version of the table. + """ + if version not in self._versions: + raise ValueError( + f"Version {version} does not exist. Available versions: {self.get_versions()}" + ) + self.uri, self.schema = self._versions[version] + + def get_versions(self) -> list[str]: + """Get all versions of the table. + + Returns: + list[str]: The versions of the table. + """ + return list(self._versions.keys()) + def get_schema(self) -> TableSchema: """Generate and return the schema of the table, including partitions and columns. diff --git a/src/datarepo/core/tables/metadata.py b/src/datarepo/core/tables/metadata.py index 617c130..db6eb7a 100644 --- a/src/datarepo/core/tables/metadata.py +++ b/src/datarepo/core/tables/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, Protocol, TypedDict +from typing import Any, Dict, Protocol, TypedDict, List from datarepo.core.dataframe.frame import NlkDataFrame from datarepo.core.tables.util import RoapiOptions @@ -54,3 +54,32 @@ def get_schema(self) -> TableSchema: Returns the schema of the table, used to generate the web catalog. """ ... + + +class VersionedTableProtocol(TableProtocol): + """A versioned table.""" + + def add_version(self, version: str, **kwargs: Dict[str, Any]) -> None: + """Add a new version of the table. + + Args: + version (str): The version number of the table. + **kwargs: Additional arguments to pass for the version. + """ + ... + + def change_version(self, version: str): + """Change the current version of the table. + + Args: + version (str): The version of the table. + """ + ... + + def get_versions(self) -> List[str]: # type: ignore[empty-body] + """Get all versions of the table. + + Returns: + list[str]: Versions of the table. + """ + ... diff --git a/test/tables/test_deltalake_table.py b/test/tables/test_deltalake_table.py index 9e49a2f..3dc3a10 100644 --- a/test/tables/test_deltalake_table.py +++ b/test/tables/test_deltalake_table.py @@ -377,6 +377,30 @@ def test_call_loads_correct_data( expected_sorted = expected.sort("value") assert actual_sorted.equals(expected_sorted) + def test_versioned_table(self, delta_table_definition: DeltalakeTable): + """Test versioning for delta tables.""" + delta_table_definition.add_version("v2", delta_table_definition.uri, delta_table_definition.schema) + assert delta_table_definition.get_versions() == ["v1", "v2"] + + new_schema = pa.schema( + [ + ("implant_id", pa.int64()), + ("uniq", pa.string()), + ("value", pa.int64()), + ] + ) + delta_table_definition.add_version("v3", delta_table_definition.uri, new_schema) + assert delta_table_definition.get_versions() == ["v1", "v2", "v3"] + + delta_table_definition.change_version("v1") + assert delta_table_definition.schema != new_schema + + with pytest.raises(ValueError, match="Version v2 already exists"): + delta_table_definition.add_version("v2", delta_table_definition.uri, delta_table_definition.schema) + + with pytest.raises(ValueError): + delta_table_definition.change_version("v4") + """ this test is commented out until we upstream delta caching def test_delta_cache( self,