Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions specfile/sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def __contains__(self, id: object) -> bool:
data = super().__getattribute__("data")
except AttributeError:
return False
return any(s.normalized_id == cast(str, id).lower() for s in data)
id_lower = cast(str, id).lower()
return any(
s.normalized_id == id_lower or s.normalized_name == id_lower for s in data
)
Comment on lines +176 to +179
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __contains__ method currently assumes that the id argument is always a string and calls .lower() on it. Since Sections inherits from UserList[Section], it is expected that __contains__ can also handle Section objects (e.g., when checking section_obj in sections). If a non-string object is passed, this will raise an AttributeError because of the .lower() call.

It is recommended to check if id is a string first, and fall back to a standard list containment check otherwise.

Suggested change
id_lower = cast(str, id).lower()
return any(
s.normalized_id == id_lower or s.normalized_name == id_lower for s in data
)
if not isinstance(id, str):
return id in data
id_lower = id.lower()
return any(
s.normalized_id == id_lower or s.normalized_name == id_lower for s in data
)


def __getattr__(self, id: str) -> Section:
if id not in self:
Expand Down Expand Up @@ -209,8 +212,9 @@ def get(self, id: str) -> Section:
return self.data[self.find(id)]

def find(self, id: str) -> int:
id_lower = id.lower()
for i, section in enumerate(self.data):
if section.normalized_id == id.lower():
if section.normalized_id == id_lower or section.normalized_name == id_lower:
return i
raise ValueError

Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ def test_get():
sections.get("package foo")


def test_contains_and_getattr():
sections = Sections(
[
Section("package"),
Section("package", Options([Token(TokenType.DEFAULT, "baz")]), " "),
Section("install", delimiter=" "),
]
)
assert "package" in sections
assert "install" in sections
assert "install " in sections
assert sections.package == sections[0]
assert getattr(sections, "package baz") == sections[1]
assert sections.install == sections[-1]
assert getattr(sections, "install ") == sections[-1]


@pytest.mark.parametrize(
"id, existing, name, options, content",
[
Expand Down
Loading