diff --git a/docs/notes/2.32.x.md b/docs/notes/2.32.x.md index 759ba6e1c53..e9b710742de 100644 --- a/docs/notes/2.32.x.md +++ b/docs/notes/2.32.x.md @@ -85,6 +85,8 @@ The `runtime` field of [`aws_python_lambda_layer`](https://www.pantsbuild.org/2. The `grpc-python-plugin` tool now uses an updated `v1.73.1` plugin built from PexBi gen_pex_binary_tgt("subdir.f.py", tags=["overridden"]), gen_pex_binary_tgt("subdir.f:main"), } + + +def test_python_dependency_validation_with_parametrized_resolve_repro() -> None: + rule_runner = RuleRunner( + rules=[ + *target_types_rules.rules(), + *python_sources.rules(), + *python_dependency_inference_rules.rules(), + QueryRule(_TargetParametrizations, [_TargetParametrizationsRequest]), + QueryRule(Addresses, [DependenciesRequest]), + ], + target_types=[PythonSourceTarget, PythonSourcesGeneratorTarget], + objects={"parametrize": Parametrize}, + ) + rule_runner.write_files( + { + "pants.toml": dedent( + """\ + [python] + enable_resolves = true + interpreter_constraints = ["==3.11.*", "==3.14.*"] + interpreter_versions_universe = ["3.11", "3.14"] + default_to_resolve_interpreter_constraints = true + + [python.resolves] + "a" = "a.lock" + "b" = "b.lock" + + [python.resolves_to_interpreter_constraints] + "a" = ["==3.11.*"] + "b" = ["==3.14.*"] + + [python-infer] + imports = true + + [source] + root_patterns = ["/src/python"] + """ + ), + "src/python/pkg/base.py": "class Base:\n pass\n", + "src/python/pkg/derived.py": ( + "from pkg.base import Base\n\nclass Derived(Base):\n pass\n" + ), + "src/python/pkg/BUILD": dedent( + """\ + python_sources( + sources=["base.py", "derived.py"], + **parametrize("b", resolve="b"), + **parametrize("a", resolve="a"), + overrides={ + "derived.py": { + "dependencies": ["./base.py"], + }, + }, + ) + """ + ), + } + ) + rule_runner.set_options(["--pants-config-files=['pants.toml']"]) + generated_targets = [ + tgt + for tgt in rule_runner.request( + _TargetParametrizations, + [ + _TargetParametrizationsRequest( + Address("src/python/pkg"), description_of_origin="tests" + ) + ], + ).parametrizations.values() + if isinstance(tgt, PythonSourceTarget) + ] + assert len(generated_targets) == 4 + base_a = next( + tgt + for tgt in generated_targets + if tgt.address.relative_file_path == "base.py" + and tgt.address.parameters.get("parametrize") == "a" + ) + derived_a = next( + tgt + for tgt in generated_targets + if tgt.address.relative_file_path == "derived.py" + and tgt.address.parameters.get("parametrize") == "a" + ) + + # The generated targets have the correct resolve. + assert base_a[PythonResolveField].value == "a" + assert derived_a[PythonResolveField].value == "a" + validation_field_set = DependencyValidationFieldSet.create(derived_a) + assert validation_field_set.resolve is not None + assert validation_field_set.resolve.value == "a" + + resolved = rule_runner.request( + Addresses, + [DependenciesRequest(derived_a[PythonDependenciesField])], + ) + assert tuple(resolved) == (base_a.address,) + + +class _ReproRequiredField(StringField): + alias = "required" + required = True + + +class _ReproOptionalField(StringField): + alias = "optional" + + +class _ReproTarget(Target): + alias = "repro_target" + core_fields = (_ReproRequiredField,) + + +@dataclass(frozen=True) +class _ReproFieldSet(FieldSet): + required_fields = (_ReproRequiredField,) + + required: _ReproRequiredField + optional: _ReproOptionalField | None = None + + +def test_field_set_optional_union_field_is_none_when_target_lacks_field_repro() -> None: + target = _ReproTarget( + {_ReproRequiredField.alias: "value"}, + Address("src/python/pkg", target_name="repro"), + ) + field_set = _ReproFieldSet.create(target) + + assert field_set.required.value == "value" + assert field_set.optional is None diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 452dbf701ee..8c486ca973c 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -19,6 +19,7 @@ from enum import Enum from operator import attrgetter from pathlib import PurePath +from types import UnionType from typing import ( AbstractSet, Any, @@ -27,8 +28,11 @@ Protocol, Self, TypeVar, + Union, cast, final, + get_args, + get_origin, get_type_hints, ) @@ -1381,13 +1385,24 @@ def gen_tgt(address: Address, full_fp: str, generated_target_fields: dict[str, A # ----------------------------------------------------------------------------------------------- def _get_field_set_fields_from_target( field_set: type[FieldSet], target: Target -) -> dict[str, Field]: - return { - dataclass_field_name: ( - target[field_cls] if field_cls in field_set.required_fields else target.get(field_cls) - ) - for dataclass_field_name, field_cls in field_set.fields.items() - } +) -> dict[str, Field | None]: + result: dict[str, Field | None] = {} + for dataclass_field_name, field_cls in field_set.fields.items(): + if field_cls in field_set.required_fields: + result[dataclass_field_name] = target[field_cls] + continue + + if dataclass_field_name in field_set.none_on_absence_fields and not target.has_field( + field_cls + ): + # Preserve true optionality for `Field | None` annotations when the target type + # doesn't define that field. + result[dataclass_field_name] = None + continue + + result[dataclass_field_name] = target.get(field_cls) + + return result _FS = TypeVar("_FS", bound="FieldSet") @@ -1399,8 +1414,9 @@ class FieldSet(EngineAwareParameter, metaclass=ABCMeta): Subclasses should declare all the fields they consume as dataclass attributes. They should also indicate which of these are required, rather than optional, through the class property - `required_fields`. When a field is optional, the default constructor for the field will be used - for any targets that do not have that field registered. + `required_fields`. For fields annotated as `Field | None`, Pants will preserve `None` when the + target type does not have that field registered. For non-union `Field` annotations, Pants will + construct the default field value when the target type does not have that field registered. Subclasses must set `@dataclass(frozen=True)` for their declared fields to be recognized. @@ -1472,15 +1488,77 @@ def applicable_target_types( def create(cls: type[_FS], tgt: Target) -> _FS: return cls(address=tgt.address, **_get_field_set_fields_from_target(cls, tgt)) + @final + @memoized_classproperty + def _field_info(cls) -> FrozenDict[str, tuple[type[Field], bool]]: + def field_type_from_annotation(annotation: Any) -> tuple[type[Field], bool] | None: + if isinstance(annotation, type) and issubclass(annotation, Field): + return annotation, False + + origin = get_origin(annotation) + if origin not in (Union, UnionType): + return None + + union_args = get_args(annotation) + field_types = [ + arg for arg in union_args if isinstance(arg, type) and issubclass(arg, Field) + ] + if len(field_types) != 1: + return None + + preserve_none_on_absence = type(None) in union_args + # Only allow optional Field annotations (`Field | None`). + if not preserve_none_on_absence or len(union_args) != 2: + return None + + return field_types[0], True + + type_hints = get_type_hints(cls) + base_dataclass_field_names = {f.name for f in dataclasses.fields(FieldSet)} + parsed: dict[str, tuple[type[Field], bool]] = {} + invalid_dataclass_fields: dict[str, Any] = {} + + for dataclass_field in dataclasses.fields(cls): + if dataclass_field.name in base_dataclass_field_names: + continue + + annotation = type_hints[dataclass_field.name] + parsed_annotation = field_type_from_annotation(annotation) + if parsed_annotation is None: + invalid_dataclass_fields[dataclass_field.name] = annotation + continue + + parsed[dataclass_field.name] = parsed_annotation + + if invalid_dataclass_fields: + field_set_name = getattr(cls, "__name__", type(cls).__name__) + invalid_field_descriptions = ", ".join( + f"{name}: {annotation!r}" + for name, annotation in sorted(invalid_dataclass_fields.items()) + ) + raise TypeError( + f"The FieldSet `{field_set_name}` has invalid dataclass field annotations. " + "Every declared dataclass field on a FieldSet must be annotated with a " + "`Field` subclass or `Field | None`. Invalid fields: " + f"{invalid_field_descriptions}" + ) + + return FrozenDict(parsed) + @final @memoized_classproperty def fields(cls) -> FrozenDict[str, type[Field]]: return FrozenDict( - ( - (name, field_type) - for name, field_type in get_type_hints(cls).items() - if isinstance(field_type, type) and issubclass(field_type, Field) - ) + (field_name, field_type) for field_name, (field_type, _) in cls._field_info.items() + ) + + @final + @memoized_classproperty + def none_on_absence_fields(cls) -> FrozenOrderedSet[str]: + return FrozenOrderedSet( + field_name + for field_name, (_, preserve_none_on_absence) in cls._field_info.items() + if preserve_none_on_absence ) def debug_hint(self) -> str: diff --git a/src/python/pants/engine/target_test.py b/src/python/pants/engine/target_test.py index 21796845c62..ea96a10ec36 100644 --- a/src/python/pants/engine/target_test.py +++ b/src/python/pants/engine/target_test.py @@ -643,10 +643,18 @@ class OptionalFieldSet(FieldSet): def opt_out(cls, tgt: Target) -> bool: return tgt.get(OptOutField).value is True + @dataclass(frozen=True) + class OptionalUnionFieldSet(FieldSet): + required_fields = () + + optional: OptionalField | None = None + required_addr = Address("", target_name="required") required_tgt = TargetWithRequired({RequiredField.alias: "configured"}, required_addr) optional_addr = Address("", target_name="unrelated") optional_tgt = TargetWithoutRequired({OptionalField.alias: "configured"}, optional_addr) + optional_default_addr = Address("", target_name="optional_default") + optional_default_tgt = TargetWithoutRequired({}, optional_default_addr) no_fields_addr = Address("", target_name="no_fields") no_fields_tgt = NoFieldsTarget({}, no_fields_addr) opt_out_addr = Address("", target_name="conditional") @@ -682,6 +690,40 @@ def opt_out(cls, tgt: Target) -> bool: assert OptionalFieldSet.create(optional_tgt).optional.value == "configured" assert OptionalFieldSet.create(no_fields_tgt).optional.value == OptionalField.default + assert OptionalUnionFieldSet.fields == FrozenDict({"optional": OptionalField}) + assert OptionalUnionFieldSet.none_on_absence_fields == FrozenOrderedSet(["optional"]) + optional_union_field = OptionalUnionFieldSet.create(optional_tgt).optional + assert optional_union_field is not None + assert optional_union_field.value == "configured" + default_registered_union_field = OptionalUnionFieldSet.create(optional_default_tgt).optional + assert default_registered_union_field is not None + assert default_registered_union_field.value == OptionalField.default + default_union_field = OptionalUnionFieldSet.create(no_fields_tgt).optional + assert default_union_field is None + + +def test_field_set_invalid_annotations() -> None: + class ValidField(StringField): + alias = "valid" + default = "default" + + @dataclass(frozen=True) + class InvalidStringFieldSet(FieldSet): + required_fields = () + + not_a_field: str + + with pytest.raises(TypeError, match="`Field` subclass or `Field \\| None`"): + _ = InvalidStringFieldSet.fields + + @dataclass(frozen=True) + class InvalidUnionFieldSet(FieldSet): + required_fields = () + + invalid: ValidField | str + + with pytest.raises(TypeError, match="`Field` subclass or `Field \\| None`"): + _ = InvalidUnionFieldSet.none_on_absence_fields # -----------------------------------------------------------------------------------------------