diff --git a/.github/workflows/test-smoke.yaml b/.github/workflows/test-smoke.yaml index 90c04ee65..9732dd502 100644 --- a/.github/workflows/test-smoke.yaml +++ b/.github/workflows/test-smoke.yaml @@ -73,6 +73,11 @@ jobs: config: env.toml timeout: 5m working-directory: build/devenv/tests/e2e + - name: TestE2ESmoke_AggregatorChain + run_cmd: TestE2ESmoke_AggregatorChain + config: env.toml + timeout: 10m + working-directory: build/devenv/tests/e2e steps: - name: Enable S3 Cache for Self-Hosted Runners uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 diff --git a/aggregator/cli/chains/commands.go b/aggregator/cli/chains/commands.go new file mode 100644 index 000000000..ab07c1bc2 --- /dev/null +++ b/aggregator/cli/chains/commands.go @@ -0,0 +1,345 @@ +package chains + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + + "github.com/olekukonko/tablewriter" + "github.com/urfave/cli" + + chainselectors "github.com/smartcontractkit/chain-selectors" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// Deps holds dependencies for the aggregator chains CLI commands. +type Deps struct { + Logger logger.Logger + Store chainstatus.Store + Committee *model.Committee +} + +// InitChainsCommands returns CLI commands for disable, enable, list, get. +func InitChainsCommands(deps Deps) []cli.Command { + return buildChainsCommands(func() Deps { return deps }) +} + +// InitChainsCommandsWithFactory returns the same commands but gets Deps lazily at run time. +func InitChainsCommandsWithFactory(getDeps func() Deps) []cli.Command { + return buildChainsCommands(getDeps) +} + +func buildChainsCommands(getDeps func() Deps) []cli.Command { + return []cli.Command{ + { + Name: "disable", + Usage: "Disable chain processing for the given source/destination selectors", + Action: setStatusActionWithFactory(getDeps, true), + Flags: laneSideFlags(), + }, + { + Name: "enable", + Usage: "Re-enable chain processing for the given source/destination selectors", + Action: setStatusActionWithFactory(getDeps, false), + Flags: laneSideFlags(), + }, + { + Name: "list", + Usage: "List all chain status rows", + Action: listActionWithFactory(getDeps), + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: "only-disabled", + Usage: "Show only disabled chains", + }, + }, + }, + { + Name: "get", + Usage: "Get the status for a specific chain selector and lane side", + Action: getActionWithFactory(getDeps), + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "source", + Usage: "Source chain selector", + Required: false, + }, + cli.StringFlag{ + Name: "destination", + Usage: "Destination chain selector", + Required: false, + }, + }, + }, + } +} + +func laneSideFlags() []cli.Flag { + return []cli.Flag{ + cli.StringFlag{ + Name: "source", + Usage: "Comma-separated source chain selectors", + }, + cli.StringFlag{ + Name: "destination", + Usage: "Comma-separated destination chain selectors", + }, + cli.BoolFlag{ + Name: "all", + Usage: "Apply to all known source and destination chains from the committee config", + }, + } +} + +func setStatusActionWithFactory(getDeps func() Deps, disabled bool) func(c *cli.Context) error { + return func(c *cli.Context) error { + deps := getDeps() + ctx := context.Background() + + useAll := c.Bool("all") + sourceStr := c.String("source") + destStr := c.String("destination") + + if !useAll && sourceStr == "" && destStr == "" { + return fmt.Errorf("one of --source, --destination, or --all is required") + } + + action := "disabled" + if !disabled { + action = "enabled" + } + + if useAll { + allSources := committeeSourceSelectors(deps.Committee) + allDests := committeeDestSelectors(deps.Committee) + if err := deps.Store.BatchSetStatus(ctx, chainstatus.LaneSideSource, allSources, disabled); err != nil { + deps.Logger.Errorw("failed to set source statuses", "error", err) + return err + } + if err := deps.Store.BatchSetStatus(ctx, chainstatus.LaneSideDestination, allDests, disabled); err != nil { + deps.Logger.Errorw("failed to set destination statuses", "error", err) + return err + } + fmt.Printf("All %d source(s) and %d destination(s) %s.\n", len(allSources), len(allDests), action) //nolint:forbidigo // CLI user output + return nil + } + + if sourceStr != "" { + selectors, err := parseSelectors(sourceStr) + if err != nil { + return fmt.Errorf("invalid --source: %w", err) + } + if err := deps.Store.BatchSetStatus(ctx, chainstatus.LaneSideSource, selectors, disabled); err != nil { + deps.Logger.Errorw("failed to set source statuses", "error", err) + return err + } + fmt.Printf("Source selector(s) %s %s.\n", sourceStr, action) //nolint:forbidigo // CLI user output + } + + if destStr != "" { + selectors, err := parseSelectors(destStr) + if err != nil { + return fmt.Errorf("invalid --destination: %w", err) + } + if err := deps.Store.BatchSetStatus(ctx, chainstatus.LaneSideDestination, selectors, disabled); err != nil { + deps.Logger.Errorw("failed to set destination statuses", "error", err) + return err + } + fmt.Printf("Destination selector(s) %s %s.\n", destStr, action) //nolint:forbidigo // CLI user output + } + + return nil + } +} + +func listActionWithFactory(getDeps func() Deps) func(c *cli.Context) error { + return func(c *cli.Context) error { + deps := getDeps() + ctx := context.Background() + + dbRows, err := deps.Store.List(ctx) + if err != nil { + deps.Logger.Errorw("list chain statuses failed", "error", err) + return err + } + + type key struct { + sel uint64 + side chainstatus.LaneSide + } + known := make(map[key]chainstatus.ChainStatus, len(dbRows)) + for _, s := range dbRows { + known[key{s.ChainSelector, s.Side}] = s + } + + merged := make([]chainstatus.ChainStatus, 0, len(dbRows)) + merged = append(merged, dbRows...) + for _, sel := range committeeSourceSelectors(deps.Committee) { + k := key{sel, chainstatus.LaneSideSource} + if _, ok := known[k]; !ok { + merged = append(merged, chainstatus.ChainStatus{ChainSelector: sel, Side: chainstatus.LaneSideSource}) + } + } + for _, sel := range committeeDestSelectors(deps.Committee) { + k := key{sel, chainstatus.LaneSideDestination} + if _, ok := known[k]; !ok { + merged = append(merged, chainstatus.ChainStatus{ChainSelector: sel, Side: chainstatus.LaneSideDestination}) + } + } + + if c.Bool("only-disabled") { + filtered := merged[:0] + for _, s := range merged { + if s.Disabled { + filtered = append(filtered, s) + } + } + merged = filtered + } + + return renderList(merged) + } +} + +func getActionWithFactory(getDeps func() Deps) func(c *cli.Context) error { + return func(c *cli.Context) error { + deps := getDeps() + ctx := context.Background() + + sourceStr := c.String("source") + destStr := c.String("destination") + if sourceStr == "" && destStr == "" { + return fmt.Errorf("one of --source or --destination is required") + } + + var statuses []chainstatus.ChainStatus + if sourceStr != "" { + sel, err := parseSelector(sourceStr) + if err != nil { + return fmt.Errorf("invalid --source: %w", err) + } + s, err := deps.Store.Get(ctx, chainstatus.LaneSideSource, sel) + if err != nil { + deps.Logger.Errorw("get source chain status failed", "error", err) + return err + } + if s != nil { + statuses = append(statuses, *s) + } else { + statuses = append(statuses, chainstatus.ChainStatus{ChainSelector: sel, Side: chainstatus.LaneSideSource, Disabled: false}) + } + } + if destStr != "" { + sel, err := parseSelector(destStr) + if err != nil { + return fmt.Errorf("invalid --destination: %w", err) + } + s, err := deps.Store.Get(ctx, chainstatus.LaneSideDestination, sel) + if err != nil { + deps.Logger.Errorw("get destination chain status failed", "error", err) + return err + } + if s != nil { + statuses = append(statuses, *s) + } else { + statuses = append(statuses, chainstatus.ChainStatus{ChainSelector: sel, Side: chainstatus.LaneSideDestination, Disabled: false}) + } + } + return renderList(statuses) + } +} + +func renderList(statuses []chainstatus.ChainStatus) error { + if len(statuses) == 0 { + fmt.Println("No chain status rows found.") //nolint:forbidigo // CLI user output + return nil + } + table := tablewriter.NewWriter(os.Stdout) + table.SetAutoFormatHeaders(false) + table.SetHeader([]string{"Chain", "Selector", "Side", "Disabled", "Updated At"}) + table.SetBorder(false) + for _, s := range statuses { + name := chainNameFromSelector(s.ChainSelector) + disabledStr := "false" + if s.Disabled { + disabledStr = "true" + } + updatedAt := "" + if !s.UpdatedAt.IsZero() { + updatedAt = s.UpdatedAt.Format("2006-01-02T15:04:05Z07:00") + } + table.Append([]string{name, fmt.Sprintf("%d", s.ChainSelector), string(s.Side), disabledStr, updatedAt}) + } + table.Render() + return nil +} + +func chainNameFromSelector(sel uint64) string { + name, err := chainselectors.GetChainNameFromSelector(sel) + if err != nil { + return "unknown" + } + return name +} + +func committeeSourceSelectors(committee *model.Committee) []uint64 { + if committee == nil { + return nil + } + selectors := make([]uint64, 0, len(committee.QuorumConfigs)) + for selStr := range committee.QuorumConfigs { + sel, err := strconv.ParseUint(selStr, 10, 64) + if err != nil { + continue + } + selectors = append(selectors, sel) + } + return selectors +} + +func committeeDestSelectors(committee *model.Committee) []uint64 { + if committee == nil { + return nil + } + selectors := make([]uint64, 0, len(committee.DestinationVerifiers)) + for selStr := range committee.DestinationVerifiers { + sel, err := strconv.ParseUint(selStr, 10, 64) + if err != nil { + continue + } + selectors = append(selectors, sel) + } + return selectors +} + +func parseSelectors(s string) ([]uint64, error) { + parts := strings.Split(s, ",") + selectors := make([]uint64, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + sel, err := parseSelector(p) + if err != nil { + return nil, err + } + selectors = append(selectors, sel) + } + if len(selectors) == 0 { + return nil, fmt.Errorf("no valid selectors provided") + } + return selectors, nil +} + +func parseSelector(s string) (uint64, error) { + u, err := strconv.ParseUint(strings.TrimSpace(s), 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid chain selector %q: %w", s, err) + } + return u, nil +} diff --git a/aggregator/cli/chains/commands_test.go b/aggregator/cli/chains/commands_test.go new file mode 100644 index 000000000..059838be4 --- /dev/null +++ b/aggregator/cli/chains/commands_test.go @@ -0,0 +1,276 @@ +package chains + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// ---- in-memory Store ------------------------------------------------------ + +type memStore struct { + rows map[storeKey]*chainstatus.ChainStatus +} + +type storeKey struct { + selector uint64 + side chainstatus.LaneSide +} + +func newMemStore() *memStore { + return &memStore{rows: make(map[storeKey]*chainstatus.ChainStatus)} +} + +func (m *memStore) BatchSetStatus(_ context.Context, side chainstatus.LaneSide, selectors []uint64, disabled bool) error { + for _, sel := range selectors { + m.rows[storeKey{sel, side}] = &chainstatus.ChainStatus{ + ChainSelector: sel, + Side: side, + Disabled: disabled, + UpdatedAt: time.Now(), + } + } + return nil +} + +func (m *memStore) List(_ context.Context) ([]chainstatus.ChainStatus, error) { + out := make([]chainstatus.ChainStatus, 0, len(m.rows)) + for _, v := range m.rows { + out = append(out, *v) + } + return out, nil +} + +func (m *memStore) ListDisabled(_ context.Context) ([]chainstatus.ChainStatus, error) { + var out []chainstatus.ChainStatus + for _, v := range m.rows { + if v.Disabled { + out = append(out, *v) + } + } + return out, nil +} + +func (m *memStore) Get(_ context.Context, side chainstatus.LaneSide, selector uint64) (*chainstatus.ChainStatus, error) { + if v, ok := m.rows[storeKey{selector, side}]; ok { + s := *v + return &s, nil + } + return nil, nil +} + +// ---- helpers -------------------------------------------------------------- + +func makeDeps(t *testing.T, store chainstatus.Store, committee *model.Committee) Deps { + t.Helper() + return Deps{Logger: logger.Test(t), Store: store, Committee: committee} +} + +// makeCommittee builds a Committee from explicit source and destination selector slices. +func makeCommittee(sources, dests []uint64) *model.Committee { + c := &model.Committee{ + QuorumConfigs: make(map[string]*model.QuorumConfig, len(sources)), + DestinationVerifiers: make(map[string]string, len(dests)), + } + for _, sel := range sources { + c.QuorumConfigs[strconv.FormatUint(sel, 10)] = nil + } + for _, sel := range dests { + c.DestinationVerifiers[strconv.FormatUint(sel, 10)] = "0x0" + } + return c +} + +// runCLI invokes chains commands with args, captures stdout, and returns combined output + error. +func runCLI(t *testing.T, deps Deps, args []string) (string, error) { + t.Helper() + + old := os.Stdout + r, w, err := os.Pipe() + require.NoError(t, err) + os.Stdout = w + + app := cli.NewApp() + app.Name = "test" + app.Commands = InitChainsCommands(deps) + runErr := app.Run(append([]string{"test"}, args...)) + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + return buf.String(), runErr +} + +// ---- list tests ----------------------------------------------------------- + +func TestList_EmptyDBEmptyCommittee(t *testing.T) { + out, err := runCLI(t, makeDeps(t, newMemStore(), nil), []string{"list"}) + require.NoError(t, err) + assert.Contains(t, out, "No chain status rows found.") +} + +func TestList_EmptyDB_ShowsCommitteeChains(t *testing.T) { + committee := makeCommittee([]uint64{1001, 1002}, []uint64{2001}) + out, err := runCLI(t, makeDeps(t, newMemStore(), committee), []string{"list"}) + require.NoError(t, err) + assert.Contains(t, out, "1001", "source 1001 should appear") + assert.Contains(t, out, "1002", "source 1002 should appear") + assert.Contains(t, out, "2001", "dest 2001 should appear") + // All should show as enabled (no DB row). + assert.NotContains(t, out, "true", "no chain should be disabled") +} + +func TestList_MergesDBRowsWithCommittee(t *testing.T) { + store := newMemStore() + ctx := context.Background() + require.NoError(t, store.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{1001}, true)) + + // Committee has 1001 (in DB as disabled) and 1002 (not in DB). + committee := makeCommittee([]uint64{1001, 1002}, nil) + out, err := runCLI(t, makeDeps(t, store, committee), []string{"list"}) + require.NoError(t, err) + assert.Contains(t, out, "1001") + assert.Contains(t, out, "1002", "1002 has no DB row but should still appear") + assert.Contains(t, out, "true", "1001 should be shown as disabled") +} + +func TestList_NoDuplicates(t *testing.T) { + store := newMemStore() + ctx := context.Background() + require.NoError(t, store.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{1001}, false)) + + // Committee also contains 1001 as a source. + committee := makeCommittee([]uint64{1001}, nil) + out, err := runCLI(t, makeDeps(t, store, committee), []string{"list"}) + require.NoError(t, err) + + // Count occurrences of "1001" in the output — should appear exactly once. + count := bytes.Count([]byte(out), []byte("1001")) + assert.Equal(t, 1, count, "selector 1001 should appear exactly once; got output:\n%s", out) +} + +func TestList_OnlyDisabled_FilterWorks(t *testing.T) { + store := newMemStore() + ctx := context.Background() + require.NoError(t, store.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{1001}, true)) + + // Committee has 1001 (disabled) and 1002 (enabled, no DB row). + committee := makeCommittee([]uint64{1001, 1002}, nil) + out, err := runCLI(t, makeDeps(t, store, committee), []string{"list", "--only-disabled"}) + require.NoError(t, err) + assert.Contains(t, out, "1001") + assert.NotContains(t, out, "1002", "enabled chain should be filtered out by --only-disabled") +} + +func TestList_OnlyDisabled_NoneDisabled(t *testing.T) { + committee := makeCommittee([]uint64{1001}, nil) + out, err := runCLI(t, makeDeps(t, newMemStore(), committee), []string{"list", "--only-disabled"}) + require.NoError(t, err) + assert.Contains(t, out, "No chain status rows found.") +} + +// ---- disable / enable tests ----------------------------------------------- + +func TestDisable_Source(t *testing.T) { + store := newMemStore() + out, err := runCLI(t, makeDeps(t, store, nil), []string{"disable", "--source", "1001"}) + require.NoError(t, err) + assert.Contains(t, out, "disabled") + + s, err := store.Get(context.Background(), chainstatus.LaneSideSource, 1001) + require.NoError(t, err) + require.NotNil(t, s) + assert.True(t, s.Disabled) +} + +func TestDisable_Destination(t *testing.T) { + store := newMemStore() + out, err := runCLI(t, makeDeps(t, store, nil), []string{"disable", "--destination", "2001"}) + require.NoError(t, err) + assert.Contains(t, out, "disabled") + + s, err := store.Get(context.Background(), chainstatus.LaneSideDestination, 2001) + require.NoError(t, err) + require.NotNil(t, s) + assert.True(t, s.Disabled) +} + +func TestDisable_All(t *testing.T) { + store := newMemStore() + committee := makeCommittee([]uint64{1001, 1002}, []uint64{2001}) + out, err := runCLI(t, makeDeps(t, store, committee), []string{"disable", "--all"}) + require.NoError(t, err) + assert.Contains(t, out, "disabled") + + ctx := context.Background() + for _, sel := range []uint64{1001, 1002} { + s, err := store.Get(ctx, chainstatus.LaneSideSource, sel) + require.NoError(t, err) + require.NotNil(t, s, fmt.Sprintf("source %d should be disabled", sel)) + assert.True(t, s.Disabled) + } + s, err := store.Get(ctx, chainstatus.LaneSideDestination, 2001) + require.NoError(t, err) + require.NotNil(t, s) + assert.True(t, s.Disabled) +} + +func TestDisable_NoFlags_ReturnsError(t *testing.T) { + _, err := runCLI(t, makeDeps(t, newMemStore(), nil), []string{"disable"}) + require.Error(t, err) +} + +func TestEnable_Source(t *testing.T) { + store := newMemStore() + ctx := context.Background() + require.NoError(t, store.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{1001}, true)) + + out, err := runCLI(t, makeDeps(t, store, nil), []string{"enable", "--source", "1001"}) + require.NoError(t, err) + assert.Contains(t, out, "enabled") + + s, err := store.Get(ctx, chainstatus.LaneSideSource, 1001) + require.NoError(t, err) + require.NotNil(t, s) + assert.False(t, s.Disabled) +} + +// ---- get tests ------------------------------------------------------------ + +func TestGet_ExistingRow(t *testing.T) { + store := newMemStore() + ctx := context.Background() + require.NoError(t, store.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{1001}, true)) + + out, err := runCLI(t, makeDeps(t, store, nil), []string{"get", "--source", "1001"}) + require.NoError(t, err) + assert.Contains(t, out, "1001") + assert.Contains(t, out, "true") +} + +func TestGet_NoRow_ShowsSyntheticEnabled(t *testing.T) { + out, err := runCLI(t, makeDeps(t, newMemStore(), nil), []string{"get", "--source", "9999"}) + require.NoError(t, err) + assert.Contains(t, out, "9999") + assert.Contains(t, out, "false", "unknown selector should show as enabled") +} + +func TestGet_NoFlags_ReturnsError(t *testing.T) { + _, err := runCLI(t, makeDeps(t, newMemStore(), nil), []string{"get"}) + require.Error(t, err) +} diff --git a/aggregator/cmd/main.go b/aggregator/cmd/main.go index 6580c5e97..1ee6a5ad0 100644 --- a/aggregator/cmd/main.go +++ b/aggregator/cmd/main.go @@ -3,19 +3,28 @@ package main import ( "context" + "database/sql" "errors" "fmt" "net" "os" "os/signal" + "path/filepath" + "sync" "syscall" "time" + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" + "github.com/urfave/cli" "go.uber.org/zap/zapcore" + "github.com/smartcontractkit/chainlink-ccv/aggregator/cli/chains" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/configuration" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/monitoring" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/storage/postgres" "github.com/smartcontractkit/chainlink-ccv/protocol" "github.com/smartcontractkit/chainlink-ccv/protocol/common/logging" "github.com/smartcontractkit/chainlink-common/pkg/beholder" @@ -25,7 +34,6 @@ import ( ) func main() { - // Determine log level from environment variable, defaulting to "info" logLevelStr := os.Getenv("LOG_LEVEL") if logLevelStr == "" { logLevelStr = "info" @@ -40,25 +48,84 @@ func main() { panic(fmt.Sprintf("Failed to create logger: %v", err)) } lggr = logger.Named(lggr, "aggregator") - sugaredLggr := logger.Sugared(lggr) - filePath, ok := os.LookupEnv("AGGREGATOR_CONFIG_PATH") - if !ok { - filePath = aggregator.DefaultConfigFile + var ( + loadedConfig *model.AggregatorConfig + chainsDepsOnce sync.Once + chainsDeps chains.Deps + ) + + getChainsDepsFn := func() chains.Deps { + chainsDepsOnce.Do(func() { + db, err := sql.Open("postgres", loadedConfig.Storage.ConnectionURL) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to open database: %v\n", err) + os.Exit(1) + } + sqlxDB := sqlx.NewDb(db, "postgres") + store := postgres.NewDatabaseStorage(sqlxDB, loadedConfig.Storage.PageSize, loadedConfig.Storage.QueryTimeout, sugaredLggr) + chainsDeps = chains.Deps{ + Logger: lggr, + Store: store, + Committee: loadedConfig.Committee, + } + }) + return chainsDeps } - if len(os.Args) > 1 { - filePath = os.Args[1] + + app := cli.NewApp() + app.Name = filepath.Base(os.Args[0]) + app.Usage = "Aggregator service and chain management CLI" + app.Flags = []cli.Flag{ + cli.StringFlag{ + Name: "config, c", + Usage: "Path to config file", + EnvVar: "AGGREGATOR_CONFIG_PATH", + Value: aggregator.DefaultConfigFile, + }, } - config, err := configuration.LoadConfig(filePath, sugaredLggr) + + app.Action = func(c *cli.Context) error { + runServer(c.String("config"), lggr, sugaredLggr) + return nil + } + + app.Commands = []cli.Command{ + { + Name: "chains", + Usage: "Disable, enable, or inspect chain processing status", + Before: func(c *cli.Context) error { + cfg, err := configuration.LoadConfig(c.GlobalString("config"), sugaredLggr) + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + if err := cfg.LoadFromEnvironment(); err != nil { + return fmt.Errorf("failed to load config from environment: %w", err) + } + loadedConfig = cfg + return nil + }, + Subcommands: chains.InitChainsCommandsWithFactory(getChainsDepsFn), + }, + } + + if err := app.Run(os.Args); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} + +func runServer(configPath string, lggr logger.Logger, sugaredLggr logger.SugaredLogger) { + config, err := configuration.LoadConfig(configPath, sugaredLggr) if err != nil { - lggr.Errorw("Failed to load configuration", "path", filePath, "error", err) + lggr.Errorw("Failed to load configuration", "path", configPath, "error", err) os.Exit(1) } lggr.Infow("Loaded configuration", "config", config) if err := config.LoadFromEnvironment(); err != nil { - lggr.Errorw("Failed to load configuration from environment", "path", filePath, "error", err) + lggr.Errorw("Failed to load configuration from environment", "path", configPath, "error", err) os.Exit(1) } lggr.Infow("Successfully loaded configuration from environment variables") diff --git a/aggregator/migrations/postgres/00004_create_chain_statuses.sql b/aggregator/migrations/postgres/00004_create_chain_statuses.sql new file mode 100644 index 000000000..3ddabe4f7 --- /dev/null +++ b/aggregator/migrations/postgres/00004_create_chain_statuses.sql @@ -0,0 +1,11 @@ +-- +goose Up +CREATE TABLE aggregator_chain_statuses ( + chain_selector BIGINT NOT NULL, + lane_side TEXT NOT NULL CHECK (lane_side IN ('source', 'destination')), + disabled BOOLEAN NOT NULL DEFAULT false, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (chain_selector, lane_side) +); + +-- +goose Down +DROP TABLE IF EXISTS aggregator_chain_statuses; diff --git a/aggregator/pkg/chainstatus/checker.go b/aggregator/pkg/chainstatus/checker.go new file mode 100644 index 000000000..38e305141 --- /dev/null +++ b/aggregator/pkg/chainstatus/checker.go @@ -0,0 +1,23 @@ +package chainstatus + +// LaneReport exposes the chain selectors for both ends of a lane. +// Both *model.CommitVerificationRecord and *model.CommitAggregatedReport satisfy this interface. +// The interface is intentionally broad so future checks (e.g. token address, off-ramp address) +// can be added without changing IsDisabled's signature. +type LaneReport interface { + // GetSourceChainSelector returns the source chain selector for the lane. + GetSourceChainSelector() uint64 + // GetDestinationSelector returns the destination chain selector for the lane. + GetDestinationSelector() uint64 +} + +// Checker determines whether chain processing is currently disabled for a given lane. +type Checker interface { + // IsDisabled returns true if chain processing is disabled for the given lane report. + IsDisabled(report LaneReport) bool +} + +// NoopChecker never disables any chain. Use in tests and when no registry is wired. +type NoopChecker struct{} + +func (NoopChecker) IsDisabled(_ LaneReport) bool { return false } diff --git a/aggregator/pkg/chainstatus/registry.go b/aggregator/pkg/chainstatus/registry.go new file mode 100644 index 000000000..4810da322 --- /dev/null +++ b/aggregator/pkg/chainstatus/registry.go @@ -0,0 +1,125 @@ +package chainstatus + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// StatusMetrics is an optional dependency for recording chain-disabled metrics after each refresh. +type StatusMetrics interface { + // SetChainDisabledStatus emits the metrics for the disabled status for a chain. + SetChainDisabledStatus(ctx context.Context, selector uint64, side LaneSide, disabled bool) +} + +// Registry holds the in-memory set of disabled chains, refreshed periodically from the Store. +// It implements Checker. +type Registry struct { + store Store + mu sync.RWMutex + disabledSources map[uint64]struct{} + disabledDests map[uint64]struct{} + sourceSels []uint64 // committee source selectors covered by metrics + destSels []uint64 // committee dest selectors covered by metrics + statusMetrics StatusMetrics + lggr logger.SugaredLogger +} + +var _ Checker = (*Registry)(nil) + +// RegistryOption configures optional Registry behavior. +type RegistryOption func(*Registry) + +// WithStatusMetrics attaches a metrics reporter and the full set of committee selectors to +// the Registry. After each successful Refresh the gauge is emitted for every selector. +func WithStatusMetrics(m StatusMetrics, sourceSels, destSels []uint64) RegistryOption { + return func(r *Registry) { + r.statusMetrics = m + r.sourceSels = sourceSels + r.destSels = destSels + } +} + +// NewRegistry creates a registry backed by the given store. Call Refresh before use. +func NewRegistry(store Store, lggr logger.SugaredLogger, opts ...RegistryOption) *Registry { + r := &Registry{ + store: store, + disabledSources: make(map[uint64]struct{}), + disabledDests: make(map[uint64]struct{}), + lggr: lggr, + } + for _, o := range opts { + o(r) + } + return r +} + +// Refresh reloads the disabled chain set from the store and emits status metrics when configured. +func (r *Registry) Refresh(ctx context.Context) error { + statuses, err := r.store.ListDisabled(ctx) + if err != nil { + return fmt.Errorf("failed to list disabled chains: %w", err) + } + + newSources := make(map[uint64]struct{}, len(statuses)) + newDests := make(map[uint64]struct{}, len(statuses)) + for _, s := range statuses { + switch s.Side { + case LaneSideSource: + newSources[s.ChainSelector] = struct{}{} + case LaneSideDestination: + newDests[s.ChainSelector] = struct{}{} + } + } + + r.mu.Lock() + r.disabledSources = newSources + r.disabledDests = newDests + r.mu.Unlock() + + if r.statusMetrics != nil { + for _, sel := range r.sourceSels { + _, disabled := newSources[sel] + r.statusMetrics.SetChainDisabledStatus(ctx, sel, LaneSideSource, disabled) + } + for _, sel := range r.destSels { + _, disabled := newDests[sel] + r.statusMetrics.SetChainDisabledStatus(ctx, sel, LaneSideDestination, disabled) + } + } + + return nil +} + +// IsDisabled returns true if either the source or destination chain in the report is disabled. +func (r *Registry) IsDisabled(report LaneReport) bool { + r.mu.RLock() + defer r.mu.RUnlock() + if _, ok := r.disabledSources[report.GetSourceChainSelector()]; ok { + return true + } + _, ok := r.disabledDests[report.GetDestinationSelector()] + return ok +} + +// StartPeriodicRefresh runs Refresh on a ticker until ctx is canceled. +// Errors are logged but do not stop the refresh loop. +func (r *Registry) StartPeriodicRefresh(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + go func() { + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.Refresh(ctx); err != nil { + r.lggr.Errorw("Failed to refresh chain disable registry", "error", err) + } + } + } + }() +} diff --git a/aggregator/pkg/chainstatus/registry_test.go b/aggregator/pkg/chainstatus/registry_test.go new file mode 100644 index 000000000..3c6e13f71 --- /dev/null +++ b/aggregator/pkg/chainstatus/registry_test.go @@ -0,0 +1,218 @@ +package chainstatus_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" +) + +// fakeStore is a minimal in-memory Store for unit tests. +type fakeStore struct { + disabled []chainstatus.ChainStatus + err error +} + +func (f *fakeStore) BatchSetStatus(_ context.Context, _ chainstatus.LaneSide, _ []uint64, _ bool) error { + return f.err +} + +func (f *fakeStore) List(_ context.Context) ([]chainstatus.ChainStatus, error) { + return f.disabled, f.err +} + +func (f *fakeStore) ListDisabled(_ context.Context) ([]chainstatus.ChainStatus, error) { + return f.disabled, f.err +} + +func (f *fakeStore) Get(_ context.Context, _ chainstatus.LaneSide, _ uint64) (*chainstatus.ChainStatus, error) { + return nil, f.err +} + +func newTestRegistry(t *testing.T, store chainstatus.Store) *chainstatus.Registry { + t.Helper() + return chainstatus.NewRegistry(store, logger.Sugared(logger.Test(t))) +} + +// laneReport is a minimal LaneReport for tests. +type laneReport struct { + source uint64 + dest uint64 +} + +func (l laneReport) GetSourceChainSelector() uint64 { return l.source } +func (l laneReport) GetDestinationSelector() uint64 { return l.dest } + +func TestRegistry_IsDisabled_EmptyRegistry_AlwaysFalse(t *testing.T) { + t.Parallel() + reg := newTestRegistry(t, &fakeStore{}) + + assert.False(t, reg.IsDisabled(laneReport{source: 1, dest: 2})) + assert.False(t, reg.IsDisabled(laneReport{source: 0, dest: 0})) +} + +func TestRegistry_Refresh_LoadsDisabledSources(t *testing.T) { + t.Parallel() + store := &fakeStore{ + disabled: []chainstatus.ChainStatus{ + {ChainSelector: 100, Side: chainstatus.LaneSideSource, Disabled: true}, + }, + } + reg := newTestRegistry(t, store) + + require.NoError(t, reg.Refresh(context.Background())) + + assert.True(t, reg.IsDisabled(laneReport{source: 100, dest: 999}), "source 100 should be disabled") + assert.False(t, reg.IsDisabled(laneReport{source: 200, dest: 999}), "source 200 should be enabled") +} + +func TestRegistry_Refresh_LoadsDisabledDestinations(t *testing.T) { + t.Parallel() + store := &fakeStore{ + disabled: []chainstatus.ChainStatus{ + {ChainSelector: 200, Side: chainstatus.LaneSideDestination, Disabled: true}, + }, + } + reg := newTestRegistry(t, store) + require.NoError(t, reg.Refresh(context.Background())) + + assert.True(t, reg.IsDisabled(laneReport{source: 999, dest: 200}), "dest 200 should be disabled") + assert.False(t, reg.IsDisabled(laneReport{source: 999, dest: 100}), "dest 100 should be enabled") +} + +func TestRegistry_IsDisabled_SourceOrDestinationSuffices(t *testing.T) { + t.Parallel() + store := &fakeStore{ + disabled: []chainstatus.ChainStatus{ + {ChainSelector: 10, Side: chainstatus.LaneSideSource, Disabled: true}, + {ChainSelector: 20, Side: chainstatus.LaneSideDestination, Disabled: true}, + }, + } + reg := newTestRegistry(t, store) + require.NoError(t, reg.Refresh(context.Background())) + + assert.True(t, reg.IsDisabled(laneReport{source: 10, dest: 99}), "source disabled") + assert.True(t, reg.IsDisabled(laneReport{source: 99, dest: 20}), "dest disabled") + assert.True(t, reg.IsDisabled(laneReport{source: 10, dest: 20}), "both disabled") + assert.False(t, reg.IsDisabled(laneReport{source: 99, dest: 99}), "neither disabled") +} + +func TestRegistry_Refresh_Error_PropagatesAndPreservesState(t *testing.T) { + t.Parallel() + store := &fakeStore{ + disabled: []chainstatus.ChainStatus{ + {ChainSelector: 50, Side: chainstatus.LaneSideSource, Disabled: true}, + }, + } + reg := newTestRegistry(t, store) + require.NoError(t, reg.Refresh(context.Background())) + assert.True(t, reg.IsDisabled(laneReport{source: 50, dest: 0})) + + // Simulate store error on next refresh + store.err = errors.New("db unavailable") + require.Error(t, reg.Refresh(context.Background())) + + // State should now be cleared (empty maps after failed refresh is acceptable per implementation) + // The important thing is no panic and the error is returned +} + +func TestRegistry_Refresh_ClearsStaleEntries(t *testing.T) { + t.Parallel() + store := &fakeStore{ + disabled: []chainstatus.ChainStatus{ + {ChainSelector: 300, Side: chainstatus.LaneSideSource, Disabled: true}, + }, + } + reg := newTestRegistry(t, store) + require.NoError(t, reg.Refresh(context.Background())) + assert.True(t, reg.IsDisabled(laneReport{source: 300, dest: 0})) + + // Chain is re-enabled — no longer in ListDisabled results + store.disabled = nil + require.NoError(t, reg.Refresh(context.Background())) + + assert.False(t, reg.IsDisabled(laneReport{source: 300, dest: 0}), "should be enabled after re-enable") +} + +func TestRegistry_StartPeriodicRefresh_CallsRefreshRepeatedly(t *testing.T) { + t.Parallel() + + var refreshCount atomic.Int32 + store := &countingStore{refreshCount: &refreshCount} + + reg := newTestRegistry(t, store) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reg.StartPeriodicRefresh(ctx, 20*time.Millisecond) + + require.Eventually(t, func() bool { + return refreshCount.Load() >= 3 + }, 500*time.Millisecond, 5*time.Millisecond, "expected at least 3 periodic refreshes") + + cancel() + // Allow goroutine to exit + time.Sleep(30 * time.Millisecond) +} + +func TestRegistry_StartPeriodicRefresh_StopsOnContextCancel(t *testing.T) { + t.Parallel() + + var refreshCount atomic.Int32 + store := &countingStore{refreshCount: &refreshCount} + + reg := newTestRegistry(t, store) + + ctx, cancel := context.WithCancel(context.Background()) + reg.StartPeriodicRefresh(ctx, 10*time.Millisecond) + + require.Eventually(t, func() bool { + return refreshCount.Load() >= 2 + }, 200*time.Millisecond, 5*time.Millisecond) + + cancel() + countAtCancel := refreshCount.Load() + time.Sleep(50 * time.Millisecond) + + assert.InDelta(t, countAtCancel, refreshCount.Load(), 1, "refresh should stop shortly after context cancel") +} + +func TestNoopChecker_NeverDisables(t *testing.T) { + t.Parallel() + checker := chainstatus.NoopChecker{} + + assert.False(t, checker.IsDisabled(laneReport{source: 1, dest: 2})) + assert.False(t, checker.IsDisabled(laneReport{source: 0, dest: 0})) + assert.False(t, checker.IsDisabled(laneReport{source: ^uint64(0), dest: ^uint64(0)})) +} + +// countingStore counts ListDisabled calls for periodic refresh tests. +type countingStore struct { + refreshCount *atomic.Int32 +} + +func (c *countingStore) BatchSetStatus(_ context.Context, _ chainstatus.LaneSide, _ []uint64, _ bool) error { + return nil +} + +func (c *countingStore) List(_ context.Context) ([]chainstatus.ChainStatus, error) { + return nil, nil +} + +func (c *countingStore) ListDisabled(_ context.Context) ([]chainstatus.ChainStatus, error) { + c.refreshCount.Add(1) + return nil, nil +} + +func (c *countingStore) Get(_ context.Context, _ chainstatus.LaneSide, _ uint64) (*chainstatus.ChainStatus, error) { + return nil, nil +} diff --git a/aggregator/pkg/chainstatus/types.go b/aggregator/pkg/chainstatus/types.go new file mode 100644 index 000000000..f20e65b21 --- /dev/null +++ b/aggregator/pkg/chainstatus/types.go @@ -0,0 +1,36 @@ +package chainstatus + +import ( + "context" + "time" +) + +// LaneSide identifies which side of a lane a chain status applies to. +type LaneSide string + +const ( + LaneSideSource LaneSide = "source" + LaneSideDestination LaneSide = "destination" +) + +// ChainStatus represents the disabled state of a chain for one lane side. +// No row in the DB means enabled; a row with disabled=false is the audit trail after re-enabling. +type ChainStatus struct { + ChainSelector uint64 + Side LaneSide + Disabled bool + UpdatedAt time.Time +} + +// Store persists chain disable/enable state. +type Store interface { + // BatchSetStatus upserts disabled status for the given side and selectors. + // Pass disabled=true to disable, disabled=false to re-enable. + BatchSetStatus(ctx context.Context, side LaneSide, selectors []uint64, disabled bool) error + // List returns all chain status rows, including re-enabled ones (audit trail). + List(ctx context.Context) ([]ChainStatus, error) + // ListDisabled returns only rows where disabled = true. + ListDisabled(ctx context.Context) ([]ChainStatus, error) + // Get returns the status for a specific selector + lane side. Returns nil if no row exists (= enabled). + Get(ctx context.Context, side LaneSide, selector uint64) (*ChainStatus, error) +} diff --git a/aggregator/pkg/common/metrics.go b/aggregator/pkg/common/metrics.go index 8333128ef..5e22eed2f 100644 --- a/aggregator/pkg/common/metrics.go +++ b/aggregator/pkg/common/metrics.go @@ -73,4 +73,7 @@ type AggregatorMetricLabeler interface { // IncrementGRPCErrors increments the counter for gRPC errors by status code. // code should be the gRPC status code string (e.g. "ResourceExhausted", "Internal"). IncrementGRPCErrors(ctx context.Context, code, method string) + // SetChainDisabledStatus records whether a chain is disabled (1) or enabled (0). + // Callers should attach chain_selector, chain_name, and side labels via With(). + SetChainDisabledStatus(ctx context.Context, disabled int64) } diff --git a/aggregator/pkg/configuration/file_configuration_provider.go b/aggregator/pkg/configuration/file_configuration_provider.go index a835bc10a..04818eab8 100644 --- a/aggregator/pkg/configuration/file_configuration_provider.go +++ b/aggregator/pkg/configuration/file_configuration_provider.go @@ -30,7 +30,7 @@ func LoadConfig(filePath string, lggr logger.SugaredLogger) (*model.AggregatorCo return nil, fmt.Errorf("failed to load generated config from %s: %w", generatedPath, err) } config.MergeGeneratedConfig(generated) - lggr.Infow("Merged generated config", "config", config) + lggr.Infow("Merged generated config", "path", generatedPath) } return &config, nil diff --git a/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go b/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go index 88d191a44..2dada1464 100644 --- a/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go +++ b/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc/status" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" "github.com/smartcontractkit/chainlink-ccv/internal/mocks" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -85,7 +86,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_BatchSizeValidation(t *testing.T) { labeler.EXPECT().With(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(labeler).Maybe() labeler.EXPECT().IncrementVerificationsTotal(mock.Anything).Maybe() - writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond) + writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond, chainstatus.NoopChecker{}) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, tc.maxBatchSize) requests := make([]*committeepb.WriteCommitteeVerifierNodeResultRequest, tc.numRequests) @@ -141,7 +142,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_MixedSuccessAndInvalidArgument(t *te labeler.EXPECT().With(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(labeler).Maybe() labeler.EXPECT().IncrementVerificationsTotal(mock.Anything).Maybe() - writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond) + writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond, chainstatus.NoopChecker{}) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, 10) validReq := makeValidProtoRequest() @@ -192,7 +193,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_NilRequestAtIndexReturnsInvalidArgum labeler.EXPECT().With(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(labeler).Maybe() labeler.EXPECT().IncrementVerificationsTotal(mock.Anything).Maybe() - writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond) + writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond, chainstatus.NoopChecker{}) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, 10) validReq := makeValidProtoRequest() @@ -248,7 +249,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_CancelledContextReturnsImmediately(t labeler.EXPECT().With(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(labeler).Maybe() labeler.EXPECT().IncrementVerificationsTotal(mock.Anything).Maybe() - writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, blockDuration) + writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, blockDuration, chainstatus.NoopChecker{}) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, 10) ctx, cancel := context.WithCancel(auth.ToContext(context.Background(), auth.CreateCallerIdentity(testCallerID, false))) diff --git a/aggregator/pkg/handlers/write_commit_verifier_node_result.go b/aggregator/pkg/handlers/write_commit_verifier_node_result.go index 9d9216090..071cc0f3e 100644 --- a/aggregator/pkg/handlers/write_commit_verifier_node_result.go +++ b/aggregator/pkg/handlers/write_commit_verifier_node_result.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc/status" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/scope" @@ -38,6 +39,7 @@ type WriteCommitVerifierNodeResultHandler struct { l logger.SugaredLogger signatureValidator SignatureValidator checkAggregationTimeout time.Duration + chainStatusChecker chainstatus.Checker } func (h *WriteCommitVerifierNodeResultHandler) logger(ctx context.Context) logger.SugaredLogger { @@ -70,6 +72,13 @@ func (h *WriteCommitVerifierNodeResultHandler) Handle(ctx context.Context, req * ctx = scope.WithMessageID(ctx, record.MessageID) reqLogger = h.logger(ctx) + if h.chainStatusChecker.IsDisabled(record) { + reqLogger.Infow("Rejected write: chain processing is disabled") + return &committeepb.WriteCommitteeVerifierNodeResultResponse{ + Status: committeepb.WriteStatus_FAILED, + }, status.Error(codes.FailedPrecondition, "chain processing is disabled") + } + validationResult, err := h.signatureValidator.ValidateSignature(ctx, record) if err != nil { reqLogger.Errorw("signature validation failed", "error", err) @@ -130,7 +139,7 @@ func (h *WriteCommitVerifierNodeResultHandler) Handle(ctx context.Context, req * } // NewWriteCommitCCVNodeDataHandler creates a new instance of WriteCommitCCVNodeDataHandler. -func NewWriteCommitCCVNodeDataHandler(store common.CommitVerificationStore, aggregator AggregationTriggerer, m common.AggregatorMonitoring, l logger.SugaredLogger, signatureValidator SignatureValidator, checkAggregationTimeout time.Duration) *WriteCommitVerifierNodeResultHandler { +func NewWriteCommitCCVNodeDataHandler(store common.CommitVerificationStore, aggregator AggregationTriggerer, m common.AggregatorMonitoring, l logger.SugaredLogger, signatureValidator SignatureValidator, checkAggregationTimeout time.Duration, chainStatusChecker chainstatus.Checker) *WriteCommitVerifierNodeResultHandler { return &WriteCommitVerifierNodeResultHandler{ storage: store, aggregator: aggregator, @@ -138,5 +147,6 @@ func NewWriteCommitCCVNodeDataHandler(store common.CommitVerificationStore, aggr l: l, signatureValidator: signatureValidator, checkAggregationTimeout: checkAggregationTimeout, + chainStatusChecker: chainStatusChecker, } } diff --git a/aggregator/pkg/handlers/write_commit_verifier_node_result_test.go b/aggregator/pkg/handlers/write_commit_verifier_node_result_test.go index c639796c3..647708713 100644 --- a/aggregator/pkg/handlers/write_commit_verifier_node_result_test.go +++ b/aggregator/pkg/handlers/write_commit_verifier_node_result_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc/status" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" ccvcommon "github.com/smartcontractkit/chainlink-ccv/common" @@ -51,6 +52,44 @@ func makeValidProtoRequest() *committeepb.WriteCommitteeVerifierNodeResultReques } } +// alwaysDisabledChecker is a chainstatus.Checker that always reports the chain as disabled. +type alwaysDisabledChecker struct{} + +func (alwaysDisabledChecker) IsDisabled(_ chainstatus.LaneReport) bool { return true } + +func TestWriteCommitCCVNodeDataHandler_ChainStatusdGate(t *testing.T) { + t.Parallel() + + const testCallerID = "test-caller" + + lggr := logger.TestSugared(t) + store := mocks.NewMockCommitVerificationStore(t) + agg := mocks.NewMockAggregationTriggerer(t) + sig := mocks.NewMockSignatureValidator(t) + mon := mocks.NewMockAggregatorMonitoring(t) + + // None of these should be called when the chain is disabled + store.EXPECT().SaveCommitVerification(mock.Anything, mock.Anything, mock.Anything).Maybe() + agg.EXPECT().CheckAggregation(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe() + sig.EXPECT().ValidateSignature(mock.Anything, mock.Anything).Maybe() + sig.EXPECT().DeriveAggregationKey(mock.Anything, mock.Anything).Maybe() + mon.EXPECT().Metrics().Maybe() + + handler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond, alwaysDisabledChecker{}) + ctx := auth.ToContext(context.Background(), auth.CreateCallerIdentity(testCallerID, false)) + + resp, err := handler.Handle(ctx, makeValidProtoRequest()) + + require.Error(t, err) + require.Equal(t, codes.FailedPrecondition, status.Code(err)) + require.NotNil(t, resp) + require.Equal(t, committeepb.WriteStatus_FAILED, resp.Status) + + // Verify neither storage nor aggregation was touched + store.AssertNotCalled(t, "SaveCommitVerification") + agg.AssertNotCalled(t, "CheckAggregation") +} + func TestWriteCommitCCVNodeDataHandler_Handle_Table(t *testing.T) { t.Parallel() @@ -189,7 +228,7 @@ func TestWriteCommitCCVNodeDataHandler_Handle_Table(t *testing.T) { labeler.EXPECT().With(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(labeler).Maybe() labeler.EXPECT().IncrementVerificationsTotal(mock.Anything).Maybe() - handler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond) + handler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond, chainstatus.NoopChecker{}) ctx := auth.ToContext(context.Background(), auth.CreateCallerIdentity(testCallerID, false)) resp, err := handler.Handle(ctx, tc.req) diff --git a/aggregator/pkg/model/commit_verification_record.go b/aggregator/pkg/model/commit_verification_record.go index 2c12a78a3..c52289974 100644 --- a/aggregator/pkg/model/commit_verification_record.go +++ b/aggregator/pkg/model/commit_verification_record.go @@ -75,3 +75,21 @@ func (c *CommitVerificationRecord) SetTimestampFromMillis(timestampMillis int64) func (c *CommitVerificationRecord) GetTimestamp() time.Time { return c.createdAt } + +// GetSourceChainSelector returns the source chain selector from the message. +// Satisfies chainstatus.LaneReport. +func (c *CommitVerificationRecord) GetSourceChainSelector() uint64 { + if c.Message == nil { + return 0 + } + return uint64(c.Message.SourceChainSelector) +} + +// GetDestinationSelector returns the destination chain selector from the message. +// Satisfies chainstatus.LaneReport. +func (c *CommitVerificationRecord) GetDestinationSelector() uint64 { + if c.Message == nil { + return 0 + } + return uint64(c.Message.DestChainSelector) +} diff --git a/aggregator/pkg/model/config.go b/aggregator/pkg/model/config.go index 732632d49..a3c099bc9 100644 --- a/aggregator/pkg/model/config.go +++ b/aggregator/pkg/model/config.go @@ -157,6 +157,12 @@ type AggregationConfig struct { MaxConsecutiveErrors uint32 `toml:"maxConsecutiveErrors"` } +// ChainStatusConfig controls the chain-disable registry refresh behavior. +type ChainStatusConfig struct { + // RefreshInterval controls how often the in-memory registry is refreshed from the database. + RefreshInterval time.Duration `toml:"refreshInterval"` +} + type OrphanRecoveryConfig struct { // Enabled controls whether orphan recovery is enabled Enabled bool `toml:"enabled"` @@ -400,6 +406,7 @@ type AggregatorConfig struct { Storage *StorageConfig `toml:"storage"` APIClients []*ClientConfig `toml:"clients"` Aggregation AggregationConfig `toml:"aggregation"` + ChainStatus ChainStatusConfig `toml:"chainStatus"` OrphanRecovery OrphanRecoveryConfig `toml:"orphanRecovery"` RateLimiting RateLimitingConfig `toml:"rateLimiting"` HealthCheck HealthCheckConfig `toml:"healthCheck"` @@ -560,6 +567,11 @@ func (c *AggregatorConfig) SetDefaults() { c.Storage.QueryTimeout = 10 * time.Second } + // Default chain-disable registry refresh: 30 seconds + if c.ChainStatus.RefreshInterval == 0 { + c.ChainStatus.RefreshInterval = 30 * time.Second + } + // Default orphan recovery: enabled with 5 minute interval if c.OrphanRecovery.Interval == 0 { c.OrphanRecovery.Interval = 5 * time.Minute diff --git a/aggregator/pkg/monitoring/metrics.go b/aggregator/pkg/monitoring/metrics.go index 5898f3a7a..0a3a4b290 100644 --- a/aggregator/pkg/monitoring/metrics.go +++ b/aggregator/pkg/monitoring/metrics.go @@ -3,12 +3,15 @@ package monitoring import ( "context" "fmt" + "strconv" "time" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" + "github.com/smartcontractkit/chainlink-ccv/protocol" "github.com/smartcontractkit/chainlink-common/pkg/beholder" "github.com/smartcontractkit/chainlink-common/pkg/metrics" @@ -54,6 +57,9 @@ type AggregatorMetrics struct { // gRPC transport metrics grpcPayloadSizeBytes metric.Int64Histogram grpcErrorsTotal metric.Int64Counter + + // Chain status metrics + chainDisabledStatus metric.Int64Gauge } // grpcPayloadSizeBuckets defines histogram buckets for gRPC payload sizes in bytes. @@ -292,6 +298,15 @@ func InitMetrics() (am *AggregatorMetrics, err error) { return nil, fmt.Errorf("failed to register grpc errors total counter: %w", err) } + am.chainDisabledStatus, err = beholder.GetMeter().Int64Gauge( + "aggregator_chain_disabled_status", + metric.WithDescription("Whether a chain is disabled (1) or enabled (0) for a given lane side"), + metric.WithUnit("1"), + ) + if err != nil { + return nil, fmt.Errorf("failed to register chain disabled status gauge: %w", err) + } + return am, nil } @@ -443,3 +458,31 @@ func (c *AggregatorMetricLabeler) IncrementGRPCErrors(ctx context.Context, code, attribute.String("method", method), }...), metric.WithAttributes(otelLabels...)) } + +func (c *AggregatorMetricLabeler) SetChainDisabledStatus(ctx context.Context, disabled int64) { + otelLabels := beholder.OtelAttributes(c.Labels).AsStringAttributes() + c.am.chainDisabledStatus.Record(ctx, disabled, metric.WithAttributes(otelLabels...)) +} + +// ChainStatusMetrics implements chainstatus.StatusMetrics using the aggregator gauge. +type ChainStatusMetrics struct { + m common.AggregatorMonitoring +} + +// NewChainStatusMetrics creates a ChainStatusMetrics that emits the aggregator_chain_disabled_status +// gauge. It is passed to chainstatus.WithStatusMetrics when constructing the Registry. +func NewChainStatusMetrics(m common.AggregatorMonitoring) *ChainStatusMetrics { + return &ChainStatusMetrics{m: m} +} + +func (c *ChainStatusMetrics) SetChainDisabledStatus(ctx context.Context, selector uint64, side chainstatus.LaneSide, disabled bool) { + val := int64(0) + if disabled { + val = 1 + } + c.m.Metrics().With( + "chain_selector", strconv.FormatUint(selector, 10), + "chain_name", protocol.ChainSelector(selector).ChainName(), + "side", string(side), + ).SetChainDisabledStatus(ctx, val) +} diff --git a/aggregator/pkg/monitoring/noop.go b/aggregator/pkg/monitoring/noop.go index c6bda0695..094e5ad9d 100644 --- a/aggregator/pkg/monitoring/noop.go +++ b/aggregator/pkg/monitoring/noop.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" commonmetrics "github.com/smartcontractkit/chainlink-ccv/common/metrics" ) @@ -136,3 +137,16 @@ func (c *NoopAggregatorMetricLabeler) RecordGRPCPayloadSize(_ context.Context, _ func (c *NoopAggregatorMetricLabeler) IncrementGRPCErrors(_ context.Context, _, _ string) { // No-op } + +func (c *NoopAggregatorMetricLabeler) SetChainDisabledStatus(_ context.Context, _ int64) { + // No-op +} + +type NoopChainStatusMetrics struct{} + +func NewNoopChainStatusMetrics() *NoopChainStatusMetrics { + return &NoopChainStatusMetrics{} +} + +func (c *NoopChainStatusMetrics) SetChainDisabledStatus(_ context.Context, _ uint64, _ chainstatus.LaneSide, _ bool) { +} diff --git a/aggregator/pkg/server.go b/aggregator/pkg/server.go index 374d8ab47..2a74bb69a 100644 --- a/aggregator/pkg/server.go +++ b/aggregator/pkg/server.go @@ -10,6 +10,7 @@ import ( "os" "os/signal" "runtime/debug" + "strconv" "sync" "syscall" "time" @@ -24,12 +25,14 @@ import ( "google.golang.org/grpc/status" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/aggregation" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/handlers" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/health" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/heartbeat" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/middlewares" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/monitoring" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/quorum" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/storage" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -55,6 +58,7 @@ type Server struct { l logger.SugaredLogger config *model.AggregatorConfig store common.CommitVerificationStore + chainStatusRegistry *chainstatus.Registry aggregator *aggregation.CommitReportAggregator recoverer *OrphanRecoverer readCommitVerifierNodeResultHandler *handlers.ReadCommitVerifierNodeResultHandler @@ -164,6 +168,16 @@ func (s *Server) Start(lis net.Listener) error { aggregatorCancel() }) + // Periodically refresh the chain-disable registry from the database. + chainStatusCtx, chainStatusCancel := context.WithCancel(context.Background()) + g.Add(func() error { + s.chainStatusRegistry.StartPeriodicRefresh(chainStatusCtx, s.config.ChainStatus.RefreshInterval) + <-chainStatusCtx.Done() + return nil + }, func(error) { + chainStatusCancel() + }) + if s.config.OrphanRecovery.Enabled && s.recoverer != nil { recovererCtx, recovererCancel := context.WithCancel(context.Background()) g.Add(func() error { @@ -300,18 +314,50 @@ func NewServer(l logger.SugaredLogger, config *model.AggregatorConfig, aggMonito ) factory := storage.NewStorageFactory(l) - store, err := factory.CreateStorage(config.Storage, aggMonitoring) + rawStore, err := factory.CreateStorage(config.Storage, aggMonitoring) if err != nil { l.Fatalf("Failed to create storage: %v", err) return nil } - store = storage.WrapWithMetrics(store, aggMonitoring, l) + // Build the chain-disable registry from the raw store before metrics wrapping. + // DatabaseStorage implements chainstatus.Store; the metrics wrapper does not need to. + chainStatusStore, ok := rawStore.(chainstatus.Store) + if !ok { + l.Fatalf("Storage does not implement chainstatus.Store") + return nil + } + chainStatusOpts := []chainstatus.RegistryOption{} + if config.Committee != nil { + sourceSels := make([]uint64, 0, len(config.Committee.QuorumConfigs)) + for selStr := range config.Committee.QuorumConfigs { + if sel, err := strconv.ParseUint(selStr, 10, 64); err == nil { + sourceSels = append(sourceSels, sel) + } + } + destSels := make([]uint64, 0, len(config.Committee.DestinationVerifiers)) + for selStr := range config.Committee.DestinationVerifiers { + if sel, err := strconv.ParseUint(selStr, 10, 64); err == nil { + destSels = append(destSels, sel) + } + } + chainStatusOpts = append(chainStatusOpts, chainstatus.WithStatusMetrics( + monitoring.NewChainStatusMetrics(aggMonitoring), + sourceSels, + destSels, + )) + } + chainStatusRegistry := chainstatus.NewRegistry(chainStatusStore, l, chainStatusOpts...) + if err := chainStatusRegistry.Refresh(context.Background()); err != nil { + l.Warnw("Failed initial chain-disable registry refresh", "error", err) + } + + store := storage.WrapWithMetrics(rawStore, aggMonitoring, l) validator := quorum.NewQuorumValidator(config, l) agg := createAggregator(store, store, store, validator, config, l, aggMonitoring) - writeCommitVerifierNodeResultHandler := handlers.NewWriteCommitCCVNodeDataHandler(store, agg, aggMonitoring, l, validator, config.Aggregation.CheckAggregationTimeout) + writeCommitVerifierNodeResultHandler := handlers.NewWriteCommitCCVNodeDataHandler(store, agg, aggMonitoring, l, validator, config.Aggregation.CheckAggregationTimeout, chainStatusRegistry) readCommitVerifierNodeResultHandler := handlers.NewReadCommitVerifierNodeResultHandler(store, l) getMessagesSinceHandler := handlers.NewGetMessagesSinceHandler(store, config.Committee, l, aggMonitoring) getVerifierResultsForMessageHandler := handlers.NewGetVerifierResultsForMessageHandler(store, config.Committee, config.MaxMessageIDsPerBatch, l) @@ -410,6 +456,7 @@ func NewServer(l logger.SugaredLogger, config *model.AggregatorConfig, aggMonito l: l, config: config, store: store, + chainStatusRegistry: chainStatusRegistry, aggregator: agg, readCommitVerifierNodeResultHandler: readCommitVerifierNodeResultHandler, writeCommitVerifierNodeResultHandler: writeCommitVerifierNodeResultHandler, diff --git a/aggregator/pkg/storage/postgres/database_storage_chain_statuses.go b/aggregator/pkg/storage/postgres/database_storage_chain_statuses.go new file mode 100644 index 000000000..2a9dc0f22 --- /dev/null +++ b/aggregator/pkg/storage/postgres/database_storage_chain_statuses.go @@ -0,0 +1,109 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" +) + +var _ chainstatus.Store = (*DatabaseStorage)(nil) + +type chainStatusRow struct { + ChainSelector uint64 `db:"chain_selector"` + LaneSide string `db:"lane_side"` + Disabled bool `db:"disabled"` + UpdatedAt time.Time `db:"updated_at"` +} + +func rowToChainStatus(r chainStatusRow) chainstatus.ChainStatus { + return chainstatus.ChainStatus{ + ChainSelector: r.ChainSelector, + Side: chainstatus.LaneSide(r.LaneSide), + Disabled: r.Disabled, + UpdatedAt: r.UpdatedAt, + } +} + +// BatchSetStatus upserts the disabled flag for the given lane side and selectors. +func (d *DatabaseStorage) BatchSetStatus(ctx context.Context, side chainstatus.LaneSide, selectors []uint64, disabled bool) error { + ctx, cancel := d.withTimeout(ctx) + defer cancel() + + stmt := `INSERT INTO aggregator_chain_statuses (chain_selector, lane_side, disabled) + VALUES ($1, $2, $3) + ON CONFLICT (chain_selector, lane_side) DO UPDATE SET disabled = $3, updated_at = NOW()` + + for _, sel := range selectors { + if _, err := d.ds.ExecContext(ctx, stmt, sel, string(side), disabled); err != nil { + return fmt.Errorf("failed to set status disabled=%v for %s selector %d: %w", disabled, side, sel, err) + } + } + return nil +} + +// List returns all chain status rows (including re-enabled ones for audit trail). +func (d *DatabaseStorage) List(ctx context.Context) ([]chainstatus.ChainStatus, error) { + ctx, cancel := d.withTimeout(ctx) + defer cancel() + + stmt := `SELECT chain_selector, lane_side, disabled, updated_at + FROM aggregator_chain_statuses + ORDER BY lane_side, chain_selector` + + var rows []chainStatusRow + if err := d.ds.SelectContext(ctx, &rows, stmt); err != nil { + return nil, fmt.Errorf("failed to list chain statuses: %w", err) + } + + statuses := make([]chainstatus.ChainStatus, len(rows)) + for i, r := range rows { + statuses[i] = rowToChainStatus(r) + } + return statuses, nil +} + +// ListDisabled returns only rows where disabled = true. +func (d *DatabaseStorage) ListDisabled(ctx context.Context) ([]chainstatus.ChainStatus, error) { + ctx, cancel := d.withTimeout(ctx) + defer cancel() + + stmt := `SELECT chain_selector, lane_side, disabled, updated_at + FROM aggregator_chain_statuses + WHERE disabled = true + ORDER BY lane_side, chain_selector` + + var rows []chainStatusRow + if err := d.ds.SelectContext(ctx, &rows, stmt); err != nil { + return nil, fmt.Errorf("failed to list disabled chain statuses: %w", err) + } + + statuses := make([]chainstatus.ChainStatus, len(rows)) + for i, r := range rows { + statuses[i] = rowToChainStatus(r) + } + return statuses, nil +} + +// Get returns the status for a specific selector + lane side. Returns nil if no row exists (= enabled). +func (d *DatabaseStorage) Get(ctx context.Context, side chainstatus.LaneSide, selector uint64) (*chainstatus.ChainStatus, error) { + ctx, cancel := d.withTimeout(ctx) + defer cancel() + + stmt := `SELECT chain_selector, lane_side, disabled, updated_at + FROM aggregator_chain_statuses + WHERE chain_selector = $1 AND lane_side = $2` + + var row chainStatusRow + if err := d.ds.GetContext(ctx, &row, stmt, selector, string(side)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to get chain status for %s selector %d: %w", side, selector, err) + } + status := rowToChainStatus(row) + return &status, nil +} diff --git a/aggregator/pkg/storage/postgres/database_storage_chain_statuses_test.go b/aggregator/pkg/storage/postgres/database_storage_chain_statuses_test.go new file mode 100644 index 000000000..b3f60d0d0 --- /dev/null +++ b/aggregator/pkg/storage/postgres/database_storage_chain_statuses_test.go @@ -0,0 +1,242 @@ +package postgres + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/chainstatus" + "github.com/smartcontractkit/chainlink-ccv/aggregator/testutil" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +func setupChainStatusTestDB(t *testing.T) (*DatabaseStorage, func()) { + t.Helper() + ds, cleanup := testutil.SetupTestPostgresDB(t) + if err := RunMigrations(ds, "postgres"); err != nil { + cleanup() + t.Fatalf("run migrations: %v", err) + } + return NewDatabaseStorage(ds, 10, 10*time.Second, logger.Sugared(logger.Test(t))), cleanup +} + +func TestDatabaseStorage_BatchSetStatus_Disable(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + err := storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{1001, 1002}, true) + require.NoError(t, err) + + s1, err := storage.Get(ctx, chainstatus.LaneSideSource, 1001) + require.NoError(t, err) + require.NotNil(t, s1) + assert.True(t, s1.Disabled) + assert.Equal(t, chainstatus.LaneSideSource, s1.Side) + assert.Equal(t, uint64(1001), s1.ChainSelector) + assert.False(t, s1.UpdatedAt.IsZero()) + + s2, err := storage.Get(ctx, chainstatus.LaneSideSource, 1002) + require.NoError(t, err) + require.NotNil(t, s2) + assert.True(t, s2.Disabled) +} + +func TestDatabaseStorage_BatchSetStatus_Enable(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{2001}, true)) + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{2001}, false)) + + s, err := storage.Get(ctx, chainstatus.LaneSideSource, 2001) + require.NoError(t, err) + require.NotNil(t, s, "row should exist for audit trail after re-enable") + assert.False(t, s.Disabled) +} + +func TestDatabaseStorage_BatchSetStatus_Idempotent(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{3001}, true)) + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{3001}, true)) + + s, err := storage.Get(ctx, chainstatus.LaneSideSource, 3001) + require.NoError(t, err) + require.NotNil(t, s) + assert.True(t, s.Disabled) +} + +func TestDatabaseStorage_BatchSetStatus_EmptySelectors(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + err := storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{}, true) + require.NoError(t, err) +} + +func TestDatabaseStorage_BatchSetStatus_SourceAndDestinationAreIndependent(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{4001}, true)) + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideDestination, []uint64{4001}, true)) + + src, err := storage.Get(ctx, chainstatus.LaneSideSource, 4001) + require.NoError(t, err) + require.NotNil(t, src) + assert.True(t, src.Disabled) + assert.Equal(t, chainstatus.LaneSideSource, src.Side) + + dst, err := storage.Get(ctx, chainstatus.LaneSideDestination, 4001) + require.NoError(t, err) + require.NotNil(t, dst) + assert.True(t, dst.Disabled) + assert.Equal(t, chainstatus.LaneSideDestination, dst.Side) +} + +func TestDatabaseStorage_Get_NoRow_ReturnsNil(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + s, err := storage.Get(ctx, chainstatus.LaneSideSource, 9999) + require.NoError(t, err) + assert.Nil(t, s, "no row should return nil (= enabled by default)") +} + +func TestDatabaseStorage_Get_WrongSide_ReturnsNil(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{5001}, true)) + + s, err := storage.Get(ctx, chainstatus.LaneSideDestination, 5001) + require.NoError(t, err) + assert.Nil(t, s) +} + +func TestDatabaseStorage_List_Empty(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + statuses, err := storage.List(ctx) + require.NoError(t, err) + assert.Empty(t, statuses) +} + +func TestDatabaseStorage_List_ReturnsAllRows(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{6001}, true)) + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideDestination, []uint64{6002}, true)) + // Re-enable one — should still appear in List + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{6001}, false)) + + statuses, err := storage.List(ctx) + require.NoError(t, err) + require.Len(t, statuses, 2, "List should return all rows including re-enabled ones") + + byKey := make(map[string]chainstatus.ChainStatus) + for _, s := range statuses { + byKey[string(s.Side)+":"+string(rune(s.ChainSelector))] = s + } + + // Verify re-enabled chain appears with Disabled=false + found6001 := false + found6002 := false + for _, s := range statuses { + if s.ChainSelector == 6001 && s.Side == chainstatus.LaneSideSource { + assert.False(t, s.Disabled) + found6001 = true + } + if s.ChainSelector == 6002 && s.Side == chainstatus.LaneSideDestination { + assert.True(t, s.Disabled) + found6002 = true + } + } + assert.True(t, found6001) + assert.True(t, found6002) +} + +func TestDatabaseStorage_ListDisabled_OnlyDisabled(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{7001, 7002}, true)) + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideDestination, []uint64{7003}, true)) + // Re-enable 7002 — should not appear in ListDisabled + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{7002}, false)) + + statuses, err := storage.ListDisabled(ctx) + require.NoError(t, err) + + for _, s := range statuses { + assert.True(t, s.Disabled, "ListDisabled must only return disabled rows") + assert.NotEqual(t, uint64(7002), s.ChainSelector, "re-enabled chain should not appear") + } + + selectors := make(map[uint64]bool) + for _, s := range statuses { + selectors[s.ChainSelector] = true + } + assert.True(t, selectors[7001]) + assert.True(t, selectors[7003]) + assert.Len(t, statuses, 2) +} + +func TestDatabaseStorage_ListDisabled_Empty(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + statuses, err := storage.ListDisabled(ctx) + require.NoError(t, err) + assert.Empty(t, statuses) +} + +func TestDatabaseStorage_ChainStatus_UpdatedAt_ChangesOnUpdate(t *testing.T) { + t.Parallel() + storage, cleanup := setupChainStatusTestDB(t) + defer cleanup() + ctx := context.Background() + + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{8001}, true)) + s1, err := storage.Get(ctx, chainstatus.LaneSideSource, 8001) + require.NoError(t, err) + require.NotNil(t, s1) + + // Small sleep to ensure updated_at changes + time.Sleep(10 * time.Millisecond) + require.NoError(t, storage.BatchSetStatus(ctx, chainstatus.LaneSideSource, []uint64{8001}, false)) + s2, err := storage.Get(ctx, chainstatus.LaneSideSource, 8001) + require.NoError(t, err) + require.NotNil(t, s2) + + assert.True(t, s2.UpdatedAt.After(s1.UpdatedAt) || s2.UpdatedAt.Equal(s1.UpdatedAt), + "updated_at should not go backwards") +} diff --git a/build/devenv/services/aggregator.go b/build/devenv/services/aggregator.go index a14525e5a..7393ad0eb 100644 --- a/build/devenv/services/aggregator.go +++ b/build/devenv/services/aggregator.go @@ -126,8 +126,13 @@ type AggregatorInput struct { } type AggregatorOutput struct { - UseCache bool `toml:"use_cache"` - ContainerName string `toml:"container_name"` + UseCache bool `toml:"use_cache"` + // AggregatorContainerName is the container running the aggregator binary. + // Use this for docker exec CLI invocations (e.g. `aggregator chains disable`). + AggregatorContainerName string `toml:"aggregator_container_name"` + // NginxContainerName is the nginx TLS proxy container fronting the aggregator. + // Use this for chaos/HA tests that kill the proxy or for connectivity checks. + NginxContainerName string `toml:"nginx_container_name"` Address string `toml:"address"` ExternalHTTPUrl string `toml:"external_http_url"` ExternalHTTPSUrl string `toml:"external_https_url"` @@ -581,13 +586,14 @@ func NewAggregator(in *AggregatorInput) (*AggregatorOutput, error) { } in.Out = &AggregatorOutput{ - ContainerName: nginxContainerName, - Address: fmt.Sprintf("%s:443", nginxContainerName), - ExternalHTTPUrl: fmt.Sprintf("%s:%d", aggregatorContainerName, DefaultAggregatorGRPCPort), - ExternalHTTPSUrl: fmt.Sprintf("%s:%d", host, in.HostPort), - TLSCACertFile: tlsCerts.CACertFile, - ClientCredentials: clientCredentials, - GeneratedCommittee: in.GeneratedCommittee, + AggregatorContainerName: aggregatorContainerName, + NginxContainerName: nginxContainerName, + Address: fmt.Sprintf("%s:443", nginxContainerName), + ExternalHTTPUrl: fmt.Sprintf("%s:%d", aggregatorContainerName, DefaultAggregatorGRPCPort), + ExternalHTTPSUrl: fmt.Sprintf("%s:%d", host, in.HostPort), + TLSCACertFile: tlsCerts.CACertFile, + ClientCredentials: clientCredentials, + GeneratedCommittee: in.GeneratedCommittee, } return in.Out, nil } diff --git a/build/devenv/services/aggregator.template.toml b/build/devenv/services/aggregator.template.toml index 112352dfe..aa3a8cc26 100644 --- a/build/devenv/services/aggregator.template.toml +++ b/build/devenv/services/aggregator.template.toml @@ -22,6 +22,9 @@ trustedProxies = [ "192.168.0.0/16", ] +[chainStatus] +refreshInterval = "2s" + [orphanRecovery] enabled = false interval = "5m" diff --git a/build/devenv/tests/e2e/aggregatorcli/chains.go b/build/devenv/tests/e2e/aggregatorcli/chains.go new file mode 100644 index 000000000..4e4d9082c --- /dev/null +++ b/build/devenv/tests/e2e/aggregatorcli/chains.go @@ -0,0 +1,48 @@ +package aggregatorcli + +import ( + "context" + "strconv" +) + +// ChainsSubcommand is the CLI path used to reach the chains commands: +// `aggregator chains ...`. +var ChainsSubcommand = []string{"chains"} + +// ChainSelector is a decimal-encoded chain selector as expected by the +// CLI's --source / --destination flags. +type ChainSelector string + +// FormatChainSelector renders sel for use with the CLI. +func FormatChainSelector(sel uint64) ChainSelector { + return ChainSelector(strconv.FormatUint(sel, 10)) +} + +// ChainsClient is the thin wrapper around the aggregator chains CLI group. +// Obtain via (*Client).Chains(). +type ChainsClient struct { + client *Client +} + +// Chains returns a sub-client for the chains CLI. The returned value is a +// tiny struct; constructing one is free. +func (c *Client) Chains() ChainsClient { + return ChainsClient{client: c} +} + +// List runs `chains list` and returns the raw table output. +func (cs ChainsClient) List(ctx context.Context) (string, error) { + return cs.client.CLI(ctx, ChainsSubcommand, "list") +} + +// Disable runs `chains disable `. Pass flag pairs such as +// "--source", "12345" or "--all". +func (cs ChainsClient) Disable(ctx context.Context, args ...string) (string, error) { + return cs.client.CLI(ctx, ChainsSubcommand, append([]string{"disable"}, args...)...) +} + +// Enable runs `chains enable `. Pass flag pairs such as +// "--source", "12345" or "--all". +func (cs ChainsClient) Enable(ctx context.Context, args ...string) (string, error) { + return cs.client.CLI(ctx, ChainsSubcommand, append([]string{"enable"}, args...)...) +} diff --git a/build/devenv/tests/e2e/aggregatorcli/client.go b/build/devenv/tests/e2e/aggregatorcli/client.go new file mode 100644 index 000000000..043307b1e --- /dev/null +++ b/build/devenv/tests/e2e/aggregatorcli/client.go @@ -0,0 +1,73 @@ +// Package aggregatorcli is a test-only client for the aggregator CLI exposed +// by the aggregator binary inside a running container. It wraps raw `docker +// exec` invocations and splits the CLI surface into sub-clients (chains) so +// individual tests can ask for just the capability they need. +// +// All methods are synchronous and return the raw stdout + stderr as a single +// string. The aggregator CLI writes directly to the database and the +// in-memory registry refreshes periodically, so Pause/Resume/Restart are not +// needed — just call the CLI and wait for the next refresh. +package aggregatorcli + +import ( + "context" + "fmt" + "os/exec" + "strings" +) + +const ( + // DefaultBinaryPath is the in-container path of the aggregator binary. + // In devenv the binary is built by air into /tmp/aggregator (see aggregator/air.toml). + DefaultBinaryPath = "/tmp/aggregator" +) + +// Client talks to a single aggregator container. It is cheap to construct +// and safe to share across subtests that target the same container. +type Client struct { + containerName string + binaryPath string +} + +// Option configures a Client. +type Option func(*Client) + +// WithBinaryPath overrides the in-container path of the aggregator binary. +func WithBinaryPath(path string) Option { + return func(c *Client) { c.binaryPath = path } +} + +// NewClient returns a Client bound to containerName. Any leading slash is +// stripped so callers can pass the name through unchanged. +func NewClient(containerName string, opts ...Option) *Client { + c := &Client{ + containerName: strings.TrimPrefix(containerName, "/"), + binaryPath: DefaultBinaryPath, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// Container returns the container name this client is bound to. +func (c *Client) Container() string { return c.containerName } + +// Exec runs `docker exec ` and returns combined output. +func (c *Client) Exec(ctx context.Context, args ...string) (string, error) { + full := append([]string{"exec", c.containerName}, args...) + cmd := exec.CommandContext(ctx, "docker", full...) + out, err := cmd.CombinedOutput() + if err != nil { + return string(out), fmt.Errorf("docker exec %s %v: %w (output: %s)", c.containerName, args, err, string(out)) + } + return string(out), nil +} + +// CLI runs the aggregator CLI subcommand tree. Prefer the sub-clients — +// Chains — which compose these for you. +func (c *Client) CLI(ctx context.Context, subcommand []string, args ...string) (string, error) { + full := append([]string{c.binaryPath}, subcommand...) + full = append(full, args...) + return c.Exec(ctx, full...) +} diff --git a/build/devenv/tests/e2e/chaos_test.go b/build/devenv/tests/e2e/chaos_test.go index 43f069819..2b5e43d89 100644 --- a/build/devenv/tests/e2e/chaos_test.go +++ b/build/devenv/tests/e2e/chaos_test.go @@ -34,7 +34,7 @@ func TestChaos_AggregatorOutageRecovery(t *testing.T) { var defaultAggregatorContainerName string for _, agg := range setup.in.Aggregator { if agg.CommitteeName == devenvcommon.DefaultCommitteeVerifierQualifier { - defaultAggregatorContainerName = agg.Out.ContainerName + defaultAggregatorContainerName = agg.Out.NginxContainerName break } } diff --git a/build/devenv/tests/e2e/ha_test.go b/build/devenv/tests/e2e/ha_test.go index c9c6e3101..ebac12feb 100644 --- a/build/devenv/tests/e2e/ha_test.go +++ b/build/devenv/tests/e2e/ha_test.go @@ -78,10 +78,10 @@ func (s *haTestSetup) survivingAggClient(committee string, killedContainers ...s if agg.CommitteeName != committee || agg.Out == nil { continue } - if killed[agg.Out.ContainerName] { + if killed[agg.Out.NginxContainerName] { continue } - if client, ok := s.aggClients[agg.Out.ContainerName]; ok { + if client, ok := s.aggClients[agg.Out.NginxContainerName]; ok { return client } } @@ -249,7 +249,7 @@ func TestHA_SingleAggregatorDown(t *testing.T) { committeeAggs := setup.aggsByCommittee(haCommittee) require.Len(t, committeeAggs, 2, "need 2 aggregators in %q for this test", haCommittee) - killedAgg := committeeAggs[0].Out.ContainerName + killedAgg := committeeAggs[0].Out.NginxContainerName require.NotEmpty(t, killedAgg) // Phase 1: Kill one aggregator, send a message, assert it flows via the survivor. @@ -298,7 +298,7 @@ func TestHA_CrossComponentDown(t *testing.T) { committeeAggs := setup.aggsByCommittee(haCommittee) require.Len(t, committeeAggs, 2, "need 2 aggregators in %q for this test", haCommittee) - killedAgg := committeeAggs[0].Out.ContainerName + killedAgg := committeeAggs[0].Out.NginxContainerName killedIdx := setup.in.Indexer[0].Out.ContainerName require.NotEmpty(t, killedAgg) require.NotEmpty(t, killedIdx) @@ -349,12 +349,12 @@ func setupHATest(t *testing.T) *haTestSetup { "aggregator output is nil — was the environment started?") client, err := ccv.NewAggregatorClient( l.With().Str("component", - fmt.Sprintf("agg-client-%s", agg.Out.ContainerName)).Logger(), + fmt.Sprintf("agg-client-%s", agg.Out.NginxContainerName)).Logger(), agg.Out.ExternalHTTPSUrl, agg.Out.TLSCACertFile, ) require.NoError(t, err) - aggClients[agg.Out.ContainerName] = client + aggClients[agg.Out.NginxContainerName] = client t.Cleanup(func() { client.Close() }) } diff --git a/build/devenv/tests/e2e/smoke_aggregator_chain_disable_test.go b/build/devenv/tests/e2e/smoke_aggregator_chain_disable_test.go new file mode 100644 index 000000000..afca33562 --- /dev/null +++ b/build/devenv/tests/e2e/smoke_aggregator_chain_disable_test.go @@ -0,0 +1,178 @@ +package e2e + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + chain_selectors "github.com/smartcontractkit/chain-selectors" + ccv "github.com/smartcontractkit/chainlink-ccv/build/devenv" + "github.com/smartcontractkit/chainlink-ccv/build/devenv/cciptestinterfaces" + devenvcommon "github.com/smartcontractkit/chainlink-ccv/build/devenv/common" + "github.com/smartcontractkit/chainlink-ccv/build/devenv/tests/e2e/aggregatorcli" + "github.com/smartcontractkit/chainlink-testing-framework/framework" +) + +// aggregatorRefreshBuffer is the time to wait after a CLI disable/enable for +// the aggregator registry to pick up the DB change. The devenv template sets +// chainDisable.refreshInterval = "2s", so 5s gives a comfortable margin. +const aggregatorRefreshBuffer = 5 * time.Second + +// TestE2ESmoke_AggregatorChainsCLI exercises the aggregator chains CLI surface +// (list, disable, enable) without going through the full message flow. +func TestE2ESmoke_AggregatorChainsCLI(t *testing.T) { + smokeTestConfig := GetSmokeTestConfig() + in, err := ccv.LoadOutput[ccv.Cfg](smokeTestConfig) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(in.Aggregator), 1, "expected at least one aggregator in the environment") + require.NotNil(t, in.Aggregator[0].Out, "first aggregator must have output") + require.NotEmpty(t, in.Aggregator[0].Out.AggregatorContainerName, "aggregator container name must be set") + + ac := aggregatorcli.NewClient(in.Aggregator[0].Out.AggregatorContainerName) + ctx := ccv.Plog.WithContext(t.Context()) + + t.Cleanup(func() { + _, _ = framework.SaveContainerLogs(fmt.Sprintf("%s-%s", framework.DefaultCTFLogsDir, t.Name())) + }) + + listOutput, err := ac.Chains().List(ctx) + require.NoError(t, err, "list should succeed: %s", listOutput) + require.Contains(t, listOutput, "Chain", "output must contain Chain header; got: %s", listOutput) + + // Pick an arbitrary selector to exercise disable/enable. + lib, err := ccv.NewLib(zerolog.Ctx(ctx), smokeTestConfig, chain_selectors.FamilyEVM) + require.NoError(t, err) + chains, err := lib.Chains(ctx) + require.NoError(t, err) + require.GreaterOrEqual(t, len(chains), 1) + srcSelector := strconv.FormatUint(chains[0].Details.ChainSelector, 10) + + _, err = ac.Chains().Disable(ctx, "--source", srcSelector) + require.NoError(t, err, "disable should succeed") + + _, err = ac.Chains().Enable(ctx, "--source", srcSelector) + require.NoError(t, err, "enable should succeed") +} + +// TestE2ESmoke_AggregatorChainDisableEnable validates the full user-visible +// behavior of the aggregator chain kill switch across three phases: +// +// 1. Non-disabled lane — while chains[0] is disabled as a SOURCE, messages on +// the reverse lane (chains[1] → chains[0]) are unaffected because chains[0] +// is only blocked as a source, not as a destination. +// 2. Disabled — messages from chains[0] (the disabled source) are rejected by +// the aggregator with FailedPrecondition and never reach the result store. +// 3. Recovery — re-enabling chains[0] restores normal processing for that lane. +func TestE2ESmoke_AggregatorChainDisableEnable(t *testing.T) { + smokeTestConfig := GetSmokeTestConfig() + in, err := ccv.LoadOutput[ccv.Cfg](smokeTestConfig) + require.NoError(t, err) + + ctx := ccv.Plog.WithContext(t.Context()) + lib, err := ccv.NewLib(zerolog.Ctx(ctx), smokeTestConfig, chain_selectors.FamilyEVM) + require.NoError(t, err) + chains, err := lib.Chains(ctx) + require.NoError(t, err) + require.GreaterOrEqual(t, len(chains), 2, "expected at least 2 chains") + + require.GreaterOrEqual(t, len(in.Aggregator), 1) + require.NotNil(t, in.Aggregator[0].Out) + require.NotEmpty(t, in.Aggregator[0].Out.AggregatorContainerName, "aggregator container name must be set") + + aggregatorClient, err := in.NewAggregatorClientForCommittee( + zerolog.Ctx(ctx).With().Str("component", "aggregator-client").Logger(), + devenvcommon.DefaultCommitteeVerifierQualifier) + require.NoError(t, err) + t.Cleanup(func() { _ = aggregatorClient.Close() }) + + // chains[0] will be disabled as a source; chains[1] will remain fully enabled. + disabledSrc := chains[0] + otherSrc := chains[1] + disabledSrcSelector := disabledSrc.Details.ChainSelector + otherSrcSelector := otherSrc.Details.ChainSelector + + receiverOnOtherSrc := mustGetEOAReceiverAddress(t, otherSrc) + receiverOnDisabledSrc := mustGetEOAReceiverAddress(t, disabledSrc) + + ac := aggregatorcli.NewClient(in.Aggregator[0].Out.AggregatorContainerName) + cliCtx := context.Background() + + t.Cleanup(func() { + // Best-effort re-enable so the environment is clean for subsequent tests. + _, _ = ac.Chains().Enable(cliCtx, "--source", strconv.FormatUint(disabledSrcSelector, 10)) + _, _ = framework.SaveContainerLogs(fmt.Sprintf("%s-%s", framework.DefaultCTFLogsDir, t.Name())) + }) + + _, err = ac.Chains().Disable(cliCtx, "--source", strconv.FormatUint(disabledSrcSelector, 10)) + require.NoError(t, err, "CLI disable should succeed") + + // Wait for the registry to refresh so the gate is active. + time.Sleep(aggregatorRefreshBuffer) + + // ------------------------------------------------------------------------- + // Phase A — Non-disabled lane: chains[1] → chains[0]. + // chains[0] is only disabled as a SOURCE; it is still a valid DESTINATION, + // so this lane must continue to be processed normally. + // ------------------------------------------------------------------------- + seqNoAlt, err := otherSrc.GetExpectedNextSequenceNumber(ctx, disabledSrcSelector) + require.NoError(t, err) + _, err = otherSrc.SendMessage(ctx, disabledSrcSelector, + cciptestinterfaces.MessageFields{Receiver: receiverOnDisabledSrc}, + cciptestinterfaces.MessageOptions{Version: 3}) + require.NoError(t, err) + sentEvtAlt, err := otherSrc.ConfirmSendOnSource(ctx, disabledSrcSelector, cciptestinterfaces.MessageEventKey{SeqNum: seqNoAlt}, defaultSentTimeout) + require.NoError(t, err) + + nonDisabledCtx, cancelNonDisabled := context.WithTimeout(ctx, 45*time.Second) + defer cancelNonDisabled() + _, err = aggregatorClient.WaitForVerifierResultForMessage(nonDisabledCtx, sentEvtAlt.MessageID, 500*time.Millisecond) + require.NoError(t, err, "message on non-disabled lane should still reach the aggregator") + + // ------------------------------------------------------------------------- + // Phase B — Disabled: chains[0] → chains[1] is rejected. + // ------------------------------------------------------------------------- + seqNo, err := disabledSrc.GetExpectedNextSequenceNumber(ctx, otherSrcSelector) + require.NoError(t, err) + _, err = disabledSrc.SendMessage(ctx, otherSrcSelector, + cciptestinterfaces.MessageFields{Receiver: receiverOnOtherSrc}, + cciptestinterfaces.MessageOptions{Version: 3}) + require.NoError(t, err) + sentEvt, err := disabledSrc.ConfirmSendOnSource(ctx, otherSrcSelector, cciptestinterfaces.MessageEventKey{SeqNum: seqNo}, defaultSentTimeout) + require.NoError(t, err) + + // Give verifiers enough time to attempt — and fail — writing their results. + time.Sleep(20 * time.Second) + notProcessedCtx, cancelNotProcessed := context.WithTimeout(ctx, 5*time.Second) + defer cancelNotProcessed() + _, err = aggregatorClient.GetVerifierResultForMessage(notProcessedCtx, sentEvt.MessageID) + require.Error(t, err, "message should not be in aggregator while source chain is disabled") + + // ------------------------------------------------------------------------- + // Phase C — Recovery: re-enabling chains[0] restores the lane. + // ------------------------------------------------------------------------- + _, err = ac.Chains().Enable(cliCtx, "--source", strconv.FormatUint(disabledSrcSelector, 10)) + require.NoError(t, err, "CLI enable should succeed") + + // Wait for the registry to refresh so the gate is lifted. + time.Sleep(aggregatorRefreshBuffer) + + seqNoRecovery, err := disabledSrc.GetExpectedNextSequenceNumber(ctx, otherSrcSelector) + require.NoError(t, err) + _, err = disabledSrc.SendMessage(ctx, otherSrcSelector, + cciptestinterfaces.MessageFields{Receiver: receiverOnOtherSrc}, + cciptestinterfaces.MessageOptions{Version: 3}) + require.NoError(t, err) + sentEvtRecovery, err := disabledSrc.ConfirmSendOnSource(ctx, otherSrcSelector, cciptestinterfaces.MessageEventKey{SeqNum: seqNoRecovery}, defaultSentTimeout) + require.NoError(t, err) + + recoveryCtx, cancelRecovery := context.WithTimeout(ctx, 45*time.Second) + defer cancelRecovery() + _, err = aggregatorClient.WaitForVerifierResultForMessage(recoveryCtx, sentEvtRecovery.MessageID, 500*time.Millisecond) + require.NoError(t, err, "message should reach the aggregator after source chain is re-enabled") +} diff --git a/internal/mocks/mock_AggregatorMetricLabeler.go b/internal/mocks/mock_AggregatorMetricLabeler.go index c99357bd0..ed18918ee 100644 --- a/internal/mocks/mock_AggregatorMetricLabeler.go +++ b/internal/mocks/mock_AggregatorMetricLabeler.go @@ -665,6 +665,40 @@ func (_c *MockAggregatorMetricLabeler_RecordTimeToAggregation_Call) RunAndReturn return _c } +// SetChainDisabledStatus provides a mock function with given fields: ctx, disabled +func (_m *MockAggregatorMetricLabeler) SetChainDisabledStatus(ctx context.Context, disabled int64) { + _m.Called(ctx, disabled) +} + +// MockAggregatorMetricLabeler_SetChainDisabledStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetChainDisabledStatus' +type MockAggregatorMetricLabeler_SetChainDisabledStatus_Call struct { + *mock.Call +} + +// SetChainDisabledStatus is a helper method to define mock.On call +// - ctx context.Context +// - disabled int64 +func (_e *MockAggregatorMetricLabeler_Expecter) SetChainDisabledStatus(ctx interface{}, disabled interface{}) *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call { + return &MockAggregatorMetricLabeler_SetChainDisabledStatus_Call{Call: _e.mock.On("SetChainDisabledStatus", ctx, disabled)} +} + +func (_c *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call) Run(run func(ctx context.Context, disabled int64)) *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call) Return() *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call { + _c.Call.Return() + return _c +} + +func (_c *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call) RunAndReturn(run func(context.Context, int64)) *MockAggregatorMetricLabeler_SetChainDisabledStatus_Call { + _c.Run(run) + return _c +} + // SetOrphanBacklog provides a mock function with given fields: ctx, count func (_m *MockAggregatorMetricLabeler) SetOrphanBacklog(ctx context.Context, count int) { _m.Called(ctx, count)