Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tfmindi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
load_motif_to_dbd,
)
from tfmindi.io import load_h5ad, load_patterns, save_h5ad, save_patterns # noqa: E402
from tfmindi.merge import concat # noqa: E402
from tfmindi.types import Pattern, Seqlet # noqa: E402

__all__ = [
Expand All @@ -43,6 +44,7 @@
"load_h5ad",
"save_patterns",
"load_patterns",
"concat",
]

__version__ = version("tfmindi")
123 changes: 123 additions & 0 deletions src/tfmindi/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""TF-MInDi anndata merge functionality."""

from _collections_abc import dict_items

import anndata # type: ignore
import numpy as np # type: ignore
from anndata._core.merge import StrategiesLiteral # type: ignore

_INDEX_COLS = ["example_oh_idx", "example_contrib_idx", "example_idx"]


def concat(
adatas: list[anndata.AnnData] | dict[str, anndata.AnnData],
idx_match: bool = False,
index_unique: str = "-",
merge: StrategiesLiteral | None = "same",
**kwargs,
) -> anndata.AnnData:
"""
Concatenate multiple TF-MInDi anndatas preserving data stored in uns['unique_examples'].

Parameters
----------
adatas
The objects to be concatenated. If a dict is passed, keys are used for
the keys argument and values are concatenated.
idx_match
Whether `example_oh_idx`, `example_contrib_idx` and `example_idx`
refer to the same data across adatas or not.
index_unique
Whether to make the index unique by using the keys.
If provided, this is the delimiter between "{orig_idx}{index_unique}{key}".
When None, the original indices are kept.
merge
How elements not aligned to the axis being concatenated along are selected.
See: anndata.concat for more info.
**kwargs
Extra key word arguments passed to anndata.concat
"""
if merge is None:
print("merge is None, vars will not be carried over to concatenated adata!")
if not isinstance(index_unique, str):
raise ValueError("index_unique should be a string.")

if isinstance(adatas, dict):
adatas_iter: dict_items[str, anndata.AnnData] | list[tuple[int, anndata.AnnData]] = adatas.items()
else:
adatas_iter = [(i, a) for i, a in enumerate(adatas)]

def _has_unique_example(adata: anndata.AnnData) -> bool:
return (
"unique_examples" in adata.uns.keys()
and "oh" in adata.uns["unique_examples"].keys()
and "contrib" in adata.uns["unique_examples"].keys()
and "example_oh_idx" in adata.obs.columns
and "example_contrib_idx" in adata.obs.columns
)

if idx_match:
# make sure same data is stored in all unique examples
v_oh = [a.uns["unique_examples"]["oh"] for _, a in adatas_iter if _has_unique_example(a)]
v_co = [a.uns["unique_examples"]["contrib"] for _, a in adatas_iter if _has_unique_example(a)]

if not all(np.array_equal(v_oh[0], arr) for arr in v_oh) or not all(
np.array_equal(v_co[0], arr) for arr in v_co
):
message = (
"All adata.uns['unique_examples']['contrib'] and adata.uns['unique_examples']['oh']"
+ "should be the same across adatas."
)
raise ValueError(message)

if not idx_match:
# the columns representing indices in adata.obs do *not* point to the
# same data. In this case we make the indices unique across all adatas.

# These values will be stored in the combined adata
l_unique_examples_oh: list[np.ndarray] = []
l_unique_examples_co: list[np.ndarray] = []

# Dictionary to keep track of index offsets, adatas will be changed
# in place. After concatenation the original values will be replaced
# in the adatas.
idx_col_offset: dict[str, int] = dict.fromkeys(_INDEX_COLS, 0)
idx_col_offset_per_ad: dict[str | int, dict[str, int]] = {}

for k, adata in adatas_iter:
if _has_unique_example(adata):
idx_col_offset_per_ad[k] = idx_col_offset
l_unique_examples_oh.extend(adata.uns["unique_examples"]["oh"])
l_unique_examples_co.extend(adata.uns["unique_examples"]["contrib"])

# change indeces in place
for col in _INDEX_COLS:
adata.obs[col] += idx_col_offset[col]

# get offset for next iteration
idx_col_offset = {col: adata.obs[col].max() + 1 for col in _INDEX_COLS}

unique_examples_oh = np.array(l_unique_examples_oh)
unique_examples_co = np.array(l_unique_examples_co)

else:
# All values in v_oh and v_co are unique, just take the first
unique_examples_oh = v_oh[0]
unique_examples_co = v_co[0]

adata_concat = anndata.concat(
adatas={str(k): adata for k, adata in adatas_iter}, index_unique=index_unique, merge=merge, **kwargs
)

adata_concat.uns["unique_examples"] = {}
adata_concat.uns["unique_examples"]["oh"] = unique_examples_oh
adata_concat.uns["unique_examples"]["contrib"] = unique_examples_co

# 2. reset example_oh_idx and example_contrib_idx in place to original values
if not idx_match:
for k, adata in adatas_iter:
if _has_unique_example(adata):
for col in _INDEX_COLS:
adata.obs[col] -= idx_col_offset_per_ad[k][col]

return adata_concat
128 changes: 128 additions & 0 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for merging functionality."""

