Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions cmd/root/flag_suggestions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package root

import (
"fmt"
"strings"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

const (
unknownFlagPrefix = "unknown flag: "
unknownShorthandFlagPrefix = "unknown shorthand flag: "
maxSuggestionDistance = 2
)

// levenshteinDistance computes the edit distance between two strings.
func levenshteinDistance(a, b string) int {
if len(a) == 0 {
return len(b)
}
if len(b) == 0 {
return len(a)
}

// Use a single row for the DP table.
prev := make([]int, len(b)+1)
for j := range len(b) + 1 {
prev[j] = j
}

for i := range len(a) {
curr := make([]int, len(b)+1)
curr[0] = i + 1
for j := range len(b) {
cost := 1
if a[i] == b[j] {
cost = 0
}
curr[j+1] = min(
curr[j]+1, // insertion
prev[j+1]+1, // deletion
prev[j]+cost, // substitution
)
}
prev = curr
}

return prev[len(b)]
}

// suggestFlagFromError inspects the error message from Cobra for "unknown flag" patterns.
// If a close match is found among the command's flags, it returns an enhanced error
// with a "Did you mean" suggestion appended. Otherwise it returns the original error.
func suggestFlagFromError(cmd *cobra.Command, err error) error {
msg := err.Error()

if strings.HasPrefix(msg, unknownShorthandFlagPrefix) {
return suggestShorthandFlag(cmd, err, msg)
}

if strings.HasPrefix(msg, unknownFlagPrefix) {
return suggestLongFlag(cmd, err, msg)
}

return err
}

// suggestLongFlag suggests a matching long flag name for an "unknown flag: --xyz" error.
func suggestLongFlag(cmd *cobra.Command, original error, msg string) error {
// Extract the flag name: "unknown flag: --flagname" -> "flagname"
flagName := strings.TrimPrefix(msg, unknownFlagPrefix)
flagName = strings.TrimPrefix(flagName, "--")
if flagName == "" {
return original
}

best, bestDist := findClosestFlag(cmd, flagName)
if best == "" || bestDist > maxSuggestionDistance {
return original
}

return fmt.Errorf("%w\n\nDid you mean \"--%s\"?", original, best)
}

// suggestShorthandFlag suggests a matching shorthand for an
// "unknown shorthand flag: 'x' in -x" error.
func suggestShorthandFlag(cmd *cobra.Command, original error, msg string) error {
// Extract the shorthand character: "unknown shorthand flag: 'x' in -x"
rest := strings.TrimPrefix(msg, unknownShorthandFlagPrefix)
if len(rest) < 3 || rest[0] != '\'' || rest[2] != '\'' {
return original
}
ch := string(rest[1])

best := findClosestShorthand(cmd, ch)
if best == "" {
return original
}

return fmt.Errorf("%w\n\nDid you mean \"-%s\"?", original, best)
}

// findClosestFlag returns the closest non-hidden, non-deprecated long flag name
// and its edit distance from the given misspelled name.
func findClosestFlag(cmd *cobra.Command, name string) (string, int) {
best := ""
bestDist := maxSuggestionDistance + 1

seen := map[string]bool{}
check := func(f *pflag.Flag) {
if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" {
return
}
if seen[f.Name] {
return
}
seen[f.Name] = true

d := levenshteinDistance(name, f.Name)
if d < bestDist {
bestDist = d
best = f.Name
}
}

cmd.Flags().VisitAll(check)
cmd.InheritedFlags().VisitAll(check)

return best, bestDist
}

// findClosestShorthand returns a case-insensitive exact match for the given
// shorthand character. Levenshtein is not useful for single characters because
// any two distinct characters always have distance 1.
func findClosestShorthand(cmd *cobra.Command, ch string) string {
best := ""
seen := map[string]bool{}
check := func(f *pflag.Flag) {
if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" || f.Shorthand == "" {
return
}
if seen[f.Shorthand] {
return
}
seen[f.Shorthand] = true
if strings.EqualFold(ch, f.Shorthand) {
best = f.Shorthand
}
}
cmd.Flags().VisitAll(check)
cmd.InheritedFlags().VisitAll(check)
return best
}
169 changes: 169 additions & 0 deletions cmd/root/flag_suggestions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package root

import (
"errors"
"fmt"
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)

