Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/poetry/console/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from cleo.helpers import option
from packaging.utils import canonicalize_name
from tomlkit import inline_table
from tomlkit import parse

from poetry.console.commands.command import Command
from poetry.console.commands.env_command import EnvCommand
from poetry.factory import Factory
from poetry.utils.dependency_specification import RequirementsParser
from poetry.utils.env.python import Python

Expand Down Expand Up @@ -265,6 +267,14 @@ def _init_pyproject(

return 1

# Validate fields before creating pyproject.toml file. If any validations fail, throw an error.
# Convert TOML string to a TOMLDocument (a dict-like object) for validation.
pyproject_dict = parse(pyproject.data.as_string())
validation_results = self._validate(pyproject_dict)
if validation_results.get("errors"):
self.line_error(f"<error>Validation failed: {validation_results}</error>")
return 1
Comment on lines +274 to +276
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider formatting validation errors for readability

Displaying only relevant error messages or a summary instead of the full dictionary will make the output clearer for users.

Suggested change
if validation_results.get("errors"):
self.line_error(f"<error>Validation failed: {validation_results}</error>")
return 1
if validation_results.get("errors"):
errors = validation_results["errors"]
if isinstance(errors, dict):
error_list = [f"- {field}: {msg}" for field, msg in errors.items()]
elif isinstance(errors, list):
error_list = [f"- {msg}" for msg in errors]
else:
error_list = [str(errors)]
formatted_errors = "\n".join(error_list)
self.line_error(f"<error>Validation failed with the following errors:\n{formatted_errors}</error>")
return 1

Comment on lines +274 to +276
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 suggestion (security): Printing the full validation_results dict may expose internal details

Extract and display only the relevant error messages to prevent leaking internal or sensitive information.

Suggested change
if validation_results.get("errors"):
self.line_error(f"<error>Validation failed: {validation_results}</error>")
return 1
if validation_results.get("errors"):
+ error_messages = validation_results.get("errors")
+ if isinstance(error_messages, dict):
+ error_messages = list(error_messages.values())
+ if isinstance(error_messages, list):
+ for error in error_messages:
+ self.line_error(f"<error>Validation error: {error}</error>")
+ else:
+ self.line_error("<error>Validation failed due to unknown error format.</error>")
+ return 1


pyproject.save()

if create_layout:
Expand Down Expand Up @@ -533,3 +543,13 @@ def _get_pool(self) -> RepositoryPool:
self._pool.add_repository(PyPiRepository(pool_size=pool_size))

return self._pool

@staticmethod
def _validate(pyproject_data: dict[str, Any]) -> dict[str, Any]:
"""
Validates the given pyproject data and returns the validation results.
"""
# Instantiate a new Factory to avoid relying on shared/global state,
# which can cause unexpected behavior in other parts of the codebase or test suite.
factory = Factory()
return factory.validate(pyproject_data)
76 changes: 76 additions & 0 deletions tests/console/commands/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any

import pytest

Expand Down Expand Up @@ -1143,3 +1144,78 @@ def test_get_pool(mocker: MockerFixture, source_dir: Path) -> None:
assert isinstance(command, InitCommand)
pool = command._get_pool()
assert pool.repositories


def build_pyproject_data(
project_name: str, description: str = "A project"
) -> dict[str, Any]:
return {
"project": {
"name": project_name,
"version": "0.1.0",
"description": description,
"authors": [{"name": "Author Name", "email": "author@example.com"}],
"readme": "README.md",
"requires-python": ">=3.13",
"dependencies": [],
},
"tool": {},
"build-system": {
"requires": ["poetry-core>=2.0.0,<3.0.0"],
"build-backend": "poetry.core.masonry.api",
},
}


@pytest.mark.parametrize(
"valid_project_name",
[
"newproject",
"new_project",
"new-project",
"new.project",
"newproject123",
],
)
def test_valid_project_name(valid_project_name: str) -> None:
pyproject_data = build_pyproject_data(valid_project_name)
result = InitCommand._validate(pyproject_data)
assert result["errors"] == []


@pytest.mark.parametrize(
"invalid_project_name, reason",
[
("new+project", "plus sign"),
("new/project", "slash"),
("new@project", "at sign"),
("new project", "space"),
("", "empty string"),
(" newproject", "leading space"),
("newproject ", "trailing space"),
("new#project", "hash (#)"),
("new%project", "percent (%)"),
("new*project", "asterisk (*)"),
("new(project)", "parentheses"),
("-newproject", "leading hyphen"),
("newproject-", "trailing hyphen"),
(".newproject", "leading dot"),
("newproject.", "trailing dot"),
(
"_newproject",
"leading underscore (PEP 621 allows, stricter validators may reject)",
),
(
"newproject_",
"trailing underscore (PEP 621 allows, stricter validators may reject)",
),
("1newproject!", "starts with digit, ends with exclamation"),
(".", "just dot"),
],
)
def test_invalid_project_name(invalid_project_name: str, reason: str) -> None:
pyproject_data = build_pyproject_data(invalid_project_name)
result = InitCommand._validate(pyproject_data)

assert "errors" in result, f"Expected error for: {reason}"
assert any("project.name must match pattern" in err for err in result["errors"])