import anndata
import numpy as np
import pandas as pd
import pytest


def test_merge_non_match_idx(sample_seqlet_adata: anndata.AnnData):
import tfmindi as tm

adata_a = sample_seqlet_adata
adata_b = sample_seqlet_adata.copy()

adata_c = tm.concat({"a": adata_a, "b": adata_b}, idx_match=False, index_unique="-")

assert adata_c.shape[0] == adata_a.shape[0] + adata_b.shape[0]

assert "unique_examples" in adata_c.uns.keys()
assert "oh" in adata_c.uns["unique_examples"].keys()
assert "contrib" in adata_c.uns["unique_examples"].keys()

pd.testing.assert_frame_equal(adata_c.var, adata_a.var)

assert (
adata_c.uns["unique_examples"]["oh"].shape[0]
== adata_a.uns["unique_examples"]["oh"].shape[0] + adata_b.uns["unique_examples"]["oh"].shape[0]
)

a_index = [f"{orig_idx}-a" for orig_idx in adata_a.obs_names]
b_index = [f"{orig_idx}-b" for orig_idx in adata_b.obs_names]

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[a_index, "example_oh_idx"]],
adata_a.uns["unique_examples"]["oh"][adata_a.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[b_index, "example_oh_idx"]],
adata_b.uns["unique_examples"]["oh"][adata_b.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[a_index, "example_contrib_idx"]],
adata_a.uns["unique_examples"]["contrib"][adata_a.obs["example_contrib_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[b_index, "example_contrib_idx"]],
adata_b.uns["unique_examples"]["contrib"][adata_b.obs["example_contrib_idx"]],
)


def test_merge_match_idx(sample_seqlet_adata: anndata.AnnData):
import tfmindi as tm

# fake data with some overlap
adata_a = sample_seqlet_adata[0:150].copy()
adata_b = sample_seqlet_adata[100:250].copy()

adata_c = tm.concat({"a": adata_a, "b": adata_b}, idx_match=True, index_unique="-")

assert adata_c.shape[0] == adata_a.shape[0] + adata_b.shape[0]

assert "unique_examples" in adata_c.uns.keys()
assert "oh" in adata_c.uns["unique_examples"].keys()
assert "contrib" in adata_c.uns["unique_examples"].keys()

pd.testing.assert_frame_equal(adata_c.var, adata_a.var)

assert adata_c.uns["unique_examples"]["oh"].shape[0] == adata_a.uns["unique_examples"]["oh"].shape[0]

a_index = [f"{orig_idx}-a" for orig_idx in adata_a.obs_names]
b_index = [f"{orig_idx}-b" for orig_idx in adata_b.obs_names]

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[a_index, "example_oh_idx"]],
adata_a.uns["unique_examples"]["oh"][adata_a.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[b_index, "example_oh_idx"]],
adata_b.uns["unique_examples"]["oh"][adata_b.obs["example_oh_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[a_index, "example_contrib_idx"]],
adata_a.uns["unique_examples"]["contrib"][adata_a.obs["example_contrib_idx"]],
)

np.testing.assert_array_equal(
adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[b_index, "example_contrib_idx"]],
adata_b.uns["unique_examples"]["contrib"][adata_b.obs["example_contrib_idx"]],
)


def test_idx_match_with_mismatched_data(sample_seqlet_adata):
"""Test that ValueError is raised when idx_match=True but data differs"""
import tfmindi as tm

adata_a = sample_seqlet_adata.copy()
adata_b = sample_seqlet_adata.copy()
# Modify unique_examples to be different
adata_b.uns["unique_examples"]["oh"] = np.zeros_like(adata_b.uns["unique_examples"]["oh"])

with pytest.raises(ValueError, match="should be the same across adatas"):
tm.concat([adata_a, adata_b], idx_match=True)


def test_invalid_index_unique_type():
"""Test that non-string index_unique raises ValueError"""
import tfmindi as tm

with pytest.raises(ValueError, match="index_unique should be a string"):
tm.concat([...], index_unique=None)


def test_original_adatas_unchanged(sample_seqlet_adata):
"""Ensure original adatas are restored after concat with idx_match=False"""
import tfmindi as tm

adata_a = sample_seqlet_adata.copy()
adata_b = sample_seqlet_adata.copy()
# Store original values
tm.concat([adata_a, adata_b], idx_match=False)

pd.testing.assert_frame_equal(adata_a.obs, sample_seqlet_adata.obs)
pd.testing.assert_frame_equal(adata_b.obs, sample_seqlet_adata.obs)
Loading