Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ You can just get pyproject.toml file from step 6, to start using linters and for

In order to automate checking of the code quality, please run:
```bash
poetry run ruff check .
poetry run black --check --diff -- .
./poetry_wrapper.sh run ruff check .
./poetry_wrapper.sh check
./poetry_wrapper.sh --experimental check
```
Expand Down
9 changes: 8 additions & 1 deletion replay/nn/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .copy import CopyTransform
from .grouping import GroupTransform
from .negative_sampling import MultiClassNegativeSamplingTransform, UniformNegativeSamplingTransform
from .negative_sampling import (
FrequencyNegativeSamplingTransform,
MultiClassNegativeSamplingTransform,
ThresholdNegativeSamplingTransform,
UniformNegativeSamplingTransform,
)
from .next_token import NextTokenTransform
from .rename import RenameTransform
from .reshape import UnsqueezeTransform
Expand All @@ -10,11 +15,13 @@

__all__ = [
"CopyTransform",
"FrequencyNegativeSamplingTransform",
"GroupTransform",
"MultiClassNegativeSamplingTransform",
"NextTokenTransform",
"RenameTransform",
"SequenceRollTransform",
"ThresholdNegativeSamplingTransform",
"TokenMaskTransform",
"TrimTransform",
"UniformNegativeSamplingTransform",
Expand Down
218 changes: 202 additions & 16 deletions replay/nn/transform/negative_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional
import warnings
from typing import Literal, Optional, cast

import torch
import torch.nn.functional as func


class UniformNegativeSamplingTransform(torch.nn.Module):
Expand Down Expand Up @@ -29,7 +31,7 @@ def __init__(
cardinality: int,
num_negative_samples: int,
*,
out_feature_name: Optional[str] = "negative_labels",
out_feature_name: str = "negative_labels",
sample_distribution: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> None:
Expand All @@ -43,12 +45,19 @@ def __init__(
:param generator: Random number generator to be used for sampling
from the distribution. Default: ``None``.
"""
if sample_distribution is not None and sample_distribution.size(-1) != cardinality:
msg = (
"The sample_distribution parameter has an incorrect size. "
f"Got {sample_distribution.size(-1)}, expected {cardinality}."
)
raise ValueError(msg)
if sample_distribution is not None:
if sample_distribution.ndim != 1:
msg: str = (
f"The `sample_distribution` parameter must be 1D.Got {sample_distribution.ndim}, will be flattened."
)
warnings.warn(msg)
sample_distribution = sample_distribution.flatten()
if sample_distribution.size(-1) != cardinality:
msg: str = (
"The sample_distribution parameter has an incorrect size. "
f"Got {sample_distribution.size(-1)}, expected {cardinality}."
)
raise ValueError(msg)

if num_negative_samples >= cardinality:
msg = (
Expand All @@ -62,10 +71,8 @@ def __init__(
self.out_feature_name = out_feature_name
self.num_negative_samples = num_negative_samples
self.generator = generator
if sample_distribution is not None:
self.sample_distribution = sample_distribution
else:
self.sample_distribution = torch.ones(cardinality)
sample_distribution = sample_distribution if sample_distribution is not None else torch.ones(cardinality)
self.sample_distribution = torch.nn.Buffer(cast(torch.Tensor, sample_distribution))

def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
output_batch = dict(batch.items())
Expand All @@ -77,7 +84,186 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
generator=self.generator,
)

output_batch[self.out_feature_name] = negatives.to(device=next(iter(output_batch.values())).device)
device = next(iter(output_batch.values())).device
output_batch[self.out_feature_name] = negatives.to(device)
return output_batch


class FrequencyNegativeSamplingTransform(torch.nn.Module):
"""
Transform for global negative sampling.

For every batch, transform generates a vector of size ``(num_negative_samples)``
consisting of random indices sampeled from a range of ``cardinality``.

Indices frequency will be computed and their sampling will be done
according to their respective frequencies.

Example:

.. code-block:: python

>>> _ = torch.manual_seed(0)
>>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])}
>>> transform = FrequencyNegativeSamplingTransform(cardinality=4, num_negative_samples=2)
>>> output_batch = transform(input_batch)
>>> output_batch
{'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])}

"""

def __init__(
self,
cardinality: int,
num_negative_samples: int,
*,
out_feature_name: str = "negative_labels",
generator: Optional[torch.Generator] = None,
mode: Literal["softmax", "softsum"] = "softmax",
) -> None:
"""
:param cardinality: The size of sample vocabulary.
:param num_negative_samples: The size of negatives vector to generate.
:param out_feature_name: The name of result feature in batch.
:param generator: Random number generator to be used for sampling
from the distribution. Default: ``None``.
:param mode: Mode of frequency-based samping for undersampled items.
Default: ``softmax``.
"""
assert num_negative_samples < cardinality

super().__init__()

self.cardinality = cardinality
self.out_feature_name = out_feature_name
self.num_negative_samples = num_negative_samples
self.generator = generator
self.mode = mode

self.frequencies = torch.nn.Buffer(torch.zeros(cardinality, dtype=torch.int64))

def get_probas(self) -> torch.Tensor:
raw: torch.Tensor = 1.0 / (1.0 + self.frequencies)
if self.mode == "softsum":
result: torch.Tensor = raw / torch.sum(raw)
elif self.mode == "softmax":
result: torch.Tensor = func.softmax(raw, dim=-1)
else:
msg: str = f"Unsupported mode: {self.mode}."
raise TypeError(msg)
return result

def update_probas(self, selected: torch.Tensor) -> None:
device = self.frequencies.device
one = torch.ones(1, dtype=torch.int64, device=device)
self.frequencies.index_add_(-1, selected, one.expand(selected.numel()))

def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
output_batch = dict(batch.items())

negatives = torch.multinomial(
input=self.get_probas(),
num_samples=self.num_negative_samples,
replacement=False,
generator=self.generator,
)

self.update_probas(negatives)

device = next(iter(output_batch.values())).device
output_batch[self.out_feature_name] = negatives.to(device)
return output_batch


class ThresholdNegativeSamplingTransform(torch.nn.Module):
"""
Transform for global negative sampling.

For every batch, transform generates a vector of size ``(num_negative_samples)``
consisting of random indices sampeled from a range of ``cardinality``.

Indices that are oversampled at this point will be ignored, while
other samples will be chosen according to their respective frequency.

Example:

.. code-block:: python

>>> _ = torch.manual_seed(0)
>>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])}
>>> transform = ThresholdNegativeSamplingTransform(cardinality=4, num_negative_samples=2)
>>> output_batch = transform(input_batch)
>>> output_batch
{'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])}

"""

def __init__(
self,
cardinality: int,
num_negative_samples: int,
*,
out_feature_name: str = "negative_labels",
generator: Optional[torch.Generator] = None,
mode: Literal["softmax", "softsum"] = "softmax",
) -> None:
"""
:param cardinality: The size of sample vocabulary.
:param num_negative_samples: The size of negatives vector to generate.
:param out_feature_name: The name of result feature in batch.
:param generator: Random number generator to be used for sampling
from the distribution. Default: ``None``.
:param mode: Mode of frequency-based samping for undersampled items.
Default: ``softmax``.
"""
assert num_negative_samples < cardinality

super().__init__()

self.cardinality = cardinality
self.out_feature_name = out_feature_name
self.num_negative_samples = num_negative_samples
self.generator = generator
self.mode = mode

self.frequencies = torch.nn.Buffer(torch.zeros(cardinality, dtype=torch.int64))

def get_probas(self) -> torch.Tensor:
raw: torch.Tensor = 1.0 / (1.0 + self.frequencies)
thr: torch.Tensor = torch.max(self.frequencies)
mask: torch.Tensor = thr != self.frequencies
if self.mode == "softsum":
eps = torch.finfo(raw.dtype).eps
raw = torch.where(mask, raw, eps)
result: torch.Tensor = raw / torch.sum(raw)
elif self.mode == "softmax":
inf = torch.finfo(raw.dtype).min
raw = torch.where(mask, raw, inf)
result: torch.Tensor = func.softmax(raw, dim=-1)
else:
msg: str = f"Unsupported mode: {self.mode}."
raise TypeError(msg)
return result

def update_probas(self, selected: torch.Tensor) -> None:
device = self.frequencies.device
one = torch.ones(1, dtype=torch.int64, device=device)
self.frequencies.index_add_(-1, selected, one.expand(selected.numel()))

def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
output_batch = dict(batch.items())

negatives = torch.multinomial(
input=self.get_probas(),
num_samples=self.num_negative_samples,
replacement=False,
generator=self.generator,
)

self.update_probas(negatives)

device = next(iter(output_batch.values())).device
output_batch[self.out_feature_name] = negatives.to(device)
return output_batch


Expand Down Expand Up @@ -124,8 +310,8 @@ def __init__(
num_negative_samples: int,
sample_mask: torch.Tensor,
*,
negative_selector_name: Optional[str] = "negative_selector",
out_feature_name: Optional[str] = "negative_labels",
negative_selector_name: str = "negative_selector",
out_feature_name: str = "negative_labels",
generator: Optional[torch.Generator] = None,
) -> None:
"""
Expand Down Expand Up @@ -153,7 +339,7 @@ def __init__(

super().__init__()

self.register_buffer("sample_mask", sample_mask.float())
self.sample_mask = torch.nn.Buffer(sample_mask.float())

self.num_negative_samples = num_negative_samples
self.negative_selector_name = negative_selector_name
Expand Down