Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements-py36.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies:
- keras>=2.2.4
# for plugins (double check)
- fastparquet
- pyarrow
- zarr
- numcodecs
# singularity
Expand Down
1 change: 1 addition & 0 deletions dev-requirements-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- keras>=2.2.4
# for plugins (double check)
- fastparquet
- pyarrow
- zarr
- numcodecs
- tinydb>=3.12.2
Expand Down
1 change: 1 addition & 0 deletions dev-requirements-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies:
- pandas=1.3.1
- pillow=8.3.1
- pip=21.2.1
- pyarrow
- pybedtools=0.8.2
- pybigwig=0.3.18
- pyfaidx=0.6.1
Expand Down
1 change: 1 addition & 0 deletions dev-requirements-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- pandas
- pillow
- pip
- pyarrow
- pybedtools
- pybigwig
- pyfaidx
Expand Down
142 changes: 139 additions & 3 deletions kipoi/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,130 @@ def close(self):
pass


class ParquetDirBatchWriter(BatchWriter):

def __init__(
self,
file_path,
chunk_size=None,
nested_sep="/",
append=True,
):
if chunk_size is None:
# 10^6 rows per file by default
chunk_size = 1000000

from pathlib import Path
import uuid
from datetime import datetime
self.file_path = Path(file_path)
self.uuid = datetime.now().strftime('%Y%m%d-%H%M-%S.%f') + str(uuid.uuid4())

self.chunk_size = chunk_size
self.nested_sep = nested_sep
self.write_buffer = list()
self.num_rows = 0
self.batch_num = 0

if self.file_path.exists():
if not append:
raise FileExistsError(f"'{file_path}' already exists!")
if not self.file_path.is_dir():
raise FileExistsError(f"'{file_path}' is no directory!")
else:
self.file_path.mkdir()

def batch_write(self, batch):
df = pd.DataFrame(flatten_batch(batch, nested_sep=self.nested_sep))
df.sort_index(axis=1, inplace=True)

self.write_buffer.append(df)
self.num_rows += df.shape[0]

if self.num_rows >= self.chunk_size:
self._flush()

def _flush(self):
df_all = pd.concat(self.write_buffer, axis=0)

part_file = self.file_path / f"part-{self.batch_num}-{self.uuid}.parquet"
df_all.to_parquet(part_file, index=False)

self.write_buffer = list()
self.num_rows = 0
self.batch_num += 1

def close(self):
if self.num_rows > 0:
self._flush()


class ParquetFileBatchWriter(BatchWriter):

def __init__(
self,
file_path,
chunk_size=None,
nested_sep="/",
):
if chunk_size is None:
chunk_size = 10000

# optional import of pyarrow
import pyarrow as pa
Copy link
Contributor

@haimasree haimasree Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So CI is telling me that pyarrow is not installed. I dont see it in any dev-requirements-py*. yml either. Perhaps try adding pyarrow in kipoi setup.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exact error is ModuleNotFoundError: No module named 'pyarrow'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yes, I have to add it as an optional dependency 👍

import pyarrow.parquet as pq
self.pa = pa
self.pq = pq

from pathlib import Path
import uuid
from datetime import datetime
self.file_path = Path(file_path)
self.uuid = datetime.now().strftime('%Y%m%d-%H%M-%S.%f') + str(uuid.uuid4())

self.chunk_size = chunk_size
self.nested_sep = nested_sep
self.write_buffer = list()
self.num_rows = 0
self.batch_num = 0

self.pq_writer = None

if self.file_path.exists():
raise FileExistsError(f"'{file_path}' already exists!")

def batch_write(self, batch):
df = pd.DataFrame(flatten_batch(batch, nested_sep=self.nested_sep))
df.sort_index(axis=1, inplace=True)

self.write_buffer.append(df)
self.num_rows += df.shape[0]

if self.num_rows >= self.chunk_size:
self._flush()

def _flush(self):
df_all = pd.concat(self.write_buffer, axis=0)
table = self.pa.Table.from_pandas(df_all, preserve_index=False)

if self.pq_writer is None:
self.pq_writer = self.pq.ParquetWriter(self.file_path, table.schema)

self.pq_writer.write_table(table)

self.write_buffer = list()
self.num_rows = 0
self.batch_num += 1

def close(self):
if self.num_rows > 0:
self._flush()

# close the parquet writer
if self.pq_writer:
self.pq_writer.close()


