diff --git a/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift b/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift index cd7c62da14..77a6dd4a7a 100644 --- a/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift +++ b/BitwardenShared/Core/Auth/Services/TestHelpers/MockKeychainRepository.swift @@ -35,6 +35,10 @@ class MockKeychainRepository: KeychainRepository { var setServerCommunicationConfigCalledConfig: BitwardenSdk.ServerCommunicationConfig? var setServerCommunicationConfigCalledHostname: String? // swiftlint:disable:this identifier_name + // Track which userId was passed to get/set methods for testing + var getAccessTokenUserId: String? + var getRefreshTokenUserId: String? + func deleteAllItems() async throws { deleteAllItemsCalled = true mockStorage.removeAll() @@ -77,7 +81,8 @@ class MockKeychainRepository: KeychainRepository { } func getAccessToken(userId: String) async throws -> String { - try getAccessTokenResult.get() + getAccessTokenUserId = userId + return try getAccessTokenResult.get() } func getAuthenticatorVaultKey(userId: String) async throws -> String { @@ -89,7 +94,8 @@ class MockKeychainRepository: KeychainRepository { } func getRefreshToken(userId: String) async throws -> String { - try getRefreshTokenResult.get() + getRefreshTokenUserId = userId + return try getRefreshTokenResult.get() } func getPendingAdminLoginRequest(userId: String) async throws -> String? { diff --git a/BitwardenShared/Core/Platform/Services/API/APIService.swift b/BitwardenShared/Core/Platform/Services/API/APIService.swift index 899500f4fe..4c5b035a59 100644 --- a/BitwardenShared/Core/Platform/Services/API/APIService.swift +++ b/BitwardenShared/Core/Platform/Services/API/APIService.swift @@ -44,6 +44,7 @@ class APIService { /// - client: The underlying `HTTPClient` that performs the network request. Defaults /// to `URLSession.shared`. /// - environmentService: The service used by the application to retrieve the environment settings. + /// - errorReporter: The service used by the application to report non-fatal errors. /// - flightRecorder: The service used by the application for recording temporary debug logs. /// - serverCommunicationConfigClientSingleton: The service to get the server communication client /// used to break circular dependency. @@ -55,6 +56,7 @@ class APIService { accountTokenProvider: AccountTokenProvider? = nil, client: HTTPClient = URLSession.shared, environmentService: EnvironmentService, + errorReporter: ErrorReporter, flightRecorder: FlightRecorder, serverCommunicationConfigClientSingleton: @escaping () -> ServerCommunicationConfigClientSingleton?, stateService: StateService, @@ -85,6 +87,7 @@ class APIService { self.accountTokenProvider = accountTokenProvider ?? DefaultAccountTokenProvider( httpService: httpServiceBuilder.makeService(baseURLGetter: { environmentService.identityURL }), tokenService: tokenService, + errorReporter: errorReporter, ) apiService = httpServiceBuilder.makeService( diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift index f76c45baa7..bb1bc5771c 100644 --- a/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift @@ -22,6 +22,9 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { /// The delegate to use for specific operations on the token provider. private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate? + /// The service used to report non-fatal errors. + private let errorReporter: ErrorReporter + /// The `HTTPService` used to make the API call to refresh the access token. private let httpService: HTTPService @@ -42,15 +45,18 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { /// - httpService: The service used to make the API call to refresh the access token. /// - timeProvider: The service used to get the present time. /// - tokenService: The service used to get the current tokens from. + /// - errorReporter: The service used to report non-fatal errors. /// init( httpService: HTTPService, timeProvider: TimeProvider = CurrentTime(), tokenService: TokenService, + errorReporter: ErrorReporter, ) { self.httpService = httpService self.timeProvider = timeProvider self.tokenService = tokenService + self.errorReporter = errorReporter } // MARK: Methods @@ -81,16 +87,29 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { defer { self.refreshTask = nil } do { - let refreshToken = try await tokenService.getRefreshToken() + let expectedUserId = try await tokenService.getActiveAccountId() + + let refreshToken = try await tokenService.getRefreshToken(userId: expectedUserId) let response = try await httpService.send( IdentityTokenRefreshRequest(refreshToken: refreshToken), ) let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn)) + let userIdAfter = try await tokenService.getActiveAccountId() + guard expectedUserId == userIdAfter else { + let error = AccountTokenProviderError( + userIdBefore: expectedUserId, + userIdAfter: userIdAfter, + ) + errorReporter.log(error: error) + throw error + } + try await tokenService.setTokens( accessToken: response.accessToken, refreshToken: response.refreshToken, expirationDate: expirationDate, + userId: expectedUserId, ) return response.accessToken diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderError.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderError.swift new file mode 100644 index 0000000000..95eef456b2 --- /dev/null +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderError.swift @@ -0,0 +1,24 @@ +import Foundation + +// MARK: - AccountTokenProviderError + +/// Error logged when the active account changes during a token refresh operation. +/// +struct AccountTokenProviderError: Error, CustomStringConvertible { + // MARK: Properties + + /// The active user ID before the token refresh operation. + let userIdBefore: String + + /// The active user ID after the token refresh operation. + let userIdAfter: String + + // MARK: CustomStringConvertible + + var description: String { + """ + Token refresh race condition detected: Active account changed from '\(userIdBefore)' to '\(userIdAfter)' \ + during token refresh operation. Tokens were not stored. + """ + } +} diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift index cfbb0a3796..5c55fa1d1d 100644 --- a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift @@ -11,6 +11,7 @@ class AccountTokenProviderTests: BitwardenTestCase { // MARK: Properties var client: MockHTTPClient! + var errorReporter: MockErrorReporter! var subject: DefaultAccountTokenProvider! var timeProvider: MockTimeProvider! var tokenService: MockTokenService! @@ -25,6 +26,7 @@ class AccountTokenProviderTests: BitwardenTestCase { super.setUp() client = MockHTTPClient() + errorReporter = MockErrorReporter() timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2))) tokenService = MockTokenService() @@ -32,6 +34,7 @@ class AccountTokenProviderTests: BitwardenTestCase { httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client), timeProvider: timeProvider, tokenService: tokenService, + errorReporter: errorReporter, ) } @@ -39,6 +42,7 @@ class AccountTokenProviderTests: BitwardenTestCase { try await super.tearDown() client = nil + errorReporter = nil subject = nil timeProvider = nil tokenService = nil @@ -171,4 +175,94 @@ class AccountTokenProviderTests: BitwardenTestCase { _ = try await subject.refreshToken() } } + + /// `refreshToken()` throws and does not store tokens when the active account switches during the HTTP request, + /// preventing tokens from being stored under the wrong account. + func test_refreshToken_accountSwitchDuringRequest_throwsAndDoesNotStoreTokens() async throws { + // Setup: Account 1 is active + tokenService.activeAccountId = "1" + tokenService.refreshTokenByUserId["1"] = "REFRESH_1" + + client.result = .httpSuccess(testData: .identityTokenRefresh) + + // Simulate account switch during HTTP request + client.onRequest = { _ in + // Active account switches to Account 2 while HTTP request is in flight + self.tokenService.activeAccountId = "2" + } + + await assertAsyncThrows { + _ = try await subject.refreshToken() + } + + // Verify: error logged, setTokens never called, no tokens stored under either account + XCTAssertEqual(errorReporter.errors.count, 1) + let error = errorReporter.errors[0] as? AccountTokenProviderError + XCTAssertNotNil(error) + XCTAssertEqual(error?.userIdBefore, "1") + XCTAssertEqual(error?.userIdAfter, "2") + XCTAssertNil(tokenService.setTokensCalledWithUserId) + } + + /// `refreshToken()` logs an error and throws when the active account changes during the token refresh operation, + /// preventing tokens from being stored in the wrong account. + func test_refreshToken_throwsRaceCondition_whenUserIdChanges() async throws { + tokenService.activeAccountId = "user-1" + tokenService.accessToken = "🔑" + tokenService.refreshToken = "🔒" + + client.result = .httpSuccess(testData: .identityTokenRefresh) + + // Simulate account switch during HTTP request + client.onRequest = { _ in + self.tokenService.activeAccountId = "user-2" + } + + await assertAsyncThrows { + _ = try await subject.refreshToken() + } + + // Verify error was logged and setTokens was never called + XCTAssertEqual(errorReporter.errors.count, 1) + let error = errorReporter.errors[0] as? AccountTokenProviderError + XCTAssertNotNil(error) + XCTAssertEqual(error?.userIdBefore, "user-1") + XCTAssertEqual(error?.userIdAfter, "user-2") + XCTAssertNil(tokenService.setTokensCalledWithUserId) + } + + /// `refreshToken()` does not log an error when the active account remains the same. + func test_refreshToken_doesNotLogError_whenUserIdStaysSame() async throws { + tokenService.activeAccountId = "user-1" + tokenService.accessToken = "🔑" + tokenService.refreshToken = "🔒" + + client.result = .httpSuccess(testData: .identityTokenRefresh) + + _ = try await subject.refreshToken() + + // Verify no error was logged + XCTAssertEqual(errorReporter.errors.count, 0) + } + + /// `refreshToken()` throws when `getActiveAccountId` throws before setting tokens, + /// preventing tokens from being stored without verifying the active account. + func test_refreshToken_throws_whenGetUserIdAfterThrows() async throws { + tokenService.accessToken = "🔑" + tokenService.refreshToken = "🔒" + client.result = .httpSuccess(testData: .identityTokenRefresh) + + // Clear the active account ID during the HTTP request so the + // getActiveAccountId() call (before setTokens) throws noActiveAccount. + client.onRequest = { _ in + self.tokenService.activeAccountId = "" + } + + await assertAsyncThrows(error: StateServiceError.noActiveAccount) { + _ = try await subject.refreshToken() + } + + // setTokens was never called + XCTAssertNil(tokenService.setTokensCalledWithUserId) + } } diff --git a/BitwardenShared/Core/Platform/Services/API/RefreshableAPIServiceTests.swift b/BitwardenShared/Core/Platform/Services/API/RefreshableAPIServiceTests.swift index bef4c57300..45535ad689 100644 --- a/BitwardenShared/Core/Platform/Services/API/RefreshableAPIServiceTests.swift +++ b/BitwardenShared/Core/Platform/Services/API/RefreshableAPIServiceTests.swift @@ -21,6 +21,7 @@ class RefreshableAPIServiceTests: BitwardenTestCase { subject = APIService( accountTokenProvider: accountTokenProvider, environmentService: MockEnvironmentService(), + errorReporter: MockErrorReporter(), flightRecorder: MockFlightRecorder(), serverCommunicationConfigClientSingleton: { MockServerCommunicationConfigClientSingleton() }, stateService: MockStateService(), diff --git a/BitwardenShared/Core/Platform/Services/API/TestHelpers/APIService+Mocks.swift b/BitwardenShared/Core/Platform/Services/API/TestHelpers/APIService+Mocks.swift index 2e1737674f..45be21d6ab 100644 --- a/BitwardenShared/Core/Platform/Services/API/TestHelpers/APIService+Mocks.swift +++ b/BitwardenShared/Core/Platform/Services/API/TestHelpers/APIService+Mocks.swift @@ -9,6 +9,7 @@ extension APIService { accountTokenProvider: AccountTokenProvider? = nil, client: HTTPClient, environmentService: EnvironmentService = MockEnvironmentService(), + errorReporter: ErrorReporter = MockErrorReporter(), flightRecorder: FlightRecorder = MockFlightRecorder(), // swiftlint:disable:next line_length serverCommunicationConfigClientSingleton: ServerCommunicationConfigClientSingleton = MockServerCommunicationConfigClientSingleton(), @@ -18,6 +19,7 @@ extension APIService { accountTokenProvider: accountTokenProvider, client: client, environmentService: environmentService, + errorReporter: errorReporter, flightRecorder: flightRecorder, serverCommunicationConfigClientSingleton: { serverCommunicationConfigClientSingleton }, stateService: stateService, diff --git a/BitwardenShared/Core/Platform/Services/ServiceContainer.swift b/BitwardenShared/Core/Platform/Services/ServiceContainer.swift index 502a68bfad..61f27004f7 100644 --- a/BitwardenShared/Core/Platform/Services/ServiceContainer.swift +++ b/BitwardenShared/Core/Platform/Services/ServiceContainer.swift @@ -532,6 +532,7 @@ public class ServiceContainer: Services { // swiftlint:disable:this type_body_le let apiService = APIService( client: noRedirectSession, environmentService: environmentService, + errorReporter: errorReporter, flightRecorder: flightRecorder, serverCommunicationConfigClientSingleton: { serverCommConfigClientSingletonHolder }, stateService: stateService, diff --git a/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift b/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift index f739f53ee0..316bbb7552 100644 --- a/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift +++ b/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift @@ -10,15 +10,35 @@ class MockTokenService: TokenService { var getIsExternalResult: Result = .success(false) var refreshToken: String? = "REFRESH_TOKEN" + // Track which userId was used in explicit userId methods + var getAccessTokenCalledWithUserId: String? + var getRefreshTokenCalledWithUserId: String? + var setTokensCalledWithUserId: String? + var activeAccountId: String = "1" + var accessTokenByUserId: [String: String] = [:] + var refreshTokenByUserId: [String: String] = [:] + func getAccessToken() async throws -> String { guard let accessToken else { throw StateServiceError.noActiveAccount } return accessToken } + func getAccessToken(userId: String) async throws -> String { + getAccessTokenCalledWithUserId = userId + return accessTokenByUserId[userId] ?? accessToken ?? "ACCESS_TOKEN" + } + func getAccessTokenExpirationDate() async throws -> Date? { try accessTokenExpirationDateResult.get() } + func getActiveAccountId() async throws -> String { + if activeAccountId.isEmpty { + throw StateServiceError.noActiveAccount + } + return activeAccountId + } + func getIsExternal() async throws -> Bool { try getIsExternalResult.get() } @@ -28,9 +48,24 @@ class MockTokenService: TokenService { return refreshToken } + func getRefreshToken(userId: String) async throws -> String { + getRefreshTokenCalledWithUserId = userId + return refreshTokenByUserId[userId] ?? refreshToken ?? "REFRESH_TOKEN" + } + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async { self.accessToken = accessToken self.refreshToken = refreshToken self.expirationDate = expirationDate } + + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async { + setTokensCalledWithUserId = userId + accessTokenByUserId[userId] = accessToken + refreshTokenByUserId[userId] = refreshToken + self.expirationDate = expirationDate + // Also update legacy properties for backward compatibility with existing tests + self.accessToken = accessToken + self.refreshToken = refreshToken + } } diff --git a/BitwardenShared/Core/Platform/Services/TokenService.swift b/BitwardenShared/Core/Platform/Services/TokenService.swift index 306e0bdabb..6811d02701 100644 --- a/BitwardenShared/Core/Platform/Services/TokenService.swift +++ b/BitwardenShared/Core/Platform/Services/TokenService.swift @@ -11,12 +11,25 @@ protocol TokenService: AnyObject { /// func getAccessToken() async throws -> String + /// Returns the access token for a specific user. + /// + /// - Parameter userId: The user ID to get the access token for. + /// - Returns: The access token for the specified user. + /// + func getAccessToken(userId: String) async throws -> String + /// Returns the access token's expiration date for the current account. /// /// - Returns: The access token's expiration date for the current account. /// func getAccessTokenExpirationDate() async throws -> Date? + /// Returns the active account's user ID. + /// + /// - Returns: The active account's user ID. + /// + func getActiveAccountId() async throws -> String + /// Returns whether the user is an external user. /// /// - Returns: Whether the user is an external user. @@ -29,6 +42,13 @@ protocol TokenService: AnyObject { /// func getRefreshToken() async throws -> String + /// Returns the refresh token for a specific user. + /// + /// - Parameter userId: The user ID to get the refresh token for. + /// - Returns: The refresh token for the specified user. + /// + func getRefreshToken(userId: String) async throws -> String + /// Sets a new access and refresh token for the current account. /// /// - Parameters: @@ -37,6 +57,16 @@ protocol TokenService: AnyObject { /// - expirationDate: The access token's expiration date. /// func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws + + /// Sets a new access and refresh token for a specific user. + /// + /// - Parameters: + /// - accessToken: The account's updated access token. + /// - refreshToken: The account's updated refresh token. + /// - expirationDate: The access token's expiration date. + /// - userId: The user ID to set the tokens for. + /// + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async throws } // MARK: - DefaultTokenService @@ -102,6 +132,24 @@ actor DefaultTokenService: TokenService { try await keychainRepository.setRefreshToken(refreshToken, userId: userId) await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId) } + + func getAccessToken(userId: String) async throws -> String { + try await keychainRepository.getAccessToken(userId: userId) + } + + func getActiveAccountId() async throws -> String { + try await stateService.getActiveAccountId() + } + + func getRefreshToken(userId: String) async throws -> String { + try await keychainRepository.getRefreshToken(userId: userId) + } + + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async throws { + try await keychainRepository.setAccessToken(accessToken, userId: userId) + try await keychainRepository.setRefreshToken(refreshToken, userId: userId) + await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId) + } } // MARK: ClientManagedTokens (SDK) diff --git a/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift b/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift index 8fcd0034bf..51ae143ec5 100644 --- a/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift +++ b/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift @@ -174,4 +174,58 @@ class TokenServiceTests: BitwardenTestCase { ) XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate) } + + // MARK: Tests - Explicit userId Methods replace the older ones if changes are correct + + /// `getAccessToken(userId:)` returns the access token for the specified user without calling getActiveAccountId. + func test_getAccessToken_withUserId() async throws { + keychainRepository.getAccessTokenResult = .success("USER_2_TOKEN") + + let accessToken = try await subject.getAccessToken(userId: "2") + XCTAssertEqual(accessToken, "USER_2_TOKEN") + XCTAssertEqual(keychainRepository.getAccessTokenUserId, "2") + } + + /// `getActiveAccountId()` returns the active account's user ID. + func test_getActiveAccountId() async throws { + stateService.activeAccount = .fixture() + + let userId = try await subject.getActiveAccountId() + XCTAssertEqual(userId, "1") + } + + /// `getActiveAccountId()` throws an error if there isn't an active account. + func test_getActiveAccountId_noAccount() async { + stateService.activeAccount = nil + + await assertAsyncThrows(error: StateServiceError.noActiveAccount) { + _ = try await subject.getActiveAccountId() + } + } + + /// `getRefreshToken(userId:)` returns the refresh token for the specified user without calling getActiveAccountId. + func test_getRefreshToken_withUserId() async throws { + keychainRepository.getRefreshTokenResult = .success("USER_2_REFRESH") + + let refreshToken = try await subject.getRefreshToken(userId: "2") + XCTAssertEqual(refreshToken, "USER_2_REFRESH") + XCTAssertEqual(keychainRepository.getRefreshTokenUserId, "2") + } + + /// `setTokens(accessToken:refreshToken:expirationDate:userId:)` sets tokens for the specified user + /// without calling getActiveAccountId. + func test_setTokens_withUserId() async throws { + let expirationDate = Date(year: 2025, month: 10, day: 1) + try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒", expirationDate: expirationDate, userId: "2") + + XCTAssertEqual( + keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "2"))], + "🔑", + ) + XCTAssertEqual( + keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "2"))], + "🔒", + ) + XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["2"], expirationDate) + } } diff --git a/BitwardenShared/UI/Vault/VaultItem/AddEditItem/AddEditItemProcessor.swift b/BitwardenShared/UI/Vault/VaultItem/AddEditItem/AddEditItemProcessor.swift index 42a92f3e53..1cafc08248 100644 --- a/BitwardenShared/UI/Vault/VaultItem/AddEditItem/AddEditItemProcessor.swift +++ b/BitwardenShared/UI/Vault/VaultItem/AddEditItem/AddEditItemProcessor.swift @@ -696,6 +696,15 @@ final class AddEditItemProcessor: StateProcessor] = [] + /// A callback that is invoked when a request is received, before returning the result. + /// Useful for simulating state changes during async operations. + public var onRequest: ((HTTPRequest) -> Void)? + /// A list of requests that have been received by the HTTP client. public var requests: [HTTPRequest] = [] @@ -62,6 +66,9 @@ public final class MockHTTPClient: HTTPClient { public func send(_ request: HTTPRequest) async throws -> HTTPResponse { requests.append(request) + // Invoke callback if provided (useful for simulating state changes during async operations) + onRequest?(request) + guard !results.isEmpty else { throw MockHTTPClientError.noResultForRequest } let result = results.removeFirst()