From 2f29eb4d46242fcfd8bc97aa705919e81ea6ae40 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Mon, 2 Feb 2026 07:11:40 +0300 Subject: [PATCH 1/5] Gitignore updated to recent examples --- examples/09_sasrec_example_freq_sampled.ipynb | 4768 +++++++++++++++++ examples/09_sasrec_example_sampled.ipynb | 4744 ++++++++++++++++ examples/09_sasrec_example_thr_sampled.ipynb | 4768 +++++++++++++++++ replay/nn/transform/__init__.py | 10 + replay/nn/transform/negative_sampling.py | 224 +- 5 files changed, 14500 insertions(+), 14 deletions(-) create mode 100644 examples/09_sasrec_example_freq_sampled.ipynb create mode 100644 examples/09_sasrec_example_sampled.ipynb create mode 100644 examples/09_sasrec_example_thr_sampled.ipynb diff --git a/examples/09_sasrec_example_freq_sampled.ipynb b/examples/09_sasrec_example_freq_sampled.ipynb new file mode 100644 index 000000000..4988e2512 --- /dev/null +++ b/examples/09_sasrec_example_freq_sampled.ipynb @@ -0,0 +1,4768 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of SasRec training/inference with Parquet Module" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + } + ], + "source": [ + "from typing import Optional\n", + "\n", + "import lightning as L\n", + "import pandas as pd\n", + "\n", + "L.seed_everything(42)\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparing data\n", + "In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.\n", + "\n", + "---\n", + "**NOTE**\n", + "\n", + "Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. \n", + "\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "interactions = pd.read_csv(\"./data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\",\"rating\",\"timestamp\"])\n", + "interactions = interactions.drop(columns=[\"rating\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idtimestamp
100013860408580
1000153604023841
99987360405932
1000007604019613
1000192604020194
............
82579349582399446
82543849581407447
82572449583264448
82573149582634449
82560349581924450
\n", + "

