From 21a3d095c10955f0d6e51a1ecae29ddebefdbb4d Mon Sep 17 00:00:00 2001 From: Dan Winship Date: Wed, 4 Mar 2026 10:30:12 -0500 Subject: [PATCH 1/2] Abstract out object type singular/plural helper --- fake.go | 12 +++++++----- nftables.go | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/fake.go b/fake.go index 0942fa2..ab6aec8 100644 --- a/fake.go +++ b/fake.go @@ -157,6 +157,8 @@ func (fake *Fake) ListAll(_ context.Context) (map[string][]string, error) { // List is part of Interface. func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { + objectType = canonicalObjectType(objectType) + fake.RLock() defer fake.RUnlock() if fake.Table == nil { @@ -166,23 +168,23 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { var result []string switch objectType { - case "flowtable", "flowtables": + case "flowtable": for name := range fake.Table.Flowtables { result = append(result, name) } - case "chain", "chains": + case "chain": for name := range fake.Table.Chains { result = append(result, name) } - case "set", "sets": + case "set": for name := range fake.Table.Sets { result = append(result, name) } - case "map", "maps": + case "map": for name := range fake.Table.Maps { result = append(result, name) } - case "counter", "counters": + case "counter": for name := range fake.Table.Counters { result = append(result, name) } diff --git a/nftables.go b/nftables.go index df920f1..41ffea4 100644 --- a/nftables.go +++ b/nftables.go @@ -412,18 +412,23 @@ func (nft *realNFTables) ListAll(ctx context.Context) (map[string][]string, erro return result, nil } +// Takes objectType, which can be either singular or plural, and returns the singular +// form. +func canonicalObjectType(objectType string) string { + // All currently-existing nftables object types have plural forms that are just + // the singular form plus 's', and none have singular forms ending in 's'. + if objectType[len(objectType)-1] == 's' { + objectType = objectType[:len(objectType)-1] + } + return objectType +} + // List is part of Interface. func (nft *realNFTables) List(ctx context.Context, objectType string) ([]string, error) { if nft.table == "" { return nil, fmt.Errorf("can't use List() on a knftables.Interface with no associated family/table") } - - // objectType is allowed to be either singular or plural. All currently-existing - // nftables object types have plural forms that are just the singular form plus 's', - // and none have singular forms ending in 's'. - if objectType[len(objectType)-1] == 's' { - objectType = objectType[:len(objectType)-1] - } + objectType = canonicalObjectType(objectType) // We want to restrict nft to looking only at our table, so we have to do "list table" // rather than any variant of "list ". From b5aea6c50461cf00e5576eb2b63ef041ae421d94 Mon Sep 17 00:00:00 2001 From: Dan Winship Date: Wed, 4 Mar 2026 10:30:12 -0500 Subject: [PATCH 2/2] Ensure everyone agrees on what types List() and ListElements() accept --- fake.go | 9 ++++++++- nftables.go | 18 ++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/fake.go b/fake.go index ab6aec8..d3184a5 100644 --- a/fake.go +++ b/fake.go @@ -158,6 +158,9 @@ func (fake *Fake) ListAll(_ context.Context) (map[string][]string, error) { // List is part of Interface. func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { objectType = canonicalObjectType(objectType) + if _, ok := listableTypes[objectType]; !ok { + return nil, fmt.Errorf("can't List() type %q", objectType) + } fake.RLock() defer fake.RUnlock() @@ -190,7 +193,7 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { } default: - return nil, fmt.Errorf("unsupported object type %q", objectType) + return nil, fmt.Errorf("internal error: missing List() support for %q", objectType) } return result, nil @@ -222,6 +225,10 @@ func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) { // ListElements is part of Interface func (fake *Fake) ListElements(_ context.Context, objectType, name string) ([]*Element, error) { + if objectType != "set" && objectType != "map" { + return nil, fmt.Errorf("invalid objectType %q", objectType) + } + fake.RLock() defer fake.RUnlock() if fake.Table == nil { diff --git a/nftables.go b/nftables.go index 41ffea4..5d2d57c 100644 --- a/nftables.go +++ b/nftables.go @@ -46,8 +46,8 @@ type Interface interface { ListAll(ctx context.Context) (map[string][]string, error) // List returns a list of the names of the objects of objectType ("chain", "set", - // "map" or "counter") in the table. If there are no such objects, this will - // return an empty list and no error. + // "map", "counter", or "flowtable" in the table. If there are no such objects, + // this will return an empty list and no error. List(ctx context.Context, objectType string) ([]string, error) // ListRules returns a list of the rules in a chain, in order. If no chain name is @@ -423,12 +423,23 @@ func canonicalObjectType(objectType string) string { return objectType } +var listableTypes = map[string]bool{ + "chain": true, + "set": true, + "map": true, + "counter": true, + "flowtable": true, +} + // List is part of Interface. func (nft *realNFTables) List(ctx context.Context, objectType string) ([]string, error) { if nft.table == "" { return nil, fmt.Errorf("can't use List() on a knftables.Interface with no associated family/table") } objectType = canonicalObjectType(objectType) + if _, ok := listableTypes[objectType]; !ok { + return nil, fmt.Errorf("can't List() type %q", objectType) + } // We want to restrict nft to looking only at our table, so we have to do "list table" // rather than any variant of "list ". @@ -506,6 +517,9 @@ func (nft *realNFTables) ListElements(ctx context.Context, objectType, name stri if nft.table == "" { return nil, fmt.Errorf("can't use ListElements() on a knftables.Interface with no associated family/table") } + if objectType != "set" && objectType != "map" { + return nil, fmt.Errorf("invalid objectType %q", objectType) + } cmd := exec.CommandContext(ctx, nft.path, "--json", "list", objectType, string(nft.family), nft.table, name) out, err := nft.exec.Run(cmd)