diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 4822876..3bcf178 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -14,7 +14,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: "3.9" + python-version: "3.10" - name: Install dependencies run: make install diff --git a/README.md b/README.md index 42bda7a..7dfab4f 100755 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ bktest [OPTIONS] | `--order-files-extension` | `` | `csv` | `[csv, parquet, json]` | Change the file extension to use when listing for order files. | | `--initial-cash` | `` | `100_000` | `number` | Change the initial cash to use for the backtesting. | | `--quantity-mode` | `` | `percent` | `[percent, share]` | If the mode is `share`, all quantities will be interpreted as integers. If the mode is `percent`, all values will be multiplied by the current cash value. | +| `--fixed-nav` | | `false` | | When using `percent` mode, always size positions based on `--initial-cash` instead of current equity. This disables equity increase in dollars volume: after each rebalance the NAV is reset to exactly `initial_cash` by extracting profits or injecting capital as needed. Fees are captured in each period's return as `(market_gain − fees) / initial_cash`. | | `--weekends` | | `false` | | Enable ordering on weekends. | | `--holidays` | | `false` | | Enable ordering on holidays. | | `--symbol-mapping` | `` | | `path` (.json) | Specify a custom symbol mapping file enabling vendor-id translation. | diff --git a/bktest/__version__.py b/bktest/__version__.py index c61861f..8a71708 100644 --- a/bktest/__version__.py +++ b/bktest/__version__.py @@ -1,6 +1,6 @@ __title__ = 'bktest' __description__ = 'bktest - A simple backtester by CrunchDAO' -__version__ = '2.1.0' +__version__ = '2.1.1' __author__ = 'Enzo CACERES' __author_email__ = 'enzo.caceres@crunchdao.com' __url__ = 'https://github.com/crunchdao/backtest' diff --git a/bktest/backtest.py b/bktest/backtest.py index d260eed..819f828 100644 --- a/bktest/backtest.py +++ b/bktest/backtest.py @@ -21,12 +21,14 @@ def __init__( price_provider: PriceProvider, account: Account, exporters: ExporterCollection, + fixed_nav: bool = False, ): self.quantity_in_decimal = quantity_in_decimal self.auto_close_others = auto_close_others self.price_provider = price_provider self.account = account self.exporters = exporters + self.fixed_nav = fixed_nav def order( self, @@ -49,7 +51,7 @@ def order( others = self.account.symbols if self.quantity_in_decimal: - equity = self.account.equity + equity = self.account.initial_cash if self.fixed_nav else self.account.equity for order in orders: symbol = order.symbol @@ -82,25 +84,28 @@ def order( if result.success: others.discard(symbol) else: - print(f"[warning] order not placed: {symbol} @ {percent}%", file=sys.stderr) + print(f"[warning] order not placed: {symbol} @ {quantity} shares", file=sys.stderr) else: print(f"[warning] cannot place order: {symbol} @ {quantity}x: no price available", file=sys.stderr) if self.auto_close_others: - self._close_all(others, date, results) + self._close_all(others, price_date, results) + + if self.fixed_nav and self.quantity_in_decimal: + self.account.cash = self.account.initial_cash - self.account.value return results def _close_all( self, symbols: typing.Iterable[str], - date: datetime.date, + price_date: datetime.date, results: OrderResultCollection ): closed, total = 0, 0 for symbol in symbols: - price = self.price_provider.get(date, symbol) + price = self.price_provider.get(price_date, symbol) result = self.account.close_position(symbol, price) if result.missing: @@ -124,13 +129,11 @@ def fire_snapshot( self, date: datetime.date, result: OrderResultCollection, - postponned=None ): self.exporters.fire_snapshot( date, self.account, result, - postponned ) @@ -153,6 +156,7 @@ def __init__( allow_weekends=False, allow_holidays=False, holiday_provider: HolidayProvider = LegacyHolidayProvider(), + fixed_nav: bool = False, ): self.order_provider = order_provider order_dates = order_provider.get_dates() @@ -160,13 +164,19 @@ def __init__( self.price_provider = PriceProvider(start, end, data_source, mapper, caching=caching) + def _make_exporter_collection(index): + ec = ExporterCollection(exporters_factory(index)) + ec.configure(fixed_nav=fixed_nav) + return ec + self.pods = [ _Pod( quantity_in_decimal, auto_close_others, self.price_provider, Account(initial_cash=initial_cash, fee_model=fee_model), - ExporterCollection(exporters_factory(index)) + _make_exporter_collection(index), + fixed_nav=fixed_nav, ) for index in range(n) ] @@ -197,7 +207,7 @@ def update_price(self, date): price = cache[symbol] = self.price_provider.get(date, holding.symbol) if price is None: - print(f"[warning] price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr) + print(f"[warning] {date}: price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr) holding.up_to_date = False else: holding.price = price @@ -219,10 +229,7 @@ def order( price_date ) - if price_date: - pod.fire_snapshot(price_date, result, postponned=date) - else: - pod.fire_snapshot(date, result) + pod.fire_snapshot(price_date or date, result) return result @@ -230,14 +237,22 @@ def run(self): self._fire_initialize() for date, ordered, skips in self.date_iterator: + self.update_price(date) + + ordered_in_skip = False for skip in skips: for pod in self.pods: pod.exporters.fire_skip(skip.date, skip.reason, skip.ordered) if skip.ordered: + ordered_in_skip = True + for pod in self.pods: + pod.fire_snapshot(date, None) self.order(skip.date, price_date=date) - self.update_price(date) + if not ordered_in_skip: + for pod in self.pods: + pod.fire_snapshot(date, None) if ordered: self.order(date) @@ -272,6 +287,7 @@ def __init__( allow_weekends=False, allow_holidays=False, holiday_provider: HolidayProvider = LegacyHolidayProvider(), + fixed_nav: bool = False, ): self.order_provider = order_provider order_dates = order_provider.get_dates() @@ -279,12 +295,15 @@ def __init__( self.price_provider = PriceProvider(start, end, data_source, mapper, caching=caching) + exporter_collection = ExporterCollection(exporters) + exporter_collection.configure(fixed_nav=fixed_nav) self.pod = _Pod( quantity_in_decimal, auto_close_others, self.price_provider, Account(initial_cash=initial_cash, fee_model=fee_model), - ExporterCollection(exporters) + exporter_collection, + fixed_nav=fixed_nav, ) self.date_iterator = DateIterator( @@ -302,7 +321,7 @@ def update_price(self, date): price = self.price_provider.get(date, holding.symbol) if price is None: - print(f"[warning] price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr) + print(f"[warning] {date}: price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr) holding.up_to_date = False else: holding.price = price @@ -325,15 +344,20 @@ def run(self): self.exporters.fire_initialize() for date, ordered, skips in self.date_iterator: + self.update_price(date) + + ordered_in_skip = False for skip in skips: self.exporters.fire_skip(skip.date, skip.reason, skip.ordered) if skip.ordered: + ordered_in_skip = True + self.exporters.fire_snapshot(date, self.account, None) result = self.order(skip.date, price_date=date) - self.exporters.fire_snapshot(date, self.account, result, postponned=skip.date) + self.exporters.fire_snapshot(date, self.account, result) - self.update_price(date) - self.exporters.fire_snapshot(date, self.account, None) + if not ordered_in_skip: + self.exporters.fire_snapshot(date, self.account, None) if ordered: result = self.order(date) diff --git a/bktest/cli.py b/bktest/cli.py index a3b4bb6..2b60888 100644 --- a/bktest/cli.py +++ b/bktest/cli.py @@ -41,6 +41,7 @@ @click.option('--weekends', is_flag=True, help="Include weekends?") @click.option('--holidays', is_flag=True, help="Include holidays?") @click.option('--symbol-mapping', type=str, required=False, help="Custom symbol mapping file enabling vendor-id translation.") +@click.option('--fixed-nav', is_flag=True, help="Use initial cash for position sizing instead of current equity (disables compounding).") @click.option('--no-caching', is_flag=True, help="Disable price caching.") @click.option('--fee-model', "fee_model_value", type=str, help="Specify a fee model. Must be a constant or an expression.") # @@ -117,6 +118,7 @@ def main( weekends, holidays, symbol_mapping, + fixed_nav, no_caching, fee_model_value, # @@ -145,6 +147,9 @@ def main( if auto_close_others: print("[warning] `--auto-close-others` is deprecated and is forced to `true`", file=sys.stderr) + if fixed_nav and quantity_mode != "percent": + print("[warning] `--fixed-nav` has no effect when `--quantity-mode` is not `percent`", file=sys.stderr) + now = datetime.date.today() quantity_in_decimal = quantity_mode == "percent" @@ -277,6 +282,7 @@ def main( csv_output_file=quantstats_output_file_csv, benchmark_ticker=quantstats_benchmark_ticker, auto_delete=quantstats_auto_delete, + fixed_nav=fixed_nav, )) if specific_return: @@ -336,7 +342,8 @@ def main( caching=not no_caching, allow_weekends=weekends, allow_holidays=holidays, - holiday_provider=holiday_provider + holiday_provider=holiday_provider, + fixed_nav=fixed_nav, ).run() diff --git a/bktest/export/base.py b/bktest/export/base.py index 772ed2e..4dcc18e 100644 --- a/bktest/export/base.py +++ b/bktest/export/base.py @@ -18,6 +18,9 @@ def on_snapshot(self, snapshot: Snapshot) -> None: def finalize(self) -> None: pass + def configure(self, fixed_nav: bool) -> None: + pass + class ExporterCollection: @@ -44,12 +47,15 @@ def fire_skip( for exporter in self.elements: exporter.on_skip(date, reason, ordered) + def configure(self, fixed_nav: bool) -> None: + for exporter in self.elements: + exporter.configure(fixed_nav) + def fire_snapshot( self, date: datetime.date, account: "Account", result: "OrderResultCollection", - postponned=None ): cash = float(account.cash) equity = float(account.equity) @@ -58,7 +64,6 @@ def fire_snapshot( snapshot = Snapshot( date=date, - postponned=postponned, cash=cash, equity=equity, holdings=holdings, diff --git a/bktest/export/console.py b/bktest/export/console.py index ee4bce5..a9bae1d 100644 --- a/bktest/export/console.py +++ b/bktest/export/console.py @@ -12,6 +12,24 @@ class ConsoleDelegate(Exporter): def __init__(self, file): self.file = file + # Track equity across snapshots to compute period return at rebalance time. + # _last_non_ordered_equity: equity just before orders execute (pre-reset). + # _last_ordered_equity: equity just after the previous rebalance (post-reset). + self._last_non_ordered_equity: typing.Optional[float] = None + self._last_ordered_equity: typing.Optional[float] = None + + def _period_return(self) -> typing.Optional[float]: + """Return (pre_reset / prev_post_reset - 1), or None on the first rebalance.""" + if self._last_non_ordered_equity is None or self._last_ordered_equity is None: + return None + return self._last_non_ordered_equity / self._last_ordered_equity - 1 + + def _track(self, snapshot: Snapshot) -> None: + """Update tracking state. Must be called at the end of on_snapshot.""" + if snapshot.ordered: + self._last_ordered_equity = snapshot.equity + else: + self._last_non_ordered_equity = snapshot.equity def _print(self, content): print(content, file=self.file) @@ -63,6 +81,11 @@ def on_snapshot(self, snapshot: Snapshot) -> None: line = f"{date} ({day}) {ordered_color}{ordered_string:20}{self.color_reset} [equity={equity:12.4f}]" if snapshot.ordered: + period_return = self._period_return() + if period_return is not None: + ret_color = self.color_green if period_return >= 0 else self.color_red + line += f" [return={ret_color}{period_return:+.2%}{self.color_reset}]" + holding_count = snapshot.holding_count line += f" [portfolio={holding_count:4}]" @@ -79,15 +102,13 @@ def on_snapshot(self, snapshot: Snapshot) -> None: closed_total = snapshot.closed_total line += f" [closed={closed_count}/{closed_total}]" + self._track(snapshot) self._print(line) def _ordered_to_string(self, snapshot: Snapshot): if snapshot.ordered: out = "ordered" - if snapshot.postponned is not None: - out += f" ({snapshot.postponned})" - return out return "price updated" @@ -128,22 +149,26 @@ def on_skip(self, date: datetime.date, reason: str, ordered: bool) -> None: def on_snapshot(self, snapshot: Snapshot) -> None: self._coma() + period_return = self._period_return() if snapshot.ordered else None + self._print_json({ "event": "SNAPSHOT", "date": str(snapshot.date), "ordered": snapshot.ordered, "cash": snapshot.cash, "equity": snapshot.equity, - "postponned": str(snapshot.postponned) if snapshot.postponned else None, "totalFees": snapshot.total_fees, "successCount": snapshot.success_count, "failedCount": snapshot.failed_count, + "periodReturn": period_return, "closed": { "count": snapshot.closed_count, "total": snapshot.closed_total } }) + self._track(snapshot) + @abc.abstractmethod def finalize(self) -> None: self._print("]") diff --git a/bktest/export/dump.py b/bktest/export/dump.py index 09593e2..758fc4f 100644 --- a/bktest/export/dump.py +++ b/bktest/export/dump.py @@ -62,9 +62,6 @@ def on_snapshot(self, snapshot: Snapshot) -> None: date = snapshot.date self.all_dates.add(date) - if snapshot.postponned is not None: - date = snapshot.postponned - common = [ snapshot.equity, float(snapshot.ordered), diff --git a/bktest/export/model.py b/bktest/export/model.py index 417818c..08974fb 100644 --- a/bktest/export/model.py +++ b/bktest/export/model.py @@ -7,25 +7,20 @@ class Snapshot: date: datetime.date - postponned: typing.Optional[datetime.date] cash: float equity: float holdings: typing.List["Holding"] ordered: bool - + # Only when ordered total_fees: float = 0.0 success_count: int = 0 failed_count: int = 0 - + # None if `--auto-close` is not specified closed_count: int = None closed_total: int = None - + @property def holding_count(self) -> int: return len(self.holdings) - - @property - def real_date(self) -> datetime.date: - return self.postponned if self.postponned else self.date diff --git a/bktest/export/quants.py b/bktest/export/quants.py index 3db8964..593e5d2 100644 --- a/bktest/export/quants.py +++ b/bktest/export/quants.py @@ -19,14 +19,20 @@ def __init__( csv_output_file='report.csv', benchmark_ticker="SPY", auto_delete=False, - auto_override=False + auto_override=False, + fixed_nav=False, ): self.html_output_file = html_output_file self.csv_output_file = csv_output_file self.benchmark_ticker = benchmark_ticker self.auto_delete = auto_delete self.auto_override = auto_override + self.fixed_nav = fixed_nav + # Each row: (date, pre_reset_equity, post_reset_equity). + # For non-rebalance days pre == post. For fixed_nav rebalance days + # pre_reset is the true market equity before cash extraction, and + # post_reset is initial_cash (the denominator for the next day). self.rows = [] warnings.filterwarnings( @@ -35,6 +41,9 @@ def __init__( module=seaborn.__name__ ) + def configure(self, fixed_nav: bool) -> None: + self.fixed_nav = fixed_nav + @abc.abstractmethod def initialize(self) -> None: if self.auto_override: @@ -54,24 +63,45 @@ def initialize(self) -> None: @abc.abstractmethod def on_snapshot(self, snapshot: Snapshot) -> None: + date = snapshot.date + if snapshot.ordered: + if self.fixed_nav and self.rows and self.rows[-1][0] == date: + # Rebalance day: subtract fees paid this period from pre_reset so + # that the return = (market_gain - fees) / initial_cash. + # post_reset is initial_cash (the denominator for the next period). + date_, pre_reset, _ = self.rows[-1] + self.rows[-1] = (date_, pre_reset - snapshot.total_fees, snapshot.equity) return - - date = snapshot.date - if snapshot.postponned is not None: - date = snapshot.postponned - self.rows.append( - (date, snapshot.equity) - ) + # Non-ordered snapshot: pre_reset == post_reset (no cash extraction yet). + self.rows.append((date, snapshot.equity, snapshot.equity)) @abc.abstractmethod def finalize(self) -> None: - self.dataframe = pandas.DataFrame( + df = pandas.DataFrame( self.rows, - columns=["date", "equity"] + columns=["date", "pre_reset_equity", "post_reset_equity"] ).set_index("date") + # Expose a single 'equity' column (post-reset) for backward compatibility. + df["equity"] = df["post_reset_equity"] + + # Return on any day = (pre_reset - prev_post_reset) / prev_post_reset. + # - Non-rebalance days: pre == post, so this is the standard formula. + # - Rebalance days: pre_reset captures the true market move (net of fees); + # post_reset is always initial_cash (the cash-reset formula guarantees + # equity == initial_cash after every rebalance), becoming tomorrow's + # denominator and preventing compounding. On loss weeks capital is + # injected to restore the NAV — the negative return is still recorded + # correctly. Consecutive rebalance days are handled automatically + # because shift(1) always reads post_reset of the previous row. + df["daily_profit_pct"] = ( + df["pre_reset_equity"] / df["post_reset_equity"].shift(1) - 1 + ) + + self.dataframe = df + if not len(self.dataframe): print( "[warning] cannot create tearsheet: dataframe is empty", @@ -82,14 +112,6 @@ def finalize(self) -> None: history_df = self.dataframe.copy() - history_df['profit'] = history_df['equity'] - \ - history_df['equity'].shift(1) - history_df['daily_profit_pct'] = history_df["profit"] / \ - history_df["equity"].shift(1) - - # history_df['profit'].fillna(0, inplace=True) - # history_df['daily_profit_pct'].fillna(0, inplace=True) - history_df.reset_index(inplace=True) history_df['date'] = history_df['date'].astype(str) diff --git a/requirements.txt b/requirements.txt index 6d746c6..d24fec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,13 +5,13 @@ tqdm >=4.48.0, <4.65.0 numpy >=1.23.0, <1.25.0 py_expression_eval >=0.3.9, <0.3.14 pyarrow >=11.0, <12.0 -quantstats ==0.0.77 +quantstats ==0.0.81 pytest >=7.1.0, <7.3.0 yfinance >=0.2.54 python-dotenv >=0.20, <1.0.0 colorama >=0.4.4, <0.4.6 ipython==8.15.0 -seaborn>=0.12.0, <0.13.0 +seaborn>=0.13.0 python-slugify cached-property fpdf2==2.7.4 diff --git a/tests/test_fixed_nav.py b/tests/test_fixed_nav.py new file mode 100644 index 0000000..554f090 --- /dev/null +++ b/tests/test_fixed_nav.py @@ -0,0 +1,359 @@ +import datetime +import unittest + +import pandas + +import bktest +from bktest.account import Account +from bktest.backtest import _Pod +from bktest.data.source import DataFrameDataSource +from bktest.export import ExporterCollection +from bktest.export.quants import QuantStatsExporter +from bktest.export.model import Snapshot +from bktest.fee import ConstantFeeModel +from bktest.price_provider import PriceProvider + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_price_df(prices_by_date): + """Build a prices DataFrame from {date: {symbol: price}} dict.""" + rows = [] + for date, symbols in prices_by_date.items(): + for symbol, price in symbols.items(): + rows.append({"date": date, "symbol": symbol, "price": price}) + return pandas.DataFrame(rows) + + +def _make_pod(initial_cash, fixed_nav, quantity_in_decimal, prices_df, start, end, + fee_model=None): + """Create a _Pod backed by a DataFrameDataSource. + + Pass fee_model to test fee-related behaviour; defaults to zero fees. + """ + data_source = DataFrameDataSource(prices_df) + price_provider = PriceProvider(start, end, data_source, mapper=None, caching=False) + account = Account(initial_cash=initial_cash, fee_model=fee_model) if fee_model \ + else Account(initial_cash=initial_cash) + return _Pod( + quantity_in_decimal=quantity_in_decimal, + auto_close_others=True, + price_provider=price_provider, + account=account, + exporters=ExporterCollection([]), + fixed_nav=fixed_nav, + ) + + +def _make_exporter(fixed_nav=True): + """Create a QuantStatsExporter with no file outputs and no benchmark.""" + return QuantStatsExporter( + html_output_file=None, + csv_output_file=None, + benchmark_ticker=None, + fixed_nav=fixed_nav, + ) + + +def _feed_snapshots(exporter, *specs): + """Feed (date, equity, ordered[, total_fees]) tuples into the exporter.""" + for spec in specs: + date, equity, ordered = spec[0], spec[1], spec[2] + total_fees = spec[3] if len(spec) > 3 else 0.0 + exporter.on_snapshot(Snapshot( + date=date, cash=0.0, equity=equity, + holdings=[], ordered=ordered, total_fees=total_fees, + )) + + +D0 = datetime.date(2024, 1, 1) +D1 = datetime.date(2024, 1, 2) +D2 = datetime.date(2024, 1, 3) + + +# --------------------------------------------------------------------------- +# Engine behaviour +# --------------------------------------------------------------------------- + +class FixedNavTest(unittest.TestCase): + + def test_sizing_uses_initial_cash_not_current_equity(self): + """fixed_nav=True pins position sizing to initial_cash so that profits + do not grow future positions (no compounding). After each rebalance the + cash reset forces equity back to initial_cash.""" + prices_df = _make_price_df({D0: {"AAPL": 10.0}, D1: {"AAPL": 20.0}}) + pod = _make_pod(100_000, fixed_nav=True, quantity_in_decimal=True, + prices_df=prices_df, start=D0, end=D1) + + # First rebalance at $10: 50% of 100_000 → 5 000 shares + pod.order(D0, [bktest.Order("AAPL", 0.5)]) + self.assertEqual(5_000, pod.account.find_holding("AAPL").quantity) + self.assertAlmostEqual(100_000, pod.account.equity) + + # Price doubles → equity rises to 150 000 before the second rebalance + pod.account.find_holding("AAPL").price = 20.0 + self.assertGreater(pod.account.equity, 100_000) + + # Second rebalance at $20: still 50% of 100_000 → 2 500 shares (not 3 750) + # The excess cash is extracted; equity is reset to initial_cash. + pod.order(D1, [bktest.Order("AAPL", 0.5)]) + self.assertEqual(2_500, pod.account.find_holding("AAPL").quantity) + self.assertAlmostEqual(100_000, pod.account.equity) + + def test_compounding_uses_current_equity_when_fixed_nav_false(self): + """Without fixed_nav the default behaviour compounds: a price gain grows + the equity base and therefore grows positions at the next rebalance.""" + prices_df = _make_price_df({D0: {"AAPL": 10.0}, D1: {"AAPL": 20.0}}) + pod = _make_pod(100_000, fixed_nav=False, quantity_in_decimal=True, + prices_df=prices_df, start=D0, end=D1) + + pod.order(D0, [bktest.Order("AAPL", 0.5)]) + self.assertEqual(5_000, pod.account.find_holding("AAPL").quantity) + + pod.account.find_holding("AAPL").price = 20.0 # equity → 150 000 + + # Second rebalance: 50% of 150 000 at $20 → 3 750 shares (compounded) + pod.order(D1, [bktest.Order("AAPL", 0.5)]) + self.assertEqual(3_750, pod.account.find_holding("AAPL").quantity) + + def test_no_effect_in_share_mode(self): + """fixed_nav has no effect when quantity_in_decimal=False because the + caller provides absolute share counts, not percentages of equity.""" + prices_df = _make_price_df({D0: {"AAPL": 10.0}}) + pod_fixed = _make_pod(100_000, fixed_nav=True, quantity_in_decimal=False, + prices_df=prices_df, start=D0, end=D0) + pod_default = _make_pod(100_000, fixed_nav=False, quantity_in_decimal=False, + prices_df=prices_df, start=D0, end=D0) + + pod_fixed.order(D0, [bktest.Order("AAPL", 100, 10.0)]) + pod_default.order(D0, [bktest.Order("AAPL", 100, 10.0)]) + + self.assertEqual(100, pod_fixed.account.find_holding("AAPL").quantity) + self.assertEqual( + pod_fixed.account.find_holding("AAPL").quantity, + pod_default.account.find_holding("AAPL").quantity, + ) + + +# --------------------------------------------------------------------------- +# Long-short portfolio (primary use case) +# --------------------------------------------------------------------------- + +class LongShortFixedNavTest(unittest.TestCase): + """Validates fixed_nav with dollar-neutral long-short portfolios.""" + + def test_dollar_neutral_equity_resets_after_price_move(self): + """For a 100% long / 100% short portfolio, a gain on the long side + raises equity above initial_cash. The cash reset brings it back to + initial_cash after the next rebalance, extracting the profit.""" + prices_df = _make_price_df({ + D0: {"AAPL": 10.0, "MSFT": 10.0}, + D1: {"AAPL": 12.0, "MSFT": 10.0}, # AAPL up 20 %, MSFT flat + }) + pod = _make_pod(100_000, fixed_nav=True, quantity_in_decimal=True, + prices_df=prices_df, start=D0, end=D1) + + # Rebalance 1 at $10: 10 000 long AAPL, 10 000 short MSFT + pod.order(D0, [bktest.Order("AAPL", 1.0), bktest.Order("MSFT", -1.0)]) + self.assertEqual( 10_000, pod.account.find_holding("AAPL").quantity) + self.assertEqual(-10_000, pod.account.find_holding("MSFT").quantity) + self.assertAlmostEqual(100_000, pod.account.equity) + + # AAPL rises: net value = 20 000 → equity > initial_cash + pod.account.find_holding("AAPL").price = 12.0 + self.assertGreater(pod.account.equity, 100_000) + + # Rebalance 2: positions shrink to fit initial_cash; equity resets + pod.order(D1, [bktest.Order("AAPL", 1.0), bktest.Order("MSFT", -1.0)]) + self.assertEqual( 8_333, pod.account.find_holding("AAPL").quantity) + self.assertEqual(-10_000, pod.account.find_holding("MSFT").quantity) + self.assertAlmostEqual(100_000, pod.account.equity) + + def test_leverage_4_sizing_and_equity_reset(self): + """With leverage 4 (200% long / 200% short), sizing always references + initial_cash even when both sides move in our favour (equity > initial_cash). + Profits are extracted via the cash reset.""" + prices_df = _make_price_df({ + D0: {"AAPL": 10.0, "MSFT": 10.0}, + D1: {"AAPL": 11.0, "MSFT": 9.0}, # long gains, short gains + }) + pod = _make_pod(100_000, fixed_nav=True, quantity_in_decimal=True, + prices_df=prices_df, start=D0, end=D1) + + # Rebalance 1: 200% of 100_000 → 20 000 shares each side + pod.order(D0, [bktest.Order("AAPL", 2.0), bktest.Order("MSFT", -2.0)]) + self.assertEqual( 20_000, pod.account.find_holding("AAPL").quantity) + self.assertEqual(-20_000, pod.account.find_holding("MSFT").quantity) + self.assertAlmostEqual(100_000, pod.account.equity) + + pod.account.find_holding("AAPL").price = 11.0 + pod.account.find_holding("MSFT").price = 9.0 + self.assertGreater(pod.account.equity, 100_000) + + # Rebalance 2: sizes based on initial_cash at new prices, not inflated equity + # AAPL: int(200_000 / 11) = 18 181 MSFT: int(-200_000 / 9) = -22 222 + pod.order(D1, [bktest.Order("AAPL", 2.0), bktest.Order("MSFT", -2.0)]) + self.assertEqual( 18_181, pod.account.find_holding("AAPL").quantity) + self.assertEqual(-22_222, pod.account.find_holding("MSFT").quantity) + self.assertAlmostEqual(100_000, pod.account.equity) + + +# --------------------------------------------------------------------------- +# Fee handling +# --------------------------------------------------------------------------- + +class FeePreservationTest(unittest.TestCase): + + def test_equity_resets_to_initial_cash_regardless_of_fees(self): + """The cash reset always restores equity to initial_cash even when fees + are charged — fees are attributed to the period return, not permanently + deducted from the NAV baseline.""" + prices_df = _make_price_df({D0: {"AAPL": 10.0}, D1: {"AAPL": 20.0}}) + pod = _make_pod(100_000, fixed_nav=True, quantity_in_decimal=True, + prices_df=prices_df, start=D0, end=D1, + fee_model=ConstantFeeModel(500.0)) + + pod.order(D0, [bktest.Order("AAPL", 0.5)]) + self.assertAlmostEqual(100_000, pod.account.equity) + + pod.account.find_holding("AAPL").price = 20.0 + pod.order(D1, [bktest.Order("AAPL", 0.5)]) + self.assertAlmostEqual(100_000, pod.account.equity) + + def test_fees_reduce_period_return(self): + """Fees paid during a rebalance are subtracted from pre_reset_equity so + that period return = (market_gain − fees) / initial_cash, not market_gain + / initial_cash.""" + exporter = _make_exporter() + _feed_snapshots(exporter, (D0, 100_000.0, False)) + # Market grew to 110 000; fees = 2 000; equity reset to 100 000 + _feed_snapshots(exporter, (D1, 110_000.0, False), (D1, 100_000.0, True, 2_000.0)) + exporter.finalize() + + # (110 000 − 2 000) / 100 000 − 1 = 8 % (not 10 %) + self.assertAlmostEqual(0.08, exporter.returns.loc[pandas.Timestamp(D1)], places=5) + + +# --------------------------------------------------------------------------- +# QuantStatsExporter row format and snapshot protocol +# --------------------------------------------------------------------------- + +class QuantStatsExporterTest(unittest.TestCase): + + def test_rebalance_row_stores_pre_and_post_reset(self): + """Each row is a 3-tuple (date, pre_reset_equity, post_reset_equity). + A non-rebalance row has pre == post. On a rebalance day the non-ordered + snapshot writes pre == post first; the ordered snapshot updates post_reset + to initial_cash in place.""" + exporter = _make_exporter() + _feed_snapshots(exporter, (D0, 105_000.0, False)) + self.assertEqual((D0, 105_000.0, 105_000.0), exporter.rows[-1]) + + _feed_snapshots(exporter, (D1, 110_000.0, False)) + self.assertEqual((D1, 110_000.0, 110_000.0), exporter.rows[-1]) + + _feed_snapshots(exporter, (D1, 100_000.0, True)) + self.assertEqual((D1, 110_000.0, 100_000.0), exporter.rows[-1]) + + def test_ordered_snapshot_ignored_without_fixed_nav(self): + """With fixed_nav=False ordered snapshots are silently ignored, preserving + the original pre-order equity in the row (backward-compatible behaviour).""" + exporter = _make_exporter(fixed_nav=False) + _feed_snapshots(exporter, (D0, 150_000.0, False)) + _feed_snapshots(exporter, (D0, 100_000.0, True)) + self.assertEqual([(D0, 150_000.0, 150_000.0)], exporter.rows) + + def test_skip_day_snapshot_updates_correct_row(self): + """When a rebalance is deferred to the next trading day (skip), snapshots + use the effective trading date. The non-ordered snapshot creates the row; + the ordered snapshot finds it by date and updates post_reset.""" + effective_date = datetime.date(2024, 1, 1) + prev_date = datetime.date(2023, 12, 29) + exporter = _make_exporter() + + exporter.on_snapshot(Snapshot( + date=prev_date, cash=0.0, equity=105_000.0, + holdings=[], ordered=False, + )) + exporter.on_snapshot(Snapshot( + date=effective_date, cash=0.0, equity=112_000.0, + holdings=[], ordered=False, + )) + self.assertEqual((effective_date, 112_000.0, 112_000.0), exporter.rows[-1]) + + exporter.on_snapshot(Snapshot( + date=effective_date, cash=0.0, equity=100_000.0, + holdings=[], ordered=True, + )) + self.assertEqual((effective_date, 112_000.0, 100_000.0), exporter.rows[-1]) + + +# --------------------------------------------------------------------------- +# Return series +# --------------------------------------------------------------------------- + +class ReturnCalculationTest(unittest.TestCase): + + def test_return_uses_pre_reset_equity(self): + """On a rebalance day the return must be (pre_reset − prev_post_reset) / + prev_post_reset. Using post_reset would give 0 % every rebalance day + since post_reset == initial_cash == prev post_reset.""" + exporter = _make_exporter() + _feed_snapshots(exporter, (D0, 100_000.0, False)) + _feed_snapshots(exporter, (D1, 110_000.0, False), (D1, 100_000.0, True)) + _feed_snapshots(exporter, (D2, 101_000.0, False)) + exporter.finalize() + + ts = lambda d: pandas.Timestamp(d) + self.assertAlmostEqual(0.10, exporter.returns.loc[ts(D1)], places=5) # 10 % + self.assertAlmostEqual(0.01, exporter.returns.loc[ts(D2)], places=5) # 1 % + + def test_consecutive_rebalances_chain_post_reset_denominators(self): + """When two consecutive days are both rebalance days, the second day's + denominator must be the first day's post_reset (initial_cash), not its + pre_reset. shift(1) on post_reset_equity handles this automatically.""" + exporter = _make_exporter() + _feed_snapshots(exporter, (D0, 100_000.0, False)) + _feed_snapshots(exporter, (D1, 105_000.0, False), (D1, 100_000.0, True)) + _feed_snapshots(exporter, (D2, 107_000.0, False), (D2, 100_000.0, True)) + exporter.finalize() + + ts = lambda d: pandas.Timestamp(d) + self.assertAlmostEqual(0.05, exporter.returns.loc[ts(D1)], places=5) # 5 % + # Denominator is D1's post_reset (100k), NOT its pre_reset (105k) + self.assertAlmostEqual(0.07, exporter.returns.loc[ts(D2)], places=5) # 7 % + + def test_dataframe_exposes_pre_and_post_reset_columns(self): + """After finalize(), dataframe['equity'] is post_reset (the NAV baseline), + dataframe['pre_reset_equity'] is the true market value before extraction, + and dataframe['daily_profit_pct'] is the net-of-extraction return.""" + exporter = _make_exporter() + _feed_snapshots(exporter, (D0, 100_000.0, False)) + _feed_snapshots(exporter, (D1, 110_000.0, False), (D1, 100_000.0, True)) + exporter.finalize() + + self.assertAlmostEqual(100_000.0, exporter.dataframe.loc[D1, 'equity']) + self.assertAlmostEqual(110_000.0, exporter.dataframe.loc[D1, 'pre_reset_equity']) + self.assertAlmostEqual(0.10, exporter.dataframe.loc[D1, 'daily_profit_pct'], places=5) + + +# --------------------------------------------------------------------------- +# configure() propagation +# --------------------------------------------------------------------------- + +class ConfigurePropagationTest(unittest.TestCase): + + def test_configure_overrides_fixed_nav_on_children(self): + """ExporterCollection.configure() is authoritative: it overrides whatever + fixed_nav each child exporter was constructed with, in both directions.""" + exporter = QuantStatsExporter(html_output_file=None, csv_output_file=None, + fixed_nav=False) + collection = ExporterCollection([exporter]) + + collection.configure(fixed_nav=True) + self.assertTrue(exporter.fixed_nav) + + collection.configure(fixed_nav=False) + self.assertFalse(exporter.fixed_nav) diff --git a/tests/test_fixed_nav_integration.py b/tests/test_fixed_nav_integration.py new file mode 100644 index 0000000..0c618bc --- /dev/null +++ b/tests/test_fixed_nav_integration.py @@ -0,0 +1,251 @@ +"""End-to-end integration test for fixed_nav with leverage 4 (2x long, 2x short). + +Simulates a full SimpleBacktester run with 3 weekly rebalancings on Monday close. +Weekly return period is Tuesday through the next Monday (included). + +Verifies that: +- positions are sized from initial_cash (not drifting equity) +- equity resets to initial_cash after each rebalance +- daily returns reflect true market P&L net of fees +- old positions are auto-closed when the portfolio changes +""" +import datetime +import math +import unittest + +import pandas + +from bktest.backtest import SimpleBacktester +from bktest.data.source import DataFrameDataSource +from bktest.export.quants import QuantStatsExporter +from bktest.fee import ExpressionFeeModel +from bktest.order import DataFrameOrderProvider + +INITIAL_CASH = 100_000 +FEE_RATE = 0.001 # 10 bps + +# 4 weeks of weekdays: Mon Jan 8 through Fri Feb 2 +# Rebal on MON1, MON2, MON3. Prices through MON4 to capture week 3 returns. +MON1 = datetime.date(2024, 1, 8) +TUE1 = datetime.date(2024, 1, 9) +WED1 = datetime.date(2024, 1, 10) +THU1 = datetime.date(2024, 1, 11) +FRI1 = datetime.date(2024, 1, 12) +MON2 = datetime.date(2024, 1, 15) +TUE2 = datetime.date(2024, 1, 16) +WED2 = datetime.date(2024, 1, 17) +THU2 = datetime.date(2024, 1, 18) +FRI2 = datetime.date(2024, 1, 19) +MON3 = datetime.date(2024, 1, 22) +TUE3 = datetime.date(2024, 1, 23) +WED3 = datetime.date(2024, 1, 24) +THU3 = datetime.date(2024, 1, 25) +FRI3 = datetime.date(2024, 1, 26) +MON4 = datetime.date(2024, 1, 29) + +# Daily prices for 6 stocks +PRICES = { + MON1: {"AAPL": 50, "MSFT": 40, "GOOG": 100, "META": 80, "NVDA": 25, "AMZN": 60}, + TUE1: {"AAPL": 52, "MSFT": 38, "GOOG": 102, "META": 79, "NVDA": 26, "AMZN": 59}, + WED1: {"AAPL": 51, "MSFT": 39, "GOOG": 105, "META": 78, "NVDA": 27, "AMZN": 58}, + THU1: {"AAPL": 53, "MSFT": 37, "GOOG": 103, "META": 81, "NVDA": 28, "AMZN": 57}, + FRI1: {"AAPL": 54, "MSFT": 36, "GOOG": 108, "META": 82, "NVDA": 29, "AMZN": 56}, + MON2: {"AAPL": 55, "MSFT": 35, "GOOG": 110, "META": 83, "NVDA": 30, "AMZN": 55}, + TUE2: {"AAPL": 56, "MSFT": 34, "GOOG": 112, "META": 84, "NVDA": 31, "AMZN": 54}, + WED2: {"AAPL": 54, "MSFT": 36, "GOOG": 109, "META": 82, "NVDA": 29, "AMZN": 56}, + THU2: {"AAPL": 57, "MSFT": 33, "GOOG": 111, "META": 85, "NVDA": 32, "AMZN": 53}, + FRI2: {"AAPL": 58, "MSFT": 32, "GOOG": 113, "META": 86, "NVDA": 33, "AMZN": 52}, + MON3: {"AAPL": 59, "MSFT": 31, "GOOG": 115, "META": 87, "NVDA": 34, "AMZN": 51}, + TUE3: {"AAPL": 60, "MSFT": 30, "GOOG": 114, "META": 88, "NVDA": 35, "AMZN": 50}, + WED3: {"AAPL": 58, "MSFT": 32, "GOOG": 116, "META": 86, "NVDA": 33, "AMZN": 52}, + THU3: {"AAPL": 61, "MSFT": 29, "GOOG": 118, "META": 89, "NVDA": 36, "AMZN": 49}, + FRI3: {"AAPL": 62, "MSFT": 28, "GOOG": 120, "META": 90, "NVDA": 37, "AMZN": 48}, + MON4: {"AAPL": 63, "MSFT": 27, "GOOG": 122, "META": 91, "NVDA": 38, "AMZN": 47}, +} + +# 3 rebalancings on Monday close, each with a different portfolio +# Week 1 return: TUE1 → MON2 (rebal 2 happens at MON2 close) +# Week 2 return: TUE2 → MON3 (rebal 3 happens at MON3 close) +# Week 3 return: TUE3 → MON4 +ORDERS = [ + # Rebal 1 (MON1): AAPL +2.0, MSFT -2.0 + {"date": MON1, "symbol": "AAPL", "quantity": 2.0}, + {"date": MON1, "symbol": "MSFT", "quantity": -2.0}, + # Rebal 2 (MON2): GOOG +0.8, NVDA +0.7, META +0.5, AMZN -1.2, AAPL -0.8 + {"date": MON2, "symbol": "GOOG", "quantity": 0.8}, + {"date": MON2, "symbol": "NVDA", "quantity": 0.7}, + {"date": MON2, "symbol": "META", "quantity": 0.5}, + {"date": MON2, "symbol": "AMZN", "quantity": -1.2}, + {"date": MON2, "symbol": "AAPL", "quantity": -0.8}, + # Rebal 3 (MON3): MSFT +1.0, NVDA +0.6, META +0.4, GOOG -1.5, AMZN -0.5 + {"date": MON3, "symbol": "MSFT", "quantity": 1.0}, + {"date": MON3, "symbol": "NVDA", "quantity": 0.6}, + {"date": MON3, "symbol": "META", "quantity": 0.4}, + {"date": MON3, "symbol": "GOOG", "quantity": -1.5}, + {"date": MON3, "symbol": "AMZN", "quantity": -0.5}, +] + + +def _target(percent, price): + return int(INITIAL_CASH * percent / price) + + +def _fee(delta, price): + return abs(price * delta) * FEE_RATE + + +def _value(holdings, prices): + """Compute portfolio value: sum(shares * price) for each symbol.""" + return sum(shares * prices[symbol] for symbol, shares in holdings.items()) + + +class FixedNavIntegrationTest(unittest.TestCase): + + def setUp(self): + prices_df = pandas.DataFrame([ + {"date": d, "symbol": s, "price": float(p)} + for d, symbols in PRICES.items() + for s, p in symbols.items() + ]) + + self.exporter = QuantStatsExporter( + html_output_file=None, + csv_output_file=None, + benchmark_ticker=None, + ) + + self.bt = SimpleBacktester( + start=MON1, + end=MON4, + order_provider=DataFrameOrderProvider(pandas.DataFrame(ORDERS)), + initial_cash=INITIAL_CASH, + quantity_in_decimal=True, + data_source=DataFrameDataSource(prices_df), + auto_close_others=True, + exporters=[self.exporter], + fee_model=ExpressionFeeModel(f"abs(price * quantity) * {FEE_RATE}"), + caching=False, + fixed_nav=True, + allow_holidays=True, + ) + + self.bt.run() + + def test_full_scenario(self): + df = self.exporter.dataframe + returns = df["daily_profit_pct"] + + # ================================================================== + # Rebal 1 (MON1 close): AAPL +2.0, MSFT -2.0 + # ================================================================== + r1 = { + "AAPL": _target(2.0, 50), # int(200000/50) = 4000 + "MSFT": _target(-2.0, 40), # int(-200000/40) = -5000 + } + self.assertEqual(4000, r1["AAPL"]) + self.assertEqual(-5000, r1["MSFT"]) + + val_r1 = _value(r1, PRICES[MON1]) + cash_r1 = INITIAL_CASH - val_r1 + + # MON1: first day, NaN return + self.assertTrue(math.isnan(returns.loc[MON1])) + self.assertAlmostEqual(INITIAL_CASH, df.loc[MON1, "post_reset_equity"]) + + # Week 1 daily returns: TUE1 through MON2 (included) + prev_post = INITIAL_CASH + for date in (TUE1, WED1, THU1, FRI1): + equity = cash_r1 + _value(r1, PRICES[date]) + expected = equity / prev_post - 1 + self.assertAlmostEqual(expected, returns.loc[date], places=5, + msg=f"return mismatch on {date}") + prev_post = equity # non-rebalance: pre == post + + # MON2 closes week 1: price update gives pre_reset, then rebal fires + mon2_equity = cash_r1 + _value(r1, PRICES[MON2]) + + # ================================================================== + # Rebal 2 (MON2 close): 5 positions + # ================================================================== + r2 = { + "GOOG": _target(0.8, 110), # int(80000/110) = 727 + "NVDA": _target(0.7, 30), # int(70000/30) = 2333 + "META": _target(0.5, 83), # int(50000/83) = 602 + "AMZN": _target(-1.2, 55), # int(-120000/55) = -2181 + "AAPL": _target(-0.8, 55), # int(-80000/55) = -1454 + } + + # Fees: deltas from r1 to r2 + fees_r2 = ( + _fee(r2["AAPL"] - r1["AAPL"], 55) + # AAPL: 4000 → -1454 + _fee(0 - r1["MSFT"], 35) + # MSFT: -5000 → 0 (closed) + _fee(r2["GOOG"], 110) + # GOOG: 0 → 727 + _fee(r2["NVDA"], 30) + # NVDA: 0 → 2333 + _fee(r2["META"], 83) + # META: 0 → 602 + _fee(r2["AMZN"], 55) # AMZN: 0 → -2181 + ) + + # MON2 return includes week 1 P&L net of rebal fees + expected_mon2 = (mon2_equity - fees_r2) / prev_post - 1 + self.assertAlmostEqual(expected_mon2, returns.loc[MON2], places=5) + self.assertAlmostEqual(INITIAL_CASH, df.loc[MON2, "post_reset_equity"]) + + val_r2 = _value(r2, PRICES[MON2]) + cash_r2 = INITIAL_CASH - val_r2 + + # Week 2 daily returns: TUE2 through MON3 (included) + prev_post = INITIAL_CASH + for date in (TUE2, WED2, THU2, FRI2): + equity = cash_r2 + _value(r2, PRICES[date]) + expected = equity / prev_post - 1 + self.assertAlmostEqual(expected, returns.loc[date], places=5, + msg=f"return mismatch on {date}") + prev_post = equity + + # MON3 closes week 2: price update gives pre_reset, then rebal fires + mon3_equity = cash_r2 + _value(r2, PRICES[MON3]) + + r3 = { + "MSFT": _target(1.0, 31), # int(100000/31) = 3225 + "NVDA": _target(0.6, 34), # int(60000/34) = 1764 + "META": _target(0.4, 87), # int(40000/87) = 459 + "GOOG": _target(-1.5, 115), # int(-150000/115) = -1304 + "AMZN": _target(-0.5, 51), # int(-50000/51) = -980 + } + + # Fees: deltas from r2 to r3 + fees_r3 = ( + _fee(r3["GOOG"] - r2["GOOG"], 115) + # GOOG: 727 → -1304 + _fee(r3["NVDA"] - r2["NVDA"], 34) + # NVDA: 2333 → 1764 + _fee(r3["META"] - r2["META"], 87) + # META: 602 → 459 + _fee(r3["AMZN"] - r2["AMZN"], 51) + # AMZN: -2181 → -980 + _fee(0 - r2["AAPL"], 59) + # AAPL: -1454 → 0 (closed) + _fee(r3["MSFT"], 31) # MSFT: 0 → 3225 + ) + + # MON3 return includes week 2 P&L net of rebal fees + expected_mon3 = (mon3_equity - fees_r3) / prev_post - 1 + self.assertAlmostEqual(expected_mon3, returns.loc[MON3], places=5) + self.assertAlmostEqual(INITIAL_CASH, df.loc[MON3, "post_reset_equity"]) + + val_r3 = _value(r3, PRICES[MON3]) + cash_r3 = INITIAL_CASH - val_r3 + + # Week 3 daily returns: TUE3 through MON4 + prev_post = INITIAL_CASH + for date in (TUE3, WED3, THU3, FRI3, MON4): + equity = cash_r3 + _value(r3, PRICES[date]) + expected = equity / prev_post - 1 + self.assertAlmostEqual(expected, returns.loc[date], places=5, + msg=f"return mismatch on {date}") + prev_post = equity + + # ================================================================== + # Final state: only r3 holdings remain + # ================================================================== + holdings = {h.symbol: h.quantity for h in self.bt.account.holdings} + self.assertEqual(r3, holdings) + + +if __name__ == "__main__": + unittest.main()