diff --git a/packages/api/model.go b/packages/api/model.go index d31eac0a..4000da29 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -824,6 +824,7 @@ type PAMAccessRequest struct { AccountName string `json:"accountName,omitempty"` ProjectId string `json:"projectId,omitempty"` MfaSessionId string `json:"mfaSessionId,omitempty"` + Reason string `json:"reason,omitempty"` } type PAMAccessResponse struct { diff --git a/packages/cmd/pam.go b/packages/cmd/pam.go index 2c5b5cb6..aa9c013c 100644 --- a/packages/cmd/pam.go +++ b/packages/cmd/pam.go @@ -1,14 +1,36 @@ package cmd import ( + "os" "time" pam "github.com/Infisical/infisical-merge/packages/pam/local" "github.com/Infisical/infisical-merge/packages/util" + "github.com/mattn/go-isatty" "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) +func readReasonFlag(cmd *cobra.Command) string { + reason, _ := cmd.Flags().GetString("reason") + return reason +} + +func resolveReason(cmd *cobra.Command) string { + if cmd.Flags().Changed("reason") { + reason, _ := cmd.Flags().GetString("reason") + return reason + } + if !isatty.IsTerminal(os.Stdin.Fd()) { + return "" + } + reason, err := pam.PromptForReason(false) + if err != nil { + return "" + } + return reason +} + var pamCmd = &cobra.Command{ Use: "pam", Short: "PAM-related commands", @@ -72,6 +94,8 @@ var pamDbAccessCmd = &cobra.Command{ util.HandleError(err, "Unable to parse port flag") } + reason := resolveReason(cmd) + log.Debug().Msg("PAM Database Access: Trying to fetch secrets using logged in details") loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) @@ -92,6 +116,7 @@ var pamDbAccessCmd = &cobra.Command{ pam.StartDatabaseLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ ResourceName: resourceName, AccountName: accountName, + Reason: reason, }, projectID, durationStr, port) }, } @@ -194,6 +219,13 @@ func runSSHCommand(cmd *cobra.Command, args []string, options pam.SSHAccessOptio projectID = workspaceFile.WorkspaceId } + var reason string + if options.ExecCommand != "" { + reason = readReasonFlag(cmd) + } else { + reason = resolveReason(cmd) + } + log.Debug().Msg("PAM SSH: Trying to fetch credentials using logged in details") loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) @@ -214,6 +246,7 @@ func runSSHCommand(cmd *cobra.Command, args []string, options pam.SSHAccessOptio pam.StartSSHLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ ResourceName: resourceName, AccountName: accountName, + Reason: reason, }, projectID, durationStr, options) } @@ -273,6 +306,8 @@ var pamKubernetesAccessCmd = &cobra.Command{ projectID = workspaceFile.WorkspaceId } + reason := resolveReason(cmd) + log.Debug().Msg("PAM Kubernetes Access: Trying to fetch credentials using logged in details") loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) @@ -293,6 +328,7 @@ var pamKubernetesAccessCmd = &cobra.Command{ pam.StartKubernetesLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ ResourceName: resourceName, AccountName: accountName, + Reason: reason, }, projectID, durationStr, port) }, } @@ -352,6 +388,8 @@ var pamRedisAccessCmd = &cobra.Command{ util.HandleError(err, "Unable to parse port flag") } + reason := resolveReason(cmd) + log.Debug().Msg("PAM Redis Access: Trying to fetch secrets using logged in details") loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) @@ -372,6 +410,7 @@ var pamRedisAccessCmd = &cobra.Command{ pam.StartRedisLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ ResourceName: resourceName, AccountName: accountName, + Reason: reason, }, projectID, durationStr, port) }, } @@ -384,6 +423,7 @@ func init() { pamDbAccessCmd.Flags().String("duration", "1h", "Duration for database access session (e.g., '1h', '30m', '2h30m')") pamDbAccessCmd.Flags().Int("port", 0, "Port for the local database proxy server (0 for auto-assign)") pamDbAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") + pamDbAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") pamDbAccessCmd.MarkFlagRequired("resource") pamDbAccessCmd.MarkFlagRequired("account") @@ -393,6 +433,7 @@ func init() { cmd.Flags().String("account", "", "Name of the account within the resource") cmd.Flags().String("duration", "1h", "Duration for SSH access session (e.g., '1h', '30m', '2h30m')") cmd.Flags().String("project-id", "", "Project ID of the account to access") + cmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") cmd.MarkFlagRequired("resource") cmd.MarkFlagRequired("account") } @@ -413,6 +454,7 @@ func init() { pamKubernetesAccessCmd.Flags().String("duration", "1h", "Duration for kubernetes access session (e.g., '1h', '30m', '2h30m')") pamKubernetesAccessCmd.Flags().Int("port", 0, "Port for the local kubernetes proxy server (0 for auto-assign)") pamKubernetesAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") + pamKubernetesAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") pamKubernetesAccessCmd.MarkFlagRequired("resource") pamKubernetesAccessCmd.MarkFlagRequired("account") @@ -423,6 +465,7 @@ func init() { pamRedisAccessCmd.Flags().String("duration", "1h", "Duration for Redis access session (e.g., '1h', '30m', '2h30m')") pamRedisAccessCmd.Flags().Int("port", 0, "Port for the local Redis proxy server (0 for auto-assign)") pamRedisAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") + pamRedisAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") pamRedisAccessCmd.MarkFlagRequired("resource") pamRedisAccessCmd.MarkFlagRequired("account") diff --git a/packages/pam/local/base-proxy.go b/packages/pam/local/base-proxy.go index 5dac6957..0ef603cd 100644 --- a/packages/pam/local/base-proxy.go +++ b/packages/pam/local/base-proxy.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "os" "slices" "strconv" "strings" @@ -21,12 +22,14 @@ import ( "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" "github.com/manifoldco/promptui" + "github.com/mattn/go-isatty" "github.com/rs/zerolog/log" ) type PAMAccessParams struct { ResourceName string AccountName string + Reason string } // GetDisplayName returns a user-friendly display name for the access params @@ -41,6 +44,7 @@ func (p PAMAccessParams) ToAPIRequest(projectID, duration string) api.PAMAccessR ResourceName: p.ResourceName, AccountName: p.AccountName, ProjectId: projectID, + Reason: p.Reason, } } @@ -313,14 +317,53 @@ func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) { } } +const reasonRequiredErrorName = "PAM_REASON_REQUIRED" + +func PromptForReason(required bool) (string, error) { + label := "Reason for access" + prompt := promptui.Prompt{ + Label: label, + Validate: func(input string) error { + if required && strings.TrimSpace(input) == "" { + return fmt.Errorf("a reason is required") + } + return nil + }, + } + result, err := prompt.Run() + if err != nil { + return "", err + } + return strings.TrimSpace(result), nil +} + // CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required // This is a shared function used by all PAM proxies -func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequest) (api.PAMAccessResponse, error) { +func CallPAMAccessWithMFA( + httpClient *resty.Client, + pamRequest api.PAMAccessRequest, + interactive bool, +) (api.PAMAccessResponse, error) { // Initial request pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) if err != nil { - // Check if MFA is required if apiErr, ok := err.(*api.APIError); ok { + // Reason required by account policy + if apiErr.Name == reasonRequiredErrorName { + if !interactive || !isatty.IsTerminal(os.Stdin.Fd()) { + return api.PAMAccessResponse{}, fmt.Errorf( + "a reason is required to access this account — pass one with --reason") + } + log.Info().Msg("A reason is required to access this account.") + reason, promptErr := PromptForReason(true) + if promptErr != nil { + return api.PAMAccessResponse{}, fmt.Errorf("reason prompt cancelled: %w", promptErr) + } + pamRequest.Reason = reason + return CallPAMAccessWithMFA(httpClient, pamRequest, interactive) + } + + // MFA required if apiErr.Name == "SESSION_MFA_REQUIRED" { // Extract MFA details from error if details, ok := apiErr.Details.(map[string]interface{}); ok { @@ -347,7 +390,7 @@ func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequ } } } - // Return original error if not MFA-related + // Return original error if not MFA/reason-related return api.PAMAccessResponse{}, err } diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index aa1e2450..9f51f4e3 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -40,7 +40,7 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p pamRequest := accessParams.ToAPIRequest(projectID, durationStr) - pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest) + pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true) if err != nil { if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) { return diff --git a/packages/pam/local/kubernetes-proxy.go b/packages/pam/local/kubernetes-proxy.go index c41fa795..0e94fe72 100644 --- a/packages/pam/local/kubernetes-proxy.go +++ b/packages/pam/local/kubernetes-proxy.go @@ -39,7 +39,7 @@ func StartKubernetesLocalProxy(accessToken string, accessParams PAMAccessParams, pamRequest := accessParams.ToAPIRequest(projectId, durationStr) - pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest) + pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true) if err != nil { if HandleApprovalWorkflow(httpClient, err, projectId, accessParams, durationStr) { return diff --git a/packages/pam/local/redis-proxy.go b/packages/pam/local/redis-proxy.go index bab8487b..901f1659 100644 --- a/packages/pam/local/redis-proxy.go +++ b/packages/pam/local/redis-proxy.go @@ -32,7 +32,7 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj pamRequest := accessParams.ToAPIRequest(projectID, durationStr) - pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest) + pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true) if err != nil { if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) { return diff --git a/packages/pam/local/ssh-proxy.go b/packages/pam/local/ssh-proxy.go index e41df1e0..efa3ba88 100644 --- a/packages/pam/local/ssh-proxy.go +++ b/packages/pam/local/ssh-proxy.go @@ -41,7 +41,8 @@ func StartSSHLocalProxy(accessToken string, accessParams PAMAccessParams, projec pamRequest := accessParams.ToAPIRequest(projectID, durationStr) - pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest) + interactive := options.ExecCommand == "" + pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, interactive) if err != nil { if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) { return