diff --git a/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift b/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift index ce881f74b5..bd2d1d1a3c 100644 --- a/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift +++ b/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift @@ -35,6 +35,10 @@ class MockKeychainRepository: KeychainRepository { var setServerCommunicationConfigCalledConfig: BitwardenSdk.ServerCommunicationConfig? var setServerCommunicationConfigCalledHostname: String? // swiftlint:disable:this identifier_name + // Track which userId was passed to get/set methods for testing + var getAccessTokenUserId: String? + var getRefreshTokenUserId: String? + func deleteAllItems() async throws { deleteAllItemsCalled = true mockStorage.removeAll() @@ -77,7 +81,8 @@ class MockKeychainRepository: KeychainRepository { } func getAccessToken(userId: String) async throws -> String { - try getAccessTokenResult.get() + getAccessTokenUserId = userId + return try getAccessTokenResult.get() } func getAuthenticatorVaultKey(userId: String) async throws -> String { @@ -89,7 +94,8 @@ class MockKeychainRepository: KeychainRepository { } func getRefreshToken(userId: String) async throws -> String { - try getRefreshTokenResult.get() + getRefreshTokenUserId = userId + return try getRefreshTokenResult.get() } func getPendingAdminLoginRequest(userId: String) async throws -> String? { diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift index f76c45baa7..16ad139807 100644 --- a/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift @@ -81,16 +81,22 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { defer { self.refreshTask = nil } do { - let refreshToken = try await tokenService.getRefreshToken() + // Check if this is the best place to apply the changes + let userId = try await tokenService.getActiveAccountId() + + // Use captured userId for all operations + let refreshToken = try await tokenService.getRefreshToken(userId: userId) let response = try await httpService.send( IdentityTokenRefreshRequest(refreshToken: refreshToken), ) let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn)) + // Store tokens using the SAME userId (even if active account changed) try await tokenService.setTokens( accessToken: response.accessToken, refreshToken: response.refreshToken, expirationDate: expirationDate, + userId: userId, ) return response.accessToken diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift index cfbb0a3796..3ad11a2bd2 100644 --- a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift @@ -171,4 +171,31 @@ class AccountTokenProviderTests: BitwardenTestCase { _ = try await subject.refreshToken() } } + + /// `refreshToken()` captures userId once at the start and uses it throughout the refresh operation, + /// preventing race conditions when the active account changes during the HTTP request. + /// This ensures tokens from Account A are stored under Account A's keychain entry, + /// even if the active account switches to Account B during the async HTTP operation. + func test_refreshToken_accountSwitchDuringRequest_storesTokensForOriginalAccount() async throws { + // Setup: Account 1 is active + tokenService.activeAccountId = "1" + tokenService.refreshTokenByUserId["1"] = "REFRESH_1" + + client.result = .httpSuccess(testData: .identityTokenRefresh) + + // Simulate account switch during HTTP request + client.onRequest = { _ in + // Active account switches to Account 2 while HTTP request is in flight + self.tokenService.activeAccountId = "2" + } + + let newAccessToken = try await subject.refreshToken() + + // Verify: Tokens stored under Account 1 (original), NOT Account 2 + XCTAssertEqual(newAccessToken, "ACCESS_TOKEN") + XCTAssertEqual(tokenService.setTokensCalledWithUserId, "1") + XCTAssertEqual(tokenService.getRefreshTokenCalledWithUserId, "1") + XCTAssertEqual(tokenService.accessTokenByUserId["1"], "ACCESS_TOKEN") + XCTAssertNil(tokenService.accessTokenByUserId["2"], "Access token should NOT be stored under Account 2") + } } diff --git a/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift b/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift index f739f53ee0..316bbb7552 100644 --- a/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift +++ b/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift @@ -10,15 +10,35 @@ class MockTokenService: TokenService { var getIsExternalResult: Result = .success(false) var refreshToken: String? = "REFRESH_TOKEN" + // Track which userId was used in explicit userId methods + var getAccessTokenCalledWithUserId: String? + var getRefreshTokenCalledWithUserId: String? + var setTokensCalledWithUserId: String? + var activeAccountId: String = "1" + var accessTokenByUserId: [String: String] = [:] + var refreshTokenByUserId: [String: String] = [:] + func getAccessToken() async throws -> String { guard let accessToken else { throw StateServiceError.noActiveAccount } return accessToken } + func getAccessToken(userId: String) async throws -> String { + getAccessTokenCalledWithUserId = userId + return accessTokenByUserId[userId] ?? accessToken ?? "ACCESS_TOKEN" + } + func getAccessTokenExpirationDate() async throws -> Date? { try accessTokenExpirationDateResult.get() } + func getActiveAccountId() async throws -> String { + if activeAccountId.isEmpty { + throw StateServiceError.noActiveAccount + } + return activeAccountId + } + func getIsExternal() async throws -> Bool { try getIsExternalResult.get() } @@ -28,9 +48,24 @@ class MockTokenService: TokenService { return refreshToken } + func getRefreshToken(userId: String) async throws -> String { + getRefreshTokenCalledWithUserId = userId + return refreshTokenByUserId[userId] ?? refreshToken ?? "REFRESH_TOKEN" + } + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async { self.accessToken = accessToken self.refreshToken = refreshToken self.expirationDate = expirationDate } + + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async { + setTokensCalledWithUserId = userId + accessTokenByUserId[userId] = accessToken + refreshTokenByUserId[userId] = refreshToken + self.expirationDate = expirationDate + // Also update legacy properties for backward compatibility with existing tests + self.accessToken = accessToken + self.refreshToken = refreshToken + } } diff --git a/BitwardenShared/Core/Platform/Services/TokenService.swift b/BitwardenShared/Core/Platform/Services/TokenService.swift index 306e0bdabb..6811d02701 100644 --- a/BitwardenShared/Core/Platform/Services/TokenService.swift +++ b/BitwardenShared/Core/Platform/Services/TokenService.swift @@ -11,12 +11,25 @@ protocol TokenService: AnyObject { /// func getAccessToken() async throws -> String + /// Returns the access token for a specific user. + /// + /// - Parameter userId: The user ID to get the access token for. + /// - Returns: The access token for the specified user. + /// + func getAccessToken(userId: String) async throws -> String + /// Returns the access token's expiration date for the current account. /// /// - Returns: The access token's expiration date for the current account. /// func getAccessTokenExpirationDate() async throws -> Date? + /// Returns the active account's user ID. + /// + /// - Returns: The active account's user ID. + /// + func getActiveAccountId() async throws -> String + /// Returns whether the user is an external user. /// /// - Returns: Whether the user is an external user. @@ -29,6 +42,13 @@ protocol TokenService: AnyObject { /// func getRefreshToken() async throws -> String + /// Returns the refresh token for a specific user. + /// + /// - Parameter userId: The user ID to get the refresh token for. + /// - Returns: The refresh token for the specified user. + /// + func getRefreshToken(userId: String) async throws -> String + /// Sets a new access and refresh token for the current account. /// /// - Parameters: @@ -37,6 +57,16 @@ protocol TokenService: AnyObject { /// - expirationDate: The access token's expiration date. /// func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws + + /// Sets a new access and refresh token for a specific user. + /// + /// - Parameters: + /// - accessToken: The account's updated access token. + /// - refreshToken: The account's updated refresh token. + /// - expirationDate: The access token's expiration date. + /// - userId: The user ID to set the tokens for. + /// + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async throws } // MARK: - DefaultTokenService @@ -102,6 +132,24 @@ actor DefaultTokenService: TokenService { try await keychainRepository.setRefreshToken(refreshToken, userId: userId) await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId) } + + func getAccessToken(userId: String) async throws -> String { + try await keychainRepository.getAccessToken(userId: userId) + } + + func getActiveAccountId() async throws -> String { + try await stateService.getActiveAccountId() + } + + func getRefreshToken(userId: String) async throws -> String { + try await keychainRepository.getRefreshToken(userId: userId) + } + + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async throws { + try await keychainRepository.setAccessToken(accessToken, userId: userId) + try await keychainRepository.setRefreshToken(refreshToken, userId: userId) + await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId) + } } // MARK: ClientManagedTokens (SDK) diff --git a/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift b/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift index 8fcd0034bf..51ae143ec5 100644 --- a/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift +++ b/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift @@ -174,4 +174,58 @@ class TokenServiceTests: BitwardenTestCase { ) XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate) } + + // MARK: Tests - Explicit userId Methods replace the older ones if changes are correct + + /// `getAccessToken(userId:)` returns the access token for the specified user without calling getActiveAccountId. + func test_getAccessToken_withUserId() async throws { + keychainRepository.getAccessTokenResult = .success("USER_2_TOKEN") + + let accessToken = try await subject.getAccessToken(userId: "2") + XCTAssertEqual(accessToken, "USER_2_TOKEN") + XCTAssertEqual(keychainRepository.getAccessTokenUserId, "2") + } + + /// `getActiveAccountId()` returns the active account's user ID. + func test_getActiveAccountId() async throws { + stateService.activeAccount = .fixture() + + let userId = try await subject.getActiveAccountId() + XCTAssertEqual(userId, "1") + } + + /// `getActiveAccountId()` throws an error if there isn't an active account. + func test_getActiveAccountId_noAccount() async { + stateService.activeAccount = nil + + await assertAsyncThrows(error: StateServiceError.noActiveAccount) { + _ = try await subject.getActiveAccountId() + } + } + + /// `getRefreshToken(userId:)` returns the refresh token for the specified user without calling getActiveAccountId. + func test_getRefreshToken_withUserId() async throws { + keychainRepository.getRefreshTokenResult = .success("USER_2_REFRESH") + + let refreshToken = try await subject.getRefreshToken(userId: "2") + XCTAssertEqual(refreshToken, "USER_2_REFRESH") + XCTAssertEqual(keychainRepository.getRefreshTokenUserId, "2") + } + + /// `setTokens(accessToken:refreshToken:expirationDate:userId:)` sets tokens for the specified user + /// without calling getActiveAccountId. + func test_setTokens_withUserId() async throws { + let expirationDate = Date(year: 2025, month: 10, day: 1) + try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒", expirationDate: expirationDate, userId: "2") + + XCTAssertEqual( + keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "2"))], + "🔑", + ) + XCTAssertEqual( + keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "2"))], + "🔒", + ) + XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["2"], expirationDate) + } } diff --git a/TestHelpers/API/MockHTTPClient.swift b/TestHelpers/API/MockHTTPClient.swift index cc69d65d09..aee0dc6978 100644 --- a/TestHelpers/API/MockHTTPClient.swift +++ b/TestHelpers/API/MockHTTPClient.swift @@ -13,6 +13,10 @@ public final class MockHTTPClient: HTTPClient { /// A list of download results that will be returned in order for future requests. public var downloadResults: [Result] = [] + /// A callback that is invoked when a request is received, before returning the result. + /// Useful for simulating state changes during async operations. + public var onRequest: ((HTTPRequest) -> Void)? + /// A list of requests that have been received by the HTTP client. public var requests: [HTTPRequest] = [] @@ -62,6 +66,9 @@ public final class MockHTTPClient: HTTPClient { public func send(_ request: HTTPRequest) async throws -> HTTPResponse { requests.append(request) + // Invoke callback to allow tests to simulate state changes during async operations + onRequest?(request) + guard !results.isEmpty else { throw MockHTTPClientError.noResultForRequest } let result = results.removeFirst()