/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.athena.jdbc.authentication;

import com.amazon.athena.jdbc.authentication.utils.JwtTrustedIdentityProviderUtils;
import com.amazon.athena.jdbc.cache.Cache;
import com.amazon.athena.jdbc.cache.TokenCacheEntry;
import com.amazon.athena.logging.AthenaLogger;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Optional;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.services.athena.AthenaClient;
import software.amazon.awssdk.services.athena.model.ListTagsForResourceRequest;
import software.amazon.awssdk.services.ssoadmin.SsoAdminClient;
import software.amazon.awssdk.services.ssoadmin.model.Tag;
import software.amazon.awssdk.services.ssooidc.SsoOidcClient;
import software.amazon.awssdk.services.ssooidc.model.CreateTokenWithIamRequest;
import software.amazon.awssdk.services.ssooidc.model.CreateTokenWithIamResponse;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityResponse;
import software.amazon.awssdk.services.sts.model.Credentials;
import software.amazon.awssdk.services.sts.model.ProvidedContext;

public class JwtTrustedIdentityCredentialsProvider
implements AwsCredentialsProvider {
    private static final AthenaLogger logger = AthenaLogger.of(JwtTrustedIdentityCredentialsProvider.class);
    private static final Duration EXPIRATION_THRESHOLD_SECS = Duration.ofSeconds(180L);
    private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer";
    private static final String REFRESH_TOKEN_REQUESTED_TYPE = "urn:ietf:params:oauth:token-type:refresh_token";
    private static final String PROVIDER_ARN = "arn:aws:iam::aws:contextProvider/IdentityCenter";
    private static final String CUSTOMER_APPLICATION_ARN_TAG = "AthenaDriverOidcAppArn";
    private static final String ACCESS_ROLE_ARN_TAG = "AthenaDriverOidcAppAccessRoleArn";
    private static final String REFRESH_TOKEN_GRANT_TYPE = "refresh_token";
    private final Clock clock;
    private final Integer roleSessionDuration;
    private final String webIdentityToken;
    private final String applicationRoleArn;
    private final String workgroupArn;
    private final String roleSessionName;
    private final String accessRoleArn;
    private final String customerIdcApplicationArn;
    private final JwtTrustedIdentityProviderUtils jwtTrustedIdentityProviderUtils;
    private final Cache driverCache;
    private Credentials credentials;
    private String resolvedCustomerIdcApplicationArn;
    private String resolvedAccessRoleArn;

    JwtTrustedIdentityCredentialsProvider(String webIdentityToken, String applicationRoleArn, String workgroupArn, String customerIdcApplicationArn, String accessRoleArn, String roleSessionName, Integer roleSessionDuration, JwtTrustedIdentityProviderUtils jwtTrustedIdentityProviderUtils, Cache driverCache) {
        this(webIdentityToken, applicationRoleArn, workgroupArn, customerIdcApplicationArn, accessRoleArn, roleSessionName, roleSessionDuration, jwtTrustedIdentityProviderUtils, Clock.systemDefaultZone(), driverCache);
    }

    JwtTrustedIdentityCredentialsProvider(String webIdentityToken, String applicationRoleArn, String workgroupArn, String customerIdcApplicationArn, String accessRoleArn, String roleSessionName, Integer roleSessionDuration, JwtTrustedIdentityProviderUtils jwtTrustedIdentityProviderUtils, Clock clock, Cache driverCache) {
        this.webIdentityToken = webIdentityToken;
        this.applicationRoleArn = applicationRoleArn;
        this.workgroupArn = workgroupArn;
        this.customerIdcApplicationArn = customerIdcApplicationArn;
        this.accessRoleArn = accessRoleArn;
        this.roleSessionName = roleSessionName;
        this.roleSessionDuration = roleSessionDuration;
        this.jwtTrustedIdentityProviderUtils = jwtTrustedIdentityProviderUtils;
        this.clock = clock;
        this.driverCache = driverCache;
    }

    @Override
    public AwsCredentials resolveCredentials() {
        boolean needsUpdate;
        boolean bl = needsUpdate = this.credentials == null || this.credentials.expiration().compareTo(this.clock.instant().plusSeconds(EXPIRATION_THRESHOLD_SECS.getSeconds())) < 0;
        if (needsUpdate) {
            this.credentials = this.obtainCredentials();
        }
        return AwsSessionCredentials.create(this.credentials.accessKeyId(), this.credentials.secretAccessKey(), this.credentials.sessionToken());
    }

    private Credentials obtainCredentials() {
        logger.debug("Obtaining credentials from STS", new Object[0]);
        logger.trace("Sending AssumeRoleWithWebIdentity request with ApplicationRoleArn: {}", this.applicationRoleArn);
        Credentials stsCredentials = this.assumeApplicationRoleWithJWT();
        this.resolveCustomerIdcApplication(stsCredentials);
        this.resolveAccessRoleArn(stsCredentials);
        TokenCacheEntry tokenCacheEntry = this.resolveIDCToken(stsCredentials);
        return this.getStsIdentityEnhancedAccessRoleCredentials(stsCredentials, tokenCacheEntry.getStsIdentityContext());
    }

    private Credentials assumeApplicationRoleWithJWT() {
        try (StsClient stsClient = this.jwtTrustedIdentityProviderUtils.createAnonymousStsClient();){
            AssumeRoleWithWebIdentityRequest request = (AssumeRoleWithWebIdentityRequest)AssumeRoleWithWebIdentityRequest.builder().webIdentityToken(this.webIdentityToken).roleArn(this.applicationRoleArn).roleSessionName(this.roleSessionName).build();
            AssumeRoleWithWebIdentityResponse response = stsClient.assumeRoleWithWebIdentity(request);
            logger.info("Obtained application role temporary credentials from STS.", new Object[0]);
            Credentials credentials = response.credentials();
            return credentials;
        }
    }

    private void resolveCustomerIdcApplication(Credentials stsCredentials) {
        if (this.resolvedCustomerIdcApplicationArn != null) {
            return;
        }
        if (this.customerIdcApplicationArn != null) {
            this.resolvedCustomerIdcApplicationArn = this.customerIdcApplicationArn;
            return;
        }
        logger.info("No IDC application ARN specified, attempting to find it by listing the Workgroup Tags", new Object[0]);
        try (AthenaClient athenaClient = this.jwtTrustedIdentityProviderUtils.createAthenaClient(stsCredentials);){
            ListTagsForResourceRequest listTagsForResourceRequest = (ListTagsForResourceRequest)ListTagsForResourceRequest.builder().resourceARN(this.workgroupArn).build();
            Optional<String> tagValue = athenaClient.listTagsForResourcePaginator(listTagsForResourceRequest).stream().flatMap(page -> page.tags().stream()).filter(tag -> tag.key().equals(CUSTOMER_APPLICATION_ARN_TAG)).findFirst().map(software.amazon.awssdk.services.athena.model.Tag::value);
            tagValue.ifPresent(value -> logger.info("Successfully found the customer application ARN: {} from Athena Workgroup", value));
            this.resolvedCustomerIdcApplicationArn = tagValue.orElseThrow(() -> new IllegalArgumentException("Unable to find the customer application ARN from the workgroup tags"));
        }
    }

    private void resolveAccessRoleArn(Credentials stsCredentials) {
        if (this.resolvedAccessRoleArn != null) {
            return;
        }
        if (this.accessRoleArn != null) {
            this.resolvedAccessRoleArn = this.accessRoleArn;
            return;
        }
        logger.info("No access role ARN specified, attempting to find it by listing the tags on the IdC application", new Object[0]);
        try (SsoAdminClient ssoAdminClient = this.jwtTrustedIdentityProviderUtils.createSsoAdminClient(stsCredentials);){
            software.amazon.awssdk.services.ssoadmin.model.ListTagsForResourceRequest listTagsForResourceRequest = (software.amazon.awssdk.services.ssoadmin.model.ListTagsForResourceRequest)software.amazon.awssdk.services.ssoadmin.model.ListTagsForResourceRequest.builder().resourceArn(this.resolvedCustomerIdcApplicationArn).build();
            Optional<String> tagValue = ssoAdminClient.listTagsForResourcePaginator(listTagsForResourceRequest).stream().flatMap(page -> page.tags().stream()).filter(tag -> tag.key().equals(ACCESS_ROLE_ARN_TAG)).findFirst().map(Tag::value);
            tagValue.ifPresent(value -> logger.info("Found an access role ARN from IDC application tags: {}", value));
            this.resolvedAccessRoleArn = tagValue.orElseThrow(() -> new IllegalArgumentException("Unable to find the access role ARN from the IdC application tags"));
        }
    }

    private TokenCacheEntry resolveIDCToken(Credentials stsCredentials) {
        String formattedCachedKey = String.format("%s_%s_%s", this.resolvedCustomerIdcApplicationArn, this.webIdentityToken, this.resolvedAccessRoleArn);
        String cacheKey = Base64.getEncoder().encodeToString(formattedCachedKey.getBytes());
        Optional cacheEntry = this.driverCache.get(cacheKey);
        if (cacheEntry.isPresent()) {
            TokenCacheEntry tokenCacheEntry = (TokenCacheEntry)cacheEntry.get();
            if (!tokenCacheEntry.isExpired(EXPIRATION_THRESHOLD_SECS)) {
                logger.debug("Cached IDC token found.", new Object[0]);
                return tokenCacheEntry;
            }
            logger.debug("The cached token is either expired or about to expire, refreshing it.", new Object[0]);
            TokenCacheEntry updatedCacheEntry = this.getOrRefreshIdcToken(stsCredentials, tokenCacheEntry.getRefreshToken());
            this.driverCache.store(cacheKey, updatedCacheEntry);
            return updatedCacheEntry;
        }
        logger.debug("No cached entry found, trying to fetch a token.", new Object[0]);
        TokenCacheEntry updatedCacheEntry = this.getOrRefreshIdcToken(stsCredentials, null);
        this.driverCache.store(cacheKey, updatedCacheEntry);
        return updatedCacheEntry;
    }

    private TokenCacheEntry getOrRefreshIdcToken(Credentials stsCredentials, String refreshToken) {
        try (SsoOidcClient ssoOidcClient = this.jwtTrustedIdentityProviderUtils.createSsoOidcClient(stsCredentials);){
            CreateTokenWithIamRequest.Builder createTokenWithIamRequestBuilder = CreateTokenWithIamRequest.builder().clientId(this.resolvedCustomerIdcApplicationArn);
            if (refreshToken != null) {
                createTokenWithIamRequestBuilder.grantType(REFRESH_TOKEN_GRANT_TYPE).requestedTokenType(REFRESH_TOKEN_REQUESTED_TYPE).refreshToken(refreshToken);
            } else {
                createTokenWithIamRequestBuilder.grantType(GRANT_TYPE).assertion(this.webIdentityToken);
            }
            CreateTokenWithIamResponse createTokenWithIamResponse = ssoOidcClient.createTokenWithIAM((CreateTokenWithIamRequest)createTokenWithIamRequestBuilder.build());
            logger.info("Successfully created a token after calling IDC.", new Object[0]);
            Instant idcTokenExpirationTime = this.clock.instant().plusSeconds(createTokenWithIamResponse.expiresIn().intValue());
            logger.debug("Retrieved IDC token expiration time: {}", idcTokenExpirationTime);
            TokenCacheEntry tokenCacheEntry = TokenCacheEntry.builder().refreshToken(createTokenWithIamResponse.refreshToken()).idToken(createTokenWithIamResponse.idToken()).expiration(idcTokenExpirationTime).stsIdentityContext(createTokenWithIamResponse.awsAdditionalDetails().identityContext()).build();
            return tokenCacheEntry;
        }
    }

    private Credentials getStsIdentityEnhancedAccessRoleCredentials(Credentials stsCredentials, String stsIdentityContext) {
        try (StsClient stsClient = this.jwtTrustedIdentityProviderUtils.createStsClient(stsCredentials);){
            ProvidedContext providedContext = (ProvidedContext)ProvidedContext.builder().providerArn(PROVIDER_ARN).contextAssertion(stsIdentityContext).build();
            AssumeRoleRequest assumeRoleRequest = (AssumeRoleRequest)AssumeRoleRequest.builder().roleArn(this.resolvedAccessRoleArn).roleSessionName(this.roleSessionName).durationSeconds(this.roleSessionDuration).providedContexts(providedContext).build();
            AssumeRoleResponse assumeRoleResponse = stsClient.assumeRole(assumeRoleRequest);
            logger.info("Successfully obtained Identity Enhanced credentials with access role from STS", new Object[0]);
            Credentials credentials = assumeRoleResponse.credentials();
            return credentials;
        }
    }
}

