Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/datarepo/core/tables/deltalake_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
import os
from typing import Any, TypeAlias
from typing import Any, TypeAlias, Dict, Tuple
import warnings

import boto3
Expand All @@ -18,7 +18,7 @@
from datarepo.core.tables.filters import InputFilters, normalize_filters
from datarepo.core.tables.metadata import (
TableMetadata,
TableProtocol,
VersionedTableProtocol,
TableSchema,
TableColumn,
TablePartition,
Expand Down Expand Up @@ -65,7 +65,7 @@ def to_storage_options(self) -> dict[str, Any]:
return opts


class DeltalakeTable(TableProtocol):
class DeltalakeTable(VersionedTableProtocol):
"""A table that is backed by a Delta Lake table."""

def __init__(
Expand Down Expand Up @@ -112,6 +112,47 @@ def __init__(
**(table_metadata_args or {}),
)

# for delta table version is formed as combination of uri and schema
self._versions: Dict[str, Tuple[str, pa.Schema]] = {
"v1": (uri, schema),
}

def add_version(
self, version: str, uri: str, schema: pa.Schema, **kwargs: Any
) -> None:
"""Add a new version of the table.

Args:
version (str): The version number of the table.
uri (str): The URI of the table.
schema (pa.Schema): The schema of the table.
"""
if version in self._versions:
raise ValueError(f"Version {version} already exists")
self._versions[version] = (uri, schema)
self.uri = uri
self.schema = schema

def change_version(self, version: str):
"""Change the current version of the table.

Args:
version (str): The version of the table.
"""
if version not in self._versions:
raise ValueError(
f"Version {version} does not exist. Available versions: {self.get_versions()}"
)
self.uri, self.schema = self._versions[version]

def get_versions(self) -> list[str]:
"""Get all versions of the table.

Returns:
list[str]: The versions of the table.
"""
return list(self._versions.keys())

def get_schema(self) -> TableSchema:
"""Generate and return the schema of the table, including partitions and columns.

Expand Down
31 changes: 30 additions & 1 deletion src/datarepo/core/tables/metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, Protocol, TypedDict
from typing import Any, Dict, Protocol, TypedDict, List

from datarepo.core.dataframe.frame import NlkDataFrame
from datarepo.core.tables.util import RoapiOptions
Expand Down Expand Up @@ -54,3 +54,32 @@ def get_schema(self) -> TableSchema:
Returns the schema of the table, used to generate the web catalog.
"""
...


class VersionedTableProtocol(TableProtocol):
"""A versioned table."""

def add_version(self, version: str, **kwargs: Dict[str, Any]) -> None:
"""Add a new version of the table.

Args:
version (str): The version number of the table.
**kwargs: Additional arguments to pass for the version.
"""
...

def change_version(self, version: str):
"""Change the current version of the table.

Args:
version (str): The version of the table.
"""
...

def get_versions(self) -> List[str]: # type: ignore[empty-body]
"""Get all versions of the table.

Returns:
list[str]: Versions of the table.
"""
...
24 changes: 24 additions & 0 deletions test/tables/test_deltalake_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,30 @@ def test_call_loads_correct_data(
expected_sorted = expected.sort("value")
assert actual_sorted.equals(expected_sorted)

def test_versioned_table(self, delta_table_definition: DeltalakeTable):
"""Test versioning for delta tables."""
delta_table_definition.add_version("v2", delta_table_definition.uri, delta_table_definition.schema)
assert delta_table_definition.get_versions() == ["v1", "v2"]

new_schema = pa.schema(
[
("implant_id", pa.int64()),
("uniq", pa.string()),
("value", pa.int64()),
]
)
delta_table_definition.add_version("v3", delta_table_definition.uri, new_schema)
assert delta_table_definition.get_versions() == ["v1", "v2", "v3"]

delta_table_definition.change_version("v1")
assert delta_table_definition.schema != new_schema

with pytest.raises(ValueError, match="Version v2 already exists"):
delta_table_definition.add_version("v2", delta_table_definition.uri, delta_table_definition.schema)

with pytest.raises(ValueError):
delta_table_definition.change_version("v4")

""" this test is commented out until we upstream delta caching
def test_delta_cache(
self,
Expand Down