diff --git a/aggregator/pkg/handlers/batch_write_commit_verifier_node_result.go b/aggregator/pkg/handlers/batch_write_commit_verifier_node_result.go index 802396e54..73bb20120 100644 --- a/aggregator/pkg/handlers/batch_write_commit_verifier_node_result.go +++ b/aggregator/pkg/handlers/batch_write_commit_verifier_node_result.go @@ -19,6 +19,11 @@ import ( type BatchWriteCommitVerifierNodeResultHandler struct { handler *WriteCommitVerifierNodeResultHandler maxCommitVerifierNodeResultRequestsPerBatch int + wg sync.WaitGroup +} + +func (h *BatchWriteCommitVerifierNodeResultHandler) Wait() { + h.wg.Wait() } func (h *BatchWriteCommitVerifierNodeResultHandler) logger(ctx context.Context) logger.SugaredLogger { @@ -41,12 +46,12 @@ func (h *BatchWriteCommitVerifierNodeResultHandler) Handle(ctx context.Context, responses := make([]*committeepb.WriteCommitteeVerifierNodeResultResponse, len(requests)) errors := NewBatchErrorArray(len(requests)) - wg := sync.WaitGroup{} + h.wg = sync.WaitGroup{} for i, r := range requests { - wg.Add(1) + h.wg.Add(1) go func(i int, r *committeepb.WriteCommitteeVerifierNodeResultRequest) { - defer wg.Done() + defer h.wg.Done() if r == nil { SetBatchError(errors, i, codes.InvalidArgument, fmt.Sprintf("nil request at index %d", i)) responses[i] = &committeepb.WriteCommitteeVerifierNodeResultResponse{ @@ -73,7 +78,7 @@ func (h *BatchWriteCommitVerifierNodeResultHandler) Handle(ctx context.Context, done := make(chan struct{}) go func() { - wg.Wait() + h.wg.Wait() close(done) }() diff --git a/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go b/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go index 88d191a44..e311cb590 100644 --- a/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go +++ b/aggregator/pkg/handlers/batch_write_commit_verifier_node_result_test.go @@ -87,6 +87,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_BatchSizeValidation(t *testing.T) { writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, tc.maxBatchSize) + defer batchHandler.Wait() // ensure all goroutines finish before the test exits requests := make([]*committeepb.WriteCommitteeVerifierNodeResultRequest, tc.numRequests) for i := range requests { @@ -143,6 +144,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_MixedSuccessAndInvalidArgument(t *te writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, time.Millisecond) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, 10) + defer batchHandler.Wait() // ensure all goroutines finish before the test exits validReq := makeValidProtoRequest() invalidReq := makeValidProtoRequest() @@ -250,6 +252,7 @@ func TestBatchWriteCommitCCVNodeDataHandler_CancelledContextReturnsImmediately(t writeHandler := NewWriteCommitCCVNodeDataHandler(store, agg, mon, lggr, sig, blockDuration) batchHandler := NewBatchWriteCommitVerifierNodeResultHandler(writeHandler, 10) + defer batchHandler.Wait() // ensure all goroutines finish before the test exits ctx, cancel := context.WithCancel(auth.ToContext(context.Background(), auth.CreateCallerIdentity(testCallerID, false))) defer cancel()