-
Notifications
You must be signed in to change notification settings - Fork 55
Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d4c9df0
6c875b3
3cec1e0
aa015f8
e24fec3
7d3850b
d68834f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,4 +95,9 @@ benchmark_results/ | |
| *.dat | ||
|
|
||
| # CatBoost | ||
| catboost_info/ | ||
| catboost_info/ | ||
|
|
||
| # Dev artifacts | ||
| training_folder/ | ||
| *.pt | ||
| data/* | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | |
| and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
|
||
|
|
||
| ## [Unreleased] | ||
|
|
||
| ### Added | ||
| - `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M | ||
| - `FlatSASRec` network and `FlatSASRecModel` — flat SASRec implementation without the ItemNet hierarchy. Pre-norm transformer encoder with id-embeddings, causal masking, softmax and BCE losses. Integrates with RecTools `ModelBase` for compatibility with the standard `fit`/`recommend` API | ||
| - `UniSRec` network and `UniSRecModel` — sequential recommender with pretrained text embeddings (e.g. Qwen) and a learnable PCA/BN adaptor. Three-phase training: (1) SASRec warm-up on ID embeddings, (2) adaptor-only with frozen transformer, (3) full fine-tune on pretrained embeddings. Configurable losses (softmax, BCE, gBCE, sampled_softmax), optimizers (Adam, AdamW), cosine warmup scheduler, early stopping, checkpoint save/load. `UniSRecModel.fit()` accepts raw `(user_ids, item_ids, timestamps)` tensors | ||
| - `rank_topk()` utility for batched top-k scoring with CSR-based viewed-item filtering and item whitelist support | ||
| - `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order | ||
| - `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data | ||
| - Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor | ||
| - Tests for all `fast_transformers` submodules (143 tests) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We normally don't add anything that doesn't affect user directly to the changelog, so not much sense to write about the tests |
||
|
|
||
|
|
||
| ## [0.18.0] - 21.02.2026 | ||
|
|
||
| ### Added | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| """Fast Transformers: flat sequential recommenders without ItemNet hierarchy.""" | ||
|
|
||
| from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, hash_item_ids, make_dataloader | ||
| from .net import FlatSASRec, SASRecBlock | ||
| from .ranking import rank_topk | ||
| from .unisrec_lightning import UniSRecLightning | ||
| from .unisrec_model import UniSRecModel | ||
| from .unisrec_net import FeedForward, UniSRec | ||
|
|
||
| __all__ = [ | ||
| "build_sequences", | ||
| "align_embeddings", | ||
| "hash_item_ids", | ||
| "GPUBatchDataset", | ||
| "make_dataloader", | ||
| "FlatSASRec", | ||
| "SASRecBlock", | ||
| "rank_topk", | ||
| "UniSRec", | ||
| "FeedForward", | ||
| "UniSRecLightning", | ||
| "UniSRecModel", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| """GPU-native sequence building for transformer training. Pure torch, no pandas/numpy.""" | ||
|
|
||
| import typing as tp | ||
|
|
||
| import torch | ||
| from torch.utils.data import DataLoader | ||
| from torch.utils.data import Dataset as TorchDataset | ||
|
|
||
|
|
||
| def _splitmix64(x: torch.Tensor) -> torch.Tensor: | ||
| """Vectorized splitmix64 bit-mixer: element-wise int64 hash over a torch tensor. | ||
|
|
||
| Standard library hashes (``hash()``, ``hashlib``) operate on scalar Python objects | ||
| and cannot be vectorized across GPU tensors. Splitmix64 is pure int64 arithmetic, | ||
| so it maps naturally to ``torch.Tensor`` ops and runs on any device. | ||
|
|
||
| Reference: https://xorshift.di.unimi.it/splitmix64.c (Vigna, 2015). | ||
| """ | ||
| x = x.long() | ||
| x = (x ^ (x >> 30)) * (-4658895280553007687) # 0xbf58476d1ce4e5b9 as signed int64 | ||
| x = (x ^ (x >> 27)) * (-7723592293110705685) # 0x94d049bb133111eb as signed int64 | ||
| return x ^ (x >> 31) | ||
|
|
||
|
|
||
| def hash_item_ids(item_ids: torch.Tensor, dict_size: int) -> torch.Tensor: | ||
| """Map arbitrary integer item IDs to [1, dict_size] via splitmix64 hash.""" | ||
| return _splitmix64(item_ids) % dict_size + 1 | ||
|
|
||
|
|
||
| def build_sequences( | ||
| user_ids: torch.Tensor, | ||
| item_ids: torch.Tensor, | ||
| timestamps: torch.Tensor, | ||
| max_len: int, | ||
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to use Literal for such things |
||
| ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add extensive docstrings for all the public method, especially for those supposed to be used stand-alone. Here it's especially important since you're returning 4 tensors and user doesn't understand their meaning. Also good to add examples |
||
| user_ids = user_ids.to(device) | ||
| item_ids = item_ids.to(device) | ||
| timestamps = timestamps.to(device) | ||
|
Comment on lines
+30
to
+41
|
||
|
|
||
| unique_items = torch.unique(item_ids) | ||
| n_unique = len(unique_items) | ||
|
|
||
| if id_mapping == "dense": | ||
| _, item_inv = torch.unique(item_ids, return_inverse=True) | ||
| internal_items = item_inv + 1 | ||
| elif id_mapping == "hash": | ||
|
Comment on lines
+43
to
+49
|
||
| internal_items = hash_item_ids(item_ids, n_unique) | ||
| else: | ||
| raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'") | ||
|
|
||
| unique_users, user_inv = torch.unique(user_ids, return_inverse=True) | ||
|
|
||
| order1 = torch.argsort(timestamps, stable=True) | ||
| order2 = torch.argsort(user_inv[order1], stable=True) | ||
| order = order1[order2] | ||
|
|
||
| sorted_user_inv = user_inv[order] | ||
| sorted_items = internal_items[order] | ||
|
|
||
| changes = torch.where(sorted_user_inv[1:] != sorted_user_inv[:-1])[0] + 1 | ||
| starts = torch.cat([torch.tensor([0], device=device), changes]) | ||
| ends = torch.cat([changes, torch.tensor([len(sorted_user_inv)], device=device)]) | ||
| lengths = ends - starts | ||
|
|
||
| mask = lengths >= min_interactions | ||
| starts = starts[mask] | ||
| ends = ends[mask] | ||
| lengths = lengths[mask] | ||
| n_users = len(starts) | ||
|
|
||
| capped_lens = torch.clamp(lengths, max=max_len + 1) | ||
|
|
||
| effective_lens = torch.clamp(capped_lens - 1, min=0) | ||
| total_elements = effective_lens.sum().item() | ||
|
|
||
| x = torch.zeros(n_users, max_len, dtype=torch.long, device=device) | ||
| y = torch.zeros(n_users, max_len, dtype=torch.long, device=device) | ||
|
|
||
| if total_elements > 0: | ||
| user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens) | ||
| cumsum = effective_lens.cumsum(0) | ||
| offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave( | ||
| cumsum - effective_lens, effective_lens | ||
| ) | ||
|
|
||
| x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets | ||
| y_src = x_src + 1 | ||
| col_indices = max_len - torch.repeat_interleave(effective_lens, effective_lens) + offsets | ||
|
|
||
| x[user_indices, col_indices] = sorted_items[x_src] | ||
| y[user_indices, col_indices] = sorted_items[y_src] | ||
|
|
||
| valid_user_indices = torch.where(mask)[0] | ||
| result_users = unique_users[valid_user_indices] if len(valid_user_indices) < len(unique_users) else unique_users | ||
|
|
||
| return x, y, unique_items, result_users | ||
|
|
||
|
|
||
| def align_embeddings( | ||
| pretrained: torch.Tensor, | ||
| unique_items: torch.Tensor, | ||
| n_items: int, | ||
| id_mapping: str = "dense", | ||
| ) -> torch.Tensor: | ||
| idx = unique_items.long().cpu() | ||
| valid = (idx >= 0) & (idx < pretrained.shape[0]) | ||
|
|
||
| if pretrained.ndim == 2: | ||
| aligned = torch.zeros(n_items + 1, pretrained.shape[1]) | ||
| else: | ||
| aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2]) | ||
|
|
||
| if id_mapping == "dense": | ||
| aligned[1:][valid] = pretrained[idx[valid]] | ||
| elif id_mapping == "hash": | ||
| positions = hash_item_ids(idx, n_items) | ||
| aligned[positions[valid]] = pretrained[idx[valid]] | ||
| else: | ||
| raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'") | ||
|
|
||
| return aligned | ||
|
|
||
|
|
||
| class GPUBatchDataset(TorchDataset): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure the name reflect the purpose
It also sounds quite "universal" even though I'd say it's more task-specific |
||
| def __init__(self, x: torch.Tensor, y: torch.Tensor, transform: tp.Optional[tp.Callable] = None): | ||
| self.x = x | ||
| self.y = y | ||
| self.transform = transform | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self.x) | ||
|
|
||
| def __getitem__(self, idx: int) -> tp.Dict[str, torch.Tensor]: | ||
| batch = {"x": self.x[idx], "y": self.y[idx]} | ||
| if self.transform: | ||
| batch = self.transform(batch) | ||
| return batch | ||
|
|
||
|
|
||
| def make_dataloader( | ||
| x: torch.Tensor, | ||
| y: torch.Tensor, | ||
| batch_size: int, | ||
| shuffle: bool = True, | ||
| transform: tp.Optional[tp.Callable] = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend to add **kwargs here to cover different parameters of data loader On the other side I'm not sure it makes much sense to wrap 2 function calls in a separate function |
||
| ) -> DataLoader: | ||
| ds = GPUBatchDataset(x, y, transform=transform) | ||
| return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| """Flat SASRec network: pre-norm transformer encoder with plain id embeddings.""" | ||
|
|
||
| import typing as tp | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
| class SASRecBlock(nn.Module): | ||
| """Pre-norm transformer block: LayerNorm -> MHA -> residual -> LayerNorm -> FFN -> residual.""" | ||
|
|
||
| def __init__(self, n_factors: int, n_heads: int, dropout: float = 0.1) -> None: | ||
| super().__init__() | ||
| self.ln1 = nn.LayerNorm(n_factors) | ||
| self.mha = nn.MultiheadAttention(n_factors, n_heads, dropout=dropout, batch_first=True) | ||
| self.ln2 = nn.LayerNorm(n_factors) | ||
| self.ffn = nn.Sequential( | ||
| nn.Linear(n_factors, n_factors * 4), | ||
| nn.GELU(), | ||
| nn.Dropout(dropout), | ||
| nn.Linear(n_factors * 4, n_factors), | ||
| nn.Dropout(dropout), | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| attn_mask: tp.Optional[torch.Tensor] = None, | ||
| key_padding_mask: tp.Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| h = self.ln1(x) | ||
| h, _ = self.mha(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) | ||
| x = x + h | ||
| h = self.ln2(x) | ||
| x = x + self.ffn(h) | ||
| return x | ||
|
|
||
|
|
||
| class FlatSASRec(nn.Module): | ||
| """ | ||
| Flat SASRec: sequential recommender with plain id-embedding table | ||
| (no ItemNet hierarchy). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_items : int | ||
| Total number of items (excluding padding token 0). | ||
| n_factors : int | ||
| Embedding / hidden dimension. | ||
| n_blocks : int | ||
| Number of transformer blocks. | ||
| n_heads : int | ||
| Number of attention heads. | ||
| session_max_len : int | ||
| Maximum sequence length. | ||
| dropout : float | ||
| Dropout rate. | ||
| """ | ||
|
|
||
| PADDING_IDX = 0 | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_items: int, | ||
| n_factors: int, | ||
| n_blocks: int, | ||
| n_heads: int, | ||
| session_max_len: int, | ||
| dropout: float = 0.1, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.n_items = n_items | ||
| self.n_factors = n_factors | ||
| self.session_max_len = session_max_len | ||
|
|
||
| # +1 for padding at index 0 | ||
| self.item_emb = nn.Embedding(n_items + 1, n_factors, padding_idx=self.PADDING_IDX) | ||
| self.pos_emb = nn.Embedding(session_max_len, n_factors) | ||
| self.emb_dropout = nn.Dropout(dropout) | ||
|
|
||
| self.blocks = nn.ModuleList([SASRecBlock(n_factors, n_heads, dropout) for _ in range(n_blocks)]) | ||
| self.final_ln = nn.LayerNorm(n_factors) | ||
|
|
||
| def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: | ||
| return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1) | ||
|
|
||
| def encode(self, x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Encode full sequence. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : LongTensor (B, L) | ||
| Item id sequences (0 = padding). | ||
|
|
||
| Returns | ||
| ------- | ||
| Tensor (B, L, D) | ||
| """ | ||
| B, L = x.shape | ||
| positions = torch.arange(L, device=x.device).unsqueeze(0) | ||
| h = self.item_emb(x) + self.pos_emb(positions) | ||
| h = self.emb_dropout(h) | ||
|
|
||
| # timeline_mask: zero out padding positions to prevent NaN from attention | ||
| timeline_mask = (x != self.PADDING_IDX).unsqueeze(-1).float() # (B, L, 1) | ||
| attn_mask = self._causal_mask(L, x.device) | ||
| key_padding_mask = x == self.PADDING_IDX | ||
|
|
||
| for block in self.blocks: | ||
| h = h * timeline_mask | ||
| h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask) | ||
| h = h * timeline_mask | ||
| h = self.final_ln(h) | ||
| return h | ||
|
|
||
| def encode_last(self, x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Encode and return only the last non-padding position representation. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : LongTensor (B, L) | ||
|
|
||
| Returns | ||
| ------- | ||
| Tensor (B, D) | ||
| """ | ||
| h = self.encode(x) # (B, L, D) | ||
| return h[:, -1, :] # left-padded: last position is always rightmost | ||
|
|
||
| def all_item_embeddings(self) -> torch.Tensor: | ||
| """ | ||
| Return embeddings for all items (1..n_items), excluding padding. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tensor (n_items, D) | ||
| """ | ||
| ids = torch.arange(1, self.n_items + 1, device=self.item_emb.weight.device) | ||
| return self.item_emb(ids) | ||
|
|
||
| def forward(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: | ||
| """ | ||
| Training forward pass. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch : dict | ||
| Must contain 'x' (B, L) and 'y' (B, L). | ||
| Optionally 'negatives' (B, L, N) for candidate-logits branch. | ||
|
|
||
| Returns | ||
| ------- | ||
| logits : Tensor | ||
| If negatives present: (B, L, 1 + N) — positive + negative logits. | ||
| Otherwise: (B, L, n_items) — full catalog logits. | ||
| """ | ||
| x = batch["x"] # (B, L) | ||
| y = batch["y"] # (B, L) | ||
|
|
||
| h = self.encode(x) # (B, L, D) | ||
|
|
||
| if "negatives" in batch: | ||
| negatives = batch["negatives"] # (B, L, N) | ||
| pos_emb = self.item_emb(y).unsqueeze(3) # (B, L, D, 1) | ||
| neg_emb = self.item_emb(negatives) # (B, L, N, D) | ||
| neg_emb = neg_emb.transpose(2, 3) # (B, L, D, N) | ||
| all_emb = torch.cat([pos_emb, neg_emb], dim=3) # (B, L, D, 1+N) | ||
| logits = (h.unsqueeze(2) @ all_emb).squeeze(2) # (B, L, 1+N) | ||
| # -> shape is (B, L, 1+N) where first column is positive logit | ||
| else: | ||
| item_embs = self.all_item_embeddings() # (n_items, D) | ||
| logits = h @ item_embs.T # (B, L, n_items) | ||
| return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit weird name, can we remove it?