Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
exclude: "^notebooks/|^notes/"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: debug-statements
- id: detect-private-key
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.4
hooks:
# Run the linter.
- id: ruff
args:
- --fix
- --exit-non-zero-on-fix
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
language: python
- repo: local
# We do not use pre-commit/mirrors-mypy,
# as it comes with opinionated defaults
# (like --ignore-missing-imports)
# and is difficult to configure to run
# with the dependencies correctly installed.
hooks:
- id: mypy
name: mypy
entry: mypy
language: python
# language_version: python3.12
additional_dependencies:
- mypy
- pandas-stubs
- pydantic
- pytest
- loguru
types:
- python
# use require_serial so that script
# is only called once per commit
require_serial: true
# Print the number of files as a sanity-check
verbose: true
exclude: ^docs/tutorials
1 change: 0 additions & 1 deletion haferml/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from haferml.version import __version__
35 changes: 35 additions & 0 deletions haferml/blend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from haferml.data.wrangle.misc import (
update_dict_recursively as _update_dict_recursively,
)
from cloudpathlib import AnyPath
import json


def load_config(config_path, base_folder=None):
Expand Down Expand Up @@ -201,3 +203,36 @@ def __getitem__(self, item):

def __str__(self) -> str:
return f"{self.config}"


class TrainConfig:
"""Config for training stage
"""
Copy link

Copilot AI Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TrainConfig class docstring is incomplete. It should include parameter descriptions and usage examples for better API documentation.

Suggested change
"""
"""
TrainConfig manages configuration files for the training stage.
This class provides methods to initialize, read, update, and save configuration files used during training.
:param config_path: Path to the configuration file. Can be a local path or a cloud path.
:type config_path: AnyPath
Example usage:
>>> from haferml.blend.config import TrainConfig
>>> tc = TrainConfig("train_config.json")
>>> tc.init() # Initializes an empty config file
>>> tc.save({"lr": 0.001, "epochs": 10})
>>> print(tc.config)
{'lr': 0.001, 'epochs': 10}
>>> tc.update({"lr": 0.002})
>>> print(tc.config)
{'lr': 0.002, 'epochs': 10}
"""

Copilot uses AI. Check for mistakes.
def __init__(self, config_path: AnyPath):
self.config_path = config_path

def init(self) -> None:
with open(self.config_path, "w") as fp:
json.dump({}, fp, indent=4)

@property
def config(self):
with open(self.config_path, "r") as fp:
return json.load(fp)

def save(self, config: dict) -> None:
with open(self.config_path, "w") as fp:
json.dump(config, fp, indent=4)

def save_as(self, path: AnyPath) -> None:
with open(path, "w") as fp:
logger.debug(f"Saving config to {path}...")
json.dump(self.config, fp, indent=4)

def update(self, data: dict, return_value: bool = False) -> dict:
config = self.config
config.update(data)
self.save(config)

if return_value:
return config
118 changes: 118 additions & 0 deletions haferml/preprocess/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations
import abc
import pandas as pd
from typing import List, Optional
from loguru import logger


class DFTransform(abc.ABC):
"""abstract class for DataFrame transformations"""

@abc.abstractmethod
def __call__(self, dataframe: pd.DataFrame) -> pd.DataFrame:
pass

def chain(self, other: DFTransform) -> Chain:
return Chain([self, other])

def __add__(self, other: DFTransform) -> Chain:
return self.chain(other)


class Chain(DFTransform):
"""Chain multiple transformations together

:param transformations: list of `DFTransform` to be iterated over
"""

def __init__(self, transformations: list[DFTransform]):
self.transformations: List[DFTransform] = []
for transformation in transformations:
if isinstance(transformation, Chain):
self.transformations.extend(transformation.transformations)
elif isinstance(transformation, DFTransform):
self.transformations.append(transformation)
else:
raise TypeError(
f"Expected DFTransform or Chains, got {type(transformation)}"
)

def __call__(self, dataframe: pd.DataFrame) -> pd.DataFrame:
for t in self.transformations:
dataframe = t(dataframe)
return dataframe


class Identity(DFTransform):
"""Returns the original dataframe

This is useful when summing up a lot of transformations.

For example, for a given list of `DFTransform`,

```python
transformations = [t_1, t_2, t_3]
```

we can use `sum` to concat them,

```python
transform = sum(transformations, Identity())
```
"""

def __init__(self):
logger.debug("This transformation does nothing.")

def __call__(self, dataframe: pd.DataFrame) -> pd.DataFrame:
logger.debug("Returning the original dataframe")
return dataframe


class ConvertCategoricalType(DFTransform):
"""Convert a column to categorical

:param column_name: name of the original column
:param target_column: name of the new column
"""

def __init__(self, column_name: str, target_column: Optional[str] = None):
self.column_name = column_name
if target_column is None:
target_column = column_name
self.target_column = target_column

def __call__(self, dataframe: pd.DataFrame) -> pd.DataFrame:
logger.debug(f"Converting {self.column_name} to categorical")
dataframe[self.target_column] = dataframe[self.column_name].astype("category")
self.categories = dataframe[self.target_column].cat.categories
dataframe[self.target_column] = dataframe[self.target_column].cat.codes

return dataframe


class ExpandJSONValues(DFTransform):
"""Create tabular form from JSON values

:param column_names: list of column names to be expanded
"""

def __init__(
self, column_names: list[str], json_key: str, target_column_prefix: str = ""
):
if isinstance(column_names, str):
column_names = [column_names]
self.column_names = column_names
self.json_key = json_key
self.target_column_prefix = target_column_prefix

def __call__(self, dataframe: pd.DataFrame) -> pd.DataFrame:
logger.debug(f"Expanding JSON values from {self.column_names}")
return dataframe.assign(
**{
f"{self.target_column_prefix}_{k}_{self.json_key}": dataframe.apply(
lambda x: x[self.json_key].get(k), axis=1
)
for k in self.column_names
}
)
20 changes: 19 additions & 1 deletion haferml/preprocess/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
from haferml.preprocess.ingredients import OrderedProcessor, attributes
from loguru import logger
from pydantic import BaseModel
import abc


class BasePreProcessor(OrderedProcessor):
Expand Down Expand Up @@ -150,6 +152,22 @@ def run(self, datasets, **params):
return dataframe


class SimpleProcessor(abc.ABC):
"""
SimpleProcessor is a simple interface for a preprocessor.

:param params: a pydantic BaseModel that contains the configurations.
"""

def __init__(self, params: BaseModel, **kwargs):
self.params = params
self.kwargs = kwargs

@abc.abstractmethod
def __call__(self, dataframe: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError("Please implement this method!")


if __name__ == "__main__":

from haferml.preprocess.ingredients import attributes
Expand Down Expand Up @@ -231,4 +249,4 @@ def _filter_columns_and_crossing(self, dataset):
"a": pd.DataFrame([{"names": "Tima Cook", "requirements": "I need it"}]),
"b": pd.DataFrame([{"names": "Time Cook", "requirements": None}]),
}
dp.preprocess(dataset)
dp.run(dataset)
1 change: 0 additions & 1 deletion haferml/version.py

This file was deleted.

9 changes: 2 additions & 7 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,9 @@ plugins:
- mkdocstrings:
handlers:
python:
setup_commands:
- import sys
- sys.path.append("docs")
selection:
docstring_style: "restructured-text"
members: yes
options:
docstring_style: sphinx
filters:
- "^_[^_]"
watch:
- haferml

Expand Down
Loading