func TestLevenshteinDistance(t *testing.T) {
tests := []struct {
a, b string
want int
}{
{"", "", 0},
{"abc", "abc", 0},
{"", "abc", 3},
{"abc", "", 3},
{"kitten", "sitting", 3},
{"output", "outpu", 1}, // deletion
{"output", "ouptut", 2}, // transposition = 2 edits
{"output", "outpux", 1}, // substitution
{"output", "outputx", 1}, // insertion
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", tt.a, tt.b), func(t *testing.T) {
assert.Equal(t, tt.want, levenshteinDistance(tt.a, tt.b))
})
}
}

func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")

err := errors.New("unknown flag: --outpu")
got := suggestFlagFromError(cmd, err)
assert.Contains(t, got.Error(), `Did you mean "--output"?`)
assert.Contains(t, got.Error(), "unknown flag: --outpu")
}

func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")

err := errors.New("unknown flag: --zzzzzzz")
got := suggestFlagFromError(cmd, err)
assert.Equal(t, err.Error(), got.Error())
}

func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("output", "o", "", "output format")

err := errors.New("unknown shorthand flag: 'O' in -O")
got := suggestFlagFromError(cmd, err)
assert.Contains(t, got.Error(), `Did you mean "-o"?`)
}

func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("secret", "", "secret flag")
_ = cmd.Flags().MarkHidden("secret")

err := errors.New("unknown flag: --secre")
got := suggestFlagFromError(cmd, err)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("legacy", "", "old flag")
_ = cmd.Flags().MarkDeprecated("legacy", "use --new instead")

err := errors.New("unknown flag: --legac")
got := suggestFlagFromError(cmd, err)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_InheritedFlags(t *testing.T) {
parent := &cobra.Command{Use: "parent"}
parent.PersistentFlags().String("profile", "", "auth profile")

child := &cobra.Command{Use: "child"}
parent.AddCommand(child)

err := errors.New("unknown flag: --profil")
got := suggestFlagFromError(child, err)
assert.Contains(t, got.Error(), `Did you mean "--profile"?`)
}

func TestSuggestFlagFromError_NonFlagError(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")

err := errors.New("flag needs an argument: --output")
got := suggestFlagFromError(cmd, err)
assert.Equal(t, err.Error(), got.Error())
}

func TestSuggestFlagFromError_CobraErrorFormats(t *testing.T) {
tests := []struct {
name string
errMsg string
flags map[string]string
contains string
}{
{
name: "long flag with double dash",
errMsg: "unknown flag: --outpu",
flags: map[string]string{"output": ""},
contains: `"--output"`,
},
{
name: "shorthand with no matching flags",
errMsg: "unknown shorthand flag: 'x' in -x",
flags: map[string]string{},
contains: "unknown shorthand flag: 'x' in -x",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
for name, usage := range tt.flags {
cmd.Flags().String(name, "", usage)
}
err := errors.New(tt.errMsg)
got := suggestFlagFromError(cmd, err)
assert.Contains(t, got.Error(), tt.contains)
})
}
}

func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) {
parent := &cobra.Command{Use: "parent"}
parent.PersistentFlags().String("target", "", "deployment target")

child := &cobra.Command{Use: "child"}
child.Flags().String("target", "", "deployment target")
parent.AddCommand(child)

err := errors.New("unknown flag: --targe")
got := suggestFlagFromError(child, err)

// Should suggest once, not panic or produce duplicate suggestions.
assert.Contains(t, got.Error(), `Did you mean "--target"?`)
}

func TestSuggestFlagFromError_EmptyFlagName(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")
err := errors.New("unknown flag: --")
got := suggestFlagFromError(cmd, err)
assert.Equal(t, err.Error(), got.Error())
}

func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("output", "o", "", "output format")

err := errors.New("unknown shorthand flag: 'z' in -z")
got := suggestFlagFromError(cmd, err)
assert.NotContains(t, got.Error(), "Did you mean")
assert.Equal(t, err.Error(), got.Error())
}
4 changes: 3 additions & 1 deletion cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ func New(ctx context.Context) *cobra.Command {
return cmd
}

// Wrap flag errors to include the usage string.
// flagErrorFunc wraps flag errors to include the usage string and, for unknown
// flags, a "Did you mean" suggestion based on Levenshtein distance.
func flagErrorFunc(c *cobra.Command, err error) error {
err = suggestFlagFromError(c, err)
return fmt.Errorf("%w\n\n%s", err, c.UsageString())
}

Expand Down
Loading