diff --git a/client.go b/client.go index 36008218..c62cfb48 100644 --- a/client.go +++ b/client.go @@ -17,10 +17,12 @@ import ( "slices" "strconv" "strings" + "sync" "sync/atomic" "time" "github.com/rs/zerolog" + "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" "go.mau.fi/util/random" "go.mau.fi/util/retryafter" @@ -96,7 +98,7 @@ type Client struct { RequestHook func(req *http.Request) ResponseHook func(req *http.Request, resp *http.Response, err error, duration time.Duration) - UpdateRequestOnRetry func(req *http.Request, cause error) *http.Request + RequestRetryTrigger *exsync.Event SyncPresence event.Presence SyncTraceLog bool @@ -620,16 +622,16 @@ func (cli *Client) doRetry( Str("url", req.URL.String()). Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") - select { - case <-time.After(backoff): - case <-req.Context().Done(): - if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { + + // if this was due to our RequestRetryTrigger then just retry immediately + // the req.Context() will still be live, otherwise do a normal backoff + if !errors.Is(cause, ErrContextCancelRetry) { + select { + case <-time.After(backoff): + case <-req.Context().Done(): return nil, nil, req.Context().Err() } } - if cli.UpdateRequestOnRetry != nil { - req = cli.UpdateRequestOnRetry(req, cause) - } return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) } @@ -738,6 +740,72 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } +func (cli *Client) prepareRequestAttempt(req *http.Request) (*http.Request, func()) { + // if there's no retry trigger, nothing to do + if cli.RequestRetryTrigger == nil { + return req, nil + } + + attemptCtx, cancel := context.WithCancelCause(req.Context()) + + go func() { + // If we hear of a reset, cancel the request context with a retry message + if cli.RequestRetryTrigger.Wait(attemptCtx) == nil { + cancel(ErrContextCancelRetry) + } + }() + + return req.WithContext(attemptCtx), sync.OnceFunc(func() { + cancel(context.Canceled) + }) +} + +type cleanupReadCloser struct { + io.ReadCloser + cleanup func() +} + +type cleanupReadCloserWriterTo struct { + io.ReadCloser + cleanup func() +} + +func (crc cleanupReadCloser) Close() error { + err := crc.ReadCloser.Close() + if crc.cleanup != nil { + crc.cleanup() + } + return err +} + +func (crc cleanupReadCloserWriterTo) Close() error { + err := crc.ReadCloser.Close() + if crc.cleanup != nil { + crc.cleanup() + } + return err +} + +func (crc cleanupReadCloserWriterTo) WriteTo(w io.Writer) (int64, error) { + return crc.ReadCloser.(io.WriterTo).WriteTo(w) +} + +func maybeWrapRespBody(rc io.ReadCloser, cleanup func()) io.ReadCloser { + if cleanup == nil { + return rc + } + if _, ok := rc.(io.WriterTo); ok { + return cleanupReadCloserWriterTo{ + ReadCloser: rc, + cleanup: cleanup, + } + } + return cleanupReadCloser{ + ReadCloser: rc, + cleanup: cleanup, + } +} + func (cli *Client) executeCompiledRequest( req *http.Request, retries int, @@ -748,30 +816,45 @@ func (cli *Client) executeCompiledRequest( sizeLimit int64, client *http.Client, ) ([]byte, *http.Response, error) { - cli.RequestStart(req) + attemptReq, cleanup := cli.prepareRequestAttempt(req) + cli.RequestStart(attemptReq) startTime := time.Now() - res, err := client.Do(req) + res, err := client.Do(attemptReq) duration := time.Since(startTime) - if res != nil && !dontReadResponse { - defer res.Body.Close() + if res != nil { + // Cleanup the child attempt context once the body is closed + res.Body = maybeWrapRespBody(res.Body, cleanup) + if !dontReadResponse { + defer res.Body.Close() + } } if err != nil { - // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry + // cleanup child attempt context on error + if cleanup != nil { + cleanup() + } + + // Either error is *not* canceled or the underlying cause of cancellation explicitly asks to retry + attemptCause := context.Cause(attemptReq.Context()) + retryCause := err + if errors.Is(attemptCause, ErrContextCancelRetry) { + retryCause = attemptCause + } canRetry := !errors.Is(err, context.Canceled) || - errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) + errors.Is(attemptCause, ErrContextCancelRetry) if retries > 0 && canRetry { return cli.doRetry( - req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + req, retryCause, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, ) } err = HTTPError{ - Request: req, + Request: attemptReq, Response: res, Message: "request error", WrappedError: err, } - cli.LogRequestDone(req, res, err, nil, 0, duration) + cli.LogRequestDone(attemptReq, res, err, nil, 0, duration) return nil, res, err } @@ -784,11 +867,11 @@ func (cli *Client) executeCompiledRequest( var body []byte if res.StatusCode < 200 || res.StatusCode >= 300 { - body, err = ParseErrorResponse(req, res) - cli.LogRequestDone(req, res, nil, nil, len(body), duration) + body, err = ParseErrorResponse(attemptReq, res) + cli.LogRequestDone(attemptReq, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON, sizeLimit) - cli.LogRequestDone(req, res, nil, err, len(body), duration) + body, err = handler(attemptReq, res, responseJSON, sizeLimit) + cli.LogRequestDone(attemptReq, res, nil, err, len(body), duration) } return body, res, err } diff --git a/client_retry_test.go b/client_retry_test.go new file mode 100644 index 00000000..b753e43b --- /dev/null +++ b/client_retry_test.go @@ -0,0 +1,491 @@ +package mautrix + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exsync" +) + +func newTestClient(t *testing.T, serverURL string) *Client { + t.Helper() + parsedURL, err := url.Parse(serverURL) + require.NoError(t, err) + return &Client{ + HomeserverURL: parsedURL, + Client: http.DefaultClient, + Log: zerolog.New(io.Discard), + DefaultHTTPRetries: 1, + DefaultHTTPBackoff: 200 * time.Millisecond, + RequestRetryTrigger: exsync.NewEvent(), + } +} + +func TestRequestRetryTriggerRetriesActiveAttempt(t *testing.T) { + requestStarted := make(chan struct{}) + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch attempts.Add(1) { + case 1: + close(requestStarted) + <-r.Context().Done() + case 2: + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + default: + t.Fatalf("unexpected extra request attempt %d", attempts.Load()) + } + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + var response struct { + OK bool `json:"ok"` + } + errCh := make(chan error, 1) + go func() { + _, err := client.MakeRequest(context.Background(), http.MethodGet, server.URL, nil, &response) + errCh <- err + }() + + select { + case <-requestStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for initial request attempt") + } + + resetAt := time.Now() + client.RequestRetryTrigger.Notify() + + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for retried request to finish") + } + + require.True(t, response.OK) + require.EqualValues(t, 2, attempts.Load()) + require.Less(t, time.Since(resetAt), 150*time.Millisecond) +} + +func TestRequestRetryTriggerUsesNormalRetryBudget(t *testing.T) { + requestStarted := make(chan struct{}) + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch attempts.Add(1) { + case 1: + close(requestStarted) + <-r.Context().Done() + default: + t.Fatalf("unexpected extra request attempt %d", attempts.Load()) + } + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + client.DefaultHTTPRetries = 0 + + errCh := make(chan error, 1) + go func() { + _, err := client.MakeRequest(context.Background(), http.MethodGet, server.URL, nil, nil) + errCh <- err + }() + + select { + case <-requestStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for request start") + } + + client.RequestRetryTrigger.Notify() + + select { + case err := <-errCh: + require.Error(t, err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for canceled request to finish") + } + + require.EqualValues(t, 1, attempts.Load()) +} + +func TestCallerCancellationDoesNotRetry(t *testing.T) { + requestStarted := make(chan struct{}) + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + close(requestStarted) + <-r.Context().Done() + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + _, err := client.MakeRequest(ctx, http.MethodGet, server.URL, nil, nil) + errCh <- err + }() + + select { + case <-requestStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for request start") + } + + cancel() + + select { + case err := <-errCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for canceled request to finish") + } + + require.EqualValues(t, 1, attempts.Load()) +} + +func TestRequestRetryTriggerDoesNotInterruptBackoff(t *testing.T) { + firstAttemptDone := make(chan time.Time, 1) + secondAttemptStarted := make(chan time.Time, 1) + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch attempts.Add(1) { + case 1: + w.WriteHeader(http.StatusBadGateway) + firstAttemptDone <- time.Now() + case 2: + secondAttemptStarted <- time.Now() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{}`)) + default: + t.Fatalf("unexpected extra request attempt %d", attempts.Load()) + } + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + client.DefaultHTTPBackoff = 250 * time.Millisecond + errCh := make(chan error, 1) + go func() { + _, err := client.MakeRequest(context.Background(), http.MethodGet, server.URL, nil, nil) + errCh <- err + }() + + var firstAt time.Time + select { + case firstAt = <-firstAttemptDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for first attempt to fail") + } + + time.Sleep(50 * time.Millisecond) + client.RequestRetryTrigger.Notify() + + var secondAt time.Time + select { + case secondAt = <-secondAttemptStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for retried request") + } + + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for request completion") + } + + require.GreaterOrEqual(t, secondAt.Sub(firstAt), 200*time.Millisecond) + require.EqualValues(t, 2, attempts.Load()) +} + +func TestRequestRetryTriggerCancelsStreamingBody(t *testing.T) { + streamStarted := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello")) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + close(streamStarted) + <-r.Context().Done() + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + _, resp, err := client.MakeFullRequestWithResp(context.Background(), FullRequest{ + Method: http.MethodGet, + URL: server.URL, + DontReadResponse: true, + }) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + select { + case <-streamStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for stream start") + } + + buf := make([]byte, 5) + n, err := io.ReadFull(resp.Body, buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "hello", string(buf)) + + client.RequestRetryTrigger.Notify() + + _, err = resp.Body.Read(make([]byte, 1)) + require.Error(t, err) +} + +func TestDontReadResponseCleanupRunsOnBodyClose(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello")) + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + attemptCtxCh := make(chan context.Context, 1) + client.RequestHook = func(req *http.Request) { + select { + case attemptCtxCh <- req.Context(): + default: + } + } + + _, resp, err := client.MakeFullRequestWithResp(context.Background(), FullRequest{ + Method: http.MethodGet, + URL: server.URL, + DontReadResponse: true, + }) + require.NoError(t, err) + require.NotNil(t, resp) + + var attemptCtx context.Context + select { + case attemptCtx = <-attemptCtxCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for attempt context") + } + + select { + case <-attemptCtx.Done(): + t.Fatal("attempt context canceled before body close") + case <-time.After(100 * time.Millisecond): + } + + require.NoError(t, resp.Body.Close()) + + select { + case <-attemptCtx.Done(): + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for attempt context cleanup after body close") + } + require.ErrorIs(t, context.Cause(attemptCtx), context.Canceled) +} + +func TestRedirectErrorCleansUpAttemptContext(t *testing.T) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/final" { + w.WriteHeader(http.StatusOK) + return + } + http.Redirect(w, r, server.URL+"/final", http.StatusFound) + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + httpClient := server.Client() + httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return errors.New("stop redirect") + } + client.Client = httpClient + + attemptCtxCh := make(chan context.Context, 1) + client.RequestHook = func(req *http.Request) { + select { + case attemptCtxCh <- req.Context(): + default: + } + } + + _, _, err := client.MakeFullRequestWithResp(context.Background(), FullRequest{ + Method: http.MethodGet, + URL: server.URL, + }) + require.Error(t, err) + + var attemptCtx context.Context + select { + case attemptCtx = <-attemptCtxCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for attempt context") + } + + select { + case <-attemptCtx.Done(): + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for attempt context cleanup after redirect error") + } + require.ErrorIs(t, context.Cause(attemptCtx), context.Canceled) +} + +type readSeekCloser struct { + *bytes.Reader +} + +func (r readSeekCloser) Close() error { + return nil +} + +type testRoundTripper func(*http.Request) (*http.Response, error) + +func (trt testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return trt(req) +} + +type writerToReadCloser struct { + *bytes.Reader +} + +func (wrc *writerToReadCloser) Close() error { + return nil +} + +func (wrc *writerToReadCloser) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, wrc.Reader) +} + +func TestRequestRetryTriggerReplaysRequestBody(t *testing.T) { + requestStarted := make(chan struct{}) + bodyBytes := []byte("hello retry body") + var attempts atomic.Int32 + receivedBodies := make(chan []byte, 2) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + receivedBodies <- body + + switch attempts.Add(1) { + case 1: + close(requestStarted) + <-r.Context().Done() + case 2: + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + default: + t.Fatalf("unexpected extra request attempt %d", attempts.Load()) + } + })) + t.Cleanup(server.Close) + + client := newTestClient(t, server.URL) + var response struct { + OK bool `json:"ok"` + } + errCh := make(chan error, 1) + go func() { + _, err := client.MakeFullRequest(context.Background(), FullRequest{ + Method: http.MethodPost, + URL: server.URL, + RequestBody: readSeekCloser{bytes.NewReader(bodyBytes)}, + RequestLength: int64(len(bodyBytes)), + ResponseJSON: &response, + }) + errCh <- err + }() + + select { + case <-requestStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for initial request attempt") + } + + client.RequestRetryTrigger.Notify() + + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for retried request to finish") + } + + require.True(t, response.OK) + require.EqualValues(t, 2, attempts.Load()) + require.Equal(t, bodyBytes, <-receivedBodies) + require.Equal(t, bodyBytes, <-receivedBodies) +} + +func TestDontReadResponseCleanupWrapperPreservesWriterTo(t *testing.T) { + body := &writerToReadCloser{Reader: bytes.NewReader([]byte("hello writer-to"))} + client := newTestClient(t, "https://example.com") + client.Client = &http.Client{ + Transport: testRoundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Body: body, + }, nil + }), + } + + _, resp, err := client.MakeFullRequestWithResp(context.Background(), FullRequest{ + Method: http.MethodGet, + URL: "https://example.com", + DontReadResponse: true, + }) + require.NoError(t, err) + require.NotNil(t, resp) + + writerTo, ok := resp.Body.(io.WriterTo) + require.True(t, ok) + + var copied bytes.Buffer + _, err = writerTo.WriteTo(&copied) + require.NoError(t, err) + require.Equal(t, "hello writer-to", copied.String()) + require.NoError(t, resp.Body.Close()) +} + +func TestDontReadResponseWithoutRetryTriggerDoesNotWrapBody(t *testing.T) { + body := &writerToReadCloser{Reader: bytes.NewReader([]byte("hello raw body"))} + client := newTestClient(t, "https://example.com") + client.RequestRetryTrigger = nil + client.Client = &http.Client{ + Transport: testRoundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Body: body, + }, nil + }), + } + + _, resp, err := client.MakeFullRequestWithResp(context.Background(), FullRequest{ + Method: http.MethodGet, + URL: "https://example.com", + DontReadResponse: true, + }) + require.NoError(t, err) + require.NotNil(t, resp) + require.Same(t, body, resp.Body) + require.NoError(t, resp.Body.Close()) +}