Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies:
- mkdocs <1.6.0
- mkdocs-material >=7.1.1
- mkdocs-material-extensions
- mkdocstrings
- mkdocstrings < 0.28.0
- mkdocstrings-python
- mkdocs-jupyter
- markdown-include
Expand Down
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ dependencies = [
"rdkit"
]

[project.optional-dependencies]
docs = [
"mkdocsi <= 1.6.0",
"mkdocs-material>=7.1.1",
"mkdocs-material-extensions",
"mkdocstrings < 0.28.0",
"mkdocstrings-python",
"mkdocs-jupyter",
"markdown-include",
"mdx_truly_sane_lists",
"mike >=1.0.0",
]

[project.urls]
"Source Code" = "https://github.com/datamol-io/safe"
"Bug Tracker" = "https://github.com/datamol-io/safe/issues"
Expand Down Expand Up @@ -91,10 +104,10 @@ lint.select = [
"F", # see: https://pypi.org/project/pyflakes
]
lint.extend-select = [
"C4", # see: https://pypi.org/project/flake8-comprehensions
"C4", # see: https://pypi.org/project/flake8-comprehensions
"SIM", # see: https://pypi.org/project/flake8-simplify
"RET", # see: https://pypi.org/project/flake8-return
"PT", # see: https://pypi.org/project/flake8-pytest-style
"PT", # see: https://pypi.org/project/flake8-pytest-style
]
lint.ignore = [
"E731", # Do not assign a lambda expression, use a def
Expand Down
10 changes: 2 additions & 8 deletions safe/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import GPT2DoubleHeadsModel, PretrainedConfig
from transformers.activations import get_activation
from transformers.utils import auto_docstring
from transformers.models.gpt2.modeling_gpt2 import (
_CONFIG_FOR_DOC,
GPT2_INPUTS_DOCSTRING,
GPT2DoubleHeadsModelOutput,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)


Expand Down Expand Up @@ -114,8 +111,7 @@ def __init__(self, config):
del self.multiple_choice_head
self.multiple_choice_head = PropertyHead(config)

@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
@auto_docstring()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -149,8 +145,6 @@ def forward(
mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*):
Labels for computing the supervized loss for regularization.
inputs: List of inputs, put here because the trainer removes information not in signature
Returns:
output (GPT2DoubleHeadsModelOutput): output of the model
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
Expand Down
Loading