Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions bktest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ def __init__(
price_provider: PriceProvider,
account: Account,
exporters: ExporterCollection,
execution_prices: bool
):
self.quantity_in_decimal = quantity_in_decimal
self.auto_close_others = auto_close_others
self.price_provider = price_provider
self.account = account
self.exporters = exporters
self.execution_prices = execution_prices

def order(
self,
Expand Down Expand Up @@ -54,26 +56,39 @@ def order(
for order in orders:
symbol = order.symbol
percent = order.quantity
price = order.price or self.price_provider.get(price_date, symbol)

if self.execution_prices:
# Use execution price for order placement
price = order.price or self.price_provider.get_execution_price(price_date, symbol)
else:
# Use close price for order placement
price = order.price or self.price_provider.get(price_date, symbol)

holding_cash_value = equity * percent
if price is not None:
quantity = int(holding_cash_value / price)

# Create order with execution price
result = self.account.order_position(Order(symbol, quantity, price))
results.append(result)

if result.success:
others.discard(symbol)
else:
print(f"[warning] order not placed: {symbol} @ {percent}%", file=sys.stderr)
print(f"[warning] order not placed: {symbol} @ {percent}% - {date}", file=sys.stderr)
else:
print(f"[warning] cannot place order: {symbol} @ {percent}%: no price available", file=sys.stderr)
print(f"[warning] cannot place order: {symbol} @ {percent}% - {date}: no execution price available", file=sys.stderr)
else:
for order in orders:
symbol = order.symbol
quantity = order.quantity
price = order.price or self.price_provider.get(price_date, symbol)

if self.execution_prices:
# Use execution price for order placement
price = order.price or self.price_provider.get_execution_price(price_date, symbol)
else:
# Use close price for order placement
price = order.price or self.price_provider.get(price_date, symbol)

if price is not None:
result = self.account.order_position(Order(symbol, quantity, price))
Expand All @@ -82,9 +97,9 @@ def order(
if result.success:
others.discard(symbol)
else:
print(f"[warning] order not placed: {symbol} @ {percent}%", file=sys.stderr)
print(f"[warning] order not placed: {symbol} @ {percent}% - {date}", file=sys.stderr)
else:
print(f"[warning] cannot place order: {symbol} @ {quantity}x: no price available", file=sys.stderr)
print(f"[warning] cannot place order: {symbol} @ {quantity}x - {date}: no execution price available", file=sys.stderr)

if self.auto_close_others:
self._close_all(others, date, results)
Expand All @@ -111,7 +126,7 @@ def _close_all(
if result.success:
closed += 1
else:
print(f"[warning] could not auto-close: {symbol}", file=sys.stderr)
print(f"[warning] could not auto-close: {symbol} - {date}", file=sys.stderr)

total += 1

Expand Down Expand Up @@ -197,7 +212,7 @@ def update_price(self, date):
price = cache[symbol] = self.price_provider.get(date, holding.symbol)

if price is None:
print(f"[warning] price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr)
print(f"[warning] price not updated: {holding.symbol} - {date}: keeping last: {holding.price}", file=sys.stderr)
holding.up_to_date = False
else:
holding.price = price
Expand Down Expand Up @@ -272,6 +287,7 @@ def __init__(
allow_weekends=False,
allow_holidays=False,
holiday_provider: HolidayProvider = LegacyHolidayProvider(),
execution_prices: bool = False
):
self.order_provider = order_provider
order_dates = order_provider.get_dates()
Expand All @@ -284,7 +300,8 @@ def __init__(
auto_close_others,
self.price_provider,
Account(initial_cash=initial_cash, fee_model=fee_model),
ExporterCollection(exporters)
ExporterCollection(exporters),
execution_prices
)

self.date_iterator = DateIterator(
Expand All @@ -302,7 +319,7 @@ def update_price(self, date):
price = self.price_provider.get(date, holding.symbol)

if price is None:
print(f"[warning] price not updated: {holding.symbol}: keeping last: {holding.price}", file=sys.stderr)
print(f"[warning] price not updated: {holding.symbol} - {date}: keeping last: {holding.price}", file=sys.stderr)
holding.up_to_date = False
else:
holding.price = price
Expand Down Expand Up @@ -332,14 +349,13 @@ def run(self):
result = self.order(skip.date, price_date=date)
self.exporters.fire_snapshot(date, self.account, result, postponned=skip.date)

self.update_price(date)
self.exporters.fire_snapshot(date, self.account, None)

if ordered:
result = self.order(date)

self.exporters.fire_snapshot(date, self.account, result)

self.update_price(date)
self.exporters.fire_snapshot(date, self.account, None)
Comment on lines +356 to +357
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: prices need to be updated before placing orders

Maybe this is problematic when taking opening prices?


self.price_provider.save()

self.exporters.fire_finalize()
Expand Down
17 changes: 13 additions & 4 deletions bktest/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
@click.option('--file-parquet-column-date', type=str, default="date", show_default=True, help="Specify the column name containing the dates.")
@click.option('--file-parquet-column-symbol', type=str, default="symbol", show_default=True, help="Specify the column name containing the symbols.")
@click.option('--file-parquet-column-price', type=str, default="price", show_default=True, help="Specify the column name containing the prices.")
@click.option('--file-parquet-column-execution-price', type=str, default=None, show_default=True, help="Specify the column name containing the execution prices (e.g., open prices).")
#
@click.pass_context
def cli(ctx: click.Context, **kwargs):
Expand Down Expand Up @@ -138,7 +139,7 @@ def main(
#
factset: bool, factset_username_serial: str, factset_api_key: str,
#
file_parquet, file_parquet_column_date, file_parquet_column_symbol, file_parquet_column_price,
file_parquet, file_parquet_column_date, file_parquet_column_symbol, file_parquet_column_price, file_parquet_column_execution_price
):
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)

Expand Down Expand Up @@ -203,12 +204,19 @@ def main(
)

if file_parquet:
selected_columns = [file_parquet_column_date,
file_parquet_column_symbol, file_parquet_column_price]
if file_parquet_column_execution_price is not None:
selected_columns.append(file_parquet_column_execution_price)
from .data.source import DataFrameDataSource
file_data_source = DataFrameDataSource(
dataframe=readwrite.read(file_parquet),
dataframe=readwrite.read(file_parquet, columns=selected_columns),
# dataframe=readwrite.read(file_parquet),
date_column=file_parquet_column_date,
symbol_column=file_parquet_column_symbol,
price_column=file_parquet_column_price
price_column=file_parquet_column_price,
execution_price_column=file_parquet_column_execution_price,
order_dataframe=order_provider.dataframe,
)

if data_source is not None:
Expand Down Expand Up @@ -336,7 +344,8 @@ def main(
caching=not no_caching,
allow_weekends=weekends,
allow_holidays=holidays,
holiday_provider=holiday_provider
holiday_provider=holiday_provider,
execution_prices = file_parquet_column_execution_price != None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: is not None

).run()


Expand Down
9 changes: 9 additions & 0 deletions bktest/data/source/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def fetch_prices(
) -> pandas.DataFrame:
raise NotImplementedError()

@abc.abstractmethod
def fetch_execution_prices(
self,
symbols: typing.Set[str],
start: datetime.date,
end: datetime.date
) -> pandas.DataFrame:
raise NotImplementedError()

def is_closeable(self) -> bool:
"""
Return whether or not the markat has closing hours.
Expand Down
50 changes: 40 additions & 10 deletions bktest/data/source/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def __init__(
date_column=constants.DEFAULT_DATE_COLUMN,
symbol_column=constants.DEFAULT_SYMBOL_COLUMN,
price_column=constants.DEFAULT_PRICE_COLUMN,
closeable=True
execution_price_column=None,
closeable=True,
order_dataframe=None,
) -> None:
super().__init__()

Expand All @@ -22,32 +24,60 @@ def __init__(
keep="first"
)

dataframe = dataframe.pivot(
# TODO: Prefiltre avant
if order_dataframe is not None:
filter_assets = set(dataframe[symbol_column].unique())
min_date = order_dataframe[date_column].min()
dataframe = dataframe[(dataframe[date_column] >= min_date) & (dataframe[symbol_column].isin(filter_assets))].copy()

self.dataframe = dataframe.pivot(
index=date_column,
columns=symbol_column,
values=price_column
)
self.dataframe.index = pandas.to_datetime(self.dataframe.index)
self.dataframe.index.name = constants.DEFAULT_DATE_COLUMN

if execution_price_column is not None:
self.execution_dataframe = dataframe.pivot(
index=date_column,
columns=symbol_column,
values=execution_price_column
)
self.execution_dataframe.index = pandas.to_datetime(self.execution_dataframe.index)
self.execution_dataframe.index.name = constants.DEFAULT_DATE_COLUMN

dataframe.index = pandas.to_datetime(dataframe.index)
dataframe.index.name = constants.DEFAULT_DATE_COLUMN
self.has_execution_prices = True
else:
# Fallback to using the same prices for execution
self.execution_dataframe = self.dataframe
self.has_execution_prices = False

self.dataframe = dataframe
self.closeable = closeable

def fetch_prices(self, symbols, start, end):
"""Fetch prices for portfolio valuation and return calculation"""
return self._fetch_from_dataframe(self.dataframe, symbols, start, end)

def fetch_execution_prices(self, symbols, start, end):
"""Fetch prices for order execution (e.g., open prices)"""
return self._fetch_from_dataframe(self.execution_dataframe, symbols, start, end)

def _fetch_from_dataframe(self, dataframe, symbols, start, end):
"""Helper method to fetch prices from a specific dataframe"""
symbols = set(symbols)

missings = symbols - set(self.dataframe.columns)
missings = symbols - set(dataframe.columns)
founds = symbols - missings

prices = None
if len(founds):
start = pandas.to_datetime(start)
end = pandas.to_datetime(end)

prices = self.dataframe[
(self.dataframe.index >= start) &
(self.dataframe.index <= end)
prices = dataframe[
(dataframe.index >= start) &
(dataframe.index <= end)
][list(founds)].copy()
else:
prices = pandas.DataFrame(
Expand All @@ -62,4 +92,4 @@ def fetch_prices(self, symbols, start, end):
return prices

def is_closeable(self):
return self.closeable
return self.closeable
25 changes: 25 additions & 0 deletions bktest/price_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,31 @@ def save(self):

self.storage.to_csv(path)

def get_execution_price(self, date: datetime.date, symbol: str):
"""Get execution price (e.g., open price) for a specific date and symbol"""
if symbol not in self.symbols:
raise ValueError(f"{symbol} not available")

symbol = self.mapper.map(symbol)

if hasattr(self.data_source, 'fetch_execution_prices'):
try:
one_day = datetime.timedelta(days=1)
execution_prices = self.data_source.fetch_execution_prices(
# symbols={symbol},
symbols=[symbol],
Comment thread
Caceresenzo marked this conversation as resolved.
start=date - one_day,
end=date + one_day
)
value = execution_prices[symbol][numpy.datetime64(date)]
if not value or numpy.isnan(value):
value = None
return value
except Exception as e:
return self.get(date, symbol)
else:
return self.get(date, symbol)

def is_closeable(self) -> bool:
return self.data_source.is_closeable()

Expand Down
Loading