diff --git a/bktest/backtest.py b/bktest/backtest.py index d260eed..93bf194 100644 --- a/bktest/backtest.py +++ b/bktest/backtest.py @@ -21,12 +21,14 @@ def __init__( price_provider: PriceProvider, account: Account, exporters: ExporterCollection, + execution_prices: bool ): self.quantity_in_decimal = quantity_in_decimal self.auto_close_others = auto_close_others self.price_provider = price_provider self.account = account self.exporters = exporters + self.execution_prices = execution_prices def order( self, @@ -54,26 +56,39 @@ def order( for order in orders: symbol = order.symbol percent = order.quantity - price = order.price or self.price_provider.get(price_date, symbol) + + if self.execution_prices: + # Use execution price for order placement + price = order.price or self.price_provider.get_execution_price(price_date, symbol) + else: + # Use close price for order placement + price = order.price or self.price_provider.get(price_date, symbol) holding_cash_value = equity * percent if price is not None: quantity = int(holding_cash_value / price) + # Create order with execution price result = self.account.order_position(Order(symbol, quantity, price)) results.append(result) if result.success: others.discard(symbol) else: - print(f"[warning] order not placed: {symbol} @ {percent}%", file=sys.stderr) + print(f"[warning] order not placed: {symbol} @ {percent}% - {date}", file=sys.stderr) else: - print(f"[warning] cannot place order: {symbol} @ {percent}%: no price available", file=sys.stderr) + print(f"[warning] cannot place order: {symbol} @ {percent}% - {date}: no execution price available", file=sys.stderr) else: for order in orders: symbol = order.symbol quantity = order.quantity - price = order.price or self.price_provider.get(price_date, symbol) + + if self.execution_prices: + # Use execution price for order placement + price = order.price or self.price_provider.get_execution_price(price_date, symbol) + else: + # Use close price for order placement + price = order.price or self.price_provider.get(price_date, symbol) if price is not None: result = self.account.order_position(Order(symbol, quantity, price)) @@ -82,9 +97,9 @@ def order( if result.success: others.discard(symbol) else: - print(f"[warning] order not placed: {symbol} @ {percent}%", file=sys.stderr) + print(f"[warning] order not placed: {symbol} @ {percent}% - {date}", file=sys.stderr) else: - print(f"[warning] cannot place order: {symbol} @ {quantity}x: no price available", file=sys.stderr) + print(f"[warning] cannot place order: {symbol} @ {quantity}x - {date}: no execution price available", file=sys.stderr) if self.auto_close_others: self._close_all(others, date, results) @@ -111,7 +126,7 @@ def _close_all( if result.success: closed += 1 else: - print(f"[warning] could not auto-close: {symbol}", file=sys.stderr) + print(f"[warning] could not auto-close: {symbol} - {date}", file=sys.stderr) total += 1 @@ -197,7 +212,7 @@ def update_price(self, date): price = cache[symbol] = self.price_provider.get(date, holding.symbol) if price is None: - print(f"[warning] price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr) + print(f"[warning] price not updated: {holding.symbol} - {date}: keeping last: {holding.price}", file=sys.stderr) holding.up_to_date = False else: holding.price = price @@ -272,6 +287,7 @@ def __init__( allow_weekends=False, allow_holidays=False, holiday_provider: HolidayProvider = LegacyHolidayProvider(), + execution_prices: bool = False ): self.order_provider = order_provider order_dates = order_provider.get_dates() @@ -284,7 +300,8 @@ def __init__( auto_close_others, self.price_provider, Account(initial_cash=initial_cash, fee_model=fee_model), - ExporterCollection(exporters) + ExporterCollection(exporters), + execution_prices ) self.date_iterator = DateIterator( @@ -302,7 +319,7 @@ def update_price(self, date): price = self.price_provider.get(date, holding.symbol) if price is None: - print(f"[warning] price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr) + print(f"[warning] price not updated: {holding.symbol} - {date}: keeping last: {holding.price}", file=sys.stderr) holding.up_to_date = False else: holding.price = price @@ -332,14 +349,13 @@ def run(self): result = self.order(skip.date, price_date=date) self.exporters.fire_snapshot(date, self.account, result, postponned=skip.date) - self.update_price(date) - self.exporters.fire_snapshot(date, self.account, None) - if ordered: result = self.order(date) - self.exporters.fire_snapshot(date, self.account, result) + self.update_price(date) + self.exporters.fire_snapshot(date, self.account, None) + self.price_provider.save() self.exporters.fire_finalize() diff --git a/bktest/cli.py b/bktest/cli.py index a3b4bb6..d93f764 100644 --- a/bktest/cli.py +++ b/bktest/cli.py @@ -92,6 +92,7 @@ @click.option('--file-parquet-column-date', type=str, default="date", show_default=True, help="Specify the column name containing the dates.") @click.option('--file-parquet-column-symbol', type=str, default="symbol", show_default=True, help="Specify the column name containing the symbols.") @click.option('--file-parquet-column-price', type=str, default="price", show_default=True, help="Specify the column name containing the prices.") +@click.option('--file-parquet-column-execution-price', type=str, default=None, show_default=True, help="Specify the column name containing the execution prices (e.g., open prices).") # @click.pass_context def cli(ctx: click.Context, **kwargs): @@ -138,7 +139,7 @@ def main( # factset: bool, factset_username_serial: str, factset_api_key: str, # - file_parquet, file_parquet_column_date, file_parquet_column_symbol, file_parquet_column_price, + file_parquet, file_parquet_column_date, file_parquet_column_symbol, file_parquet_column_price, file_parquet_column_execution_price ): logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) @@ -203,12 +204,19 @@ def main( ) if file_parquet: + selected_columns = [file_parquet_column_date, + file_parquet_column_symbol, file_parquet_column_price] + if file_parquet_column_execution_price is not None: + selected_columns.append(file_parquet_column_execution_price) from .data.source import DataFrameDataSource file_data_source = DataFrameDataSource( - dataframe=readwrite.read(file_parquet), + dataframe=readwrite.read(file_parquet, columns=selected_columns), + # dataframe=readwrite.read(file_parquet), date_column=file_parquet_column_date, symbol_column=file_parquet_column_symbol, - price_column=file_parquet_column_price + price_column=file_parquet_column_price, + execution_price_column=file_parquet_column_execution_price, + order_dataframe=order_provider.dataframe, ) if data_source is not None: @@ -336,7 +344,8 @@ def main( caching=not no_caching, allow_weekends=weekends, allow_holidays=holidays, - holiday_provider=holiday_provider + holiday_provider=holiday_provider, + execution_prices = file_parquet_column_execution_price != None ).run() diff --git a/bktest/data/source/base.py b/bktest/data/source/base.py index 6b274d0..564e829 100644 --- a/bktest/data/source/base.py +++ b/bktest/data/source/base.py @@ -16,6 +16,15 @@ def fetch_prices( ) -> pandas.DataFrame: raise NotImplementedError() + @abc.abstractmethod + def fetch_execution_prices( + self, + symbols: typing.Set[str], + start: datetime.date, + end: datetime.date + ) -> pandas.DataFrame: + raise NotImplementedError() + def is_closeable(self) -> bool: """ Return whether or not the markat has closing hours. diff --git a/bktest/data/source/dataframe.py b/bktest/data/source/dataframe.py index f04bc4b..1a5ecb0 100644 --- a/bktest/data/source/dataframe.py +++ b/bktest/data/source/dataframe.py @@ -13,7 +13,9 @@ def __init__( date_column=constants.DEFAULT_DATE_COLUMN, symbol_column=constants.DEFAULT_SYMBOL_COLUMN, price_column=constants.DEFAULT_PRICE_COLUMN, - closeable=True + execution_price_column=None, + closeable=True, + order_dataframe=None, ) -> None: super().__init__() @@ -22,22 +24,50 @@ def __init__( keep="first" ) - dataframe = dataframe.pivot( + # TODO: Prefiltre avant + if order_dataframe is not None: + filter_assets = set(dataframe[symbol_column].unique()) + min_date = order_dataframe[date_column].min() + dataframe = dataframe[(dataframe[date_column] >= min_date) & (dataframe[symbol_column].isin(filter_assets))].copy() + + self.dataframe = dataframe.pivot( index=date_column, columns=symbol_column, values=price_column ) + self.dataframe.index = pandas.to_datetime(self.dataframe.index) + self.dataframe.index.name = constants.DEFAULT_DATE_COLUMN + + if execution_price_column is not None: + self.execution_dataframe = dataframe.pivot( + index=date_column, + columns=symbol_column, + values=execution_price_column + ) + self.execution_dataframe.index = pandas.to_datetime(self.execution_dataframe.index) + self.execution_dataframe.index.name = constants.DEFAULT_DATE_COLUMN - dataframe.index = pandas.to_datetime(dataframe.index) - dataframe.index.name = constants.DEFAULT_DATE_COLUMN + self.has_execution_prices = True + else: + # Fallback to using the same prices for execution + self.execution_dataframe = self.dataframe + self.has_execution_prices = False - self.dataframe = dataframe self.closeable = closeable def fetch_prices(self, symbols, start, end): + """Fetch prices for portfolio valuation and return calculation""" + return self._fetch_from_dataframe(self.dataframe, symbols, start, end) + + def fetch_execution_prices(self, symbols, start, end): + """Fetch prices for order execution (e.g., open prices)""" + return self._fetch_from_dataframe(self.execution_dataframe, symbols, start, end) + + def _fetch_from_dataframe(self, dataframe, symbols, start, end): + """Helper method to fetch prices from a specific dataframe""" symbols = set(symbols) - missings = symbols - set(self.dataframe.columns) + missings = symbols - set(dataframe.columns) founds = symbols - missings prices = None @@ -45,9 +75,9 @@ def fetch_prices(self, symbols, start, end): start = pandas.to_datetime(start) end = pandas.to_datetime(end) - prices = self.dataframe[ - (self.dataframe.index >= start) & - (self.dataframe.index <= end) + prices = dataframe[ + (dataframe.index >= start) & + (dataframe.index <= end) ][list(founds)].copy() else: prices = pandas.DataFrame( @@ -62,4 +92,4 @@ def fetch_prices(self, symbols, start, end): return prices def is_closeable(self): - return self.closeable + return self.closeable \ No newline at end of file diff --git a/bktest/price_provider.py b/bktest/price_provider.py index 0456ec2..eca5228 100644 --- a/bktest/price_provider.py +++ b/bktest/price_provider.py @@ -166,6 +166,30 @@ def save(self): self.storage.to_csv(path) + def get_execution_price(self, date: datetime.date, symbol: str): + """Get execution price (e.g., open price) for a specific date and symbol""" + if symbol not in self.symbols: + raise ValueError(f"{symbol} not available") + + symbol = self.mapper.map(symbol) + + if hasattr(self.data_source, 'fetch_execution_prices'): + try: + one_day = datetime.timedelta(days=1) + execution_prices = self.data_source.fetch_execution_prices( + symbols=[symbol], + start=date - one_day, + end=date + one_day + ) + value = execution_prices[symbol][numpy.datetime64(date)] + if not value or numpy.isnan(value): + value = None + return value + except Exception as e: + return self.get(date, symbol) + else: + return self.get(date, symbol) + def is_closeable(self) -> bool: return self.data_source.is_closeable()