Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ api/io
api/datasets
api/types
api/backends
api/merge
```
12 changes: 12 additions & 0 deletions docs/api/merge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Merge

Concatenation of TF-MInDi AnnData objects.

```{eval-rst}
.. currentmodule:: tfmindi

.. autosummary::
:toctree: ../generated

concat
```
2 changes: 2 additions & 0 deletions src/tfmindi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
load_motif_to_dbd,
)
from tfmindi.io import load_h5ad, load_patterns, save_h5ad, save_patterns # noqa: E402
from tfmindi.merge import concat # noqa: E402
from tfmindi.types import Pattern, Seqlet # noqa: E402

__all__ = [
Expand All @@ -43,6 +44,7 @@
"load_h5ad",
"save_patterns",
"load_patterns",
"concat",
]

__version__ = version("tfmindi")
123 changes: 123 additions & 0 deletions src/tfmindi/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""TF-MInDi anndata merge functionality."""

from _collections_abc import dict_items

import anndata # type: ignore
import numpy as np # type: ignore
from anndata._core.merge import StrategiesLiteral # type: ignore

_INDEX_COLS = ["example_oh_idx", "example_contrib_idx", "example_idx"]


def concat(
adatas: list[anndata.AnnData] | dict[str, anndata.AnnData],
idx_match: bool = False,
index_unique: str = "-",
merge: StrategiesLiteral | None = "same",
**kwargs,
) -> anndata.AnnData:
"""
Concatenate multiple TF-MInDi anndatas preserving data stored in uns['unique_examples'].

Parameters
----------
adatas
The objects to be concatenated. If a dict is passed, keys are used for
the keys argument and values are concatenated.
idx_match
Whether `example_oh_idx`, `example_contrib_idx` and `example_idx`
refer to the same data across adatas or not.
index_unique
Whether to make the index unique by using the keys.
If provided, this is the delimiter between "{orig_idx}{index_unique}{key}".
When None, the original indices are kept.
merge
How elements not aligned to the axis being concatenated along are selected.
See: anndata.concat for more info.
**kwargs
Extra key word arguments passed to anndata.concat
"""
if merge is None:
print("merge is None, vars will not be carried over to concatenated adata!")
if not isinstance(index_unique, str):
raise ValueError("index_unique should be a string.")

if isinstance(adatas, dict):
adatas_iter: dict_items[str, anndata.AnnData] | list[tuple[int, anndata.AnnData]] = adatas.items()
else:
adatas_iter = [(i, a) for i, a in enumerate(adatas)]

def _has_unique_example(adata: anndata.AnnData) -> bool:
return (
"unique_examples" in adata.uns.keys()
and "oh" in adata.uns["unique_examples"].keys()
and "contrib" in adata.uns["unique_examples"].keys()
and "example_oh_idx" in adata.obs.columns
and "example_contrib_idx" in adata.obs.columns
)

if idx_match:
# make sure same data is stored in all unique examples
v_oh = [a.uns["unique_examples"]["oh"] for _, a in adatas_iter if _has_unique_example(a)]
v_co = [a.uns["unique_examples"]["contrib"] for _, a in adatas_iter if _has_unique_example(a)]

if not all(np.array_equal(v_oh[0], arr) for arr in v_oh) or not all(
np.array_equal(v_co[0], arr) for arr in v_co
):
message = (
"All adata.uns['unique_examples']['contrib'] and adata.uns['unique_examples']['oh']"
+ "should be the same across adatas."
)
raise ValueError(message)

if not idx_match:
# the columns representing indices in adata.obs do *not* point to the
# same data. In this case we make the indices unique across all adatas.

# These values will be stored in the combined adata
l_unique_examples_oh: list[np.ndarray] = []
l_unique_examples_co: list[np.ndarray] = []

# Dictionary to keep track of index offsets, adatas will be changed
# in place. After concatenation the original values will be replaced
# in the adatas.
idx_col_offset: dict[str, int] = dict.fromkeys(_INDEX_COLS, 0)
idx_col_offset_per_ad: dict[str | int, dict[str, int]] = {}

for k, adata in adatas_iter:
if _has_unique_example(adata):
idx_col_offset_per_ad[k] = idx_col_offset
l_unique_examples_oh.extend(adata.uns["unique_examples"]["oh"])
l_unique_examples_co.extend(adata.uns["unique_examples"]["contrib"])

# change indeces in place
for col in _INDEX_COLS:
adata.obs[col] += idx_col_offset[col]

# get offset for next iteration
idx_col_offset = {col: adata.obs[col].max() + 1 for col in _INDEX_COLS}

unique_examples_oh = np.array(l_unique_examples_oh)
unique_examples_co = np.array(l_unique_examples_co)

else:
# All values in v_oh and v_co are unique, just take the first
unique_examples_oh = v_oh[0]
unique_examples_co = v_co[0]

adata_concat = anndata.concat(
adatas={str(k): adata for k, adata in adatas_iter}, index_unique=index_unique, merge=merge, **kwargs
)

adata_concat.uns["unique_examples"] = {}
adata_concat.uns["unique_examples"]["oh"] = unique_examples_oh
adata_concat.uns["unique_examples"]["contrib"] = unique_examples_co

# 2. reset example_oh_idx and example_contrib_idx in place to original values
if not idx_match:
for k, adata in adatas_iter:
if _has_unique_example(adata):
for col in _INDEX_COLS:
adata.obs[col] -= idx_col_offset_per_ad[k][col]

return adata_concat
128 changes: 128 additions & 0 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for merging functionality."""

import anndata
import numpy as np
import pandas as pd
import pytest


def test_merge_non_match_idx(sample_seqlet_adata: anndata.AnnData):
import tfmindi as tm

adata_a = sample_seqlet_adata
adata_b = sample_seqlet_adata.copy()

adata_c = tm.concat({"a": adata_a, "b": adata_b}, idx_match=False, index_unique="-")

assert adata_c.shape[0] == adata_a.shape[0] + adata_b.shape[0]

assert "unique_examples" in adata_c.uns.keys()
assert "oh" in adata_c.uns["unique_examples"].keys()
assert "contrib" in adata_c.uns["unique_examples"].keys()

pd.testing.assert_frame_equal(adata_c.var, adata_a.var)

assert (
adata_c.uns["unique_examples"]["oh"].shape[0]
== adata_a.uns["unique_examples"]["oh"].shape[0] + adata_b.uns["unique_examples"]["oh"].shape[0]
)

a_index = [f"{orig_idx}-a" for orig_idx in adata_a.obs_names]
b_index = [f"{orig_idx}-b" for orig_idx in adata_b.obs_names]

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[a_index, "example_oh_idx"]],
adata_a.uns["unique_examples"]["oh"][adata_a.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[b_index, "example_oh_idx"]],
adata_b.uns["unique_examples"]["oh"][adata_b.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[a_index, "example_contrib_idx"]],
adata_a.uns["unique_examples"]["contrib"][adata_a.obs["example_contrib_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[b_index, "example_contrib_idx"]],
adata_b.uns["unique_examples"]["contrib"][adata_b.obs["example_contrib_idx"]],
)


def test_merge_match_idx(sample_seqlet_adata: anndata.AnnData):
import tfmindi as tm

# fake data with some overlap
adata_a = sample_seqlet_adata[0:150].copy()
adata_b = sample_seqlet_adata[100:250].copy()

adata_c = tm.concat({"a": adata_a, "b": adata_b}, idx_match=True, index_unique="-")

assert adata_c.shape[0] == adata_a.shape[0] + adata_b.shape[0]

assert "unique_examples" in adata_c.uns.keys()
assert "oh" in adata_c.uns["unique_examples"].keys()
assert "contrib" in adata_c.uns["unique_examples"].keys()

pd.testing.assert_frame_equal(adata_c.var, adata_a.var)

assert adata_c.uns["unique_examples"]["oh"].shape[0] == adata_a.uns["unique_examples"]["oh"].shape[0]

a_index = [f"{orig_idx}-a" for orig_idx in adata_a.obs_names]
b_index = [f"{orig_idx}-b" for orig_idx in adata_b.obs_names]

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[a_index, "example_oh_idx"]],
adata_a.uns["unique_examples"]["oh"][adata_a.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[b_index, "example_oh_idx"]],
adata_b.uns["unique_examples"]["oh"][adata_b.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[a_index, "example_contrib_idx"]],
adata_a.uns["unique_examples"]["contrib"][adata_a.obs["example_contrib_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[b_index, "example_contrib_idx"]],
adata_b.uns["unique_examples"]["contrib"][adata_b.obs["example_contrib_idx"]],
)


def test_idx_match_with_mismatched_data(sample_seqlet_adata):
"""Test that ValueError is raised when idx_match=True but data differs"""
import tfmindi as tm

adata_a = sample_seqlet_adata.copy()
adata_b = sample_seqlet_adata.copy()
# Modify unique_examples to be different
adata_b.uns["unique_examples"]["oh"] = np.zeros_like(adata_b.uns["unique_examples"]["oh"])

with pytest.raises(ValueError, match="should be the same across adatas"):
tm.concat([adata_a, adata_b], idx_match=True)


def test_invalid_index_unique_type():
"""Test that non-string index_unique raises ValueError"""
import tfmindi as tm

with pytest.raises(ValueError, match="index_unique should be a string"):
tm.concat([...], index_unique=None)


def test_original_adatas_unchanged(sample_seqlet_adata):
"""Ensure original adatas are restored after concat with idx_match=False"""
import tfmindi as tm

adata_a = sample_seqlet_adata.copy()
adata_b = sample_seqlet_adata.copy()
# Store original values
tm.concat([adata_a, adata_b], idx_match=False)

pd.testing.assert_frame_equal(adata_a.obs, sample_seqlet_adata.obs)
pd.testing.assert_frame_equal(adata_b.obs, sample_seqlet_adata.obs)
Loading