Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class MockKeychainRepository: KeychainRepository {
var setServerCommunicationConfigCalledConfig: BitwardenSdk.ServerCommunicationConfig?
var setServerCommunicationConfigCalledHostname: String? // swiftlint:disable:this identifier_name

// Track which userId was passed to get/set methods for testing
var getAccessTokenUserId: String?
var getRefreshTokenUserId: String?

func deleteAllItems() async throws {
deleteAllItemsCalled = true
mockStorage.removeAll()
Expand Down Expand Up @@ -77,7 +81,8 @@ class MockKeychainRepository: KeychainRepository {
}

func getAccessToken(userId: String) async throws -> String {
try getAccessTokenResult.get()
getAccessTokenUserId = userId
return try getAccessTokenResult.get()
}

func getAuthenticatorVaultKey(userId: String) async throws -> String {
Expand All @@ -89,7 +94,8 @@ class MockKeychainRepository: KeychainRepository {
}

func getRefreshToken(userId: String) async throws -> String {
try getRefreshTokenResult.get()
getRefreshTokenUserId = userId
return try getRefreshTokenResult.get()
}

func getPendingAdminLoginRequest(userId: String) async throws -> String? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,22 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
defer { self.refreshTask = nil }

do {
let refreshToken = try await tokenService.getRefreshToken()
// Check if this is the best place to apply the changes
let userId = try await tokenService.getActiveAccountId()

// Use captured userId for all operations
let refreshToken = try await tokenService.getRefreshToken(userId: userId)
let response = try await httpService.send(
IdentityTokenRefreshRequest(refreshToken: refreshToken),
)
let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn))

// Store tokens using the SAME userId (even if active account changed)
try await tokenService.setTokens(
accessToken: response.accessToken,
refreshToken: response.refreshToken,
expirationDate: expirationDate,
userId: userId,
)

return response.accessToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,31 @@ class AccountTokenProviderTests: BitwardenTestCase {
_ = try await subject.refreshToken()
}
}

/// `refreshToken()` captures userId once at the start and uses it throughout the refresh operation,
/// preventing race conditions when the active account changes during the HTTP request.
/// This ensures tokens from Account A are stored under Account A's keychain entry,
/// even if the active account switches to Account B during the async HTTP operation.
func test_refreshToken_accountSwitchDuringRequest_storesTokensForOriginalAccount() async throws {
// Setup: Account 1 is active
tokenService.activeAccountId = "1"
tokenService.refreshTokenByUserId["1"] = "REFRESH_1"

client.result = .httpSuccess(testData: .identityTokenRefresh)

// Simulate account switch during HTTP request
client.onRequest = { _ in
// Active account switches to Account 2 while HTTP request is in flight
self.tokenService.activeAccountId = "2"
}

let newAccessToken = try await subject.refreshToken()

// Verify: Tokens stored under Account 1 (original), NOT Account 2
XCTAssertEqual(newAccessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.setTokensCalledWithUserId, "1")
XCTAssertEqual(tokenService.getRefreshTokenCalledWithUserId, "1")
XCTAssertEqual(tokenService.accessTokenByUserId["1"], "ACCESS_TOKEN")
XCTAssertNil(tokenService.accessTokenByUserId["2"], "Access token should NOT be stored under Account 2")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,35 @@ class MockTokenService: TokenService {
var getIsExternalResult: Result<Bool, Error> = .success(false)
var refreshToken: String? = "REFRESH_TOKEN"

// Track which userId was used in explicit userId methods
var getAccessTokenCalledWithUserId: String?
var getRefreshTokenCalledWithUserId: String?
var setTokensCalledWithUserId: String?
var activeAccountId: String = "1"
var accessTokenByUserId: [String: String] = [:]
var refreshTokenByUserId: [String: String] = [:]

func getAccessToken() async throws -> String {
guard let accessToken else { throw StateServiceError.noActiveAccount }
return accessToken
}

func getAccessToken(userId: String) async throws -> String {
getAccessTokenCalledWithUserId = userId
return accessTokenByUserId[userId] ?? accessToken ?? "ACCESS_TOKEN"
}

func getAccessTokenExpirationDate() async throws -> Date? {
try accessTokenExpirationDateResult.get()
}

func getActiveAccountId() async throws -> String {
if activeAccountId.isEmpty {
throw StateServiceError.noActiveAccount
}
return activeAccountId
}

func getIsExternal() async throws -> Bool {
try getIsExternalResult.get()
}
Expand All @@ -28,9 +48,24 @@ class MockTokenService: TokenService {
return refreshToken
}

func getRefreshToken(userId: String) async throws -> String {
getRefreshTokenCalledWithUserId = userId
return refreshTokenByUserId[userId] ?? refreshToken ?? "REFRESH_TOKEN"
}

func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async {
self.accessToken = accessToken
self.refreshToken = refreshToken
self.expirationDate = expirationDate
}

func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async {
setTokensCalledWithUserId = userId
accessTokenByUserId[userId] = accessToken
refreshTokenByUserId[userId] = refreshToken
self.expirationDate = expirationDate
// Also update legacy properties for backward compatibility with existing tests
self.accessToken = accessToken
self.refreshToken = refreshToken
}
}
48 changes: 48 additions & 0 deletions BitwardenShared/Core/Platform/Services/TokenService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,25 @@ protocol TokenService: AnyObject {
///
func getAccessToken() async throws -> String

/// Returns the access token for a specific user.
///
/// - Parameter userId: The user ID to get the access token for.
/// - Returns: The access token for the specified user.
///
func getAccessToken(userId: String) async throws -> String

/// Returns the access token's expiration date for the current account.
///
/// - Returns: The access token's expiration date for the current account.
///
func getAccessTokenExpirationDate() async throws -> Date?

/// Returns the active account's user ID.
///
/// - Returns: The active account's user ID.
///
func getActiveAccountId() async throws -> String

/// Returns whether the user is an external user.
///
/// - Returns: Whether the user is an external user.
Expand All @@ -29,6 +42,13 @@ protocol TokenService: AnyObject {
///
func getRefreshToken() async throws -> String

/// Returns the refresh token for a specific user.
///
/// - Parameter userId: The user ID to get the refresh token for.
/// - Returns: The refresh token for the specified user.
///
func getRefreshToken(userId: String) async throws -> String

/// Sets a new access and refresh token for the current account.
///
/// - Parameters:
Expand All @@ -37,6 +57,16 @@ protocol TokenService: AnyObject {
/// - expirationDate: The access token's expiration date.
///
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws

/// Sets a new access and refresh token for a specific user.
///
/// - Parameters:
/// - accessToken: The account's updated access token.
/// - refreshToken: The account's updated refresh token.
/// - expirationDate: The access token's expiration date.
/// - userId: The user ID to set the tokens for.
///
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async throws
}

// MARK: - DefaultTokenService
Expand Down Expand Up @@ -102,6 +132,24 @@ actor DefaultTokenService: TokenService {
try await keychainRepository.setRefreshToken(refreshToken, userId: userId)
await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId)
}

func getAccessToken(userId: String) async throws -> String {
try await keychainRepository.getAccessToken(userId: userId)
}

func getActiveAccountId() async throws -> String {
try await stateService.getActiveAccountId()
}

func getRefreshToken(userId: String) async throws -> String {
try await keychainRepository.getRefreshToken(userId: userId)
}

func setTokens(accessToken: String, refreshToken: String, expirationDate: Date, userId: String) async throws {
try await keychainRepository.setAccessToken(accessToken, userId: userId)
try await keychainRepository.setRefreshToken(refreshToken, userId: userId)
await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId)
}
}

// MARK: ClientManagedTokens (SDK)
Expand Down
54 changes: 54 additions & 0 deletions BitwardenShared/Core/Platform/Services/TokenServiceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,58 @@ class TokenServiceTests: BitwardenTestCase {
)
XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate)
}

