Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/api/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ type PAMAccessRequest struct {
AccountName string `json:"accountName,omitempty"`
ProjectId string `json:"projectId,omitempty"`
Comment thread
carlosmonastyrski marked this conversation as resolved.
MfaSessionId string `json:"mfaSessionId,omitempty"`
Comment thread
carlosmonastyrski marked this conversation as resolved.
Reason string `json:"reason,omitempty"`
}
Comment thread
carlosmonastyrski marked this conversation as resolved.

type PAMAccessResponse struct {
Expand Down
38 changes: 38 additions & 0 deletions packages/cmd/pam.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
package cmd

import (
"os"
"time"

pam "github.com/Infisical/infisical-merge/packages/pam/local"
"github.com/Infisical/infisical-merge/packages/util"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)

func readReasonFlag(cmd *cobra.Command) string {
reason, _ := cmd.Flags().GetString("reason")
return reason
}

func resolveReason(cmd *cobra.Command) string {
if cmd.Flags().Changed("reason") {
reason, _ := cmd.Flags().GetString("reason")
return reason
}
if !isatty.IsTerminal(os.Stdin.Fd()) {
return ""
}
reason, err := pam.PromptForReason(false)
if err != nil {
return ""
}
return reason
Comment thread
carlosmonastyrski marked this conversation as resolved.
}

var pamCmd = &cobra.Command{
Use: "pam",
Short: "PAM-related commands",
Expand Down Expand Up @@ -72,6 +94,8 @@ var pamDbAccessCmd = &cobra.Command{
util.HandleError(err, "Unable to parse port flag")
}

reason := resolveReason(cmd)

log.Debug().Msg("PAM Database Access: Trying to fetch secrets using logged in details")

loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true)
Expand All @@ -92,6 +116,7 @@ var pamDbAccessCmd = &cobra.Command{
pam.StartDatabaseLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{
ResourceName: resourceName,
AccountName: accountName,
Reason: reason,
}, projectID, durationStr, port)
},
}
Expand Down Expand Up @@ -194,6 +219,8 @@ func runSSHCommand(cmd *cobra.Command, args []string, options pam.SSHAccessOptio
projectID = workspaceFile.WorkspaceId
}

reason := readReasonFlag(cmd)

log.Debug().Msg("PAM SSH: Trying to fetch credentials using logged in details")

loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true)
Expand All @@ -214,6 +241,7 @@ func runSSHCommand(cmd *cobra.Command, args []string, options pam.SSHAccessOptio
pam.StartSSHLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{
ResourceName: resourceName,
AccountName: accountName,
Reason: reason,
}, projectID, durationStr, options)
}

Expand Down Expand Up @@ -273,6 +301,8 @@ var pamKubernetesAccessCmd = &cobra.Command{
projectID = workspaceFile.WorkspaceId
}

reason := resolveReason(cmd)

log.Debug().Msg("PAM Kubernetes Access: Trying to fetch credentials using logged in details")

loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true)
Expand All @@ -293,6 +323,7 @@ var pamKubernetesAccessCmd = &cobra.Command{
pam.StartKubernetesLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{
ResourceName: resourceName,
AccountName: accountName,
Reason: reason,
}, projectID, durationStr, port)
},
}
Expand Down Expand Up @@ -352,6 +383,8 @@ var pamRedisAccessCmd = &cobra.Command{
util.HandleError(err, "Unable to parse port flag")
}

reason := resolveReason(cmd)

log.Debug().Msg("PAM Redis Access: Trying to fetch secrets using logged in details")

loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true)
Expand All @@ -372,6 +405,7 @@ var pamRedisAccessCmd = &cobra.Command{
pam.StartRedisLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{
ResourceName: resourceName,
AccountName: accountName,
Reason: reason,
}, projectID, durationStr, port)
},
}
Expand All @@ -384,6 +418,7 @@ func init() {
pamDbAccessCmd.Flags().String("duration", "1h", "Duration for database access session (e.g., '1h', '30m', '2h30m')")
pamDbAccessCmd.Flags().Int("port", 0, "Port for the local database proxy server (0 for auto-assign)")
pamDbAccessCmd.Flags().String("project-id", "", "Project ID of the account to access")
pamDbAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)")
pamDbAccessCmd.MarkFlagRequired("resource")
pamDbAccessCmd.MarkFlagRequired("account")

Expand All @@ -393,6 +428,7 @@ func init() {
cmd.Flags().String("account", "", "Name of the account within the resource")
cmd.Flags().String("duration", "1h", "Duration for SSH access session (e.g., '1h', '30m', '2h30m')")
cmd.Flags().String("project-id", "", "Project ID of the account to access")
cmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)")
cmd.MarkFlagRequired("resource")
cmd.MarkFlagRequired("account")
}
Expand All @@ -413,6 +449,7 @@ func init() {
pamKubernetesAccessCmd.Flags().String("duration", "1h", "Duration for kubernetes access session (e.g., '1h', '30m', '2h30m')")
pamKubernetesAccessCmd.Flags().Int("port", 0, "Port for the local kubernetes proxy server (0 for auto-assign)")
pamKubernetesAccessCmd.Flags().String("project-id", "", "Project ID of the account to access")
pamKubernetesAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)")
pamKubernetesAccessCmd.MarkFlagRequired("resource")
pamKubernetesAccessCmd.MarkFlagRequired("account")

