diff --git a/src/_internal/types.ts b/src/_internal/types.ts index 77039c723..5a644ccf6 100644 --- a/src/_internal/types.ts +++ b/src/_internal/types.ts @@ -4,12 +4,17 @@ import type * as revalidateEvents from './events' export type GlobalState = [ Record, // EVENT_REVALIDATORS Record, // MUTATION: [ts, end_ts] - Record, // FETCH: [data, ts] + Record, // FETCH: [data, ts, abort_controller] Record>, // PRELOAD ScopedMutator, // Mutator (key: string, value: any, prev: any) => void, // Setter (key: string, callback: (current: any, prev: any) => void) => () => void // Subscriber ] +// Extra options passed to the fetcher +export type FetcherOptions = { + /** An AbortSignal to support request cancellation */ + signal: AbortSignal +} export type FetcherResponse = Data | Promise export type BareFetcher = ( ...args: any[] @@ -18,11 +23,11 @@ export type Fetcher< Data = unknown, SWRKey extends Key = Key > = SWRKey extends () => infer Arg | null | undefined | false - ? (arg: Arg) => FetcherResponse + ? (arg: Arg, options: FetcherOptions) => FetcherResponse : SWRKey extends null | undefined | false ? never : SWRKey extends infer Arg - ? (arg: Arg) => FetcherResponse + ? (arg: Arg, options: FetcherOptions) => FetcherResponse : never export type BlockingData< diff --git a/src/_internal/utils/mutate.ts b/src/_internal/utils/mutate.ts index 071af7803..aff07963d 100644 --- a/src/_internal/utils/mutate.ts +++ b/src/_internal/utils/mutate.ts @@ -102,6 +102,12 @@ export async function internalMutate( ? options.revalidate(get().data, _k) : options.revalidate !== false if (revalidate) { + // Cancel ongoing fetches + const maybeCurrentFetchController = FETCH[key]?.[2] + if (maybeCurrentFetchController) { + maybeCurrentFetchController.abort() + } + // Invalidate the key by deleting the concurrent request markers so new // requests will not be deduped. delete FETCH[key] diff --git a/src/index/use-swr.ts b/src/index/use-swr.ts index 7edcdcf4b..f3c16f1ef 100644 --- a/src/index/use-swr.ts +++ b/src/index/use-swr.ts @@ -440,9 +440,21 @@ export const useSWRHandler = ( // Start the request and save the timestamp. // Key must be truthy if entering here. + + // We also need to abort the current fetch as its result must be + // discarded anyway. + const maybeCurrentFetchController = FETCH[key]?.[2] + if (maybeCurrentFetchController) { + // Abort the ongoing fetch request if any. + maybeCurrentFetchController.abort() + } + + const abortController = new AbortController() + const signal = abortController.signal FETCH[key] = [ - currentFetcher(fnArg as DefinitelyTruthy), - getTimestamp() + currentFetcher(fnArg as DefinitelyTruthy, { signal }), + getTimestamp(), + abortController ] } @@ -452,17 +464,18 @@ export const useSWRHandler = ( newData = await newData if (shouldStartNewRequest) { - // If the request isn't interrupted, clean it up after the + // If the request wasn't interrupted, clean it up after the // deduplication interval. setTimeout(cleanupState, config.dedupingInterval) } - // If there're other ongoing request(s), started after the current one, + // If there're other new request(s) started after the current one, // we need to ignore the current one to avoid possible race conditions: // req1------------------>res1 (current one) // req2---------------->res2 - // the request that fired later will always be kept. - // The timestamp maybe be `undefined` or a number + // Requests fired later will always be kept. + + // The timestamp maybe be `undefined` or a number: if (!FETCH[key] || FETCH[key][1] !== startAt) { if (shouldStartNewRequest) { if (callbackSafeguard()) { @@ -505,6 +518,7 @@ export const useSWRHandler = ( } return false } + // Deep compare with the latest state to avoid extra re-renders. // For local state, compare and assign. const cacheData = getCache().data diff --git a/test/use-swr-auto-abort.test.tsx b/test/use-swr-auto-abort.test.tsx new file mode 100644 index 000000000..7538ec67f --- /dev/null +++ b/test/use-swr-auto-abort.test.tsx @@ -0,0 +1,218 @@ +import { act, screen } from '@testing-library/react' +import useSWR from 'swr' +import { createKey, renderWithConfig, sleep } from './utils' + +describe('useSWR - auto abort', () => { + it('should abort previous request when a new request starts', async () => { + const key = createKey() + let abortedCount = 0 + let fetchCount = 0 + + const fetcher = async ( + _key: string, + { signal }: { signal: AbortSignal } + ) => { + fetchCount++ + const currentFetch = fetchCount + + signal.addEventListener('abort', () => { + abortedCount++ + }) + + await sleep(100) + + if (signal.aborted) { + throw new Error(`aborted-${currentFetch}`) + } + + return `response-${currentFetch}` + } + + let mutate: any + + function Page() { + const { data, mutate: boundMutate, error } = useSWR(key, fetcher) + mutate = boundMutate + + return ( +
+ data:{data},error:{error?.message} +
+ ) + } + + renderWithConfig() + + // Make sure the first request is ongoing + await sleep(20) + + // Immediately trigger a mutation + await act(() => mutate()) + + // Final state should be from the mutation + await screen.findByText(/data:response-2/) + + // The subsequent requests should have aborted previous ones + expect(fetchCount).toBe(2) + expect(abortedCount).toBe(1) + }) + + it('should pass AbortSignal to fetcher', async () => { + const key = createKey() + let receivedSignal: AbortSignal | undefined + + const fetcher = async ( + _key: string, + { signal }: { signal: AbortSignal } + ) => { + receivedSignal = signal + await sleep(10) + return 'data' + } + + function Page() { + const { data } = useSWR(key, fetcher) + return
data:{data}
+ } + + renderWithConfig() + + await screen.findByText('data:data') + + // Verify that an AbortSignal was passed to the fetcher + expect(receivedSignal).toBeInstanceOf(AbortSignal) + expect(receivedSignal.aborted).toBe(false) + }) + + it('should not abort request during deduplication', async () => { + const key = createKey() + let abortedCount = 0 + let fetchCount = 0 + + const fetcher = async ( + _key: string, + { signal }: { signal?: AbortSignal } + ) => { + fetchCount++ + + signal?.addEventListener('abort', () => { + abortedCount++ + }) + + await sleep(100) + return 'data' + } + + function Page1() { + const { data } = useSWR(key, fetcher) + return
page1:{data}
+ } + + function Page2() { + const { data } = useSWR(key, fetcher) + return
page2:{data}
+ } + + renderWithConfig( + <> + + + + ) + + await screen.findByText('page1:data') + await screen.findByText('page2:data') + + // Should only fetch once due to deduplication + expect(fetchCount).toBe(1) + // No aborts should have occurred + expect(abortedCount).toBe(0) + }) + + it('should handle fetch errors gracefully when aborted', async () => { + const key = createKey() + let fetchCount = 0 + + const fetcher = async ( + _key: string, + { signal }: { signal?: AbortSignal } + ) => { + fetchCount++ + + await sleep(50) + + if (signal?.aborted) { + const error = new Error('Aborted') + error.name = 'AbortError' + throw error + } + + return 'data' + } + + let mutate: any + + function Page() { + const { + data, + error, + mutate: boundMutate + } = useSWR(key, fetcher, { + // Disable error retry to make the test faster + shouldRetryOnError: false + }) + mutate = boundMutate + + return ( +
+
data:{data}
+ {error &&
error:{error.message}
} +
+ ) + } + + renderWithConfig() + + // Wait for initial fetch + await screen.findByText('data:data') + + // Trigger rapid revalidations + await act(() => mutate()) + + await sleep(10) + + await act(() => mutate()) + + // Wait a bit for things to settle + await sleep(200) + + // Should eventually show data (not error) + expect(screen.getByText(/data:data/)).toBeInTheDocument() + expect(fetchCount).toBeGreaterThan(1) + }) + + it('should cleanup abort controller after request completes', async () => { + const key = createKey() + let signal: AbortSignal | undefined + + const fetcher = async (_key: string, opts: { signal: AbortSignal }) => { + signal = opts.signal + await sleep(50) + return 'data' + } + + function Page() { + const { data } = useSWR(key, fetcher) + return
data:{data}
+ } + + renderWithConfig() + + await screen.findByText('data:data') + + // After the request completes, the signal should not be aborted + // (it's cleaned up properly) + expect(signal).toBeDefined() + expect(signal?.aborted).toBe(false) + }) +})