class ParquetBatchWriter(BatchWriter):
"""
Args:
Expand All @@ -192,9 +316,12 @@ class ParquetBatchWriter(BatchWriter):
# Install: conda install -c conda-forge fastparquet
"""

def __init__(self,
file_path,
nested_sep="/", **kwargs):
def __init__(
self,
file_path,
nested_sep="/",
**kwargs
):
try:
import fastparquet as fp
except:
Expand Down Expand Up @@ -575,6 +702,15 @@ def get_writer(output_file, metadata_schema=None, **kwargs):
return HDF5BatchWriter(file_path=output_file, chunk_size=kwargs['hdf5_chunk_size'])
else:
return HDF5BatchWriter(file_path=output_file)
elif ending in ["parquet", "pq", "pqt"]:
if 'parquet_chunk_size' in kwargs:
chunk_size = kwargs["parquet_chunk_size"]
else:
chunk_size = None
return ParquetFileBatchWriter(
file_path=output_file,
chunk_size=chunk_size
)
else:
return None

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"wheel",
"jedi",
"epc",
"pyarrow",
"pytest>=3.3.1",
"pytest-xdist", # running tests in parallel
"pytest-pep8", # see https://github.com/kipoi/kipoi/issues/91
Expand Down
60 changes: 56 additions & 4 deletions tests/test_161_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@
import pytest
from pytest import fixture
from kipoi.metadata import GenomicRanges
from kipoi.writers import (AsyncBatchWriter, BedBatchWriter,
TsvBatchWriter, ZarrBatchWriter, get_zarr_store,
HDF5BatchWriter, BedGraphWriter, MultipleBatchWriter,
ParquetBatchWriter)
from kipoi.writers import (
AsyncBatchWriter,
BedBatchWriter,
TsvBatchWriter,
ZarrBatchWriter,
get_zarr_store,
HDF5BatchWriter,
BedGraphWriter,
MultipleBatchWriter,
ParquetBatchWriter,
ParquetFileBatchWriter,
ParquetDirBatchWriter,
)
from kipoi.readers import HDF5Reader, ZarrReader
from kipoi.cli.main import prepare_batch
import numpy as np
import pandas as pd
from kipoi.specs import DataLoaderSchema, ArraySchema, MetadataStruct, MetadataType
from collections import OrderedDict

from kipoi_utils.utils import get_subsuffix
import zarr

Expand Down Expand Up @@ -112,6 +122,48 @@ def test_TsvBatchWriter_array(dl_batch, pred_batch_array, tmpdir):
assert list(df['metadata/ranges/id']) == [0, 1, 2, 0, 1, 2]


def test_ParquetFileBatchWriter_array(dl_batch, pred_batch_array, tmpdir):
tmpfile = str(tmpdir.mkdir("example").join("out.parquet"))
writer = ParquetFileBatchWriter(tmpfile)
batch = prepare_batch(dl_batch, pred_batch_array, keep_metadata=True)
writer.batch_write(batch)
writer.batch_write(batch)
writer.close()
df = pd.read_parquet(tmpfile)

assert set(list(df.columns)) == {'metadata/ranges/id',
'metadata/ranges/strand',
'metadata/ranges/chr',
'metadata/ranges/start',
'metadata/ranges/end',
'metadata/gene_id',
'preds/0',
'preds/1',
'preds/2'}
assert list(df['metadata/ranges/id']) == ['0', '1', '2', '0', '1', '2']


def test_ParquetDirBatchWriter_array(dl_batch, pred_batch_array, tmpdir):
tmpfile = str(tmpdir.mkdir("example").join("out.parquet"))
writer = ParquetDirBatchWriter(tmpfile)
batch = prepare_batch(dl_batch, pred_batch_array, keep_metadata=True)
writer.batch_write(batch)
writer.batch_write(batch)
writer.close()
df = pd.read_parquet(tmpfile)

assert set(list(df.columns)) == {'metadata/ranges/id',
'metadata/ranges/strand',
'metadata/ranges/chr',
'metadata/ranges/start',
'metadata/ranges/end',
'metadata/gene_id',
'preds/0',
'preds/1',
'preds/2'}
assert list(df['metadata/ranges/id']) == ['0', '1', '2', '0', '1', '2']


# For no good reason this test fails when installing
# from conda even tough this work very fine locally
@pytest.mark.skipif("os.environ.get('CI_JOB_PY_YAML') is not None")
Expand Down