diff --git a/src/poetry/console/commands/init.py b/src/poetry/console/commands/init.py index 7412aab95fe..d5ded69f71c 100644 --- a/src/poetry/console/commands/init.py +++ b/src/poetry/console/commands/init.py @@ -13,9 +13,11 @@ from cleo.helpers import option from packaging.utils import canonicalize_name from tomlkit import inline_table +from tomlkit import parse from poetry.console.commands.command import Command from poetry.console.commands.env_command import EnvCommand +from poetry.factory import Factory from poetry.utils.dependency_specification import RequirementsParser from poetry.utils.env.python import Python @@ -265,6 +267,14 @@ def _init_pyproject( return 1 + # Validate fields before creating pyproject.toml file. If any validations fail, throw an error. + # Convert TOML string to a TOMLDocument (a dict-like object) for validation. + pyproject_dict = parse(pyproject.data.as_string()) + validation_results = self._validate(pyproject_dict) + if validation_results.get("errors"): + self.line_error(f"Validation failed: {validation_results}") + return 1 + pyproject.save() if create_layout: @@ -533,3 +543,13 @@ def _get_pool(self) -> RepositoryPool: self._pool.add_repository(PyPiRepository(pool_size=pool_size)) return self._pool + + @staticmethod + def _validate(pyproject_data: dict[str, Any]) -> dict[str, Any]: + """ + Validates the given pyproject data and returns the validation results. + """ + # Instantiate a new Factory to avoid relying on shared/global state, + # which can cause unexpected behavior in other parts of the codebase or test suite. + factory = Factory() + return factory.validate(pyproject_data) diff --git a/tests/console/commands/test_init.py b/tests/console/commands/test_init.py index 7412d75990e..05e6cd4230e 100644 --- a/tests/console/commands/test_init.py +++ b/tests/console/commands/test_init.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from typing import Any import pytest @@ -1143,3 +1144,78 @@ def test_get_pool(mocker: MockerFixture, source_dir: Path) -> None: assert isinstance(command, InitCommand) pool = command._get_pool() assert pool.repositories + + +def build_pyproject_data( + project_name: str, description: str = "A project" +) -> dict[str, Any]: + return { + "project": { + "name": project_name, + "version": "0.1.0", + "description": description, + "authors": [{"name": "Author Name", "email": "author@example.com"}], + "readme": "README.md", + "requires-python": ">=3.13", + "dependencies": [], + }, + "tool": {}, + "build-system": { + "requires": ["poetry-core>=2.0.0,<3.0.0"], + "build-backend": "poetry.core.masonry.api", + }, + } + + +@pytest.mark.parametrize( + "valid_project_name", + [ + "newproject", + "new_project", + "new-project", + "new.project", + "newproject123", + ], +) +def test_valid_project_name(valid_project_name: str) -> None: + pyproject_data = build_pyproject_data(valid_project_name) + result = InitCommand._validate(pyproject_data) + assert result["errors"] == [] + + +@pytest.mark.parametrize( + "invalid_project_name, reason", + [ + ("new+project", "plus sign"), + ("new/project", "slash"), + ("new@project", "at sign"), + ("new project", "space"), + ("", "empty string"), + (" newproject", "leading space"), + ("newproject ", "trailing space"), + ("new#project", "hash (#)"), + ("new%project", "percent (%)"), + ("new*project", "asterisk (*)"), + ("new(project)", "parentheses"), + ("-newproject", "leading hyphen"), + ("newproject-", "trailing hyphen"), + (".newproject", "leading dot"), + ("newproject.", "trailing dot"), + ( + "_newproject", + "leading underscore (PEP 621 allows, stricter validators may reject)", + ), + ( + "newproject_", + "trailing underscore (PEP 621 allows, stricter validators may reject)", + ), + ("1newproject!", "starts with digit, ends with exclamation"), + (".", "just dot"), + ], +) +def test_invalid_project_name(invalid_project_name: str, reason: str) -> None: + pyproject_data = build_pyproject_data(invalid_project_name) + result = InitCommand._validate(pyproject_data) + + assert "errors" in result, f"Expected error for: {reason}" + assert any("project.name must match pattern" in err for err in result["errors"])