Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ func (fake *Fake) ListAll(_ context.Context) (map[string][]string, error) {

// List is part of Interface.
func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) {
objectType = canonicalObjectType(objectType)
if _, ok := listableTypes[objectType]; !ok {
return nil, fmt.Errorf("can't List() type %q", objectType)
}

fake.RLock()
defer fake.RUnlock()
if fake.Table == nil {
Expand All @@ -166,29 +171,29 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) {
var result []string

switch objectType {
case "flowtable", "flowtables":
case "flowtable":
for name := range fake.Table.Flowtables {
result = append(result, name)
}
case "chain", "chains":
case "chain":
for name := range fake.Table.Chains {
result = append(result, name)
}
case "set", "sets":
case "set":
for name := range fake.Table.Sets {
result = append(result, name)
}
case "map", "maps":
case "map":
for name := range fake.Table.Maps {
result = append(result, name)
}
case "counter", "counters":
case "counter":
for name := range fake.Table.Counters {
result = append(result, name)
}

default:
return nil, fmt.Errorf("unsupported object type %q", objectType)
return nil, fmt.Errorf("internal error: missing List() support for %q", objectType)
}

return result, nil
Expand Down Expand Up @@ -220,6 +225,10 @@ func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) {

// ListElements is part of Interface
func (fake *Fake) ListElements(_ context.Context, objectType, name string) ([]*Element, error) {
if objectType != "set" && objectType != "map" {
return nil, fmt.Errorf("invalid objectType %q", objectType)
}

fake.RLock()
defer fake.RUnlock()
if fake.Table == nil {
Expand Down
35 changes: 27 additions & 8 deletions nftables.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ type Interface interface {
ListAll(ctx context.Context) (map[string][]string, error)

// List returns a list of the names of the objects of objectType ("chain", "set",
// "map" or "counter") in the table. If there are no such objects, this will
// return an empty list and no error.
// "map", "counter", or "flowtable" in the table. If there are no such objects,
// this will return an empty list and no error.
List(ctx context.Context, objectType string) ([]string, error)

// ListRules returns a list of the rules in a chain, in order. If no chain name is
Expand Down Expand Up @@ -412,17 +412,33 @@ func (nft *realNFTables) ListAll(ctx context.Context) (map[string][]string, erro
return result, nil
}

// Takes objectType, which can be either singular or plural, and returns the singular
// form.
func canonicalObjectType(objectType string) string {
// All currently-existing nftables object types have plural forms that are just
// the singular form plus 's', and none have singular forms ending in 's'.
if objectType[len(objectType)-1] == 's' {
objectType = objectType[:len(objectType)-1]
}
return objectType
}

var listableTypes = map[string]bool{
"chain": true,
"set": true,
"map": true,
"counter": true,
"flowtable": true,
}

// List is part of Interface.
func (nft *realNFTables) List(ctx context.Context, objectType string) ([]string, error) {
if nft.table == "" {
return nil, fmt.Errorf("can't use List() on a knftables.Interface with no associated family/table")
}

// objectType is allowed to be either singular or plural. All currently-existing
// nftables object types have plural forms that are just the singular form plus 's',
// and none have singular forms ending in 's'.
if objectType[len(objectType)-1] == 's' {
objectType = objectType[:len(objectType)-1]
objectType = canonicalObjectType(objectType)
if _, ok := listableTypes[objectType]; !ok {
return nil, fmt.Errorf("can't List() type %q", objectType)
}

// We want to restrict nft to looking only at our table, so we have to do "list table"
Expand Down Expand Up @@ -501,6 +517,9 @@ func (nft *realNFTables) ListElements(ctx context.Context, objectType, name stri
if nft.table == "" {
return nil, fmt.Errorf("can't use ListElements() on a knftables.Interface with no associated family/table")
}
if objectType != "set" && objectType != "map" {
return nil, fmt.Errorf("invalid objectType %q", objectType)
}

cmd := exec.CommandContext(ctx, nft.path, "--json", "list", objectType, string(nft.family), nft.table, name)
out, err := nft.exec.Run(cmd)
Expand Down
Loading