Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.cornellappdev.uplift.data.auth

import com.apollographql.apollo.ApolloClient
import com.apollographql.apollo.api.ApolloRequest
import com.apollographql.apollo.api.ApolloResponse
import com.apollographql.apollo.api.ExecutionContext
import com.apollographql.apollo.api.Operation
import com.apollographql.apollo.interceptor.ApolloInterceptor
import com.apollographql.apollo.interceptor.ApolloInterceptorChain
import com.cornellappdev.uplift.RefreshAccessTokenMutation
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flow
import javax.inject.Inject
import javax.inject.Named
import javax.inject.Singleton

/**
* Execution context to track retries.
*/
internal class RetryContext(val retryCount: Int) : ExecutionContext.Element {
override val key: ExecutionContext.Key<*> = Key

companion object Key : ExecutionContext.Key<RetryContext>
}

/**
* An Apollo Interceptor that handles token expiration errors that return as 200 OK with
* GraphQL errors (specifically "Signature has expired").
*/
@Singleton
class ApolloAuthInterceptor @Inject constructor(
private val tokenManager: TokenManager,
private val sessionManager: SessionManager,
@Named("refresh") private val refreshClient: ApolloClient
) : ApolloInterceptor {
override fun <D : Operation.Data> intercept(
request: ApolloRequest<D>,
chain: ApolloInterceptorChain
): Flow<ApolloResponse<D>> = flow {
val response = chain.proceed(request).first()

val retryCount = request.executionContext[RetryContext]?.retryCount ?: 0
// TODO: replace string check with explicit error codes if backend implements
if (response.errors?.any { it.message.contains("Signature has expired") } == true && retryCount < 1) {
val refreshToken = tokenManager.getRefreshToken()
if (refreshToken != null) {
try {
val mutationResponse = refreshClient.mutation(RefreshAccessTokenMutation())
.addHttpHeader("Authorization", "Bearer $refreshToken")
.execute()

val newAccessToken = mutationResponse.data?.refreshAccessToken?.newAccessToken
if (newAccessToken != null) {
tokenManager.saveTokens(newAccessToken, refreshToken)
// Retry the request with the new token
val newRequest = request.newBuilder()
.addExecutionContext(RetryContext(retryCount + 1))
.build()
emitAll(chain.proceed(newRequest))
return@flow
} else {
sessionManager.logout()
}
} catch (e: Exception) {
sessionManager.logout()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
}
} else {
sessionManager.logout()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
}
}

emit(response)
}
}
7 changes: 6 additions & 1 deletion app/src/main/java/com/cornellappdev/uplift/di/AppModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.cornellappdev.uplift.di
import com.apollographql.apollo.ApolloClient
import com.apollographql.apollo.network.okHttpClient
import com.cornellappdev.uplift.BuildConfig
import com.cornellappdev.uplift.data.auth.ApolloAuthInterceptor
import com.cornellappdev.uplift.data.auth.AuthInterceptor
import com.cornellappdev.uplift.data.auth.TokenAuthenticator
import dagger.Module
Expand Down Expand Up @@ -67,10 +68,14 @@ object AppModule {
@Provides
@Singleton
@Named("main")
fun provideApolloClient(@Named("main") okHttpClient: OkHttpClient): ApolloClient {
fun provideApolloClient(
@Named("main") okHttpClient: OkHttpClient,
apolloAuthInterceptor: ApolloAuthInterceptor
): ApolloClient {
return ApolloClient.Builder()
.serverUrl(BuildConfig.BACKEND_URL)
.okHttpClient(okHttpClient)
.addInterceptor(apolloAuthInterceptor)
.build()
}

Expand Down
Loading