Expand All @@ -423,6 +460,7 @@ func init() {
pamRedisAccessCmd.Flags().String("duration", "1h", "Duration for Redis access session (e.g., '1h', '30m', '2h30m')")
pamRedisAccessCmd.Flags().Int("port", 0, "Port for the local Redis proxy server (0 for auto-assign)")
pamRedisAccessCmd.Flags().String("project-id", "", "Project ID of the account to access")
pamRedisAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)")
pamRedisAccessCmd.MarkFlagRequired("resource")
pamRedisAccessCmd.MarkFlagRequired("account")

Expand Down
49 changes: 46 additions & 3 deletions packages/pam/local/base-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"os"
"slices"
"strconv"
"strings"
Expand All @@ -21,12 +22,14 @@ import (
"github.com/Infisical/infisical-merge/packages/util"
"github.com/go-resty/resty/v2"
"github.com/manifoldco/promptui"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog/log"
)

type PAMAccessParams struct {
ResourceName string
AccountName string
Reason string
}

// GetDisplayName returns a user-friendly display name for the access params
Expand All @@ -41,6 +44,7 @@ func (p PAMAccessParams) ToAPIRequest(projectID, duration string) api.PAMAccessR
ResourceName: p.ResourceName,
AccountName: p.AccountName,
ProjectId: projectID,
Reason: p.Reason,
}
}

Comment thread
carlosmonastyrski marked this conversation as resolved.
Expand Down Expand Up @@ -313,14 +317,53 @@ func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) {
}
}

const reasonRequiredErrorName = "PAM_REASON_REQUIRED"

func PromptForReason(required bool) (string, error) {
label := "Reason for access"
prompt := promptui.Prompt{
Label: label,
Validate: func(input string) error {
if required && strings.TrimSpace(input) == "" {
return fmt.Errorf("a reason is required")
}
return nil
},
}
result, err := prompt.Run()
if err != nil {
return "", err
}
return strings.TrimSpace(result), nil
}

// CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required
// This is a shared function used by all PAM proxies
func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequest) (api.PAMAccessResponse, error) {
func CallPAMAccessWithMFA(
httpClient *resty.Client,
pamRequest api.PAMAccessRequest,
interactive bool,
) (api.PAMAccessResponse, error) {
// Initial request
pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
if err != nil {
// Check if MFA is required
if apiErr, ok := err.(*api.APIError); ok {
// Reason required by account policy
if apiErr.Name == reasonRequiredErrorName {
if !interactive || !isatty.IsTerminal(os.Stdin.Fd()) {
return api.PAMAccessResponse{}, fmt.Errorf(
"a reason is required to access this account — pass one with --reason")
}
log.Info().Msg("A reason is required to access this account.")
reason, promptErr := PromptForReason(true)
if promptErr != nil {
return api.PAMAccessResponse{}, fmt.Errorf("reason prompt cancelled: %w", promptErr)
}
pamRequest.Reason = reason
return CallPAMAccessWithMFA(httpClient, pamRequest, interactive)
}

// MFA required
if apiErr.Name == "SESSION_MFA_REQUIRED" {
// Extract MFA details from error
if details, ok := apiErr.Details.(map[string]interface{}); ok {
Comment thread
carlosmonastyrski marked this conversation as resolved.
Expand All @@ -347,7 +390,7 @@ func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequ
}
}
}
// Return original error if not MFA-related
// Return original error if not MFA/reason-related
return api.PAMAccessResponse{}, err
}

Expand Down
2 changes: 1 addition & 1 deletion packages/pam/local/database-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p

pamRequest := accessParams.ToAPIRequest(projectID, durationStr)

pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true)
if err != nil {
if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) {
return
Expand Down
2 changes: 1 addition & 1 deletion packages/pam/local/kubernetes-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func StartKubernetesLocalProxy(accessToken string, accessParams PAMAccessParams,

pamRequest := accessParams.ToAPIRequest(projectId, durationStr)

pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true)
if err != nil {
if HandleApprovalWorkflow(httpClient, err, projectId, accessParams, durationStr) {
return
Expand Down
2 changes: 1 addition & 1 deletion packages/pam/local/redis-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj

pamRequest := accessParams.ToAPIRequest(projectID, durationStr)

pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true)
if err != nil {
if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) {
return
Expand Down
3 changes: 2 additions & 1 deletion packages/pam/local/ssh-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func StartSSHLocalProxy(accessToken string, accessParams PAMAccessParams, projec

pamRequest := accessParams.ToAPIRequest(projectID, durationStr)

pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
interactive := options.ExecCommand == ""
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, interactive)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, for SSH, we are not asking for the optional reason, both on the CLI and in the web. I think we should only disable this optional prompt for SSH EXEC and not for SSH ACCESS.

This will make SSH consistent with other resource types.

if err != nil {
if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) {
return
Expand Down
Loading