15 KiB
15 KiB
Spring Boot Integration with AWS KMS
Configuration
Basic Configuration
@Configuration
public class KmsConfiguration {
@Bean
public KmsClient kmsClient() {
return KmsClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public KmsAsyncClient kmsAsyncClient() {
return KmsAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
Configuration with Custom Settings
@Configuration
@ConfigurationProperties(prefix = "aws.kms")
public class KmsAdvancedConfiguration {
private Region region = Region.US_EAST_1;
private String endpoint;
private Duration timeout = Duration.ofSeconds(10);
private String accessKeyId;
private String secretAccessKey;
@Bean
public KmsClient kmsClient() {
KmsClientBuilder builder = KmsClient.builder()
.region(region)
.overrideConfiguration(c -> c.retryPolicy(RetryPolicy.builder()
.numRetries(3)
.build()));
if (endpoint != null) {
builder.endpointOverride(URI.create(endpoint));
}
// Add credentials if provided
if (accessKeyId != null && secretAccessKey != null) {
builder.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(accessKeyId, secretAccessKey)));
}
return builder.build();
}
// Getters and Setters
public Region getRegion() { return region; }
public void setRegion(Region region) { this.region = region; }
public String getEndpoint() { return endpoint; }
public void setEndpoint(String endpoint) { this.endpoint = endpoint; }
public Duration getTimeout() { return timeout; }
public void setTimeout(Duration timeout) { this.timeout = timeout; }
public String getAccessKeyId() { return accessKeyId; }
public void setAccessKeyId(String accessKeyId) { this.accessKeyId = accessKeyId; }
public String getSecretAccessKey() { return secretAccessKey; }
public void setSecretAccessKey(String secretAccessKey) { this.secretAccessKey = secretAccessKey; }
}
Application Properties
# AWS KMS Configuration
aws.kms.region=us-east-1
aws.kms.endpoint=
aws.kms.timeout=10s
aws.kms.access-key-id=
aws.kms.secret-access-key=
# KMS Key Configuration
kms.encryption-key-id=alias/your-encryption-key
kms.signing-key-id=alias/your-signing-key
Encryption Service
Basic Encryption Service
@Service
public class KmsEncryptionService {
private final KmsClient kmsClient;
@Value("${kms.encryption-key-id}")
private String keyId;
public KmsEncryptionService(KmsClient kmsClient) {
this.kmsClient = kmsClient;
}
public String encrypt(String plaintext) {
try {
EncryptRequest request = EncryptRequest.builder()
.keyId(keyId)
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.build();
EncryptResponse response = kmsClient.encrypt(request);
// Return Base64-encoded ciphertext
return Base64.getEncoder()
.encodeToString(response.ciphertextBlob().asByteArray());
} catch (KmsException e) {
throw new RuntimeException("Encryption failed", e);
}
}
public String decrypt(String ciphertextBase64) {
try {
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
DecryptRequest request = DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
.build();
DecryptResponse response = kmsClient.decrypt(request);
return response.plaintext().asString(StandardCharsets.UTF_8);
} catch (KmsException e) {
throw new RuntimeException("Decryption failed", e);
}
}
}
Secure Data Repository
@Repository
public class SecureDataRepository {
private final KmsEncryptionService encryptionService;
private final JdbcTemplate jdbcTemplate;
public SecureDataRepository(KmsEncryptionService encryptionService,
JdbcTemplate jdbcTemplate) {
this.encryptionService = encryptionService;
this.jdbcTemplate = jdbcTemplate;
}
public void saveSecureData(String id, String sensitiveData) {
String encryptedData = encryptionService.encrypt(sensitiveData);
jdbcTemplate.update(
"INSERT INTO secure_data (id, encrypted_value) VALUES (?, ?)",
id, encryptedData);
}
public String getSecureData(String id) {
String encryptedData = jdbcTemplate.queryForObject(
"SELECT encrypted_value FROM secure_data WHERE id = ?",
String.class, id);
return encryptionService.decrypt(encryptedData);
}
}
Advanced Envelope Encryption Service
@Service
public class EnvelopeEncryptionService {
private final KmsClient kmsClient;
@Value("${kms.master-key-id}")
private String masterKeyId;
private final Cache<String, DataKeyPair> keyCache =
Caffeine.newBuilder()
.expireAfterWrite(1, TimeUnit.HOURS)
.maximumSize(100)
.build();
public EnvelopeEncryptionService(KmsClient kmsClient) {
this.kmsClient = kmsClient;
}
public EncryptedEnvelope encryptLargeData(byte[] data) {
// Check cache for existing key
DataKeyPair dataKeyPair = keyCache.getIfPresent(masterKeyId);
if (dataKeyPair == null) {
// Generate new data key
GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey(
GenerateDataKeyRequest.builder()
.keyId(masterKeyId)
.keySpec(DataKeySpec.AES_256)
.build());
dataKeyPair = new DataKeyPair(
dataKeyResponse.plaintext().asByteArray(),
dataKeyResponse.ciphertextBlob().asByteArray());
// Cache the encrypted key (not plaintext)
keyCache.put(masterKeyId, dataKeyPair);
}
try {
// Encrypt data with plaintext data key
byte[] encryptedData = encryptWithAES(data, dataKeyPair.plaintext());
// Clear plaintext key from memory immediately after use
Arrays.fill(dataKeyPair.plaintext(), (byte) 0);
return new EncryptedEnvelope(encryptedData, dataKeyPair.encrypted());
} catch (Exception e) {
throw new RuntimeException("Envelope encryption failed", e);
}
}
public byte[] decryptLargeData(EncryptedEnvelope envelope) {
// Get data key from cache or decrypt from KMS
DataKeyPair dataKeyPair = keyCache.getIfPresent(masterKeyId);
if (dataKeyPair == null || !Arrays.equals(dataKeyPair.encrypted(), envelope.encryptedKey())) {
// Decrypt data key from KMS
DecryptResponse decryptResponse = kmsClient.decrypt(
DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(envelope.encryptedKey()))
.build());
dataKeyPair = new DataKeyPair(
decryptResponse.plaintext().asByteArray(),
envelope.encryptedKey());
// Cache for future use
keyCache.put(masterKeyId, dataKeyPair);
}
try {
// Decrypt data with plaintext data key
byte[] decryptedData = decryptWithAES(envelope.encryptedData(), dataKeyPair.plaintext());
// Clear plaintext key from memory
Arrays.fill(dataKeyPair.plaintext(), (byte) 0);
return decryptedData;
} catch (Exception e) {
throw new RuntimeException("Envelope decryption failed", e);
}
}
private byte[] encryptWithAES(byte[] data, byte[] key) throws Exception {
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
GCMParameterSpec spec = new GCMParameterSpec(128, key, key.length - 16);
cipher.init(Cipher.ENCRYPT_MODE, keySpec, spec);
return cipher.doFinal(data);
}
private byte[] decryptWithAES(byte[] data, byte[] key) throws Exception {
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
GCMParameterSpec spec = new GCMParameterSpec(128, key, key.length - 16);
cipher.init(Cipher.DECRYPT_MODE, keySpec, spec);
return cipher.doFinal(data);
}
public record DataKeyPair(byte[] plaintext, byte[] encrypted) {}
public record EncryptedEnvelope(byte[] encryptedData, byte[] encryptedKey) {}
}
Data Encryption Interceptor
SQL Encryption Interceptor
public class KmsDataEncryptInterceptor implements StatementInterceptor {
private final KmsEncryptionService encryptionService;
public KmsDataEncryptInterceptor(KmsEncryptionService encryptionService) {
this.encryptionService = encryptionService;
}
@Override
public ResultSet intercept(ResultSet rs, Statement statement, Connection connection) throws SQLException {
return new EncryptingResultSetWrapper(rs, encryptionService);
}
@Override
public void interceptAfterExecution(Statement statement) {
// No-op
}
}
class EncryptingResultSetWrapper implements ResultSet {
private final ResultSet delegate;
private final KmsEncryptionService encryptionService;
public EncryptingResultSetWrapper(ResultSet delegate, KmsEncryptionService encryptionService) {
this.delegate = delegate;
this.encryptionService = encryptionService;
}
@Override
public String getString(String columnLabel) throws SQLException {
String value = delegate.getString(columnLabel);
if (value == null) return null;
// Check if this is an encrypted column
if (isEncryptedColumn(columnLabel)) {
return encryptionService.decrypt(value);
}
return value;
}
private boolean isEncryptedColumn(String columnLabel) {
// Implement logic to identify encrypted columns
return columnLabel.contains("encrypted") || columnLabel.contains("secure");
}
// Delegate other methods to original ResultSet
@Override
public boolean next() throws SQLException {
return delegate.next();
}
// ... other ResultSet method implementations
}
Configuration Profiles
Development Profile
# src/main/resources/application-dev.properties
aws.kms.region=us-east-1
kms.encryption-key-id=alias/dev-encryption-key
logging.level.com.yourcompany=DEBUG
Production Profile
# src/main/resources/application-prod.properties
aws.kms.region=${AWS_REGION:us-east-1}
kms.encryption-key-id=${KMS_ENCRYPTION_KEY_ID:alias/production-encryption-key}
logging.level.com.yourcompany=WARN
spring.cloud.aws.credentials.access-key=${AWS_ACCESS_KEY_ID}
spring.cloud.aws.credentials.secret-key=${AWS_SECRET_ACCESS_KEY}
Test Configuration
@Configuration
@Profile("test")
public class KmsTestConfiguration {
@Bean
@Primary
public KmsClient testKmsClient() {
// Return a mock or test-specific KMS client
return mock(KmsClient.class);
}
@Bean
public KmsEncryptionService testKmsEncryptionService() {
return new KmsEncryptionService(testKmsClient());
}
}
Health Checks and Monitoring
KMS Health Indicator
@Component
public class KmsHealthIndicator implements HealthIndicator {
private final KmsClient kmsClient;
private final String keyId;
public KmsHealthIndicator(KmsClient kmsClient,
@Value("${kms.encryption-key-id}") String keyId) {
this.kmsClient = kmsClient;
this.keyId = keyId;
}
@Override
public Health health() {
try {
// Test KMS connectivity by describing the key
DescribeKeyRequest request = DescribeKeyRequest.builder()
.keyId(keyId)
.build();
DescribeKeyResponse response = kmsClient.describeKey(request);
// Check if key is in a healthy state
KeyState keyState = response.keyMetadata().keyState();
boolean isHealthy = keyState == KeyState.ENABLED;
if (isHealthy) {
return Health.up()
.withDetail("keyId", keyId)
.withDetail("keyState", keyState)
.withDetail("keyArn", response.keyMetadata().arn())
.build();
} else {
return Health.down()
.withDetail("keyId", keyId)
.withDetail("keyState", keyState)
.withDetail("message", "KMS key is not in ENABLED state")
.build();
}
} catch (KmsException e) {
return Health.down()
.withDetail("keyId", keyId)
.withDetail("error", e.awsErrorDetails().errorMessage())
.withDetail("errorCode", e.awsErrorDetails().errorCode())
.build();
}
}
}
Metrics Collection
@Service
public class KmsMetricsCollector {
private final MeterRegistry meterRegistry;
private final KmsClient kmsClient;
private final Counter encryptionCounter;
private final Counter decryptionCounter;
private final Timer encryptionTimer;
private final Timer decryptionTimer;
public KmsMetricsCollector(MeterRegistry meterRegistry, KmsClient kmsClient) {
this.meterRegistry = meterRegistry;
this.kmsClient = kmsClient;
this.encryptionCounter = Counter.builder("kms.encryption.count")
.description("Number of encryption operations")
.register(meterRegistry);
this.decryptionCounter = Counter.builder("kms.decryption.count")
.description("Number of decryption operations")
.register(meterRegistry);
this.encryptionTimer = Timer.builder("kms.encryption.time")
.description("Time taken for encryption operations")
.register(meterRegistry);
this.decryptionTimer = Timer.builder("kms.decryption.time")
.description("Time taken for decryption operations")
.register(meterRegistry);
}
public String encryptWithMetrics(String plaintext) {
encryptionCounter.increment();
return encryptionTimer.record(() -> {
try {
EncryptRequest request = EncryptRequest.builder()
.keyId("your-key-id")
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.build();
EncryptResponse response = kmsClient.encrypt(request);
return Base64.getEncoder().encodeToString(
response.ciphertextBlob().asByteArray());
} catch (KmsException e) {
meterRegistry.counter("kms.encryption.errors")
.increment();
throw e;
}
});
}
}