Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions pkg/providers/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,21 @@ func New(block schema.OptionBlock) (*Provider, error) {

if err != nil && options.AssumeRoleName != "" && len(options.AccountIds) > 0 {
// Base user doesn't have DescribeRegions permission, try with assumed role
tempSession, err := createAssumedRoleSession(options, sess, config)
if err != nil {
return nil, errors.Wrap(err, "could not create assumed role session")
var regionErr error
for _, accountId := range options.AccountIds {
tempSession, err := createAssumedRoleSession(options, sess, config, accountId)
if err != nil {
regionErr = err
continue
}
tempRC := ec2.New(tempSession)
regions, regionErr = tempRC.DescribeRegions(&ec2.DescribeRegionsInput{})
if regionErr == nil {
break
}
}

// Use assumed role session for DescribeRegions
tempRC := ec2.New(tempSession)
regions, err = tempRC.DescribeRegions(&ec2.DescribeRegionsInput{})
if err != nil {
return nil, errors.Wrap(err, "could not get list of regions even with assumed role")
if regionErr != nil {
return nil, errors.Wrap(regionErr, "could not get list of regions with any account")
}
} else if err != nil {
return nil, errors.Wrap(err, "could not get list of regions")
Expand All @@ -270,12 +275,9 @@ func New(block schema.OptionBlock) (*Provider, error) {
return provider, nil
}

func createAssumedRoleSession(options *ProviderOptions, sess *session.Session, config *aws.Config) (*session.Session, error) {
if len(options.AccountIds) == 0 {
return nil, errors.New("no account IDs provided for assume role")
}
func createAssumedRoleSession(options *ProviderOptions, sess *session.Session, config *aws.Config, accountId string) (*session.Session, error) {
stsClient := sts.New(sess)
roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", options.AccountIds[0], options.AssumeRoleName)
roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountId, options.AssumeRoleName)

roleInput := &sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
Expand Down Expand Up @@ -524,14 +526,38 @@ func (p *Provider) Verify(ctx context.Context) error {
}

if p.options.AssumeRoleName != "" && len(p.options.AccountIds) > 0 {
tempSession, err := createAssumedRoleSession(p.options, p.session, p.session.Config)
if err != nil {
return err
}
p.initServices(tempSession)
err = p.verify()
if err != nil {
return err
var mu sync.Mutex
var failedAccounts []string
var wg sync.WaitGroup

for _, accountId := range p.options.AccountIds {
wg.Add(1)
go func(id string) {
defer wg.Done()
tempSession, err := createAssumedRoleSession(p.options, p.session, p.session.Config, id)
if err != nil {
mu.Lock()
failedAccounts = append(failedAccounts, id)
mu.Unlock()
return
}
tempProvider := &Provider{options: p.options, session: tempSession}
tempProvider.initServices(tempSession)
if err := tempProvider.verify(); err != nil {
mu.Lock()
failedAccounts = append(failedAccounts, id)
mu.Unlock()
}
}(accountId)
}
wg.Wait()

if len(failedAccounts) > 0 {
msg := fmt.Sprintf("failed to assume role %s in accounts: %s", p.options.AssumeRoleName, strings.Join(failedAccounts, ", "))
if p.options.OrgDiscoveryRoleArn != "" {
msg += ". Add these to exclude_account_ids if they should not be part of discovery"
}
return errors.New(msg)
}
return nil
}
Expand Down
Loading