diff --git a/docs/api.md b/docs/api.md index 4980020..4bdb29a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -14,4 +14,5 @@ api/io api/datasets api/types api/backends +api/merge ``` diff --git a/docs/api/merge.md b/docs/api/merge.md new file mode 100644 index 0000000..264ef10 --- /dev/null +++ b/docs/api/merge.md @@ -0,0 +1,12 @@ +# Merge + +Concatenation of TF-MInDi AnnData objects. + +```{eval-rst} +.. currentmodule:: tfmindi + +.. autosummary:: + :toctree: ../generated + + concat +``` diff --git a/src/tfmindi/__init__.py b/src/tfmindi/__init__.py index 73f55ab..c2ec022 100644 --- a/src/tfmindi/__init__.py +++ b/src/tfmindi/__init__.py @@ -23,6 +23,7 @@ load_motif_to_dbd, ) from tfmindi.io import load_h5ad, load_patterns, save_h5ad, save_patterns # noqa: E402 +from tfmindi.merge import concat # noqa: E402 from tfmindi.types import Pattern, Seqlet # noqa: E402 __all__ = [ @@ -43,6 +44,7 @@ "load_h5ad", "save_patterns", "load_patterns", + "concat", ] __version__ = version("tfmindi") diff --git a/src/tfmindi/merge.py b/src/tfmindi/merge.py new file mode 100644 index 0000000..e7e0544 --- /dev/null +++ b/src/tfmindi/merge.py @@ -0,0 +1,123 @@ +"""TF-MInDi anndata merge functionality.""" + +from _collections_abc import dict_items + +import anndata # type: ignore +import numpy as np # type: ignore +from anndata._core.merge import StrategiesLiteral # type: ignore + +_INDEX_COLS = ["example_oh_idx", "example_contrib_idx", "example_idx"] + + +def concat( + adatas: list[anndata.AnnData] | dict[str, anndata.AnnData], + idx_match: bool = False, + index_unique: str = "-", + merge: StrategiesLiteral | None = "same", + **kwargs, +) -> anndata.AnnData: + """ + Concatenate multiple TF-MInDi anndatas preserving data stored in uns['unique_examples']. + + Parameters + ---------- + adatas + The objects to be concatenated. If a dict is passed, keys are used for + the keys argument and values are concatenated. + idx_match + Whether `example_oh_idx`, `example_contrib_idx` and `example_idx` + refer to the same data across adatas or not. + index_unique + Whether to make the index unique by using the keys. + If provided, this is the delimiter between "{orig_idx}{index_unique}{key}". + When None, the original indices are kept. + merge + How elements not aligned to the axis being concatenated along are selected. + See: anndata.concat for more info. + **kwargs + Extra key word arguments passed to anndata.concat + """ + if merge is None: + print("merge is None, vars will not be carried over to concatenated adata!") + if not isinstance(index_unique, str): + raise ValueError("index_unique should be a string.") + + if isinstance(adatas, dict): + adatas_iter: dict_items[str, anndata.AnnData] | list[tuple[int, anndata.AnnData]] = adatas.items() + else: + adatas_iter = [(i, a) for i, a in enumerate(adatas)] + + def _has_unique_example(adata: anndata.AnnData) -> bool: + return ( + "unique_examples" in adata.uns.keys() + and "oh" in adata.uns["unique_examples"].keys() + and "contrib" in adata.uns["unique_examples"].keys() + and "example_oh_idx" in adata.obs.columns + and "example_contrib_idx" in adata.obs.columns + ) + + if idx_match: + # make sure same data is stored in all unique examples + v_oh = [a.uns["unique_examples"]["oh"] for _, a in adatas_iter if _has_unique_example(a)] + v_co = [a.uns["unique_examples"]["contrib"] for _, a in adatas_iter if _has_unique_example(a)] + + if not all(np.array_equal(v_oh[0], arr) for arr in v_oh) or not all( + np.array_equal(v_co[0], arr) for arr in v_co + ): + message = ( + "All adata.uns['unique_examples']['contrib'] and adata.uns['unique_examples']['oh']" + + "should be the same across adatas." + ) + raise ValueError(message) + + if not idx_match: + # the columns representing indices in adata.obs do *not* point to the + # same data. In this case we make the indices unique across all adatas. + + # These values will be stored in the combined adata + l_unique_examples_oh: list[np.ndarray] = [] + l_unique_examples_co: list[np.ndarray] = [] + + # Dictionary to keep track of index offsets, adatas will be changed + # in place. After concatenation the original values will be replaced + # in the adatas. + idx_col_offset: dict[str, int] = dict.fromkeys(_INDEX_COLS, 0) + idx_col_offset_per_ad: dict[str | int, dict[str, int]] = {} + + for k, adata in adatas_iter: + if _has_unique_example(adata): + idx_col_offset_per_ad[k] = idx_col_offset + l_unique_examples_oh.extend(adata.uns["unique_examples"]["oh"]) + l_unique_examples_co.extend(adata.uns["unique_examples"]["contrib"]) + + # change indeces in place + for col in _INDEX_COLS: + adata.obs[col] += idx_col_offset[col] + + # get offset for next iteration + idx_col_offset = {col: adata.obs[col].max() + 1 for col in _INDEX_COLS} + + unique_examples_oh = np.array(l_unique_examples_oh) + unique_examples_co = np.array(l_unique_examples_co) + + else: + # All values in v_oh and v_co are unique, just take the first + unique_examples_oh = v_oh[0] + unique_examples_co = v_co[0] + + adata_concat = anndata.concat( + adatas={str(k): adata for k, adata in adatas_iter}, index_unique=index_unique, merge=merge, **kwargs + ) + + adata_concat.uns["unique_examples"] = {} + adata_concat.uns["unique_examples"]["oh"] = unique_examples_oh + adata_concat.uns["unique_examples"]["contrib"] = unique_examples_co + + # 2. reset example_oh_idx and example_contrib_idx in place to original values + if not idx_match: + for k, adata in adatas_iter: + if _has_unique_example(adata): + for col in _INDEX_COLS: + adata.obs[col] -= idx_col_offset_per_ad[k][col] + + return adata_concat diff --git a/tests/test_merge.py b/tests/test_merge.py new file mode 100644 index 0000000..c72fb10 --- /dev/null +++ b/tests/test_merge.py @@ -0,0 +1,128 @@ +"""Tests for merging functionality.""" + +import anndata +import numpy as np +import pandas as pd +import pytest + + +def test_merge_non_match_idx(sample_seqlet_adata: anndata.AnnData): + import tfmindi as tm + + adata_a = sample_seqlet_adata + adata_b = sample_seqlet_adata.copy() + + adata_c = tm.concat({"a": adata_a, "b": adata_b}, idx_match=False, index_unique="-") + + assert adata_c.shape[0] == adata_a.shape[0] + adata_b.shape[0] + + assert "unique_examples" in adata_c.uns.keys() + assert "oh" in adata_c.uns["unique_examples"].keys() + assert "contrib" in adata_c.uns["unique_examples"].keys() + + pd.testing.assert_frame_equal(adata_c.var, adata_a.var) + + assert ( + adata_c.uns["unique_examples"]["oh"].shape[0] + == adata_a.uns["unique_examples"]["oh"].shape[0] + adata_b.uns["unique_examples"]["oh"].shape[0] + ) + + a_index = [f"{orig_idx}-a" for orig_idx in adata_a.obs_names] + b_index = [f"{orig_idx}-b" for orig_idx in adata_b.obs_names] + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[a_index, "example_oh_idx"]], + adata_a.uns["unique_examples"]["oh"][adata_a.obs["example_oh_idx"]], + ) + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[b_index, "example_oh_idx"]], + adata_b.uns["unique_examples"]["oh"][adata_b.obs["example_oh_idx"]], + ) + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[a_index, "example_contrib_idx"]], + adata_a.uns["unique_examples"]["contrib"][adata_a.obs["example_contrib_idx"]], + ) + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[b_index, "example_contrib_idx"]], + adata_b.uns["unique_examples"]["contrib"][adata_b.obs["example_contrib_idx"]], + ) + + +def test_merge_match_idx(sample_seqlet_adata: anndata.AnnData): + import tfmindi as tm + + # fake data with some overlap + adata_a = sample_seqlet_adata[0:150].copy() + adata_b = sample_seqlet_adata[100:250].copy() + + adata_c = tm.concat({"a": adata_a, "b": adata_b}, idx_match=True, index_unique="-") + + assert adata_c.shape[0] == adata_a.shape[0] + adata_b.shape[0] + + assert "unique_examples" in adata_c.uns.keys() + assert "oh" in adata_c.uns["unique_examples"].keys() + assert "contrib" in adata_c.uns["unique_examples"].keys() + + pd.testing.assert_frame_equal(adata_c.var, adata_a.var) + + assert adata_c.uns["unique_examples"]["oh"].shape[0] == adata_a.uns["unique_examples"]["oh"].shape[0] + + a_index = [f"{orig_idx}-a" for orig_idx in adata_a.obs_names] + b_index = [f"{orig_idx}-b" for orig_idx in adata_b.obs_names] + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[a_index, "example_oh_idx"]], + adata_a.uns["unique_examples"]["oh"][adata_a.obs["example_oh_idx"]], + ) + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["oh"][adata_c.obs.loc[b_index, "example_oh_idx"]], + adata_b.uns["unique_examples"]["oh"][adata_b.obs["example_oh_idx"]], + ) + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[a_index, "example_contrib_idx"]], + adata_a.uns["unique_examples"]["contrib"][adata_a.obs["example_contrib_idx"]], + ) + + np.testing.assert_array_equal( + adata_c.uns["unique_examples"]["contrib"][adata_c.obs.loc[b_index, "example_contrib_idx"]], + adata_b.uns["unique_examples"]["contrib"][adata_b.obs["example_contrib_idx"]], + ) + + +def test_idx_match_with_mismatched_data(sample_seqlet_adata): + """Test that ValueError is raised when idx_match=True but data differs""" + import tfmindi as tm + + adata_a = sample_seqlet_adata.copy() + adata_b = sample_seqlet_adata.copy() + # Modify unique_examples to be different + adata_b.uns["unique_examples"]["oh"] = np.zeros_like(adata_b.uns["unique_examples"]["oh"]) + + with pytest.raises(ValueError, match="should be the same across adatas"): + tm.concat([adata_a, adata_b], idx_match=True) + + +def test_invalid_index_unique_type(): + """Test that non-string index_unique raises ValueError""" + import tfmindi as tm + + with pytest.raises(ValueError, match="index_unique should be a string"): + tm.concat([...], index_unique=None) + + +def test_original_adatas_unchanged(sample_seqlet_adata): + """Ensure original adatas are restored after concat with idx_match=False""" + import tfmindi as tm + + adata_a = sample_seqlet_adata.copy() + adata_b = sample_seqlet_adata.copy() + # Store original values + tm.concat([adata_a, adata_b], idx_match=False) + + pd.testing.assert_frame_equal(adata_a.obs, sample_seqlet_adata.obs) + pd.testing.assert_frame_equal(adata_b.obs, sample_seqlet_adata.obs)