// MARK: Tests - Explicit userId Methods replace the older ones if changes are correct

/// `getAccessToken(userId:)` returns the access token for the specified user without calling getActiveAccountId.
func test_getAccessToken_withUserId() async throws {
keychainRepository.getAccessTokenResult = .success("USER_2_TOKEN")

let accessToken = try await subject.getAccessToken(userId: "2")
XCTAssertEqual(accessToken, "USER_2_TOKEN")
XCTAssertEqual(keychainRepository.getAccessTokenUserId, "2")
}

/// `getActiveAccountId()` returns the active account's user ID.
func test_getActiveAccountId() async throws {
stateService.activeAccount = .fixture()

let userId = try await subject.getActiveAccountId()
XCTAssertEqual(userId, "1")
}

/// `getActiveAccountId()` throws an error if there isn't an active account.
func test_getActiveAccountId_noAccount() async {
stateService.activeAccount = nil

await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.getActiveAccountId()
}
}

/// `getRefreshToken(userId:)` returns the refresh token for the specified user without calling getActiveAccountId.
func test_getRefreshToken_withUserId() async throws {
keychainRepository.getRefreshTokenResult = .success("USER_2_REFRESH")

let refreshToken = try await subject.getRefreshToken(userId: "2")
XCTAssertEqual(refreshToken, "USER_2_REFRESH")
XCTAssertEqual(keychainRepository.getRefreshTokenUserId, "2")
}

/// `setTokens(accessToken:refreshToken:expirationDate:userId:)` sets tokens for the specified user
/// without calling getActiveAccountId.
func test_setTokens_withUserId() async throws {
let expirationDate = Date(year: 2025, month: 10, day: 1)
try await subject.setTokens(accessToken: "πŸ”‘", refreshToken: "πŸ”’", expirationDate: expirationDate, userId: "2")

XCTAssertEqual(
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "2"))],
"πŸ”‘",
)
XCTAssertEqual(
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "2"))],
"πŸ”’",
)
XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["2"], expirationDate)
}
}
7 changes: 7 additions & 0 deletions TestHelpers/API/MockHTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ public final class MockHTTPClient: HTTPClient {
/// A list of download results that will be returned in order for future requests.
public var downloadResults: [Result<URL, Error>] = []

/// A callback that is invoked when a request is received, before returning the result.
/// Useful for simulating state changes during async operations.
public var onRequest: ((HTTPRequest) -> Void)?

/// A list of requests that have been received by the HTTP client.
public var requests: [HTTPRequest] = []

Expand Down Expand Up @@ -62,6 +66,9 @@ public final class MockHTTPClient: HTTPClient {
public func send(_ request: HTTPRequest) async throws -> HTTPResponse {
requests.append(request)

// Invoke callback to allow tests to simulate state changes during async operations
onRequest?(request)

guard !results.isEmpty else { throw MockHTTPClientError.noResultForRequest }

let result = results.removeFirst()
Expand Down
Loading