1000209 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id timestamp\n", + "1000138 6040 858 0\n", + "1000153 6040 2384 1\n", + "999873 6040 593 2\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", + "... ... ... ...\n", + "825793 4958 2399 446\n", + "825438 4958 1407 447\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", + "825603 4958 1924 450\n", + "\n", + "[1000209 rows x 3 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions[\"timestamp\"] = interactions[\"timestamp\"].astype(\"int64\")\n", + "interactions = interactions.sort_values(by=\"timestamp\")\n", + "interactions[\"timestamp\"] = interactions.groupby(\"user_id\").cumcount()\n", + "interactions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Encode catagorical data.\n", + "To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestampuser_iditem_id
01200
16810
26720
31230
414040
............
10002041445553705
10002059028133705
10002067024043705
10002072558353705
10002083809793705
\n", + "

1000209 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " timestamp user_id item_id\n", + "0 12 0 0\n", + "1 68 1 0\n", + "2 67 2 0\n", + "3 12 3 0\n", + "4 140 4 0\n", + "... ... ... ...\n", + "1000204 14 4555 3705\n", + "1000205 90 2813 3705\n", + "1000206 70 2404 3705\n", + "1000207 25 5835 3705\n", + "1000208 380 979 3705\n", + "\n", + "[1000209 rows x 3 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", + "\n", + "encoder = LabelEncoder(\n", + " [\n", + " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", + " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", + " ]\n", + ")\n", + "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", + "encoded_interactions = encoder.fit_transform(interactions)\n", + "encoded_interactions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split interactions into the train, validation and test datasets using LastNSplitter\n", + "We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.splitters import LastNSplitter\n", + "\n", + "splitter = LastNSplitter(\n", + " N=1,\n", + " divide_column=\"user_id\",\n", + " query_column=\"user_id\",\n", + " strategy=\"interactions\",\n", + " drop_cold_users=True,\n", + " drop_cold_items=True\n", + ")\n", + "\n", + "test_events, test_gt = splitter.split(encoded_interactions)\n", + "validation_events, validation_gt = splitter.split(test_events)\n", + "train_events = validation_events" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset preprocessing (\"baking\")\n", + "SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data.nn.utils import groupby_sequences\n", + "\n", + "\n", + "def bake_data(full_data):\n", + " grouped_interactions = groupby_sequences(events=full_data, groupby_col=\"user_id\", sort_col=\"timestamp\")\n", + " return grouped_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idtimestampitem_id
00[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,...
11[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ...
22[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621...
33[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1781, 2940, 2468, 890, 948, 106, 593, 309, 49...
44[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1...
............
60356035[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278...
60366036[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250...
60376037[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283...
60386038[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...
60396039[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2432, 2960, 1848, 2114, 2142, 3091, 3248, 317...
\n", + "

6040 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id timestamp \\\n", + "0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "1 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "2 2 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "3 3 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "4 4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "... ... ... \n", + "6035 6035 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6036 6036 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6037 6037 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6038 6038 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "\n", + " item_id \n", + "0 [2426, 822, 2733, 2587, 2937, 3618, 2943, 708,... \n", + "1 [3272, 3026, 2760, 851, 346, 3393, 1107, 515, ... \n", + "2 [579, 1140, 1154, 2426, 1524, 1260, 2160, 2621... \n", + "3 [1781, 2940, 2468, 890, 948, 106, 593, 309, 49... \n", + "4 [1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1... \n", + "... ... \n", + "6035 [2426, 1279, 3151, 3321, 1178, 3301, 2501, 278... \n", + "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 250... \n", + "6037 [1971, 3500, 2077, 1666, 1399, 2651, 2748, 283... \n", + "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", + "6039 [2432, 2960, 1848, 2114, 2142, 3091, 3248, 317... \n", + "\n", + "[6040 rows x 3 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_events = bake_data(train_events)\n", + "\n", + "validation_events = bake_data(validation_events)\n", + "validation_gt = bake_data(validation_gt)\n", + "\n", + "test_events = bake_data(test_events)\n", + "test_gt = bake_data(test_gt)\n", + "\n", + "train_events" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def add_gt_to_events(events_df, gt_df):\n", + " gt_to_join = gt_df[[\"user_id\", \"item_id\"]].rename(columns={\"item_id\": \"ground_truth\"})\n", + "\n", + " events_df = events_df.merge(gt_to_join, on=\"user_id\", how=\"inner\")\n", + " return events_df\n", + "\n", + "validation_events = add_gt_to_events(validation_events, validation_gt)\n", + "test_events = add_gt_to_events(test_events, test_gt)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "data_dir = Path(\"temp/data/\")\n", + "data_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "TRAIN_PATH = data_dir / \"train.parquet\"\n", + "VAL_PATH = data_dir / \"val.parquet\"\n", + "PREDICT_PATH = data_dir / \"test.parquet\"\n", + "\n", + "ENCODER_PATH = data_dir / \"encoder\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "train_events.to_parquet(TRAIN_PATH)\n", + "validation_events.to_parquet(VAL_PATH)\n", + "test_events.to_parquet(PREDICT_PATH)\n", + "\n", + "encoder.save(ENCODER_PATH)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare to model training\n", + "### Create the tensor schema\n", + "A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.\n", + "\n", + "Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data import FeatureHint, FeatureType\n", + "from replay.data.nn import TensorFeatureInfo, TensorSchema\n", + "\n", + "\n", + "EMBEDDING_DIM = 64\n", + "\n", + "encoder = encoder.load(ENCODER_PATH)\n", + "NUM_UNIQUE_ITEMS = len(encoder.mapping[\"item_id\"])\n", + "\n", + "tensor_schema = TensorSchema(\n", + " [\n", + " TensorFeatureInfo(\n", + " name=\"item_id\",\n", + " is_seq=True,\n", + " padding_value=NUM_UNIQUE_ITEMS,\n", + " cardinality=NUM_UNIQUE_ITEMS + 1, # taking into account padding\n", + " embedding_dim=EMBEDDING_DIM,\n", + " feature_type=FeatureType.CATEGORICAL,\n", + " feature_hint=FeatureHint.ITEM_ID,\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure ParquetModule and transformation pipelines\n", + "\n", + "The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s \"transform pipelines\" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass. \n", + "\n", + "For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.\n", + "\n", + "Internally this function creates the following transforms:\n", + "1) Training:\n", + " 1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).\n", + " 2. Rename features to match it with expected format by the model during training.\n", + " 3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.\n", + " 4. Group input features to be embed in expected format.\n", + "\n", + "2) Validation/Inference:\n", + " 1. Rename/group features to match it with expected format by the model during valdiation/inference.\n", + "\n", + "If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.\n", + "\n", + "**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "\n", + "import torch\n", + "\n", + "from replay.data.nn import TensorSchema\n", + "from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, FrequencyNegativeSamplingTransform" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def make_sasrec_transforms(\n", + " tensor_schema: TensorSchema, query_column: str = \"query_id\", num_negative_samples: int = 128,\n", + ") -> dict[str, list[torch.nn.Module]]:\n", + " item_column = tensor_schema.item_id_feature_name\n", + " vocab_size = tensor_schema[item_column].cardinality\n", + " train_transforms = [\n", + " FrequencyNegativeSamplingTransform(vocab_size, num_negative_samples),\n", + " NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),\n", + " RenameTransform(\n", + " {\n", + " query_column: \"query_id\",\n", + " f\"{item_column}_mask\": \"padding_mask\",\n", + " \"positive_labels_mask\": \"target_padding_mask\",\n", + " }\n", + " ),\n", + " UnsqueezeTransform(\"target_padding_mask\", -1),\n", + " UnsqueezeTransform(\"positive_labels\", -1),\n", + " GroupTransform({\"feature_tensors\": [item_column]}),\n", + " ]\n", + "\n", + " val_transforms = [\n", + " RenameTransform({query_column: \"query_id\", f\"{item_column}_mask\": \"padding_mask\"}),\n", + " GroupTransform({\"feature_tensors\": [item_column]}),\n", + " ]\n", + " test_transforms = copy.deepcopy(val_transforms)\n", + "\n", + " predict_transforms = copy.deepcopy(val_transforms)\n", + "\n", + " transforms = {\n", + " \"train\": train_transforms,\n", + " \"validate\": val_transforms,\n", + " \"test\": test_transforms,\n", + " \"predict\": predict_transforms,\n", + " }\n", + "\n", + " return transforms\n", + "\n", + "transforms = make_sasrec_transforms(tensor_schema, query_column=\"user_id\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_SEQ_LEN = 50\n", + "\n", + "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", + " meta = {\n", + " \"user_id\": {},\n", + " \"item_id\": {\"shape\": shape, \"padding\": tensor_schema[\"item_id\"].padding_value},\n", + " }\n", + " if gt_shape is not None:\n", + " meta.update({\"ground_truth\": {\"shape\": gt_shape, \"padding\": -1}})\n", + "\n", + " return meta\n", + "\n", + "train_metadata = {\n", + " \"train\": create_meta(shape=MAX_SEQ_LEN+1),\n", + " \"validate\": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data.nn import ParquetModule\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "parquet_module = ParquetModule(\n", + " train_path=TRAIN_PATH,\n", + " validate_path=VAL_PATH,\n", + " batch_size=BATCH_SIZE,\n", + " metadata=train_metadata,\n", + " transforms=transforms,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model\n", + "### Create SasRec model instance and run the training stage using lightning\n", + "We may now train the model using the Lightning trainer class. \n", + "\n", + "RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.\n", + "\n", + "#### Default Configuration\n", + "\n", + "Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.nn.sequential import SasRec\n", + "from typing import Literal\n", + "def make_sasrec(\n", + " schema: TensorSchema,\n", + " embedding_dim: int = 192,\n", + " num_heads: int = 4,\n", + " num_blocks: int = 2,\n", + " max_sequence_length: int = 50,\n", + " dropout: float = 0.3,\n", + " excluded_features: Optional[list[str]] = None,\n", + " categorical_list_feature_aggregation_method: Literal[\"sum\", \"mean\", \"max\"] = \"sum\",\n", + ") -> SasRec:\n", + " from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer\n", + " from replay.nn.agg import SumAggregator\n", + " from replay.nn.embedding import SequenceEmbedding\n", + " from replay.nn.loss import CE, CESampled\n", + " from replay.nn.mask import DefaultAttentionMask\n", + " from replay.nn.sequential.sasrec.agg import PositionAwareAggregator\n", + " from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer\n", + " excluded_features = [\n", + " schema.query_id_feature_name,\n", + " schema.timestamp_feature_name,\n", + " *(excluded_features or []),\n", + " ]\n", + " excluded_features = list(set(excluded_features))\n", + " body = SasRecBody(\n", + " embedder=SequenceEmbedding(\n", + " schema=schema,\n", + " categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,\n", + " excluded_features=excluded_features,\n", + " ),\n", + " embedding_aggregator=PositionAwareAggregator(\n", + " embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),\n", + " max_sequence_length=max_sequence_length,\n", + " dropout=dropout,\n", + " ),\n", + " attn_mask_builder=DefaultAttentionMask(\n", + " reference_feature_name=schema.item_id_feature_name,\n", + " num_heads=num_heads,\n", + " ),\n", + " encoder=SasRecTransformerLayer(\n", + " embedding_dim=embedding_dim,\n", + " num_heads=num_heads,\n", + " num_blocks=num_blocks,\n", + " dropout=dropout,\n", + " activation=\"relu\",\n", + " ),\n", + " output_normalization=torch.nn.LayerNorm(embedding_dim),\n", + " )\n", + " padding_idx = schema.item_id_features.item().padding_value\n", + " return SasRec(\n", + " body=body,\n", + " loss=CESampled(padding_idx=padding_idx),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_BLOCKS = 2\n", + "NUM_HEADS = 2\n", + "DROPOUT = 0.3\n", + "\n", + "sasrec = make_sasrec(\n", + " schema=tensor_schema,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " max_sequence_length=MAX_SEQ_LEN,\n", + " num_heads=NUM_HEADS,\n", + " num_blocks=NUM_BLOCKS,\n", + " dropout=DROPOUT,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A universal PyTorch Lightning module is provided. It can work with any NN model." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.nn.lightning.optimizer import OptimizerFactory\n", + "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", + "from replay.nn.lightning import LightningModule\n", + "\n", + "model = LightningModule(\n", + " sasrec,\n", + " optimizer_factory=OptimizerFactory(),\n", + " lr_scheduler_factory=LRSchedulerFactory(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To facilitate training, we add the following callbacks:\n", + "1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.\n", + "1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode | FLOPs\n", + "-------------------------------------------------\n", + "0 | model | SasRec | 291 K | train | 0 \n", + "-------------------------------------------------\n", + "291 K Trainable params\n", + "0 Non-trainable params\n", + "291 K Total params\n", + "1.164 Total estimated model params size (MB)\n", + "39 Modules in train mode\n", + "0 Modules in eval mode\n", + "0 Total Flops\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "45904c0e26294a109f736787e2831727", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.figure(figsize = (5, 4), dpi = 120)\n", + "plt.hist(transforms[\"train\"][0].frequencies.cpu().numpy(), bins = 50)\n", + "plt.grid()\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can get the best model path stored in the checkpoint callback." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=96-step=18333.ckpt'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_model_path = checkpoint_callback.best_model_path\n", + "best_model_path" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import replay\n", + "torch.serialization.add_safe_globals([\n", + " replay.nn.lightning.optimizer.OptimizerFactory,\n", + " replay.nn.lightning.scheduler.LRSchedulerFactory\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "sasrec = make_sasrec(\n", + " schema=tensor_schema,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " max_sequence_length=MAX_SEQ_LEN,\n", + " num_heads=NUM_HEADS,\n", + " num_blocks=NUM_BLOCKS,\n", + " dropout=DROPOUT,\n", + ")\n", + "\n", + "best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n", + "best_model.eval();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configure `ParquetModule` for inference" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", + "\n", + "parquet_module = ParquetModule(\n", + " predict_path=PREDICT_PATH,\n", + " batch_size=BATCH_SIZE,\n", + " metadata=inference_metadata,\n", + " transforms=transforms,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.\n", + "\n", + "\n", + "Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9b94b2f1745140e5ae3bc1d63bdbbe4e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscore
004866.457348
002246.311622
002106.299765
001016.193144
0021426.179489
............
603760398410.029899
6037603926349.860862
6037603914909.738086
6037603926339.715384
6037603924979.684793
\n", + "

120760 rows × 3 columns

\n", + "" + ], + "text/plain": [ + " user_id item_id score\n", + "0 0 486 6.457348\n", + "0 0 224 6.311622\n", + "0 0 210 6.299765\n", + "0 0 101 6.193144\n", + "0 0 2142 6.179489\n", + "... ... ... ...\n", + "6037 6039 84 10.029899\n", + "6037 6039 2634 9.860862\n", + "6037 6039 1490 9.738086\n", + "6037 6039 2633 9.715384\n", + "6037 6039 2497 9.684793\n", + "\n", + "[120760 rows x 3 columns]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculating metrics\n", + "\n", + "*test_gt* is already encoded, so we can use it for computing metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.metrics import MAP, OfflineMetrics, Precision, Recall\n", + "from replay.metrics.torch_metrics_builder import metrics_to_df" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "result_metrics = OfflineMetrics(\n", + " [Recall(TOPK), Precision(TOPK), MAP(TOPK)],\n", + " query_column=\"user_id\",\n", + " rating_column=\"score\",\n", + ")(pandas_res, test_gt.explode(\"item_id\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
k110205
MAP0.0203710.0611890.0685750.051101
Precision0.0203710.0187310.0147900.022093
Recall0.0203710.1873140.2957930.110467
\n", + "
" + ], + "text/plain": [ + "k 1 10 20 5\n", + "MAP 0.020371 0.061189 0.068575 0.051101\n", + "Precision 0.020371 0.018731 0.014790 0.022093\n", + "Recall 0.020371 0.187314 0.295793 0.110467" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics_to_df(result_metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscore
020125006.457348
020122316.311622
020122166.299765
020121046.193144
0201223356.179489
............
603757278610.029899
6037572728419.860862
6037572716239.738086
6037572728409.715384
6037572727029.684793
\n", + "

120760 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id score\n", + "0 2012 500 6.457348\n", + "0 2012 231 6.311622\n", + "0 2012 216 6.299765\n", + "0 2012 104 6.193144\n", + "0 2012 2335 6.179489\n", + "... ... ... ...\n", + "6037 5727 86 10.029899\n", + "6037 5727 2841 9.860862\n", + "6037 5727 1623 9.738086\n", + "6037 5727 2840 9.715384\n", + "6037 5727 2702 9.684793\n", + "\n", + "[120760 rows x 3 columns]" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder.inverse_transform(pandas_res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "new_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/09_sasrec_example_sampled.ipynb b/examples/09_sasrec_example_sampled.ipynb new file mode 100644 index 000000000..6672edd85 --- /dev/null +++ b/examples/09_sasrec_example_sampled.ipynb @@ -0,0 +1,4744 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of SasRec training/inference with Parquet Module" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + } + ], + "source": [ + "from typing import Optional\n", + "\n", + "import lightning as L\n", + "import pandas as pd\n", + "\n", + "L.seed_everything(42)\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparing data\n", + "In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.\n", + "\n", + "---\n", + "**NOTE**\n", + "\n", + "Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. \n", + "\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "interactions = pd.read_csv(\"./data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\",\"rating\",\"timestamp\"])\n", + "interactions = interactions.drop(columns=[\"rating\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idtimestamp
100013860408580
1000153604023841
99987360405932
1000007604019613
1000192604020194
............
82579349582399446
82543849581407447
82572449583264448
82573149582634449
82560349581924450
\n", + "

1000209 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id timestamp\n", + "1000138 6040 858 0\n", + "1000153 6040 2384 1\n", + "999873 6040 593 2\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", + "... ... ... ...\n", + "825793 4958 2399 446\n", + "825438 4958 1407 447\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", + "825603 4958 1924 450\n", + "\n", + "[1000209 rows x 3 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions[\"timestamp\"] = interactions[\"timestamp\"].astype(\"int64\")\n", + "interactions = interactions.sort_values(by=\"timestamp\")\n", + "interactions[\"timestamp\"] = interactions.groupby(\"user_id\").cumcount()\n", + "interactions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Encode catagorical data.\n", + "To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestampuser_iditem_id
01200
16810
26720
31230
414040
............
10002041445553705
10002059028133705
10002067024043705
10002072558353705
10002083809793705
\n", + "

1000209 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " timestamp user_id item_id\n", + "0 12 0 0\n", + "1 68 1 0\n", + "2 67 2 0\n", + "3 12 3 0\n", + "4 140 4 0\n", + "... ... ... ...\n", + "1000204 14 4555 3705\n", + "1000205 90 2813 3705\n", + "1000206 70 2404 3705\n", + "1000207 25 5835 3705\n", + "1000208 380 979 3705\n", + "\n", + "[1000209 rows x 3 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", + "\n", + "encoder = LabelEncoder(\n", + " [\n", + " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", + " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", + " ]\n", + ")\n", + "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", + "encoded_interactions = encoder.fit_transform(interactions)\n", + "encoded_interactions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split interactions into the train, validation and test datasets using LastNSplitter\n", + "We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.splitters import LastNSplitter\n", + "\n", + "splitter = LastNSplitter(\n", + " N=1,\n", + " divide_column=\"user_id\",\n", + " query_column=\"user_id\",\n", + " strategy=\"interactions\",\n", + " drop_cold_users=True,\n", + " drop_cold_items=True\n", + ")\n", + "\n", + "test_events, test_gt = splitter.split(encoded_interactions)\n", + "validation_events, validation_gt = splitter.split(test_events)\n", + "train_events = validation_events" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset preprocessing (\"baking\")\n", + "SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data.nn.utils import groupby_sequences\n", + "\n", + "\n", + "def bake_data(full_data):\n", + " grouped_interactions = groupby_sequences(events=full_data, groupby_col=\"user_id\", sort_col=\"timestamp\")\n", + " return grouped_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idtimestampitem_id
00[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,...
11[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ...
22[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621...
33[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1781, 2940, 2468, 890, 948, 106, 593, 309, 49...
44[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1...
............
60356035[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278...
60366036[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250...
60376037[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283...
60386038[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...
60396039[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2432, 2960, 1848, 2114, 2142, 3091, 3248, 317...
\n", + "

6040 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id timestamp \\\n", + "0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "1 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "2 2 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "3 3 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "4 4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "... ... ... \n", + "6035 6035 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6036 6036 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6037 6037 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6038 6038 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "\n", + " item_id \n", + "0 [2426, 822, 2733, 2587, 2937, 3618, 2943, 708,... \n", + "1 [3272, 3026, 2760, 851, 346, 3393, 1107, 515, ... \n", + "2 [579, 1140, 1154, 2426, 1524, 1260, 2160, 2621... \n", + "3 [1781, 2940, 2468, 890, 948, 106, 593, 309, 49... \n", + "4 [1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1... \n", + "... ... \n", + "6035 [2426, 1279, 3151, 3321, 1178, 3301, 2501, 278... \n", + "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 250... \n", + "6037 [1971, 3500, 2077, 1666, 1399, 2651, 2748, 283... \n", + "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", + "6039 [2432, 2960, 1848, 2114, 2142, 3091, 3248, 317... \n", + "\n", + "[6040 rows x 3 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_events = bake_data(train_events)\n", + "\n", + "validation_events = bake_data(validation_events)\n", + "validation_gt = bake_data(validation_gt)\n", + "\n", + "test_events = bake_data(test_events)\n", + "test_gt = bake_data(test_gt)\n", + "\n", + "train_events" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def add_gt_to_events(events_df, gt_df):\n", + " gt_to_join = gt_df[[\"user_id\", \"item_id\"]].rename(columns={\"item_id\": \"ground_truth\"})\n", + "\n", + " events_df = events_df.merge(gt_to_join, on=\"user_id\", how=\"inner\")\n", + " return events_df\n", + "\n", + "validation_events = add_gt_to_events(validation_events, validation_gt)\n", + "test_events = add_gt_to_events(test_events, test_gt)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "data_dir = Path(\"temp/data/\")\n", + "data_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "TRAIN_PATH = data_dir / \"train.parquet\"\n", + "VAL_PATH = data_dir / \"val.parquet\"\n", + "PREDICT_PATH = data_dir / \"test.parquet\"\n", + "\n", + "ENCODER_PATH = data_dir / \"encoder\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "train_events.to_parquet(TRAIN_PATH)\n", + "validation_events.to_parquet(VAL_PATH)\n", + "test_events.to_parquet(PREDICT_PATH)\n", + "\n", + "encoder.save(ENCODER_PATH)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare to model training\n", + "### Create the tensor schema\n", + "A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.\n", + "\n", + "Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data import FeatureHint, FeatureType\n", + "from replay.data.nn import TensorFeatureInfo, TensorSchema\n", + "\n", + "\n", + "EMBEDDING_DIM = 64\n", + "\n", + "encoder = encoder.load(ENCODER_PATH)\n", + "NUM_UNIQUE_ITEMS = len(encoder.mapping[\"item_id\"])\n", + "\n", + "tensor_schema = TensorSchema(\n", + " [\n", + " TensorFeatureInfo(\n", + " name=\"item_id\",\n", + " is_seq=True,\n", + " padding_value=NUM_UNIQUE_ITEMS,\n", + " cardinality=NUM_UNIQUE_ITEMS + 1, # taking into account padding\n", + " embedding_dim=EMBEDDING_DIM,\n", + " feature_type=FeatureType.CATEGORICAL,\n", + " feature_hint=FeatureHint.ITEM_ID,\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure ParquetModule and transformation pipelines\n", + "\n", + "The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s \"transform pipelines\" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass. \n", + "\n", + "For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.\n", + "\n", + "Internally this function creates the following transforms:\n", + "1) Training:\n", + " 1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).\n", + " 2. Rename features to match it with expected format by the model during training.\n", + " 3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.\n", + " 4. Group input features to be embed in expected format.\n", + "\n", + "2) Validation/Inference:\n", + " 1. Rename/group features to match it with expected format by the model during valdiation/inference.\n", + "\n", + "If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.\n", + "\n", + "**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "\n", + "import torch\n", + "\n", + "from replay.data.nn import TensorSchema\n", + "from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, UniformNegativeSamplingTransform" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def make_sasrec_transforms(\n", + " tensor_schema: TensorSchema, query_column: str = \"query_id\", num_negative_samples: int = 128,\n", + ") -> dict[str, list[torch.nn.Module]]:\n", + " item_column = tensor_schema.item_id_feature_name\n", + " vocab_size = tensor_schema[item_column].cardinality\n", + " train_transforms = [\n", + " UniformNegativeSamplingTransform(vocab_size, num_negative_samples),\n", + " NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),\n", + " RenameTransform(\n", + " {\n", + " query_column: \"query_id\",\n", + " f\"{item_column}_mask\": \"padding_mask\",\n", + " \"positive_labels_mask\": \"target_padding_mask\",\n", + " }\n", + " ),\n", + " UnsqueezeTransform(\"target_padding_mask\", -1),\n", + " UnsqueezeTransform(\"positive_labels\", -1),\n", + " GroupTransform({\"feature_tensors\": [item_column]}),\n", + " ]\n", + "\n", + " val_transforms = [\n", + " RenameTransform({query_column: \"query_id\", f\"{item_column}_mask\": \"padding_mask\"}),\n", + " GroupTransform({\"feature_tensors\": [item_column]}),\n", + " ]\n", + " test_transforms = copy.deepcopy(val_transforms)\n", + "\n", + " predict_transforms = copy.deepcopy(val_transforms)\n", + "\n", + " transforms = {\n", + " \"train\": train_transforms,\n", + " \"validate\": val_transforms,\n", + " \"test\": test_transforms,\n", + " \"predict\": predict_transforms,\n", + " }\n", + "\n", + " return transforms\n", + "\n", + "transforms = make_sasrec_transforms(tensor_schema, query_column=\"user_id\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_SEQ_LEN = 50\n", + "\n", + "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", + " meta = {\n", + " \"user_id\": {},\n", + " \"item_id\": {\"shape\": shape, \"padding\": tensor_schema[\"item_id\"].padding_value},\n", + " }\n", + " if gt_shape is not None:\n", + " meta.update({\"ground_truth\": {\"shape\": gt_shape, \"padding\": -1}})\n", + "\n", + " return meta\n", + "\n", + "train_metadata = {\n", + " \"train\": create_meta(shape=MAX_SEQ_LEN+1),\n", + " \"validate\": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data.nn import ParquetModule\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "parquet_module = ParquetModule(\n", + " train_path=TRAIN_PATH,\n", + " validate_path=VAL_PATH,\n", + " batch_size=BATCH_SIZE,\n", + " metadata=train_metadata,\n", + " transforms=transforms,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model\n", + "### Create SasRec model instance and run the training stage using lightning\n", + "We may now train the model using the Lightning trainer class. \n", + "\n", + "RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.\n", + "\n", + "#### Default Configuration\n", + "\n", + "Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.nn.sequential import SasRec\n", + "from typing import Literal\n", + "def make_sasrec(\n", + " schema: TensorSchema,\n", + " embedding_dim: int = 192,\n", + " num_heads: int = 4,\n", + " num_blocks: int = 2,\n", + " max_sequence_length: int = 50,\n", + " dropout: float = 0.3,\n", + " excluded_features: Optional[list[str]] = None,\n", + " categorical_list_feature_aggregation_method: Literal[\"sum\", \"mean\", \"max\"] = \"sum\",\n", + ") -> SasRec:\n", + " from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer\n", + " from replay.nn.agg import SumAggregator\n", + " from replay.nn.embedding import SequenceEmbedding\n", + " from replay.nn.loss import CE, CESampled\n", + " from replay.nn.mask import DefaultAttentionMask\n", + " from replay.nn.sequential.sasrec.agg import PositionAwareAggregator\n", + " from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer\n", + " excluded_features = [\n", + " schema.query_id_feature_name,\n", + " schema.timestamp_feature_name,\n", + " *(excluded_features or []),\n", + " ]\n", + " excluded_features = list(set(excluded_features))\n", + " body = SasRecBody(\n", + " embedder=SequenceEmbedding(\n", + " schema=schema,\n", + " categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,\n", + " excluded_features=excluded_features,\n", + " ),\n", + " embedding_aggregator=PositionAwareAggregator(\n", + " embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),\n", + " max_sequence_length=max_sequence_length,\n", + " dropout=dropout,\n", + " ),\n", + " attn_mask_builder=DefaultAttentionMask(\n", + " reference_feature_name=schema.item_id_feature_name,\n", + " num_heads=num_heads,\n", + " ),\n", + " encoder=SasRecTransformerLayer(\n", + " embedding_dim=embedding_dim,\n", + " num_heads=num_heads,\n", + " num_blocks=num_blocks,\n", + " dropout=dropout,\n", + " activation=\"relu\",\n", + " ),\n", + " output_normalization=torch.nn.LayerNorm(embedding_dim),\n", + " )\n", + " padding_idx = schema.item_id_features.item().padding_value\n", + " return SasRec(\n", + " body=body,\n", + " loss=CESampled(padding_idx=padding_idx),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_BLOCKS = 2\n", + "NUM_HEADS = 2\n", + "DROPOUT = 0.3\n", + "\n", + "sasrec = make_sasrec(\n", + " schema=tensor_schema,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " max_sequence_length=MAX_SEQ_LEN,\n", + " num_heads=NUM_HEADS,\n", + " num_blocks=NUM_BLOCKS,\n", + " dropout=DROPOUT,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A universal PyTorch Lightning module is provided. It can work with any NN model." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.nn.lightning.optimizer import OptimizerFactory\n", + "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", + "from replay.nn.lightning import LightningModule\n", + "\n", + "model = LightningModule(\n", + " sasrec,\n", + " optimizer_factory=OptimizerFactory(),\n", + " lr_scheduler_factory=LRSchedulerFactory(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To facilitate training, we add the following callbacks:\n", + "1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.\n", + "1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode | FLOPs\n", + "-------------------------------------------------\n", + "0 | model | SasRec | 291 K | train | 0 \n", + "-------------------------------------------------\n", + "291 K Trainable params\n", + "0 Non-trainable params\n", + "291 K Total params\n", + "1.164 Total estimated model params size (MB)\n", + "39 Modules in train mode\n", + "0 Modules in eval mode\n", + "0 Total Flops\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bdbe6004cd2b40f09355b6ed047313af", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscore
002247.242342
005726.818249
004866.81148
0013716.534966
002106.52649
............
60376039249710.457304
60376039350310.305973
60376039260110.280416
60376039275010.01198
6037603914909.916577
\n", + "

120760 rows × 3 columns

\n", + "" + ], + "text/plain": [ + " user_id item_id score\n", + "0 0 224 7.242342\n", + "0 0 572 6.818249\n", + "0 0 486 6.81148\n", + "0 0 1371 6.534966\n", + "0 0 210 6.52649\n", + "... ... ... ...\n", + "6037 6039 2497 10.457304\n", + "6037 6039 3503 10.305973\n", + "6037 6039 2601 10.280416\n", + "6037 6039 2750 10.01198\n", + "6037 6039 1490 9.916577\n", + "\n", + "[120760 rows x 3 columns]" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculating metrics\n", + "\n", + "*test_gt* is already encoded, so we can use it for computing metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.metrics import MAP, OfflineMetrics, Precision, Recall\n", + "from replay.metrics.torch_metrics_builder import metrics_to_df" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "result_metrics = OfflineMetrics(\n", + " [Recall(TOPK), Precision(TOPK), MAP(TOPK)],\n", + " query_column=\"user_id\",\n", + " rating_column=\"score\",\n", + ")(pandas_res, test_gt.explode(\"item_id\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
k110205
MAP0.0160650.0540390.0617490.043969
Precision0.0160650.0176550.0144580.020073
Recall0.0160650.1765490.2891690.100364
\n", + "
" + ], + "text/plain": [ + "k 1 10 20 5\n", + "MAP 0.016065 0.054039 0.061749 0.043969\n", + "Precision 0.016065 0.017655 0.014458 0.020073\n", + "Recall 0.016065 0.176549 0.289169 0.100364" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics_to_df(result_metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscore
020122317.242342
020125866.818249
020125006.81148
0201214856.534966
020122166.52649
............
60375727270210.457304
60375727374510.305973
60375727280610.280416
60375727296110.01198
6037572716239.916577
\n", + "

120760 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id score\n", + "0 2012 231 7.242342\n", + "0 2012 586 6.818249\n", + "0 2012 500 6.81148\n", + "0 2012 1485 6.534966\n", + "0 2012 216 6.52649\n", + "... ... ... ...\n", + "6037 5727 2702 10.457304\n", + "6037 5727 3745 10.305973\n", + "6037 5727 2806 10.280416\n", + "6037 5727 2961 10.01198\n", + "6037 5727 1623 9.916577\n", + "\n", + "[120760 rows x 3 columns]" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder.inverse_transform(pandas_res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/09_sasrec_example_thr_sampled.ipynb b/examples/09_sasrec_example_thr_sampled.ipynb new file mode 100644 index 000000000..b27ec3f8d --- /dev/null +++ b/examples/09_sasrec_example_thr_sampled.ipynb @@ -0,0 +1,4768 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of SasRec training/inference with Parquet Module" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + } + ], + "source": [ + "from typing import Optional\n", + "\n", + "import lightning as L\n", + "import pandas as pd\n", + "\n", + "L.seed_everything(42)\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparing data\n", + "In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.\n", + "\n", + "---\n", + "**NOTE**\n", + "\n", + "Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. \n", + "\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "interactions = pd.read_csv(\"./data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\",\"rating\",\"timestamp\"])\n", + "interactions = interactions.drop(columns=[\"rating\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idtimestamp
100013860408580
1000153604023841
99987360405932
1000007604019613
1000192604020194
............
82579349582399446
82543849581407447
82572449583264448
82573149582634449
82560349581924450
\n", + "

1000209 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id timestamp\n", + "1000138 6040 858 0\n", + "1000153 6040 2384 1\n", + "999873 6040 593 2\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", + "... ... ... ...\n", + "825793 4958 2399 446\n", + "825438 4958 1407 447\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", + "825603 4958 1924 450\n", + "\n", + "[1000209 rows x 3 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions[\"timestamp\"] = interactions[\"timestamp\"].astype(\"int64\")\n", + "interactions = interactions.sort_values(by=\"timestamp\")\n", + "interactions[\"timestamp\"] = interactions.groupby(\"user_id\").cumcount()\n", + "interactions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Encode catagorical data.\n", + "To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestampuser_iditem_id
01200
16810
26720
31230
414040
............
10002041445553705
10002059028133705
10002067024043705
10002072558353705
10002083809793705
\n", + "

1000209 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " timestamp user_id item_id\n", + "0 12 0 0\n", + "1 68 1 0\n", + "2 67 2 0\n", + "3 12 3 0\n", + "4 140 4 0\n", + "... ... ... ...\n", + "1000204 14 4555 3705\n", + "1000205 90 2813 3705\n", + "1000206 70 2404 3705\n", + "1000207 25 5835 3705\n", + "1000208 380 979 3705\n", + "\n", + "[1000209 rows x 3 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", + "\n", + "encoder = LabelEncoder(\n", + " [\n", + " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", + " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", + " ]\n", + ")\n", + "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", + "encoded_interactions = encoder.fit_transform(interactions)\n", + "encoded_interactions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split interactions into the train, validation and test datasets using LastNSplitter\n", + "We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.splitters import LastNSplitter\n", + "\n", + "splitter = LastNSplitter(\n", + " N=1,\n", + " divide_column=\"user_id\",\n", + " query_column=\"user_id\",\n", + " strategy=\"interactions\",\n", + " drop_cold_users=True,\n", + " drop_cold_items=True\n", + ")\n", + "\n", + "test_events, test_gt = splitter.split(encoded_interactions)\n", + "validation_events, validation_gt = splitter.split(test_events)\n", + "train_events = validation_events" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset preprocessing (\"baking\")\n", + "SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data.nn.utils import groupby_sequences\n", + "\n", + "\n", + "def bake_data(full_data):\n", + " grouped_interactions = groupby_sequences(events=full_data, groupby_col=\"user_id\", sort_col=\"timestamp\")\n", + " return grouped_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idtimestampitem_id
00[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,...
11[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ...
22[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621...
33[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1781, 2940, 2468, 890, 948, 106, 593, 309, 49...
44[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1...
............
60356035[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278...
60366036[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250...
60376037[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283...
60386038[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...
60396039[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2432, 2960, 1848, 2114, 2142, 3091, 3248, 317...
\n", + "

6040 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id timestamp \\\n", + "0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "1 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "2 2 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "3 3 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "4 4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "... ... ... \n", + "6035 6035 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6036 6036 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6037 6037 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6038 6038 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", + "\n", + " item_id \n", + "0 [2426, 822, 2733, 2587, 2937, 3618, 2943, 708,... \n", + "1 [3272, 3026, 2760, 851, 346, 3393, 1107, 515, ... \n", + "2 [579, 1140, 1154, 2426, 1524, 1260, 2160, 2621... \n", + "3 [1781, 2940, 2468, 890, 948, 106, 593, 309, 49... \n", + "4 [1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1... \n", + "... ... \n", + "6035 [2426, 1279, 3151, 3321, 1178, 3301, 2501, 278... \n", + "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 250... \n", + "6037 [1971, 3500, 2077, 1666, 1399, 2651, 2748, 283... \n", + "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", + "6039 [2432, 2960, 1848, 2114, 2142, 3091, 3248, 317... \n", + "\n", + "[6040 rows x 3 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_events = bake_data(train_events)\n", + "\n", + "validation_events = bake_data(validation_events)\n", + "validation_gt = bake_data(validation_gt)\n", + "\n", + "test_events = bake_data(test_events)\n", + "test_gt = bake_data(test_gt)\n", + "\n", + "train_events" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def add_gt_to_events(events_df, gt_df):\n", + " gt_to_join = gt_df[[\"user_id\", \"item_id\"]].rename(columns={\"item_id\": \"ground_truth\"})\n", + "\n", + " events_df = events_df.merge(gt_to_join, on=\"user_id\", how=\"inner\")\n", + " return events_df\n", + "\n", + "validation_events = add_gt_to_events(validation_events, validation_gt)\n", + "test_events = add_gt_to_events(test_events, test_gt)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "data_dir = Path(\"temp/data/\")\n", + "data_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "TRAIN_PATH = data_dir / \"train.parquet\"\n", + "VAL_PATH = data_dir / \"val.parquet\"\n", + "PREDICT_PATH = data_dir / \"test.parquet\"\n", + "\n", + "ENCODER_PATH = data_dir / \"encoder\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "train_events.to_parquet(TRAIN_PATH)\n", + "validation_events.to_parquet(VAL_PATH)\n", + "test_events.to_parquet(PREDICT_PATH)\n", + "\n", + "encoder.save(ENCODER_PATH)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare to model training\n", + "### Create the tensor schema\n", + "A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.\n", + "\n", + "Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data import FeatureHint, FeatureType\n", + "from replay.data.nn import TensorFeatureInfo, TensorSchema\n", + "\n", + "\n", + "EMBEDDING_DIM = 64\n", + "\n", + "encoder = encoder.load(ENCODER_PATH)\n", + "NUM_UNIQUE_ITEMS = len(encoder.mapping[\"item_id\"])\n", + "\n", + "tensor_schema = TensorSchema(\n", + " [\n", + " TensorFeatureInfo(\n", + " name=\"item_id\",\n", + " is_seq=True,\n", + " padding_value=NUM_UNIQUE_ITEMS,\n", + " cardinality=NUM_UNIQUE_ITEMS + 1, # taking into account padding\n", + " embedding_dim=EMBEDDING_DIM,\n", + " feature_type=FeatureType.CATEGORICAL,\n", + " feature_hint=FeatureHint.ITEM_ID,\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure ParquetModule and transformation pipelines\n", + "\n", + "The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s \"transform pipelines\" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass. \n", + "\n", + "For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.\n", + "\n", + "Internally this function creates the following transforms:\n", + "1) Training:\n", + " 1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).\n", + " 2. Rename features to match it with expected format by the model during training.\n", + " 3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.\n", + " 4. Group input features to be embed in expected format.\n", + "\n", + "2) Validation/Inference:\n", + " 1. Rename/group features to match it with expected format by the model during valdiation/inference.\n", + "\n", + "If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.\n", + "\n", + "**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "\n", + "import torch\n", + "\n", + "from replay.data.nn import TensorSchema\n", + "from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, ThresholdNegativeSamplingTransform" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def make_sasrec_transforms(\n", + " tensor_schema: TensorSchema, query_column: str = \"query_id\", num_negative_samples: int = 128,\n", + ") -> dict[str, list[torch.nn.Module]]:\n", + " item_column = tensor_schema.item_id_feature_name\n", + " vocab_size = tensor_schema[item_column].cardinality\n", + " train_transforms = [\n", + " ThresholdNegativeSamplingTransform(vocab_size, num_negative_samples),\n", + " NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),\n", + " RenameTransform(\n", + " {\n", + " query_column: \"query_id\",\n", + " f\"{item_column}_mask\": \"padding_mask\",\n", + " \"positive_labels_mask\": \"target_padding_mask\",\n", + " }\n", + " ),\n", + " UnsqueezeTransform(\"target_padding_mask\", -1),\n", + " UnsqueezeTransform(\"positive_labels\", -1),\n", + " GroupTransform({\"feature_tensors\": [item_column]}),\n", + " ]\n", + "\n", + " val_transforms = [\n", + " RenameTransform({query_column: \"query_id\", f\"{item_column}_mask\": \"padding_mask\"}),\n", + " GroupTransform({\"feature_tensors\": [item_column]}),\n", + " ]\n", + " test_transforms = copy.deepcopy(val_transforms)\n", + "\n", + " predict_transforms = copy.deepcopy(val_transforms)\n", + "\n", + " transforms = {\n", + " \"train\": train_transforms,\n", + " \"validate\": val_transforms,\n", + " \"test\": test_transforms,\n", + " \"predict\": predict_transforms,\n", + " }\n", + "\n", + " return transforms\n", + "\n", + "transforms = make_sasrec_transforms(tensor_schema, query_column=\"user_id\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_SEQ_LEN = 50\n", + "\n", + "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", + " meta = {\n", + " \"user_id\": {},\n", + " \"item_id\": {\"shape\": shape, \"padding\": tensor_schema[\"item_id\"].padding_value},\n", + " }\n", + " if gt_shape is not None:\n", + " meta.update({\"ground_truth\": {\"shape\": gt_shape, \"padding\": -1}})\n", + "\n", + " return meta\n", + "\n", + "train_metadata = {\n", + " \"train\": create_meta(shape=MAX_SEQ_LEN+1),\n", + " \"validate\": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.data.nn import ParquetModule\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "parquet_module = ParquetModule(\n", + " train_path=TRAIN_PATH,\n", + " validate_path=VAL_PATH,\n", + " batch_size=BATCH_SIZE,\n", + " metadata=train_metadata,\n", + " transforms=transforms,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model\n", + "### Create SasRec model instance and run the training stage using lightning\n", + "We may now train the model using the Lightning trainer class. \n", + "\n", + "RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.\n", + "\n", + "#### Default Configuration\n", + "\n", + "Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.nn.sequential import SasRec\n", + "from typing import Literal\n", + "def make_sasrec(\n", + " schema: TensorSchema,\n", + " embedding_dim: int = 192,\n", + " num_heads: int = 4,\n", + " num_blocks: int = 2,\n", + " max_sequence_length: int = 50,\n", + " dropout: float = 0.3,\n", + " excluded_features: Optional[list[str]] = None,\n", + " categorical_list_feature_aggregation_method: Literal[\"sum\", \"mean\", \"max\"] = \"sum\",\n", + ") -> SasRec:\n", + " from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer\n", + " from replay.nn.agg import SumAggregator\n", + " from replay.nn.embedding import SequenceEmbedding\n", + " from replay.nn.loss import CE, CESampled\n", + " from replay.nn.mask import DefaultAttentionMask\n", + " from replay.nn.sequential.sasrec.agg import PositionAwareAggregator\n", + " from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer\n", + " excluded_features = [\n", + " schema.query_id_feature_name,\n", + " schema.timestamp_feature_name,\n", + " *(excluded_features or []),\n", + " ]\n", + " excluded_features = list(set(excluded_features))\n", + " body = SasRecBody(\n", + " embedder=SequenceEmbedding(\n", + " schema=schema,\n", + " categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,\n", + " excluded_features=excluded_features,\n", + " ),\n", + " embedding_aggregator=PositionAwareAggregator(\n", + " embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),\n", + " max_sequence_length=max_sequence_length,\n", + " dropout=dropout,\n", + " ),\n", + " attn_mask_builder=DefaultAttentionMask(\n", + " reference_feature_name=schema.item_id_feature_name,\n", + " num_heads=num_heads,\n", + " ),\n", + " encoder=SasRecTransformerLayer(\n", + " embedding_dim=embedding_dim,\n", + " num_heads=num_heads,\n", + " num_blocks=num_blocks,\n", + " dropout=dropout,\n", + " activation=\"relu\",\n", + " ),\n", + " output_normalization=torch.nn.LayerNorm(embedding_dim),\n", + " )\n", + " padding_idx = schema.item_id_features.item().padding_value\n", + " return SasRec(\n", + " body=body,\n", + " loss=CESampled(padding_idx=padding_idx),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_BLOCKS = 2\n", + "NUM_HEADS = 2\n", + "DROPOUT = 0.3\n", + "\n", + "sasrec = make_sasrec(\n", + " schema=tensor_schema,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " max_sequence_length=MAX_SEQ_LEN,\n", + " num_heads=NUM_HEADS,\n", + " num_blocks=NUM_BLOCKS,\n", + " dropout=DROPOUT,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A universal PyTorch Lightning module is provided. It can work with any NN model." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.nn.lightning.optimizer import OptimizerFactory\n", + "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", + "from replay.nn.lightning import LightningModule\n", + "\n", + "model = LightningModule(\n", + " sasrec,\n", + " optimizer_factory=OptimizerFactory(),\n", + " lr_scheduler_factory=LRSchedulerFactory(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To facilitate training, we add the following callbacks:\n", + "1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.\n", + "1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode | FLOPs\n", + "-------------------------------------------------\n", + "0 | model | SasRec | 291 K | train | 0 \n", + "-------------------------------------------------\n", + "291 K Trainable params\n", + "0 Non-trainable params\n", + "291 K Total params\n", + "1.164 Total estimated model params size (MB)\n", + "39 Modules in train mode\n", + "0 Modules in eval mode\n", + "0 Total Flops\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "21ca5ea332d9418fb2481e0296b399a0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.figure(figsize = (5, 4), dpi = 120)\n", + "plt.hist(transforms[\"train\"][0].frequencies.cpu().numpy(), bins = 50)\n", + "plt.grid()\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can get the best model path stored in the checkpoint callback." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=99-step=18900.ckpt'" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_model_path = checkpoint_callback.best_model_path\n", + "best_model_path" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import replay\n", + "torch.serialization.add_safe_globals([\n", + " replay.nn.lightning.optimizer.OptimizerFactory,\n", + " replay.nn.lightning.scheduler.LRSchedulerFactory\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "sasrec = make_sasrec(\n", + " schema=tensor_schema,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " max_sequence_length=MAX_SEQ_LEN,\n", + " num_heads=NUM_HEADS,\n", + " num_blocks=NUM_BLOCKS,\n", + " dropout=DROPOUT,\n", + ")\n", + "\n", + "best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n", + "best_model.eval();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configure `ParquetModule` for inference" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", + "\n", + "parquet_module = ParquetModule(\n", + " predict_path=PREDICT_PATH,\n", + " batch_size=BATCH_SIZE,\n", + " metadata=inference_metadata,\n", + " transforms=transforms,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.\n", + "\n", + "\n", + "Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "94a84cd4fca248e7b340909e2048450f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscore
004867.196251
001017.086551
002106.872698
0016136.834656
003576.555361
............
603760398411.528885
60376039270011.400557
60376039255711.273298
60376039263411.121119
60376039263311.119842
\n", + "

120760 rows × 3 columns

\n", + "" + ], + "text/plain": [ + " user_id item_id score\n", + "0 0 486 7.196251\n", + "0 0 101 7.086551\n", + "0 0 210 6.872698\n", + "0 0 1613 6.834656\n", + "0 0 357 6.555361\n", + "... ... ... ...\n", + "6037 6039 84 11.528885\n", + "6037 6039 2700 11.400557\n", + "6037 6039 2557 11.273298\n", + "6037 6039 2634 11.121119\n", + "6037 6039 2633 11.119842\n", + "\n", + "[120760 rows x 3 columns]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculating metrics\n", + "\n", + "*test_gt* is already encoded, so we can use it for computing metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "from replay.metrics import MAP, OfflineMetrics, Precision, Recall\n", + "from replay.metrics.torch_metrics_builder import metrics_to_df" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "result_metrics = OfflineMetrics(\n", + " [Recall(TOPK), Precision(TOPK), MAP(TOPK)],\n", + " query_column=\"user_id\",\n", + " rating_column=\"score\",\n", + ")(pandas_res, test_gt.explode(\"item_id\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
k110205
MAP0.0160650.0566370.0640670.045948
Precision0.0160650.0186650.0147980.021133
Recall0.0160650.1866510.2959590.105664
\n", + "
" + ], + "text/plain": [ + "k 1 10 20 5\n", + "MAP 0.016065 0.056637 0.064067 0.045948\n", + "Precision 0.016065 0.018665 0.014798 0.021133\n", + "Recall 0.016065 0.186651 0.295959 0.105664" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics_to_df(result_metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscore
020125007.196251
020121047.086551
020122166.872698
0201217776.834656
020123676.555361
............
603757278611.528885
60375727290711.400557
60375727276211.273298
60375727284111.121119
60375727284011.119842
\n", + "

120760 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id score\n", + "0 2012 500 7.196251\n", + "0 2012 104 7.086551\n", + "0 2012 216 6.872698\n", + "0 2012 1777 6.834656\n", + "0 2012 367 6.555361\n", + "... ... ... ...\n", + "6037 5727 86 11.528885\n", + "6037 5727 2907 11.400557\n", + "6037 5727 2762 11.273298\n", + "6037 5727 2841 11.121119\n", + "6037 5727 2840 11.119842\n", + "\n", + "[120760 rows x 3 columns]" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder.inverse_transform(pandas_res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "new_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/replay/nn/transform/__init__.py b/replay/nn/transform/__init__.py index 355cf8351..40c2cdec1 100644 --- a/replay/nn/transform/__init__.py +++ b/replay/nn/transform/__init__.py @@ -1,6 +1,14 @@ from .copy import CopyTransform from .grouping import GroupTransform +<<<<<<< HEAD from .negative_sampling import MultiClassNegativeSamplingTransform, UniformNegativeSamplingTransform +======= +from .negative_sampling import ( + UniformNegativeSamplingTransform, + FrequencyNegativeSamplingTransform, + ThresholdNegativeSamplingTransform, +) +>>>>>>> 36b25601 (Gitignore updated to recent examples) from .next_token import NextTokenTransform from .rename import RenameTransform from .reshape import UnsqueezeTransform @@ -18,5 +26,7 @@ "TokenMaskTransform", "TrimTransform", "UniformNegativeSamplingTransform", + "FrequencyNegativeSamplingTransform", + "ThresholdNegativeSamplingTransform", "UnsqueezeTransform", ] diff --git a/replay/nn/transform/negative_sampling.py b/replay/nn/transform/negative_sampling.py index 4d1364eda..34debab8e 100644 --- a/replay/nn/transform/negative_sampling.py +++ b/replay/nn/transform/negative_sampling.py @@ -1,6 +1,9 @@ -from typing import Optional +from typing import Optional, Literal, cast import torch +import torch.nn.functional as func + +import warnings class UniformNegativeSamplingTransform(torch.nn.Module): @@ -29,7 +32,7 @@ def __init__( cardinality: int, num_negative_samples: int, *, - out_feature_name: Optional[str] = "negative_labels", + out_feature_name: str = "negative_labels", sample_distribution: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, ) -> None: @@ -43,12 +46,20 @@ def __init__( :param generator: Random number generator to be used for sampling from the distribution. Default: ``None``. """ - if sample_distribution is not None and sample_distribution.size(-1) != cardinality: - msg = ( - "The sample_distribution parameter has an incorrect size. " - f"Got {sample_distribution.size(-1)}, expected {cardinality}." - ) - raise ValueError(msg) + if sample_distribution is not None: + if sample_distribution.ndim != 1: + msg: str = ( + "The `sample_distribution` parameter must be 1D." + f"Got {sample_distribution.ndim}, will be flattened." + ) + warnings.warn(msg) + sample_distribution = sample_distribution.flatten() + if sample_distribution.size(-1) != cardinality: + msg: str = ( + "The sample_distribution parameter has an incorrect size. " + f"Got {sample_distribution.size(-1)}, expected {cardinality}." + ) + raise ValueError(msg) if num_negative_samples >= cardinality: msg = ( @@ -63,9 +74,13 @@ def __init__( self.num_negative_samples = num_negative_samples self.generator = generator if sample_distribution is not None: - self.sample_distribution = sample_distribution + sample_distribution = sample_distribution else: - self.sample_distribution = torch.ones(cardinality) + sample_distribution = torch.ones(cardinality) + + self.sample_distribution = torch.nn.Buffer( + cast(torch.Tensor, sample_distribution) + ) def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_batch = dict(batch.items()) @@ -77,7 +92,186 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: generator=self.generator, ) - output_batch[self.out_feature_name] = negatives.to(device=next(iter(output_batch.values())).device) + device = next(iter(output_batch.values())).device + output_batch[self.out_feature_name] = negatives.to(device) + return output_batch + +class FrequencyNegativeSamplingTransform(torch.nn.Module): + """ + Transform for global negative sampling. + + For every batch, transform generates a vector of size ``(num_negative_samples)`` + consisting of random indices sampeled from a range of ``cardinality``. Unless a custom sample + distribution is provided, the indices are weighted equally. + + Example: + + .. code-block:: python + + >>> _ = torch.manual_seed(0) + >>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])} + >>> transform = UniformNegativeSamplingTransform(cardinality=4, num_negative_samples=2) + >>> output_batch = transform(input_batch) + >>> output_batch + {'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])} + + """ + + def __init__( + self, + cardinality: int, + num_negative_samples: int, + *, + out_feature_name: str = "negative_labels", + generator: Optional[torch.Generator] = None, + mode: Literal["softmax", "softsum"] = "softmax", + ) -> None: + """ + :param cardinality: The size of sample vocabulary. + :param num_negative_samples: The size of negatives vector to generate. + :param out_feature_name: The name of result feature in batch. + :param sample_distribution: The weighs of indices in the vocabulary. If specified, must + match the ``cardinality``. Default: ``None``. + :param generator: Random number generator to be used for sampling + from the distribution. Default: ``None``. + """ + assert num_negative_samples < cardinality + + super().__init__() + + self.cardinality = cardinality + self.out_feature_name = out_feature_name + self.num_negative_samples = num_negative_samples + self.generator = generator + self.mode = mode + + self.frequencies = torch.nn.Buffer( + torch.zeros(cardinality, dtype = torch.int64) + ) + + def get_probas(self) -> torch.Tensor: + raw: torch.Tensor = 1.0 / (1.0 + self.frequencies) + match self.mode: + case "softsum": + result: torch.Tensor = raw / torch.sum(raw) + case "softmax": + result: torch.Tensor = func.softmax(raw, dim = -1) + case _: + msg: str = f"Unsupported mode: {self.mode}." + raise TypeError(msg) + return result + + def update_probas(self, selected: torch.Tensor) -> None: + device = self.frequencies.device + one = torch.ones(1, dtype = torch.int64, device = device) + self.frequencies.index_add_(-1, selected, one.expand(selected.numel())) + + def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + output_batch = dict(batch.items()) + + negatives = torch.multinomial( + input = self.get_probas(), + num_samples=self.num_negative_samples, + replacement=False, + generator=self.generator, + ) + + self.update_probas(negatives) + + device = next(iter(output_batch.values())).device + output_batch[self.out_feature_name] = negatives.to(device) + return output_batch + +class ThresholdNegativeSamplingTransform(torch.nn.Module): + """ + Transform for global negative sampling. + + For every batch, transform generates a vector of size ``(num_negative_samples)`` + consisting of random indices sampeled from a range of ``cardinality``. Unless a custom sample + distribution is provided, the indices are weighted equally. + + Example: + + .. code-block:: python + + >>> _ = torch.manual_seed(0) + >>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])} + >>> transform = UniformNegativeSamplingTransform(cardinality=4, num_negative_samples=2) + >>> output_batch = transform(input_batch) + >>> output_batch + {'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])} + + """ + + def __init__( + self, + cardinality: int, + num_negative_samples: int, + *, + out_feature_name: str = "negative_labels", + generator: Optional[torch.Generator] = None, + mode: Literal["softmax", "softsum"] = "softmax", + ) -> None: + """ + :param cardinality: The size of sample vocabulary. + :param num_negative_samples: The size of negatives vector to generate. + :param out_feature_name: The name of result feature in batch. + :param sample_distribution: The weighs of indices in the vocabulary. If specified, must + match the ``cardinality``. Default: ``None``. + :param generator: Random number generator to be used for sampling + from the distribution. Default: ``None``. + """ + assert num_negative_samples < cardinality + + super().__init__() + + self.cardinality = cardinality + self.out_feature_name = out_feature_name + self.num_negative_samples = num_negative_samples + self.generator = generator + self.mode = mode + + self.frequencies = torch.nn.Buffer( + torch.zeros(cardinality, dtype = torch.int64) + ) + + def get_probas(self) -> torch.Tensor: + raw: torch.Tensor = 1.0 / (1.0 + self.frequencies) + thr: torch.Tensor = torch.max(self.frequencies) + mask: torch.Tensor = thr != self.frequencies + match self.mode: + case "softsum": + eps = torch.finfo(raw.dtype).eps + raw = torch.where(mask, raw, eps) + result: torch.Tensor = raw / torch.sum(raw) + case "softmax": + inf = torch.finfo(raw.dtype).min + raw = torch.where(mask, raw, inf) + result: torch.Tensor = func.softmax(raw, dim = -1) + case _: + msg: str = f"Unsupported mode: {self.mode}." + raise TypeError(msg) + return result + + def update_probas(self, selected: torch.Tensor) -> None: + device = self.frequencies.device + one = torch.ones(1, dtype = torch.int64, device = device) + self.frequencies.index_add_(-1, selected, one.expand(selected.numel())) + + def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + output_batch = dict(batch.items()) + + negatives = torch.multinomial( + input = self.get_probas(), + num_samples=self.num_negative_samples, + replacement=False, + generator=self.generator, + ) + + self.update_probas(negatives) + + device = next(iter(output_batch.values())).device + output_batch[self.out_feature_name] = negatives.to(device) return output_batch @@ -124,8 +318,8 @@ def __init__( num_negative_samples: int, sample_mask: torch.Tensor, *, - negative_selector_name: Optional[str] = "negative_selector", - out_feature_name: Optional[str] = "negative_labels", + negative_selector_name: str = "negative_selector", + out_feature_name: str = "negative_labels", generator: Optional[torch.Generator] = None, ) -> None: """ @@ -153,7 +347,9 @@ def __init__( super().__init__() - self.register_buffer("sample_mask", sample_mask.float()) + self.sample_mask = torch.nn.Buffer( + sample_mask.float() + ) self.num_negative_samples = num_negative_samples self.negative_selector_name = negative_selector_name From 584b54a46fcafd0943b8178f21b5c178c17856f6 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Mon, 2 Feb 2026 07:16:13 +0300 Subject: [PATCH 2/5] Unnecessary files removed --- examples/09_sasrec_example_freq_sampled.ipynb | 4768 ----------------- examples/09_sasrec_example_sampled.ipynb | 4744 ---------------- examples/09_sasrec_example_thr_sampled.ipynb | 4768 ----------------- 3 files changed, 14280 deletions(-) delete mode 100644 examples/09_sasrec_example_freq_sampled.ipynb delete mode 100644 examples/09_sasrec_example_sampled.ipynb delete mode 100644 examples/09_sasrec_example_thr_sampled.ipynb diff --git a/examples/09_sasrec_example_freq_sampled.ipynb b/examples/09_sasrec_example_freq_sampled.ipynb deleted file mode 100644 index 4988e2512..000000000 --- a/examples/09_sasrec_example_freq_sampled.ipynb +++ /dev/null @@ -1,4768 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Example of SasRec training/inference with Parquet Module" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 42\n" - ] - } - ], - "source": [ - "from typing import Optional\n", - "\n", - "import lightning as L\n", - "import pandas as pd\n", - "\n", - "L.seed_everything(42)\n", - "\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparing data\n", - "In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.\n", - "\n", - "---\n", - "**NOTE**\n", - "\n", - "Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. \n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "interactions = pd.read_csv(\"./data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\",\"rating\",\"timestamp\"])\n", - "interactions = interactions.drop(columns=[\"rating\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idtimestamp
100013860408580
1000153604023841
99987360405932
1000007604019613
1000192604020194
............
82579349582399446
82543849581407447
82572449583264448
82573149582634449
82560349581924450
\n", - "

1000209 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id timestamp\n", - "1000138 6040 858 0\n", - "1000153 6040 2384 1\n", - "999873 6040 593 2\n", - "1000007 6040 1961 3\n", - "1000192 6040 2019 4\n", - "... ... ... ...\n", - "825793 4958 2399 446\n", - "825438 4958 1407 447\n", - "825724 4958 3264 448\n", - "825731 4958 2634 449\n", - "825603 4958 1924 450\n", - "\n", - "[1000209 rows x 3 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "interactions[\"timestamp\"] = interactions[\"timestamp\"].astype(\"int64\")\n", - "interactions = interactions.sort_values(by=\"timestamp\")\n", - "interactions[\"timestamp\"] = interactions.groupby(\"user_id\").cumcount()\n", - "interactions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Encode catagorical data.\n", - "To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
timestampuser_iditem_id
01200
16810
26720
31230
414040
............
10002041445553705
10002059028133705
10002067024043705
10002072558353705
10002083809793705
\n", - "

1000209 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " timestamp user_id item_id\n", - "0 12 0 0\n", - "1 68 1 0\n", - "2 67 2 0\n", - "3 12 3 0\n", - "4 140 4 0\n", - "... ... ... ...\n", - "1000204 14 4555 3705\n", - "1000205 90 2813 3705\n", - "1000206 70 2404 3705\n", - "1000207 25 5835 3705\n", - "1000208 380 979 3705\n", - "\n", - "[1000209 rows x 3 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", - "\n", - "encoder = LabelEncoder(\n", - " [\n", - " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", - " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", - " ]\n", - ")\n", - "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", - "encoded_interactions = encoder.fit_transform(interactions)\n", - "encoded_interactions" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Split interactions into the train, validation and test datasets using LastNSplitter\n", - "We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.splitters import LastNSplitter\n", - "\n", - "splitter = LastNSplitter(\n", - " N=1,\n", - " divide_column=\"user_id\",\n", - " query_column=\"user_id\",\n", - " strategy=\"interactions\",\n", - " drop_cold_users=True,\n", - " drop_cold_items=True\n", - ")\n", - "\n", - "test_events, test_gt = splitter.split(encoded_interactions)\n", - "validation_events, validation_gt = splitter.split(test_events)\n", - "train_events = validation_events" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dataset preprocessing (\"baking\")\n", - "SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data.nn.utils import groupby_sequences\n", - "\n", - "\n", - "def bake_data(full_data):\n", - " grouped_interactions = groupby_sequences(events=full_data, groupby_col=\"user_id\", sort_col=\"timestamp\")\n", - " return grouped_interactions" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idtimestampitem_id
00[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,...
11[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ...
22[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621...
33[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1781, 2940, 2468, 890, 948, 106, 593, 309, 49...
44[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1...
............
60356035[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278...
60366036[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250...
60376037[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283...
60386038[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...
60396039[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2432, 2960, 1848, 2114, 2142, 3091, 3248, 317...
\n", - "

6040 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id timestamp \\\n", - "0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "1 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "2 2 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "3 3 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "4 4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "... ... ... \n", - "6035 6035 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6036 6036 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6037 6037 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6038 6038 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "\n", - " item_id \n", - "0 [2426, 822, 2733, 2587, 2937, 3618, 2943, 708,... \n", - "1 [3272, 3026, 2760, 851, 346, 3393, 1107, 515, ... \n", - "2 [579, 1140, 1154, 2426, 1524, 1260, 2160, 2621... \n", - "3 [1781, 2940, 2468, 890, 948, 106, 593, 309, 49... \n", - "4 [1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1... \n", - "... ... \n", - "6035 [2426, 1279, 3151, 3321, 1178, 3301, 2501, 278... \n", - "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 250... \n", - "6037 [1971, 3500, 2077, 1666, 1399, 2651, 2748, 283... \n", - "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", - "6039 [2432, 2960, 1848, 2114, 2142, 3091, 3248, 317... \n", - "\n", - "[6040 rows x 3 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_events = bake_data(train_events)\n", - "\n", - "validation_events = bake_data(validation_events)\n", - "validation_gt = bake_data(validation_gt)\n", - "\n", - "test_events = bake_data(test_events)\n", - "test_gt = bake_data(test_gt)\n", - "\n", - "train_events" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones. " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def add_gt_to_events(events_df, gt_df):\n", - " gt_to_join = gt_df[[\"user_id\", \"item_id\"]].rename(columns={\"item_id\": \"ground_truth\"})\n", - "\n", - " events_df = events_df.merge(gt_to_join, on=\"user_id\", how=\"inner\")\n", - " return events_df\n", - "\n", - "validation_events = add_gt_to_events(validation_events, validation_gt)\n", - "test_events = add_gt_to_events(test_events, test_gt)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "data_dir = Path(\"temp/data/\")\n", - "data_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "TRAIN_PATH = data_dir / \"train.parquet\"\n", - "VAL_PATH = data_dir / \"val.parquet\"\n", - "PREDICT_PATH = data_dir / \"test.parquet\"\n", - "\n", - "ENCODER_PATH = data_dir / \"encoder\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "train_events.to_parquet(TRAIN_PATH)\n", - "validation_events.to_parquet(VAL_PATH)\n", - "test_events.to_parquet(PREDICT_PATH)\n", - "\n", - "encoder.save(ENCODER_PATH)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prepare to model training\n", - "### Create the tensor schema\n", - "A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.\n", - "\n", - "Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data import FeatureHint, FeatureType\n", - "from replay.data.nn import TensorFeatureInfo, TensorSchema\n", - "\n", - "\n", - "EMBEDDING_DIM = 64\n", - "\n", - "encoder = encoder.load(ENCODER_PATH)\n", - "NUM_UNIQUE_ITEMS = len(encoder.mapping[\"item_id\"])\n", - "\n", - "tensor_schema = TensorSchema(\n", - " [\n", - " TensorFeatureInfo(\n", - " name=\"item_id\",\n", - " is_seq=True,\n", - " padding_value=NUM_UNIQUE_ITEMS,\n", - " cardinality=NUM_UNIQUE_ITEMS + 1, # taking into account padding\n", - " embedding_dim=EMBEDDING_DIM,\n", - " feature_type=FeatureType.CATEGORICAL,\n", - " feature_hint=FeatureHint.ITEM_ID,\n", - " )\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configure ParquetModule and transformation pipelines\n", - "\n", - "The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s \"transform pipelines\" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass. \n", - "\n", - "For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.\n", - "\n", - "Internally this function creates the following transforms:\n", - "1) Training:\n", - " 1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).\n", - " 2. Rename features to match it with expected format by the model during training.\n", - " 3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.\n", - " 4. Group input features to be embed in expected format.\n", - "\n", - "2) Validation/Inference:\n", - " 1. Rename/group features to match it with expected format by the model during valdiation/inference.\n", - "\n", - "If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.\n", - "\n", - "**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "import copy\n", - "\n", - "import torch\n", - "\n", - "from replay.data.nn import TensorSchema\n", - "from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, FrequencyNegativeSamplingTransform" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def make_sasrec_transforms(\n", - " tensor_schema: TensorSchema, query_column: str = \"query_id\", num_negative_samples: int = 128,\n", - ") -> dict[str, list[torch.nn.Module]]:\n", - " item_column = tensor_schema.item_id_feature_name\n", - " vocab_size = tensor_schema[item_column].cardinality\n", - " train_transforms = [\n", - " FrequencyNegativeSamplingTransform(vocab_size, num_negative_samples),\n", - " NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),\n", - " RenameTransform(\n", - " {\n", - " query_column: \"query_id\",\n", - " f\"{item_column}_mask\": \"padding_mask\",\n", - " \"positive_labels_mask\": \"target_padding_mask\",\n", - " }\n", - " ),\n", - " UnsqueezeTransform(\"target_padding_mask\", -1),\n", - " UnsqueezeTransform(\"positive_labels\", -1),\n", - " GroupTransform({\"feature_tensors\": [item_column]}),\n", - " ]\n", - "\n", - " val_transforms = [\n", - " RenameTransform({query_column: \"query_id\", f\"{item_column}_mask\": \"padding_mask\"}),\n", - " GroupTransform({\"feature_tensors\": [item_column]}),\n", - " ]\n", - " test_transforms = copy.deepcopy(val_transforms)\n", - "\n", - " predict_transforms = copy.deepcopy(val_transforms)\n", - "\n", - " transforms = {\n", - " \"train\": train_transforms,\n", - " \"validate\": val_transforms,\n", - " \"test\": test_transforms,\n", - " \"predict\": predict_transforms,\n", - " }\n", - "\n", - " return transforms\n", - "\n", - "transforms = make_sasrec_transforms(tensor_schema, query_column=\"user_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "MAX_SEQ_LEN = 50\n", - "\n", - "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", - " meta = {\n", - " \"user_id\": {},\n", - " \"item_id\": {\"shape\": shape, \"padding\": tensor_schema[\"item_id\"].padding_value},\n", - " }\n", - " if gt_shape is not None:\n", - " meta.update({\"ground_truth\": {\"shape\": gt_shape, \"padding\": -1}})\n", - "\n", - " return meta\n", - "\n", - "train_metadata = {\n", - " \"train\": create_meta(shape=MAX_SEQ_LEN+1),\n", - " \"validate\": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data.nn import ParquetModule\n", - "\n", - "BATCH_SIZE = 32\n", - "\n", - "parquet_module = ParquetModule(\n", - " train_path=TRAIN_PATH,\n", - " validate_path=VAL_PATH,\n", - " batch_size=BATCH_SIZE,\n", - " metadata=train_metadata,\n", - " transforms=transforms,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train model\n", - "### Create SasRec model instance and run the training stage using lightning\n", - "We may now train the model using the Lightning trainer class. \n", - "\n", - "RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.\n", - "\n", - "#### Default Configuration\n", - "\n", - "Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.nn.sequential import SasRec\n", - "from typing import Literal\n", - "def make_sasrec(\n", - " schema: TensorSchema,\n", - " embedding_dim: int = 192,\n", - " num_heads: int = 4,\n", - " num_blocks: int = 2,\n", - " max_sequence_length: int = 50,\n", - " dropout: float = 0.3,\n", - " excluded_features: Optional[list[str]] = None,\n", - " categorical_list_feature_aggregation_method: Literal[\"sum\", \"mean\", \"max\"] = \"sum\",\n", - ") -> SasRec:\n", - " from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer\n", - " from replay.nn.agg import SumAggregator\n", - " from replay.nn.embedding import SequenceEmbedding\n", - " from replay.nn.loss import CE, CESampled\n", - " from replay.nn.mask import DefaultAttentionMask\n", - " from replay.nn.sequential.sasrec.agg import PositionAwareAggregator\n", - " from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer\n", - " excluded_features = [\n", - " schema.query_id_feature_name,\n", - " schema.timestamp_feature_name,\n", - " *(excluded_features or []),\n", - " ]\n", - " excluded_features = list(set(excluded_features))\n", - " body = SasRecBody(\n", - " embedder=SequenceEmbedding(\n", - " schema=schema,\n", - " categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,\n", - " excluded_features=excluded_features,\n", - " ),\n", - " embedding_aggregator=PositionAwareAggregator(\n", - " embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),\n", - " max_sequence_length=max_sequence_length,\n", - " dropout=dropout,\n", - " ),\n", - " attn_mask_builder=DefaultAttentionMask(\n", - " reference_feature_name=schema.item_id_feature_name,\n", - " num_heads=num_heads,\n", - " ),\n", - " encoder=SasRecTransformerLayer(\n", - " embedding_dim=embedding_dim,\n", - " num_heads=num_heads,\n", - " num_blocks=num_blocks,\n", - " dropout=dropout,\n", - " activation=\"relu\",\n", - " ),\n", - " output_normalization=torch.nn.LayerNorm(embedding_dim),\n", - " )\n", - " padding_idx = schema.item_id_features.item().padding_value\n", - " return SasRec(\n", - " body=body,\n", - " loss=CESampled(padding_idx=padding_idx),\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "NUM_BLOCKS = 2\n", - "NUM_HEADS = 2\n", - "DROPOUT = 0.3\n", - "\n", - "sasrec = make_sasrec(\n", - " schema=tensor_schema,\n", - " embedding_dim=EMBEDDING_DIM,\n", - " max_sequence_length=MAX_SEQ_LEN,\n", - " num_heads=NUM_HEADS,\n", - " num_blocks=NUM_BLOCKS,\n", - " dropout=DROPOUT,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A universal PyTorch Lightning module is provided. It can work with any NN model." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.nn.lightning.optimizer import OptimizerFactory\n", - "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", - "from replay.nn.lightning import LightningModule\n", - "\n", - "model = LightningModule(\n", - " sasrec,\n", - " optimizer_factory=OptimizerFactory(),\n", - " lr_scheduler_factory=LRSchedulerFactory(),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To facilitate training, we add the following callbacks:\n", - "1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.\n", - "1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "\n", - " | Name | Type | Params | Mode | FLOPs\n", - "-------------------------------------------------\n", - "0 | model | SasRec | 291 K | train | 0 \n", - "-------------------------------------------------\n", - "291 K Trainable params\n", - "0 Non-trainable params\n", - "291 K Total params\n", - "1.164 Total estimated model params size (MB)\n", - "39 Modules in train mode\n", - "0 Modules in eval mode\n", - "0 Total Flops\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "45904c0e26294a109f736787e2831727", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "plt.figure(figsize = (5, 4), dpi = 120)\n", - "plt.hist(transforms[\"train\"][0].frequencies.cpu().numpy(), bins = 50)\n", - "plt.grid()\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can get the best model path stored in the checkpoint callback." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=96-step=18333.ckpt'" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "best_model_path = checkpoint_callback.best_model_path\n", - "best_model_path" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inference\n", - "\n", - "To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "import replay\n", - "torch.serialization.add_safe_globals([\n", - " replay.nn.lightning.optimizer.OptimizerFactory,\n", - " replay.nn.lightning.scheduler.LRSchedulerFactory\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "sasrec = make_sasrec(\n", - " schema=tensor_schema,\n", - " embedding_dim=EMBEDDING_DIM,\n", - " max_sequence_length=MAX_SEQ_LEN,\n", - " num_heads=NUM_HEADS,\n", - " num_blocks=NUM_BLOCKS,\n", - " dropout=DROPOUT,\n", - ")\n", - "\n", - "best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n", - "best_model.eval();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Configure `ParquetModule` for inference" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", - "\n", - "parquet_module = ParquetModule(\n", - " predict_path=PREDICT_PATH,\n", - " batch_size=BATCH_SIZE,\n", - " metadata=inference_metadata,\n", - " transforms=transforms,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.\n", - "\n", - "\n", - "Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9b94b2f1745140e5ae3bc1d63bdbbe4e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscore
004866.457348
002246.311622
002106.299765
001016.193144
0021426.179489
............
603760398410.029899
6037603926349.860862
6037603914909.738086
6037603926339.715384
6037603924979.684793
\n", - "

120760 rows × 3 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score\n", - "0 0 486 6.457348\n", - "0 0 224 6.311622\n", - "0 0 210 6.299765\n", - "0 0 101 6.193144\n", - "0 0 2142 6.179489\n", - "... ... ... ...\n", - "6037 6039 84 10.029899\n", - "6037 6039 2634 9.860862\n", - "6037 6039 1490 9.738086\n", - "6037 6039 2633 9.715384\n", - "6037 6039 2497 9.684793\n", - "\n", - "[120760 rows x 3 columns]" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pandas_res" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Calculating metrics\n", - "\n", - "*test_gt* is already encoded, so we can use it for computing metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.metrics import MAP, OfflineMetrics, Precision, Recall\n", - "from replay.metrics.torch_metrics_builder import metrics_to_df" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "result_metrics = OfflineMetrics(\n", - " [Recall(TOPK), Precision(TOPK), MAP(TOPK)],\n", - " query_column=\"user_id\",\n", - " rating_column=\"score\",\n", - ")(pandas_res, test_gt.explode(\"item_id\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
k110205
MAP0.0203710.0611890.0685750.051101
Precision0.0203710.0187310.0147900.022093
Recall0.0203710.1873140.2957930.110467
\n", - "
" - ], - "text/plain": [ - "k 1 10 20 5\n", - "MAP 0.020371 0.061189 0.068575 0.051101\n", - "Precision 0.020371 0.018731 0.014790 0.022093\n", - "Recall 0.020371 0.187314 0.295793 0.110467" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metrics_to_df(result_metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscore
020125006.457348
020122316.311622
020122166.299765
020121046.193144
0201223356.179489
............
603757278610.029899
6037572728419.860862
6037572716239.738086
6037572728409.715384
6037572727029.684793
\n", - "

120760 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id score\n", - "0 2012 500 6.457348\n", - "0 2012 231 6.311622\n", - "0 2012 216 6.299765\n", - "0 2012 104 6.193144\n", - "0 2012 2335 6.179489\n", - "... ... ... ...\n", - "6037 5727 86 10.029899\n", - "6037 5727 2841 9.860862\n", - "6037 5727 1623 9.738086\n", - "6037 5727 2840 9.715384\n", - "6037 5727 2702 9.684793\n", - "\n", - "[120760 rows x 3 columns]" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "encoder.inverse_transform(pandas_res)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "new_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/09_sasrec_example_sampled.ipynb b/examples/09_sasrec_example_sampled.ipynb deleted file mode 100644 index 6672edd85..000000000 --- a/examples/09_sasrec_example_sampled.ipynb +++ /dev/null @@ -1,4744 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Example of SasRec training/inference with Parquet Module" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 42\n" - ] - } - ], - "source": [ - "from typing import Optional\n", - "\n", - "import lightning as L\n", - "import pandas as pd\n", - "\n", - "L.seed_everything(42)\n", - "\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparing data\n", - "In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.\n", - "\n", - "---\n", - "**NOTE**\n", - "\n", - "Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. \n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "interactions = pd.read_csv(\"./data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\",\"rating\",\"timestamp\"])\n", - "interactions = interactions.drop(columns=[\"rating\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idtimestamp
100013860408580
1000153604023841
99987360405932
1000007604019613
1000192604020194
............
82579349582399446
82543849581407447
82572449583264448
82573149582634449
82560349581924450
\n", - "

1000209 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id timestamp\n", - "1000138 6040 858 0\n", - "1000153 6040 2384 1\n", - "999873 6040 593 2\n", - "1000007 6040 1961 3\n", - "1000192 6040 2019 4\n", - "... ... ... ...\n", - "825793 4958 2399 446\n", - "825438 4958 1407 447\n", - "825724 4958 3264 448\n", - "825731 4958 2634 449\n", - "825603 4958 1924 450\n", - "\n", - "[1000209 rows x 3 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "interactions[\"timestamp\"] = interactions[\"timestamp\"].astype(\"int64\")\n", - "interactions = interactions.sort_values(by=\"timestamp\")\n", - "interactions[\"timestamp\"] = interactions.groupby(\"user_id\").cumcount()\n", - "interactions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Encode catagorical data.\n", - "To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
timestampuser_iditem_id
01200
16810
26720
31230
414040
............
10002041445553705
10002059028133705
10002067024043705
10002072558353705
10002083809793705
\n", - "

1000209 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " timestamp user_id item_id\n", - "0 12 0 0\n", - "1 68 1 0\n", - "2 67 2 0\n", - "3 12 3 0\n", - "4 140 4 0\n", - "... ... ... ...\n", - "1000204 14 4555 3705\n", - "1000205 90 2813 3705\n", - "1000206 70 2404 3705\n", - "1000207 25 5835 3705\n", - "1000208 380 979 3705\n", - "\n", - "[1000209 rows x 3 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", - "\n", - "encoder = LabelEncoder(\n", - " [\n", - " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", - " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", - " ]\n", - ")\n", - "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", - "encoded_interactions = encoder.fit_transform(interactions)\n", - "encoded_interactions" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Split interactions into the train, validation and test datasets using LastNSplitter\n", - "We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.splitters import LastNSplitter\n", - "\n", - "splitter = LastNSplitter(\n", - " N=1,\n", - " divide_column=\"user_id\",\n", - " query_column=\"user_id\",\n", - " strategy=\"interactions\",\n", - " drop_cold_users=True,\n", - " drop_cold_items=True\n", - ")\n", - "\n", - "test_events, test_gt = splitter.split(encoded_interactions)\n", - "validation_events, validation_gt = splitter.split(test_events)\n", - "train_events = validation_events" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dataset preprocessing (\"baking\")\n", - "SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data.nn.utils import groupby_sequences\n", - "\n", - "\n", - "def bake_data(full_data):\n", - " grouped_interactions = groupby_sequences(events=full_data, groupby_col=\"user_id\", sort_col=\"timestamp\")\n", - " return grouped_interactions" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idtimestampitem_id
00[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,...
11[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ...
22[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621...
33[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1781, 2940, 2468, 890, 948, 106, 593, 309, 49...
44[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1...
............
60356035[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278...
60366036[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250...
60376037[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283...
60386038[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...
60396039[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2432, 2960, 1848, 2114, 2142, 3091, 3248, 317...
\n", - "

6040 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id timestamp \\\n", - "0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "1 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "2 2 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "3 3 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "4 4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "... ... ... \n", - "6035 6035 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6036 6036 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6037 6037 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6038 6038 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "\n", - " item_id \n", - "0 [2426, 822, 2733, 2587, 2937, 3618, 2943, 708,... \n", - "1 [3272, 3026, 2760, 851, 346, 3393, 1107, 515, ... \n", - "2 [579, 1140, 1154, 2426, 1524, 1260, 2160, 2621... \n", - "3 [1781, 2940, 2468, 890, 948, 106, 593, 309, 49... \n", - "4 [1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1... \n", - "... ... \n", - "6035 [2426, 1279, 3151, 3321, 1178, 3301, 2501, 278... \n", - "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 250... \n", - "6037 [1971, 3500, 2077, 1666, 1399, 2651, 2748, 283... \n", - "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", - "6039 [2432, 2960, 1848, 2114, 2142, 3091, 3248, 317... \n", - "\n", - "[6040 rows x 3 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_events = bake_data(train_events)\n", - "\n", - "validation_events = bake_data(validation_events)\n", - "validation_gt = bake_data(validation_gt)\n", - "\n", - "test_events = bake_data(test_events)\n", - "test_gt = bake_data(test_gt)\n", - "\n", - "train_events" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones. " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def add_gt_to_events(events_df, gt_df):\n", - " gt_to_join = gt_df[[\"user_id\", \"item_id\"]].rename(columns={\"item_id\": \"ground_truth\"})\n", - "\n", - " events_df = events_df.merge(gt_to_join, on=\"user_id\", how=\"inner\")\n", - " return events_df\n", - "\n", - "validation_events = add_gt_to_events(validation_events, validation_gt)\n", - "test_events = add_gt_to_events(test_events, test_gt)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "data_dir = Path(\"temp/data/\")\n", - "data_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "TRAIN_PATH = data_dir / \"train.parquet\"\n", - "VAL_PATH = data_dir / \"val.parquet\"\n", - "PREDICT_PATH = data_dir / \"test.parquet\"\n", - "\n", - "ENCODER_PATH = data_dir / \"encoder\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "train_events.to_parquet(TRAIN_PATH)\n", - "validation_events.to_parquet(VAL_PATH)\n", - "test_events.to_parquet(PREDICT_PATH)\n", - "\n", - "encoder.save(ENCODER_PATH)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prepare to model training\n", - "### Create the tensor schema\n", - "A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.\n", - "\n", - "Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data import FeatureHint, FeatureType\n", - "from replay.data.nn import TensorFeatureInfo, TensorSchema\n", - "\n", - "\n", - "EMBEDDING_DIM = 64\n", - "\n", - "encoder = encoder.load(ENCODER_PATH)\n", - "NUM_UNIQUE_ITEMS = len(encoder.mapping[\"item_id\"])\n", - "\n", - "tensor_schema = TensorSchema(\n", - " [\n", - " TensorFeatureInfo(\n", - " name=\"item_id\",\n", - " is_seq=True,\n", - " padding_value=NUM_UNIQUE_ITEMS,\n", - " cardinality=NUM_UNIQUE_ITEMS + 1, # taking into account padding\n", - " embedding_dim=EMBEDDING_DIM,\n", - " feature_type=FeatureType.CATEGORICAL,\n", - " feature_hint=FeatureHint.ITEM_ID,\n", - " )\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configure ParquetModule and transformation pipelines\n", - "\n", - "The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s \"transform pipelines\" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass. \n", - "\n", - "For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.\n", - "\n", - "Internally this function creates the following transforms:\n", - "1) Training:\n", - " 1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).\n", - " 2. Rename features to match it with expected format by the model during training.\n", - " 3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.\n", - " 4. Group input features to be embed in expected format.\n", - "\n", - "2) Validation/Inference:\n", - " 1. Rename/group features to match it with expected format by the model during valdiation/inference.\n", - "\n", - "If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.\n", - "\n", - "**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "import copy\n", - "\n", - "import torch\n", - "\n", - "from replay.data.nn import TensorSchema\n", - "from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, UniformNegativeSamplingTransform" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "def make_sasrec_transforms(\n", - " tensor_schema: TensorSchema, query_column: str = \"query_id\", num_negative_samples: int = 128,\n", - ") -> dict[str, list[torch.nn.Module]]:\n", - " item_column = tensor_schema.item_id_feature_name\n", - " vocab_size = tensor_schema[item_column].cardinality\n", - " train_transforms = [\n", - " UniformNegativeSamplingTransform(vocab_size, num_negative_samples),\n", - " NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),\n", - " RenameTransform(\n", - " {\n", - " query_column: \"query_id\",\n", - " f\"{item_column}_mask\": \"padding_mask\",\n", - " \"positive_labels_mask\": \"target_padding_mask\",\n", - " }\n", - " ),\n", - " UnsqueezeTransform(\"target_padding_mask\", -1),\n", - " UnsqueezeTransform(\"positive_labels\", -1),\n", - " GroupTransform({\"feature_tensors\": [item_column]}),\n", - " ]\n", - "\n", - " val_transforms = [\n", - " RenameTransform({query_column: \"query_id\", f\"{item_column}_mask\": \"padding_mask\"}),\n", - " GroupTransform({\"feature_tensors\": [item_column]}),\n", - " ]\n", - " test_transforms = copy.deepcopy(val_transforms)\n", - "\n", - " predict_transforms = copy.deepcopy(val_transforms)\n", - "\n", - " transforms = {\n", - " \"train\": train_transforms,\n", - " \"validate\": val_transforms,\n", - " \"test\": test_transforms,\n", - " \"predict\": predict_transforms,\n", - " }\n", - "\n", - " return transforms\n", - "\n", - "transforms = make_sasrec_transforms(tensor_schema, query_column=\"user_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "MAX_SEQ_LEN = 50\n", - "\n", - "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", - " meta = {\n", - " \"user_id\": {},\n", - " \"item_id\": {\"shape\": shape, \"padding\": tensor_schema[\"item_id\"].padding_value},\n", - " }\n", - " if gt_shape is not None:\n", - " meta.update({\"ground_truth\": {\"shape\": gt_shape, \"padding\": -1}})\n", - "\n", - " return meta\n", - "\n", - "train_metadata = {\n", - " \"train\": create_meta(shape=MAX_SEQ_LEN+1),\n", - " \"validate\": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data.nn import ParquetModule\n", - "\n", - "BATCH_SIZE = 32\n", - "\n", - "parquet_module = ParquetModule(\n", - " train_path=TRAIN_PATH,\n", - " validate_path=VAL_PATH,\n", - " batch_size=BATCH_SIZE,\n", - " metadata=train_metadata,\n", - " transforms=transforms,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train model\n", - "### Create SasRec model instance and run the training stage using lightning\n", - "We may now train the model using the Lightning trainer class. \n", - "\n", - "RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.\n", - "\n", - "#### Default Configuration\n", - "\n", - "Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.nn.sequential import SasRec\n", - "from typing import Literal\n", - "def make_sasrec(\n", - " schema: TensorSchema,\n", - " embedding_dim: int = 192,\n", - " num_heads: int = 4,\n", - " num_blocks: int = 2,\n", - " max_sequence_length: int = 50,\n", - " dropout: float = 0.3,\n", - " excluded_features: Optional[list[str]] = None,\n", - " categorical_list_feature_aggregation_method: Literal[\"sum\", \"mean\", \"max\"] = \"sum\",\n", - ") -> SasRec:\n", - " from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer\n", - " from replay.nn.agg import SumAggregator\n", - " from replay.nn.embedding import SequenceEmbedding\n", - " from replay.nn.loss import CE, CESampled\n", - " from replay.nn.mask import DefaultAttentionMask\n", - " from replay.nn.sequential.sasrec.agg import PositionAwareAggregator\n", - " from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer\n", - " excluded_features = [\n", - " schema.query_id_feature_name,\n", - " schema.timestamp_feature_name,\n", - " *(excluded_features or []),\n", - " ]\n", - " excluded_features = list(set(excluded_features))\n", - " body = SasRecBody(\n", - " embedder=SequenceEmbedding(\n", - " schema=schema,\n", - " categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,\n", - " excluded_features=excluded_features,\n", - " ),\n", - " embedding_aggregator=PositionAwareAggregator(\n", - " embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),\n", - " max_sequence_length=max_sequence_length,\n", - " dropout=dropout,\n", - " ),\n", - " attn_mask_builder=DefaultAttentionMask(\n", - " reference_feature_name=schema.item_id_feature_name,\n", - " num_heads=num_heads,\n", - " ),\n", - " encoder=SasRecTransformerLayer(\n", - " embedding_dim=embedding_dim,\n", - " num_heads=num_heads,\n", - " num_blocks=num_blocks,\n", - " dropout=dropout,\n", - " activation=\"relu\",\n", - " ),\n", - " output_normalization=torch.nn.LayerNorm(embedding_dim),\n", - " )\n", - " padding_idx = schema.item_id_features.item().padding_value\n", - " return SasRec(\n", - " body=body,\n", - " loss=CESampled(padding_idx=padding_idx),\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "NUM_BLOCKS = 2\n", - "NUM_HEADS = 2\n", - "DROPOUT = 0.3\n", - "\n", - "sasrec = make_sasrec(\n", - " schema=tensor_schema,\n", - " embedding_dim=EMBEDDING_DIM,\n", - " max_sequence_length=MAX_SEQ_LEN,\n", - " num_heads=NUM_HEADS,\n", - " num_blocks=NUM_BLOCKS,\n", - " dropout=DROPOUT,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A universal PyTorch Lightning module is provided. It can work with any NN model." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.nn.lightning.optimizer import OptimizerFactory\n", - "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", - "from replay.nn.lightning import LightningModule\n", - "\n", - "model = LightningModule(\n", - " sasrec,\n", - " optimizer_factory=OptimizerFactory(),\n", - " lr_scheduler_factory=LRSchedulerFactory(),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To facilitate training, we add the following callbacks:\n", - "1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.\n", - "1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "\n", - " | Name | Type | Params | Mode | FLOPs\n", - "-------------------------------------------------\n", - "0 | model | SasRec | 291 K | train | 0 \n", - "-------------------------------------------------\n", - "291 K Trainable params\n", - "0 Non-trainable params\n", - "291 K Total params\n", - "1.164 Total estimated model params size (MB)\n", - "39 Modules in train mode\n", - "0 Modules in eval mode\n", - "0 Total Flops\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bdbe6004cd2b40f09355b6ed047313af", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscore
002247.242342
005726.818249
004866.81148
0013716.534966
002106.52649
............
60376039249710.457304
60376039350310.305973
60376039260110.280416
60376039275010.01198
6037603914909.916577
\n", - "

120760 rows × 3 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score\n", - "0 0 224 7.242342\n", - "0 0 572 6.818249\n", - "0 0 486 6.81148\n", - "0 0 1371 6.534966\n", - "0 0 210 6.52649\n", - "... ... ... ...\n", - "6037 6039 2497 10.457304\n", - "6037 6039 3503 10.305973\n", - "6037 6039 2601 10.280416\n", - "6037 6039 2750 10.01198\n", - "6037 6039 1490 9.916577\n", - "\n", - "[120760 rows x 3 columns]" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pandas_res" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Calculating metrics\n", - "\n", - "*test_gt* is already encoded, so we can use it for computing metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.metrics import MAP, OfflineMetrics, Precision, Recall\n", - "from replay.metrics.torch_metrics_builder import metrics_to_df" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [], - "source": [ - "result_metrics = OfflineMetrics(\n", - " [Recall(TOPK), Precision(TOPK), MAP(TOPK)],\n", - " query_column=\"user_id\",\n", - " rating_column=\"score\",\n", - ")(pandas_res, test_gt.explode(\"item_id\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
k110205
MAP0.0160650.0540390.0617490.043969
Precision0.0160650.0176550.0144580.020073
Recall0.0160650.1765490.2891690.100364
\n", - "
" - ], - "text/plain": [ - "k 1 10 20 5\n", - "MAP 0.016065 0.054039 0.061749 0.043969\n", - "Precision 0.016065 0.017655 0.014458 0.020073\n", - "Recall 0.016065 0.176549 0.289169 0.100364" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metrics_to_df(result_metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscore
020122317.242342
020125866.818249
020125006.81148
0201214856.534966
020122166.52649
............
60375727270210.457304
60375727374510.305973
60375727280610.280416
60375727296110.01198
6037572716239.916577
\n", - "

120760 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id score\n", - "0 2012 231 7.242342\n", - "0 2012 586 6.818249\n", - "0 2012 500 6.81148\n", - "0 2012 1485 6.534966\n", - "0 2012 216 6.52649\n", - "... ... ... ...\n", - "6037 5727 2702 10.457304\n", - "6037 5727 3745 10.305973\n", - "6037 5727 2806 10.280416\n", - "6037 5727 2961 10.01198\n", - "6037 5727 1623 9.916577\n", - "\n", - "[120760 rows x 3 columns]" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "encoder.inverse_transform(pandas_res)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.14" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/09_sasrec_example_thr_sampled.ipynb b/examples/09_sasrec_example_thr_sampled.ipynb deleted file mode 100644 index b27ec3f8d..000000000 --- a/examples/09_sasrec_example_thr_sampled.ipynb +++ /dev/null @@ -1,4768 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Example of SasRec training/inference with Parquet Module" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 42\n" - ] - } - ], - "source": [ - "from typing import Optional\n", - "\n", - "import lightning as L\n", - "import pandas as pd\n", - "\n", - "L.seed_everything(42)\n", - "\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparing data\n", - "In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.\n", - "\n", - "---\n", - "**NOTE**\n", - "\n", - "Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. \n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "interactions = pd.read_csv(\"./data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\",\"rating\",\"timestamp\"])\n", - "interactions = interactions.drop(columns=[\"rating\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idtimestamp
100013860408580
1000153604023841
99987360405932
1000007604019613
1000192604020194
............
82579349582399446
82543849581407447
82572449583264448
82573149582634449
82560349581924450
\n", - "

1000209 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id timestamp\n", - "1000138 6040 858 0\n", - "1000153 6040 2384 1\n", - "999873 6040 593 2\n", - "1000007 6040 1961 3\n", - "1000192 6040 2019 4\n", - "... ... ... ...\n", - "825793 4958 2399 446\n", - "825438 4958 1407 447\n", - "825724 4958 3264 448\n", - "825731 4958 2634 449\n", - "825603 4958 1924 450\n", - "\n", - "[1000209 rows x 3 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "interactions[\"timestamp\"] = interactions[\"timestamp\"].astype(\"int64\")\n", - "interactions = interactions.sort_values(by=\"timestamp\")\n", - "interactions[\"timestamp\"] = interactions.groupby(\"user_id\").cumcount()\n", - "interactions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Encode catagorical data.\n", - "To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
timestampuser_iditem_id
01200
16810
26720
31230
414040
............
10002041445553705
10002059028133705
10002067024043705
10002072558353705
10002083809793705
\n", - "

1000209 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " timestamp user_id item_id\n", - "0 12 0 0\n", - "1 68 1 0\n", - "2 67 2 0\n", - "3 12 3 0\n", - "4 140 4 0\n", - "... ... ... ...\n", - "1000204 14 4555 3705\n", - "1000205 90 2813 3705\n", - "1000206 70 2404 3705\n", - "1000207 25 5835 3705\n", - "1000208 380 979 3705\n", - "\n", - "[1000209 rows x 3 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", - "\n", - "encoder = LabelEncoder(\n", - " [\n", - " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", - " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", - " ]\n", - ")\n", - "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", - "encoded_interactions = encoder.fit_transform(interactions)\n", - "encoded_interactions" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Split interactions into the train, validation and test datasets using LastNSplitter\n", - "We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.splitters import LastNSplitter\n", - "\n", - "splitter = LastNSplitter(\n", - " N=1,\n", - " divide_column=\"user_id\",\n", - " query_column=\"user_id\",\n", - " strategy=\"interactions\",\n", - " drop_cold_users=True,\n", - " drop_cold_items=True\n", - ")\n", - "\n", - "test_events, test_gt = splitter.split(encoded_interactions)\n", - "validation_events, validation_gt = splitter.split(test_events)\n", - "train_events = validation_events" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dataset preprocessing (\"baking\")\n", - "SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data.nn.utils import groupby_sequences\n", - "\n", - "\n", - "def bake_data(full_data):\n", - " grouped_interactions = groupby_sequences(events=full_data, groupby_col=\"user_id\", sort_col=\"timestamp\")\n", - " return grouped_interactions" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idtimestampitem_id
00[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,...
11[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ...
22[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621...
33[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1781, 2940, 2468, 890, 948, 106, 593, 309, 49...
44[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1...
............
60356035[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278...
60366036[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250...
60376037[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283...
60386038[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...
60396039[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...[2432, 2960, 1848, 2114, 2142, 3091, 3248, 317...
\n", - "

6040 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id timestamp \\\n", - "0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "1 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "2 2 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "3 3 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "4 4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "... ... ... \n", - "6035 6035 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6036 6036 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6037 6037 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6038 6038 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", - "\n", - " item_id \n", - "0 [2426, 822, 2733, 2587, 2937, 3618, 2943, 708,... \n", - "1 [3272, 3026, 2760, 851, 346, 3393, 1107, 515, ... \n", - "2 [579, 1140, 1154, 2426, 1524, 1260, 2160, 2621... \n", - "3 [1781, 2940, 2468, 890, 948, 106, 593, 309, 49... \n", - "4 [1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1... \n", - "... ... \n", - "6035 [2426, 1279, 3151, 3321, 1178, 3301, 2501, 278... \n", - "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 250... \n", - "6037 [1971, 3500, 2077, 1666, 1399, 2651, 2748, 283... \n", - "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", - "6039 [2432, 2960, 1848, 2114, 2142, 3091, 3248, 317... \n", - "\n", - "[6040 rows x 3 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_events = bake_data(train_events)\n", - "\n", - "validation_events = bake_data(validation_events)\n", - "validation_gt = bake_data(validation_gt)\n", - "\n", - "test_events = bake_data(test_events)\n", - "test_gt = bake_data(test_gt)\n", - "\n", - "train_events" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones. " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def add_gt_to_events(events_df, gt_df):\n", - " gt_to_join = gt_df[[\"user_id\", \"item_id\"]].rename(columns={\"item_id\": \"ground_truth\"})\n", - "\n", - " events_df = events_df.merge(gt_to_join, on=\"user_id\", how=\"inner\")\n", - " return events_df\n", - "\n", - "validation_events = add_gt_to_events(validation_events, validation_gt)\n", - "test_events = add_gt_to_events(test_events, test_gt)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "data_dir = Path(\"temp/data/\")\n", - "data_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "TRAIN_PATH = data_dir / \"train.parquet\"\n", - "VAL_PATH = data_dir / \"val.parquet\"\n", - "PREDICT_PATH = data_dir / \"test.parquet\"\n", - "\n", - "ENCODER_PATH = data_dir / \"encoder\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "train_events.to_parquet(TRAIN_PATH)\n", - "validation_events.to_parquet(VAL_PATH)\n", - "test_events.to_parquet(PREDICT_PATH)\n", - "\n", - "encoder.save(ENCODER_PATH)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prepare to model training\n", - "### Create the tensor schema\n", - "A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.\n", - "\n", - "Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data import FeatureHint, FeatureType\n", - "from replay.data.nn import TensorFeatureInfo, TensorSchema\n", - "\n", - "\n", - "EMBEDDING_DIM = 64\n", - "\n", - "encoder = encoder.load(ENCODER_PATH)\n", - "NUM_UNIQUE_ITEMS = len(encoder.mapping[\"item_id\"])\n", - "\n", - "tensor_schema = TensorSchema(\n", - " [\n", - " TensorFeatureInfo(\n", - " name=\"item_id\",\n", - " is_seq=True,\n", - " padding_value=NUM_UNIQUE_ITEMS,\n", - " cardinality=NUM_UNIQUE_ITEMS + 1, # taking into account padding\n", - " embedding_dim=EMBEDDING_DIM,\n", - " feature_type=FeatureType.CATEGORICAL,\n", - " feature_hint=FeatureHint.ITEM_ID,\n", - " )\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configure ParquetModule and transformation pipelines\n", - "\n", - "The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s \"transform pipelines\" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass. \n", - "\n", - "For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.\n", - "\n", - "Internally this function creates the following transforms:\n", - "1) Training:\n", - " 1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).\n", - " 2. Rename features to match it with expected format by the model during training.\n", - " 3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.\n", - " 4. Group input features to be embed in expected format.\n", - "\n", - "2) Validation/Inference:\n", - " 1. Rename/group features to match it with expected format by the model during valdiation/inference.\n", - "\n", - "If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.\n", - "\n", - "**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "import copy\n", - "\n", - "import torch\n", - "\n", - "from replay.data.nn import TensorSchema\n", - "from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, ThresholdNegativeSamplingTransform" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def make_sasrec_transforms(\n", - " tensor_schema: TensorSchema, query_column: str = \"query_id\", num_negative_samples: int = 128,\n", - ") -> dict[str, list[torch.nn.Module]]:\n", - " item_column = tensor_schema.item_id_feature_name\n", - " vocab_size = tensor_schema[item_column].cardinality\n", - " train_transforms = [\n", - " ThresholdNegativeSamplingTransform(vocab_size, num_negative_samples),\n", - " NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),\n", - " RenameTransform(\n", - " {\n", - " query_column: \"query_id\",\n", - " f\"{item_column}_mask\": \"padding_mask\",\n", - " \"positive_labels_mask\": \"target_padding_mask\",\n", - " }\n", - " ),\n", - " UnsqueezeTransform(\"target_padding_mask\", -1),\n", - " UnsqueezeTransform(\"positive_labels\", -1),\n", - " GroupTransform({\"feature_tensors\": [item_column]}),\n", - " ]\n", - "\n", - " val_transforms = [\n", - " RenameTransform({query_column: \"query_id\", f\"{item_column}_mask\": \"padding_mask\"}),\n", - " GroupTransform({\"feature_tensors\": [item_column]}),\n", - " ]\n", - " test_transforms = copy.deepcopy(val_transforms)\n", - "\n", - " predict_transforms = copy.deepcopy(val_transforms)\n", - "\n", - " transforms = {\n", - " \"train\": train_transforms,\n", - " \"validate\": val_transforms,\n", - " \"test\": test_transforms,\n", - " \"predict\": predict_transforms,\n", - " }\n", - "\n", - " return transforms\n", - "\n", - "transforms = make_sasrec_transforms(tensor_schema, query_column=\"user_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "MAX_SEQ_LEN = 50\n", - "\n", - "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", - " meta = {\n", - " \"user_id\": {},\n", - " \"item_id\": {\"shape\": shape, \"padding\": tensor_schema[\"item_id\"].padding_value},\n", - " }\n", - " if gt_shape is not None:\n", - " meta.update({\"ground_truth\": {\"shape\": gt_shape, \"padding\": -1}})\n", - "\n", - " return meta\n", - "\n", - "train_metadata = {\n", - " \"train\": create_meta(shape=MAX_SEQ_LEN+1),\n", - " \"validate\": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.data.nn import ParquetModule\n", - "\n", - "BATCH_SIZE = 32\n", - "\n", - "parquet_module = ParquetModule(\n", - " train_path=TRAIN_PATH,\n", - " validate_path=VAL_PATH,\n", - " batch_size=BATCH_SIZE,\n", - " metadata=train_metadata,\n", - " transforms=transforms,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train model\n", - "### Create SasRec model instance and run the training stage using lightning\n", - "We may now train the model using the Lightning trainer class. \n", - "\n", - "RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.\n", - "\n", - "#### Default Configuration\n", - "\n", - "Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.nn.sequential import SasRec\n", - "from typing import Literal\n", - "def make_sasrec(\n", - " schema: TensorSchema,\n", - " embedding_dim: int = 192,\n", - " num_heads: int = 4,\n", - " num_blocks: int = 2,\n", - " max_sequence_length: int = 50,\n", - " dropout: float = 0.3,\n", - " excluded_features: Optional[list[str]] = None,\n", - " categorical_list_feature_aggregation_method: Literal[\"sum\", \"mean\", \"max\"] = \"sum\",\n", - ") -> SasRec:\n", - " from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer\n", - " from replay.nn.agg import SumAggregator\n", - " from replay.nn.embedding import SequenceEmbedding\n", - " from replay.nn.loss import CE, CESampled\n", - " from replay.nn.mask import DefaultAttentionMask\n", - " from replay.nn.sequential.sasrec.agg import PositionAwareAggregator\n", - " from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer\n", - " excluded_features = [\n", - " schema.query_id_feature_name,\n", - " schema.timestamp_feature_name,\n", - " *(excluded_features or []),\n", - " ]\n", - " excluded_features = list(set(excluded_features))\n", - " body = SasRecBody(\n", - " embedder=SequenceEmbedding(\n", - " schema=schema,\n", - " categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,\n", - " excluded_features=excluded_features,\n", - " ),\n", - " embedding_aggregator=PositionAwareAggregator(\n", - " embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),\n", - " max_sequence_length=max_sequence_length,\n", - " dropout=dropout,\n", - " ),\n", - " attn_mask_builder=DefaultAttentionMask(\n", - " reference_feature_name=schema.item_id_feature_name,\n", - " num_heads=num_heads,\n", - " ),\n", - " encoder=SasRecTransformerLayer(\n", - " embedding_dim=embedding_dim,\n", - " num_heads=num_heads,\n", - " num_blocks=num_blocks,\n", - " dropout=dropout,\n", - " activation=\"relu\",\n", - " ),\n", - " output_normalization=torch.nn.LayerNorm(embedding_dim),\n", - " )\n", - " padding_idx = schema.item_id_features.item().padding_value\n", - " return SasRec(\n", - " body=body,\n", - " loss=CESampled(padding_idx=padding_idx),\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "NUM_BLOCKS = 2\n", - "NUM_HEADS = 2\n", - "DROPOUT = 0.3\n", - "\n", - "sasrec = make_sasrec(\n", - " schema=tensor_schema,\n", - " embedding_dim=EMBEDDING_DIM,\n", - " max_sequence_length=MAX_SEQ_LEN,\n", - " num_heads=NUM_HEADS,\n", - " num_blocks=NUM_BLOCKS,\n", - " dropout=DROPOUT,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A universal PyTorch Lightning module is provided. It can work with any NN model." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.nn.lightning.optimizer import OptimizerFactory\n", - "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", - "from replay.nn.lightning import LightningModule\n", - "\n", - "model = LightningModule(\n", - " sasrec,\n", - " optimizer_factory=OptimizerFactory(),\n", - " lr_scheduler_factory=LRSchedulerFactory(),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To facilitate training, we add the following callbacks:\n", - "1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.\n", - "1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "\n", - " | Name | Type | Params | Mode | FLOPs\n", - "-------------------------------------------------\n", - "0 | model | SasRec | 291 K | train | 0 \n", - "-------------------------------------------------\n", - "291 K Trainable params\n", - "0 Non-trainable params\n", - "291 K Total params\n", - "1.164 Total estimated model params size (MB)\n", - "39 Modules in train mode\n", - "0 Modules in eval mode\n", - "0 Total Flops\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "21ca5ea332d9418fb2481e0296b399a0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "plt.figure(figsize = (5, 4), dpi = 120)\n", - "plt.hist(transforms[\"train\"][0].frequencies.cpu().numpy(), bins = 50)\n", - "plt.grid()\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can get the best model path stored in the checkpoint callback." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=99-step=18900.ckpt'" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "best_model_path = checkpoint_callback.best_model_path\n", - "best_model_path" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inference\n", - "\n", - "To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "import replay\n", - "torch.serialization.add_safe_globals([\n", - " replay.nn.lightning.optimizer.OptimizerFactory,\n", - " replay.nn.lightning.scheduler.LRSchedulerFactory\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "sasrec = make_sasrec(\n", - " schema=tensor_schema,\n", - " embedding_dim=EMBEDDING_DIM,\n", - " max_sequence_length=MAX_SEQ_LEN,\n", - " num_heads=NUM_HEADS,\n", - " num_blocks=NUM_BLOCKS,\n", - " dropout=DROPOUT,\n", - ")\n", - "\n", - "best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n", - "best_model.eval();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Configure `ParquetModule` for inference" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", - "\n", - "parquet_module = ParquetModule(\n", - " predict_path=PREDICT_PATH,\n", - " batch_size=BATCH_SIZE,\n", - " metadata=inference_metadata,\n", - " transforms=transforms,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.\n", - "\n", - "\n", - "Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "94a84cd4fca248e7b340909e2048450f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscore
004867.196251
001017.086551
002106.872698
0016136.834656
003576.555361
............
603760398411.528885
60376039270011.400557
60376039255711.273298
60376039263411.121119
60376039263311.119842
\n", - "

120760 rows × 3 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score\n", - "0 0 486 7.196251\n", - "0 0 101 7.086551\n", - "0 0 210 6.872698\n", - "0 0 1613 6.834656\n", - "0 0 357 6.555361\n", - "... ... ... ...\n", - "6037 6039 84 11.528885\n", - "6037 6039 2700 11.400557\n", - "6037 6039 2557 11.273298\n", - "6037 6039 2634 11.121119\n", - "6037 6039 2633 11.119842\n", - "\n", - "[120760 rows x 3 columns]" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pandas_res" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Calculating metrics\n", - "\n", - "*test_gt* is already encoded, so we can use it for computing metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "from replay.metrics import MAP, OfflineMetrics, Precision, Recall\n", - "from replay.metrics.torch_metrics_builder import metrics_to_df" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "result_metrics = OfflineMetrics(\n", - " [Recall(TOPK), Precision(TOPK), MAP(TOPK)],\n", - " query_column=\"user_id\",\n", - " rating_column=\"score\",\n", - ")(pandas_res, test_gt.explode(\"item_id\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
k110205
MAP0.0160650.0566370.0640670.045948
Precision0.0160650.0186650.0147980.021133
Recall0.0160650.1866510.2959590.105664
\n", - "
" - ], - "text/plain": [ - "k 1 10 20 5\n", - "MAP 0.016065 0.056637 0.064067 0.045948\n", - "Precision 0.016065 0.018665 0.014798 0.021133\n", - "Recall 0.016065 0.186651 0.295959 0.105664" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metrics_to_df(result_metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscore
020125007.196251
020121047.086551
020122166.872698
0201217776.834656
020123676.555361
............
603757278611.528885
60375727290711.400557
60375727276211.273298
60375727284111.121119
60375727284011.119842
\n", - "

120760 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id score\n", - "0 2012 500 7.196251\n", - "0 2012 104 7.086551\n", - "0 2012 216 6.872698\n", - "0 2012 1777 6.834656\n", - "0 2012 367 6.555361\n", - "... ... ... ...\n", - "6037 5727 86 11.528885\n", - "6037 5727 2907 11.400557\n", - "6037 5727 2762 11.273298\n", - "6037 5727 2841 11.121119\n", - "6037 5727 2840 11.119842\n", - "\n", - "[120760 rows x 3 columns]" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "encoder.inverse_transform(pandas_res)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "new_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 5e98b75dbe4d7bc1e3cb1f247ad14f2c0156d1ac Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Mon, 2 Feb 2026 07:24:43 +0300 Subject: [PATCH 3/5] Documentation updated --- replay/nn/transform/__init__.py | 5 +---- replay/nn/transform/negative_sampling.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/replay/nn/transform/__init__.py b/replay/nn/transform/__init__.py index 40c2cdec1..1baa08b32 100644 --- a/replay/nn/transform/__init__.py +++ b/replay/nn/transform/__init__.py @@ -1,14 +1,11 @@ from .copy import CopyTransform from .grouping import GroupTransform -<<<<<<< HEAD -from .negative_sampling import MultiClassNegativeSamplingTransform, UniformNegativeSamplingTransform -======= from .negative_sampling import ( + MultiClassNegativeSamplingTransform, UniformNegativeSamplingTransform, FrequencyNegativeSamplingTransform, ThresholdNegativeSamplingTransform, ) ->>>>>>> 36b25601 (Gitignore updated to recent examples) from .next_token import NextTokenTransform from .rename import RenameTransform from .reshape import UnsqueezeTransform diff --git a/replay/nn/transform/negative_sampling.py b/replay/nn/transform/negative_sampling.py index 34debab8e..76aa815f3 100644 --- a/replay/nn/transform/negative_sampling.py +++ b/replay/nn/transform/negative_sampling.py @@ -101,8 +101,10 @@ class FrequencyNegativeSamplingTransform(torch.nn.Module): Transform for global negative sampling. For every batch, transform generates a vector of size ``(num_negative_samples)`` - consisting of random indices sampeled from a range of ``cardinality``. Unless a custom sample - distribution is provided, the indices are weighted equally. + consisting of random indices sampeled from a range of ``cardinality``. + + Indices frequency will be computed and their sampling will be done + according to their respective frequencies. Example: @@ -110,7 +112,7 @@ class FrequencyNegativeSamplingTransform(torch.nn.Module): >>> _ = torch.manual_seed(0) >>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])} - >>> transform = UniformNegativeSamplingTransform(cardinality=4, num_negative_samples=2) + >>> transform = FrequencyNegativeSamplingTransform(cardinality=4, num_negative_samples=2) >>> output_batch = transform(input_batch) >>> output_batch {'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])} @@ -130,10 +132,10 @@ def __init__( :param cardinality: The size of sample vocabulary. :param num_negative_samples: The size of negatives vector to generate. :param out_feature_name: The name of result feature in batch. - :param sample_distribution: The weighs of indices in the vocabulary. If specified, must - match the ``cardinality``. Default: ``None``. :param generator: Random number generator to be used for sampling from the distribution. Default: ``None``. + :param mode: Mode of frequency-based samping for undersampled items. + Default: ``softmax``. """ assert num_negative_samples < cardinality @@ -187,8 +189,10 @@ class ThresholdNegativeSamplingTransform(torch.nn.Module): Transform for global negative sampling. For every batch, transform generates a vector of size ``(num_negative_samples)`` - consisting of random indices sampeled from a range of ``cardinality``. Unless a custom sample - distribution is provided, the indices are weighted equally. + consisting of random indices sampeled from a range of ``cardinality``. + + Indices that are oversampled at this point will be ignored, while + other samples will be chosen according to their respective frequency. Example: @@ -196,7 +200,7 @@ class ThresholdNegativeSamplingTransform(torch.nn.Module): >>> _ = torch.manual_seed(0) >>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])} - >>> transform = UniformNegativeSamplingTransform(cardinality=4, num_negative_samples=2) + >>> transform = ThresholdNegativeSamplingTransform(cardinality=4, num_negative_samples=2) >>> output_batch = transform(input_batch) >>> output_batch {'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])} @@ -216,10 +220,10 @@ def __init__( :param cardinality: The size of sample vocabulary. :param num_negative_samples: The size of negatives vector to generate. :param out_feature_name: The name of result feature in batch. - :param sample_distribution: The weighs of indices in the vocabulary. If specified, must - match the ``cardinality``. Default: ``None``. :param generator: Random number generator to be used for sampling from the distribution. Default: ``None``. + :param mode: Mode of frequency-based samping for undersampled items. + Default: ``softmax``. """ assert num_negative_samples < cardinality From bd614aaa61cbba62606b7dc54bee9a3f0222cd66 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Mon, 2 Feb 2026 07:30:20 +0300 Subject: [PATCH 4/5] Formatted and checked --- CONTRIBUTING.md | 3 +- replay/nn/transform/__init__.py | 6 +-- replay/nn/transform/negative_sampling.py | 51 ++++++++++-------------- 3 files changed, 26 insertions(+), 34 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 638adc03f..236d90953 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -136,8 +136,7 @@ You can just get pyproject.toml file from step 6, to start using linters and for In order to automate checking of the code quality, please run: ```bash -poetry run ruff check . -poetry run black --check --diff -- . +./poetry_wrapper.sh run ruff check . ./poetry_wrapper.sh check ./poetry_wrapper.sh --experimental check ``` diff --git a/replay/nn/transform/__init__.py b/replay/nn/transform/__init__.py index 1baa08b32..4f47dba4b 100644 --- a/replay/nn/transform/__init__.py +++ b/replay/nn/transform/__init__.py @@ -2,9 +2,9 @@ from .grouping import GroupTransform from .negative_sampling import ( MultiClassNegativeSamplingTransform, - UniformNegativeSamplingTransform, FrequencyNegativeSamplingTransform, ThresholdNegativeSamplingTransform, + UniformNegativeSamplingTransform, ) from .next_token import NextTokenTransform from .rename import RenameTransform @@ -15,15 +15,15 @@ __all__ = [ "CopyTransform", + "FrequencyNegativeSamplingTransform", "GroupTransform", "MultiClassNegativeSamplingTransform", "NextTokenTransform", "RenameTransform", "SequenceRollTransform", + "ThresholdNegativeSamplingTransform", "TokenMaskTransform", "TrimTransform", "UniformNegativeSamplingTransform", - "FrequencyNegativeSamplingTransform", - "ThresholdNegativeSamplingTransform", "UnsqueezeTransform", ] diff --git a/replay/nn/transform/negative_sampling.py b/replay/nn/transform/negative_sampling.py index 76aa815f3..c1e9bd5b2 100644 --- a/replay/nn/transform/negative_sampling.py +++ b/replay/nn/transform/negative_sampling.py @@ -1,10 +1,9 @@ -from typing import Optional, Literal, cast +import warnings +from typing import Literal, Optional, cast import torch import torch.nn.functional as func -import warnings - class UniformNegativeSamplingTransform(torch.nn.Module): """ @@ -73,11 +72,7 @@ def __init__( self.out_feature_name = out_feature_name self.num_negative_samples = num_negative_samples self.generator = generator - if sample_distribution is not None: - sample_distribution = sample_distribution - else: - sample_distribution = torch.ones(cardinality) - + sample_distribution = sample_distribution if sample_distribution is not None else torch.ones(cardinality) self.sample_distribution = torch.nn.Buffer( cast(torch.Tensor, sample_distribution) ) @@ -153,16 +148,15 @@ def __init__( def get_probas(self) -> torch.Tensor: raw: torch.Tensor = 1.0 / (1.0 + self.frequencies) - match self.mode: - case "softsum": - result: torch.Tensor = raw / torch.sum(raw) - case "softmax": + if self.mode == "softsum": + result: torch.Tensor = raw / torch.sum(raw) + elif self.mode == "softmax": result: torch.Tensor = func.softmax(raw, dim = -1) - case _: - msg: str = f"Unsupported mode: {self.mode}." - raise TypeError(msg) + else: + msg: str = f"Unsupported mode: {self.mode}." + raise TypeError(msg) return result - + def update_probas(self, selected: torch.Tensor) -> None: device = self.frequencies.device one = torch.ones(1, dtype = torch.int64, device = device) @@ -243,20 +237,19 @@ def get_probas(self) -> torch.Tensor: raw: torch.Tensor = 1.0 / (1.0 + self.frequencies) thr: torch.Tensor = torch.max(self.frequencies) mask: torch.Tensor = thr != self.frequencies - match self.mode: - case "softsum": - eps = torch.finfo(raw.dtype).eps - raw = torch.where(mask, raw, eps) - result: torch.Tensor = raw / torch.sum(raw) - case "softmax": - inf = torch.finfo(raw.dtype).min - raw = torch.where(mask, raw, inf) - result: torch.Tensor = func.softmax(raw, dim = -1) - case _: - msg: str = f"Unsupported mode: {self.mode}." - raise TypeError(msg) + if self.mode == "softsum": + eps = torch.finfo(raw.dtype).eps + raw = torch.where(mask, raw, eps) + result: torch.Tensor = raw / torch.sum(raw) + elif self.mode == "softmax": + inf = torch.finfo(raw.dtype).min + raw = torch.where(mask, raw, inf) + result: torch.Tensor = func.softmax(raw, dim = -1) + else: + msg: str = f"Unsupported mode: {self.mode}." + raise TypeError(msg) return result - + def update_probas(self, selected: torch.Tensor) -> None: device = self.frequencies.device one = torch.ones(1, dtype = torch.int64, device = device) From 3e03d8e206364b3d7e619ecbf29bd8e63a858222 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Mon, 2 Feb 2026 08:13:09 +0300 Subject: [PATCH 5/5] Ruff format applied --- replay/nn/transform/__init__.py | 2 +- replay/nn/transform/negative_sampling.py | 33 ++++++++++-------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/replay/nn/transform/__init__.py b/replay/nn/transform/__init__.py index 4f47dba4b..0ae787bf3 100644 --- a/replay/nn/transform/__init__.py +++ b/replay/nn/transform/__init__.py @@ -1,8 +1,8 @@ from .copy import CopyTransform from .grouping import GroupTransform from .negative_sampling import ( - MultiClassNegativeSamplingTransform, FrequencyNegativeSamplingTransform, + MultiClassNegativeSamplingTransform, ThresholdNegativeSamplingTransform, UniformNegativeSamplingTransform, ) diff --git a/replay/nn/transform/negative_sampling.py b/replay/nn/transform/negative_sampling.py index c1e9bd5b2..ef68936a3 100644 --- a/replay/nn/transform/negative_sampling.py +++ b/replay/nn/transform/negative_sampling.py @@ -48,8 +48,7 @@ def __init__( if sample_distribution is not None: if sample_distribution.ndim != 1: msg: str = ( - "The `sample_distribution` parameter must be 1D." - f"Got {sample_distribution.ndim}, will be flattened." + f"The `sample_distribution` parameter must be 1D.Got {sample_distribution.ndim}, will be flattened." ) warnings.warn(msg) sample_distribution = sample_distribution.flatten() @@ -73,9 +72,7 @@ def __init__( self.num_negative_samples = num_negative_samples self.generator = generator sample_distribution = sample_distribution if sample_distribution is not None else torch.ones(cardinality) - self.sample_distribution = torch.nn.Buffer( - cast(torch.Tensor, sample_distribution) - ) + self.sample_distribution = torch.nn.Buffer(cast(torch.Tensor, sample_distribution)) def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_batch = dict(batch.items()) @@ -91,6 +88,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_batch[self.out_feature_name] = negatives.to(device) return output_batch + class FrequencyNegativeSamplingTransform(torch.nn.Module): """ Transform for global negative sampling. @@ -142,16 +140,14 @@ def __init__( self.generator = generator self.mode = mode - self.frequencies = torch.nn.Buffer( - torch.zeros(cardinality, dtype = torch.int64) - ) + self.frequencies = torch.nn.Buffer(torch.zeros(cardinality, dtype=torch.int64)) def get_probas(self) -> torch.Tensor: raw: torch.Tensor = 1.0 / (1.0 + self.frequencies) if self.mode == "softsum": result: torch.Tensor = raw / torch.sum(raw) elif self.mode == "softmax": - result: torch.Tensor = func.softmax(raw, dim = -1) + result: torch.Tensor = func.softmax(raw, dim=-1) else: msg: str = f"Unsupported mode: {self.mode}." raise TypeError(msg) @@ -159,14 +155,14 @@ def get_probas(self) -> torch.Tensor: def update_probas(self, selected: torch.Tensor) -> None: device = self.frequencies.device - one = torch.ones(1, dtype = torch.int64, device = device) + one = torch.ones(1, dtype=torch.int64, device=device) self.frequencies.index_add_(-1, selected, one.expand(selected.numel())) def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_batch = dict(batch.items()) negatives = torch.multinomial( - input = self.get_probas(), + input=self.get_probas(), num_samples=self.num_negative_samples, replacement=False, generator=self.generator, @@ -178,6 +174,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_batch[self.out_feature_name] = negatives.to(device) return output_batch + class ThresholdNegativeSamplingTransform(torch.nn.Module): """ Transform for global negative sampling. @@ -229,9 +226,7 @@ def __init__( self.generator = generator self.mode = mode - self.frequencies = torch.nn.Buffer( - torch.zeros(cardinality, dtype = torch.int64) - ) + self.frequencies = torch.nn.Buffer(torch.zeros(cardinality, dtype=torch.int64)) def get_probas(self) -> torch.Tensor: raw: torch.Tensor = 1.0 / (1.0 + self.frequencies) @@ -244,7 +239,7 @@ def get_probas(self) -> torch.Tensor: elif self.mode == "softmax": inf = torch.finfo(raw.dtype).min raw = torch.where(mask, raw, inf) - result: torch.Tensor = func.softmax(raw, dim = -1) + result: torch.Tensor = func.softmax(raw, dim=-1) else: msg: str = f"Unsupported mode: {self.mode}." raise TypeError(msg) @@ -252,14 +247,14 @@ def get_probas(self) -> torch.Tensor: def update_probas(self, selected: torch.Tensor) -> None: device = self.frequencies.device - one = torch.ones(1, dtype = torch.int64, device = device) + one = torch.ones(1, dtype=torch.int64, device=device) self.frequencies.index_add_(-1, selected, one.expand(selected.numel())) def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_batch = dict(batch.items()) negatives = torch.multinomial( - input = self.get_probas(), + input=self.get_probas(), num_samples=self.num_negative_samples, replacement=False, generator=self.generator, @@ -344,9 +339,7 @@ def __init__( super().__init__() - self.sample_mask = torch.nn.Buffer( - sample_mask.float() - ) + self.sample_mask = torch.nn.Buffer(sample_mask.float()) self.num_negative_samples = num_negative_samples self.negative_selector_name = negative_selector_name