Initial commit

This commit is contained in:
Zhongwei Li
2025-11-29 18:28:30 +08:00
commit 171acedaa4
220 changed files with 85967 additions and 0 deletions

View File

@@ -0,0 +1,194 @@
---
name: chunking-strategy
description: Implement optimal chunking strategies in RAG systems and document processing pipelines. Use when building retrieval-augmented generation systems, vector databases, or processing large documents that require breaking into semantically meaningful segments for embeddings and search.
allowed-tools: Read, Write, Bash
category: artificial-intelligence
tags: [rag, chunking, vector-search, embeddings, document-processing]
version: 1.0.0
---
# Chunking Strategy for RAG Systems
## Overview
Implement optimal chunking strategies for Retrieval-Augmented Generation (RAG) systems and document processing pipelines. This skill provides a comprehensive framework for breaking large documents into smaller, semantically meaningful segments that preserve context while enabling efficient retrieval and search.
## When to Use
Use this skill when building RAG systems, optimizing vector search performance, implementing document processing pipelines, handling multi-modal content, or performance-tuning existing RAG systems with poor retrieval quality.
## Instructions
### Choose Chunking Strategy
Select appropriate chunking strategy based on document type and use case:
1. **Fixed-Size Chunking** (Level 1)
- Use for simple documents without clear structure
- Start with 512 tokens and 10-20% overlap
- Adjust size based on query type: 256 for factoid, 1024 for analytical
2. **Recursive Character Chunking** (Level 2)
- Use for documents with clear structural boundaries
- Implement hierarchical separators: paragraphs → sentences → words
- Customize separators for document types (HTML, Markdown)
3. **Structure-Aware Chunking** (Level 3)
- Use for structured documents (Markdown, code, tables, PDFs)
- Preserve semantic units: functions, sections, table blocks
- Validate structure preservation post-splitting
4. **Semantic Chunking** (Level 4)
- Use for complex documents with thematic shifts
- Implement embedding-based boundary detection
- Configure similarity threshold (0.8) and buffer size (3-5 sentences)
5. **Advanced Methods** (Level 5)
- Use Late Chunking for long-context embedding models
- Apply Contextual Retrieval for high-precision requirements
- Monitor computational costs vs. retrieval improvements
Reference detailed strategy implementations in [references/strategies.md](references/strategies.md).
### Implement Chunking Pipeline
Follow these steps to implement effective chunking:
1. **Pre-process documents**
- Analyze document structure and content types
- Identify multi-modal content (tables, images, code)
- Assess information density and complexity
2. **Select strategy parameters**
- Choose chunk size based on embedding model context window
- Set overlap percentage (10-20% for most cases)
- Configure strategy-specific parameters
3. **Process and validate**
- Apply chosen chunking strategy
- Validate semantic coherence of chunks
- Test with representative documents
4. **Evaluate and iterate**
- Measure retrieval precision and recall
- Monitor processing latency and resource usage
- Optimize based on specific use case requirements
Reference detailed implementation guidelines in [references/implementation.md](references/implementation.md).
### Evaluate Performance
Use these metrics to evaluate chunking effectiveness:
- **Retrieval Precision**: Fraction of retrieved chunks that are relevant
- **Retrieval Recall**: Fraction of relevant chunks that are retrieved
- **End-to-End Accuracy**: Quality of final RAG responses
- **Processing Time**: Latency impact on overall system
- **Resource Usage**: Memory and computational costs
Reference detailed evaluation framework in [references/evaluation.md](references/evaluation.md).
## Examples
### Basic Fixed-Size Chunking
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Configure for factoid queries
splitter = RecursiveCharacterTextSplitter(
chunk_size=256,
chunk_overlap=25,
length_function=len
)
chunks = splitter.split_documents(documents)
```
### Structure-Aware Code Chunking
```python
def chunk_python_code(code):
"""Split Python code into semantic chunks"""
import ast
tree = ast.parse(code)
chunks = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
chunks.append(ast.get_source_segment(code, node))
return chunks
```
### Semantic Chunking with Embeddings
```python
def semantic_chunk(text, similarity_threshold=0.8):
"""Chunk text based on semantic boundaries"""
sentences = split_into_sentences(text)
embeddings = generate_embeddings(sentences)
chunks = []
current_chunk = [sentences[0]]
for i in range(1, len(sentences)):
similarity = cosine_similarity(embeddings[i-1], embeddings[i])
if similarity < similarity_threshold:
chunks.append(" ".join(current_chunk))
current_chunk = [sentences[i]]
else:
current_chunk.append(sentences[i])
chunks.append(" ".join(current_chunk))
return chunks
```
## Best Practices
### Core Principles
- Balance context preservation with retrieval precision
- Maintain semantic coherence within chunks
- Optimize for embedding model constraints
- Preserve document structure when beneficial
### Implementation Guidelines
- Start simple with fixed-size chunking (512 tokens, 10-20% overlap)
- Test thoroughly with representative documents
- Monitor both accuracy metrics and computational costs
- Iterate based on specific document characteristics
### Common Pitfalls to Avoid
- Over-chunking: Creating too many small, context-poor chunks
- Under-chunking: Missing relevant information due to oversized chunks
- Ignoring document structure and semantic boundaries
- Using one-size-fits-all approach for diverse content types
- Neglecting overlap for boundary-crossing information
## Constraints
### Resource Considerations
- Semantic and contextual methods require significant computational resources
- Late chunking needs long-context embedding models
- Complex strategies increase processing latency
- Monitor memory usage for large document processing
### Quality Requirements
- Validate chunk semantic coherence post-processing
- Test with domain-specific documents before deployment
- Ensure chunks maintain standalone meaning where possible
- Implement proper error handling for edge cases
## References
Reference detailed documentation in the [references/](references/) folder:
- [strategies.md](references/strategies.md) - Detailed strategy implementations
- [implementation.md](references/implementation.md) - Complete implementation guidelines
- [evaluation.md](references/evaluation.md) - Performance evaluation framework
- [tools.md](references/tools.md) - Recommended libraries and frameworks
- [research.md](references/research.md) - Key research papers and findings
- [advanced-strategies.md](references/advanced-strategies.md) - 11 comprehensive chunking methods
- [semantic-methods.md](references/semantic-methods.md) - Semantic and contextual approaches
- [visualization-tools.md](references/visualization-tools.md) - Evaluation and visualization tools

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,904 @@
# Performance Evaluation Framework
This document provides comprehensive methodologies for evaluating chunking strategy performance and effectiveness.
## Evaluation Metrics
### Core Retrieval Metrics
#### Retrieval Precision
Measures the fraction of retrieved chunks that are relevant to the query.
```python
def calculate_precision(retrieved_chunks: List[Dict], relevant_chunks: List[Dict]) -> float:
"""
Calculate retrieval precision
Precision = |Relevant ∩ Retrieved| / |Retrieved|
"""
retrieved_ids = {chunk.get('id') for chunk in retrieved_chunks}
relevant_ids = {chunk.get('id') for chunk in relevant_chunks}
intersection = retrieved_ids & relevant_ids
if not retrieved_ids:
return 0.0
return len(intersection) / len(retrieved_ids)
```
#### Retrieval Recall
Measures the fraction of relevant chunks that are successfully retrieved.
```python
def calculate_recall(retrieved_chunks: List[Dict], relevant_chunks: List[Dict]) -> float:
"""
Calculate retrieval recall
Recall = |Relevant ∩ Retrieved| / |Relevant|
"""
retrieved_ids = {chunk.get('id') for chunk in retrieved_chunks}
relevant_ids = {chunk.get('id') for chunk in relevant_chunks}
intersection = retrieved_ids & relevant_ids
if not relevant_ids:
return 0.0
return len(intersection) / len(relevant_ids)
```
#### F1-Score
Harmonic mean of precision and recall.
```python
def calculate_f1_score(precision: float, recall: float) -> float:
"""
Calculate F1-score
F1 = 2 * (Precision * Recall) / (Precision + Recall)
"""
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
```
### Mean Reciprocal Rank (MRR)
Measures the rank of the first relevant result.
```python
def calculate_mrr(queries: List[Dict], results: List[List[Dict]]) -> float:
"""
Calculate Mean Reciprocal Rank
"""
reciprocal_ranks = []
for query, query_results in zip(queries, results):
relevant_found = False
for rank, result in enumerate(query_results, 1):
if result.get('is_relevant', False):
reciprocal_ranks.append(1.0 / rank)
relevant_found = True
break
if not relevant_found:
reciprocal_ranks.append(0.0)
return sum(reciprocal_ranks) / len(reciprocal_ranks)
```
### Mean Average Precision (MAP)
Considers both precision and the ranking of relevant documents.
```python
def calculate_average_precision(retrieved_chunks: List[Dict], relevant_chunks: List[Dict]) -> float:
"""
Calculate Average Precision for a single query
"""
retrieved_ids = {chunk.get('id') for chunk in retrieved_chunks}
relevant_ids = {chunk.get('id') for chunk in relevant_chunks}
if not relevant_ids:
return 0.0
precisions = []
relevant_count = 0
for rank, chunk in enumerate(retrieved_chunks, 1):
if chunk.get('id') in relevant_ids:
relevant_count += 1
precision_at_rank = relevant_count / rank
precisions.append(precision_at_rank)
return sum(precisions) / len(relevant_ids) if relevant_ids else 0.0
def calculate_map(queries: List[Dict], results: List[List[Dict]]) -> float:
"""
Calculate Mean Average Precision across multiple queries
"""
average_precisions = []
for query, query_results in zip(queries, results):
ap = calculate_average_precision(query_results, query.get('relevant_chunks', []))
average_precisions.append(ap)
return sum(average_precisions) / len(average_precisions) if average_precisions else 0.0
```
### Normalized Discounted Cumulative Gain (NDCG)
Measures ranking quality with emphasis on highly relevant results.
```python
def calculate_dcg(retrieved_chunks: List[Dict]) -> float:
"""
Calculate Discounted Cumulative Gain
"""
dcg = 0.0
for rank, chunk in enumerate(retrieved_chunks, 1):
relevance = chunk.get('relevance_score', 0)
dcg += relevance / np.log2(rank + 1)
return dcg
def calculate_ndcg(retrieved_chunks: List[Dict], ideal_chunks: List[Dict]) -> float:
"""
Calculate Normalized Discounted Cumulative Gain
"""
dcg = calculate_dcg(retrieved_chunks)
idcg = calculate_dcg(ideal_chunks)
if idcg == 0:
return 0.0
return dcg / idcg
```
## End-to-End RAG Evaluation
### Answer Quality Metrics
#### Factual Consistency
Measures how well the generated answer aligns with retrieved chunks.
```python
import spacy
from transformers import pipeline
class FactualConsistencyEvaluator:
def __init__(self):
self.nlp = spacy.load("en_core_web_sm")
self.nli_pipeline = pipeline("text-classification",
model="roberta-large-mnli")
def evaluate_consistency(self, answer: str, retrieved_chunks: List[str]) -> float:
"""
Evaluate factual consistency between answer and retrieved context
"""
if not retrieved_chunks:
return 0.0
# Combine retrieved chunks as context
context = " ".join(retrieved_chunks[:3]) # Use top 3 chunks
# Use Natural Language Inference to check consistency
result = self.nli_pipeline(f"premise: {context} hypothesis: {answer}")
# Extract consistency score (entailment probability)
for item in result:
if item['label'] == 'ENTAILMENT':
return item['score']
elif item['label'] == 'CONTRADICTION':
return 1.0 - item['score']
return 0.5 # Neutral if NLI is inconclusive
```
#### Answer Completeness
Measures how completely the answer addresses the user's query.
```python
def evaluate_completeness(answer: str, query: str, reference_answer: str = None) -> float:
"""
Evaluate answer completeness
"""
# Extract key entities from query
query_entities = extract_entities(query)
answer_entities = extract_entities(answer)
# Calculate entity coverage
if not query_entities:
return 0.5 # Neutral if no entities in query
covered_entities = query_entities & answer_entities
entity_coverage = len(covered_entities) / len(query_entities)
# If reference answer is available, compare against it
if reference_answer:
reference_entities = extract_entities(reference_answer)
answer_reference_overlap = len(answer_entities & reference_entities) / max(len(reference_entities), 1)
return (entity_coverage + answer_reference_overlap) / 2
return entity_coverage
def extract_entities(text: str) -> set:
"""
Extract named entities from text (simplified)
"""
# This would use a proper NER model in practice
import re
# Simple noun phrase extraction as placeholder
noun_phrases = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
return set(noun_phrases)
```
#### Response Relevance
Measures how relevant the answer is to the original query.
```python
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
class RelevanceEvaluator:
def __init__(self, model_name="all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
def evaluate_relevance(self, query: str, answer: str) -> float:
"""
Evaluate semantic relevance between query and answer
"""
# Generate embeddings
query_embedding = self.model.encode([query])
answer_embedding = self.model.encode([answer])
# Calculate cosine similarity
similarity = cosine_similarity(query_embedding, answer_embedding)[0][0]
return float(similarity)
```
## Performance Metrics
### Processing Time
```python
import time
from dataclasses import dataclass
from typing import List, Dict
@dataclass
class PerformanceMetrics:
total_time: float
chunking_time: float
embedding_time: float
search_time: float
generation_time: float
throughput: float # documents per second
class PerformanceProfiler:
def __init__(self):
self.timings = {}
self.start_times = {}
def start_timer(self, operation: str):
self.start_times[operation] = time.time()
def end_timer(self, operation: str):
if operation in self.start_times:
duration = time.time() - self.start_times[operation]
if operation not in self.timings:
self.timings[operation] = []
self.timings[operation].append(duration)
return duration
return 0.0
def get_performance_metrics(self, document_count: int) -> PerformanceMetrics:
total_time = sum(sum(times) for times in self.timings.values())
return PerformanceMetrics(
total_time=total_time,
chunking_time=sum(self.timings.get('chunking', [0])),
embedding_time=sum(self.timings.get('embedding', [0])),
search_time=sum(self.timings.get('search', [0])),
generation_time=sum(self.timings.get('generation', [0])),
throughput=document_count / total_time if total_time > 0 else 0
)
```
### Memory Usage
```python
import psutil
import os
from typing import Dict, List
class MemoryProfiler:
def __init__(self):
self.process = psutil.Process(os.getpid())
self.memory_snapshots = []
def take_memory_snapshot(self, label: str):
"""Take a snapshot of current memory usage"""
memory_info = self.process.memory_info()
memory_mb = memory_info.rss / 1024 / 1024 # Convert to MB
self.memory_snapshots.append({
'label': label,
'memory_mb': memory_mb,
'timestamp': time.time()
})
def get_peak_memory_usage(self) -> float:
"""Get peak memory usage in MB"""
if not self.memory_snapshots:
return 0.0
return max(snapshot['memory_mb'] for snapshot in self.memory_snapshots)
def get_memory_usage_by_operation(self) -> Dict[str, float]:
"""Get memory usage breakdown by operation"""
if not self.memory_snapshots:
return {}
memory_by_op = {}
for i in range(1, len(self.memory_snapshots)):
prev_snapshot = self.memory_snapshots[i-1]
curr_snapshot = self.memory_snapshots[i]
operation = curr_snapshot['label']
memory_delta = curr_snapshot['memory_mb'] - prev_snapshot['memory_mb']
if operation not in memory_by_op:
memory_by_op[operation] = []
memory_by_op[operation].append(memory_delta)
return {op: sum(deltas) for op, deltas in memory_by_op.items()}
```
## Evaluation Datasets
### Standardized Test Sets
#### Question-Answer Pairs
```python
from dataclasses import dataclass
from typing import List, Optional
import json
@dataclass
class EvaluationQuery:
id: str
question: str
reference_answer: Optional[str]
relevant_chunk_ids: List[str]
query_type: str # factoid, analytical, comparative
difficulty: str # easy, medium, hard
domain: str # finance, medical, legal, technical
class EvaluationDataset:
def __init__(self, name: str):
self.name = name
self.queries: List[EvaluationQuery] = []
self.documents: Dict[str, str] = {}
self.chunks: Dict[str, Dict] = {}
def add_query(self, query: EvaluationQuery):
self.queries.append(query)
def add_document(self, doc_id: str, content: str):
self.documents[doc_id] = content
def add_chunk(self, chunk_id: str, content: str, doc_id: str, metadata: Dict):
self.chunks[chunk_id] = {
'id': chunk_id,
'content': content,
'doc_id': doc_id,
'metadata': metadata
}
def save_to_file(self, filepath: str):
data = {
'name': self.name,
'queries': [
{
'id': q.id,
'question': q.question,
'reference_answer': q.reference_answer,
'relevant_chunk_ids': q.relevant_chunk_ids,
'query_type': q.query_type,
'difficulty': q.difficulty,
'domain': q.domain
}
for q in self.queries
],
'documents': self.documents,
'chunks': self.chunks
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
@classmethod
def load_from_file(cls, filepath: str):
with open(filepath, 'r') as f:
data = json.load(f)
dataset = cls(data['name'])
dataset.documents = data['documents']
dataset.chunks = data['chunks']
for q_data in data['queries']:
query = EvaluationQuery(
id=q_data['id'],
question=q_data['question'],
reference_answer=q_data.get('reference_answer'),
relevant_chunk_ids=q_data['relevant_chunk_ids'],
query_type=q_data['query_type'],
difficulty=q_data['difficulty'],
domain=q_data['domain']
)
dataset.add_query(query)
return dataset
```
### Dataset Generation
#### Synthetic Query Generation
```python
import random
from typing import List, Dict
class SyntheticQueryGenerator:
def __init__(self):
self.query_templates = {
'factoid': [
"What is {concept}?",
"When did {event} occur?",
"Who developed {technology}?",
"How many {items} are mentioned?",
"What is the value of {metric}?"
],
'analytical': [
"Compare and contrast {concept1} and {concept2}.",
"Analyze the impact of {concept} on {domain}.",
"What are the advantages and disadvantages of {technology}?",
"Explain the relationship between {concept1} and {concept2}.",
"Evaluate the effectiveness of {approach} for {problem}."
],
'comparative': [
"Which is better: {option1} or {option2}?",
"How does {method1} differ from {method2}?",
"Compare the performance of {system1} and {system2}.",
"What are the key differences between {approach1} and {approach2}?"
]
}
def generate_queries_from_chunks(self, chunks: List[Dict], num_queries: int = 100) -> List[EvaluationQuery]:
"""Generate synthetic queries from document chunks"""
queries = []
# Extract entities and concepts from chunks
entities = self._extract_entities_from_chunks(chunks)
for i in range(num_queries):
query_type = random.choice(['factoid', 'analytical', 'comparative'])
template = random.choice(self.query_templates[query_type])
# Fill template with extracted entities
query_text = self._fill_template(template, entities)
# Find relevant chunks for this query
relevant_chunks = self._find_relevant_chunks(query_text, chunks)
query = EvaluationQuery(
id=f"synthetic_{i}",
question=query_text,
reference_answer=None, # Would need generation model
relevant_chunk_ids=[chunk['id'] for chunk in relevant_chunks],
query_type=query_type,
difficulty=random.choice(['easy', 'medium', 'hard']),
domain='synthetic'
)
queries.append(query)
return queries
def _extract_entities_from_chunks(self, chunks: List[Dict]) -> Dict[str, List[str]]:
"""Extract entities, concepts, and relationships from chunks"""
# This would use proper NER in practice
entities = {
'concepts': [],
'technologies': [],
'methods': [],
'metrics': [],
'events': []
}
for chunk in chunks:
content = chunk['content']
# Simplified entity extraction
words = content.split()
entities['concepts'].extend([word for word in words if len(word) > 6])
entities['technologies'].extend([word for word in words if 'technology' in word.lower()])
entities['methods'].extend([word for word in words if 'method' in word.lower()])
entities['metrics'].extend([word for word in words if '%' in word or '$' in word])
# Remove duplicates and limit
for key in entities:
entities[key] = list(set(entities[key]))[:50]
return entities
def _fill_template(self, template: str, entities: Dict[str, List[str]]) -> str:
"""Fill query template with random entities"""
import re
def replace_placeholder(match):
placeholder = match.group(1)
# Map placeholders to entity types
entity_mapping = {
'concept': 'concepts',
'concept1': 'concepts',
'concept2': 'concepts',
'technology': 'technologies',
'method': 'methods',
'method1': 'methods',
'method2': 'methods',
'metric': 'metrics',
'event': 'events',
'items': 'concepts',
'option1': 'concepts',
'option2': 'concepts',
'approach': 'methods',
'problem': 'concepts',
'domain': 'concepts',
'system1': 'concepts',
'system2': 'concepts'
}
entity_type = entity_mapping.get(placeholder, 'concepts')
available_entities = entities.get(entity_type, ['something'])
if available_entities:
return random.choice(available_entities)
else:
return 'something'
return re.sub(r'\{(\w+)\}', replace_placeholder, template)
def _find_relevant_chunks(self, query: str, chunks: List[Dict], k: int = 3) -> List[Dict]:
"""Find chunks most relevant to the query"""
# Simple keyword matching for synthetic generation
query_words = set(query.lower().split())
chunk_scores = []
for chunk in chunks:
chunk_words = set(chunk['content'].lower().split())
overlap = len(query_words & chunk_words)
chunk_scores.append((overlap, chunk))
# Sort by overlap and return top k
chunk_scores.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in chunk_scores[:k]]
```
## A/B Testing Framework
### Statistical Significance Testing
```python
import numpy as np
from scipy import stats
from typing import List, Dict, Tuple
class ABTestAnalyzer:
def __init__(self):
self.significance_level = 0.05
def compare_metrics(self, control_metrics: List[float],
treatment_metrics: List[float],
metric_name: str) -> Dict:
"""
Compare metrics between control and treatment groups
"""
control_mean = np.mean(control_metrics)
treatment_mean = np.mean(treatment_metrics)
control_std = np.std(control_metrics)
treatment_std = np.std(treatment_metrics)
# Perform t-test
t_statistic, p_value = stats.ttest_ind(control_metrics, treatment_metrics)
# Calculate effect size (Cohen's d)
pooled_std = np.sqrt(((len(control_metrics) - 1) * control_std**2 +
(len(treatment_metrics) - 1) * treatment_std**2) /
(len(control_metrics) + len(treatment_metrics) - 2))
cohens_d = (treatment_mean - control_mean) / pooled_std if pooled_std > 0 else 0
# Determine significance
is_significant = p_value < self.significance_level
return {
'metric_name': metric_name,
'control_mean': control_mean,
'treatment_mean': treatment_mean,
'absolute_difference': treatment_mean - control_mean,
'relative_difference': ((treatment_mean - control_mean) / control_mean * 100) if control_mean != 0 else 0,
'control_std': control_std,
'treatment_std': treatment_std,
't_statistic': t_statistic,
'p_value': p_value,
'is_significant': is_significant,
'effect_size': cohens_d,
'significance_level': self.significance_level
}
def analyze_ab_test_results(self,
control_results: Dict[str, List[float]],
treatment_results: Dict[str, List[float]]) -> Dict:
"""
Analyze A/B test results across multiple metrics
"""
analysis_results = {}
# Ensure both dictionaries have the same keys
all_metrics = set(control_results.keys()) & set(treatment_results.keys())
for metric in all_metrics:
if metric in control_results and metric in treatment_results:
analysis_results[metric] = self.compare_metrics(
control_results[metric],
treatment_results[metric],
metric
)
# Calculate overall summary
significant_improvements = sum(1 for result in analysis_results.values()
if result['is_significant'] and result['relative_difference'] > 0)
significant_degradations = sum(1 for result in analysis_results.values()
if result['is_significant'] and result['relative_difference'] < 0)
analysis_results['summary'] = {
'total_metrics_compared': len(analysis_results),
'significant_improvements': significant_improvements,
'significant_degradations': significant_degradations,
'no_significant_change': len(analysis_results) - significant_improvements - significant_degradations
}
return analysis_results
```
## Automated Evaluation Pipeline
### End-to-End Evaluation
```python
class ChunkingEvaluationPipeline:
def __init__(self, strategies: Dict[str, Any], dataset: EvaluationDataset):
self.strategies = strategies
self.dataset = dataset
self.results = {}
self.profiler = PerformanceProfiler()
self.memory_profiler = MemoryProfiler()
def run_evaluation(self) -> Dict:
"""Run comprehensive evaluation of all strategies"""
evaluation_results = {}
for strategy_name, strategy in self.strategies.items():
print(f"Evaluating strategy: {strategy_name}")
# Reset profilers for each strategy
self.profiler = PerformanceProfiler()
self.memory_profiler = MemoryProfiler()
# Evaluate strategy
strategy_results = self._evaluate_strategy(strategy, strategy_name)
evaluation_results[strategy_name] = strategy_results
# Compare strategies
comparison_results = self._compare_strategies(evaluation_results)
return {
'individual_results': evaluation_results,
'comparison': comparison_results,
'recommendations': self._generate_recommendations(comparison_results)
}
def _evaluate_strategy(self, strategy: Any, strategy_name: str) -> Dict:
"""Evaluate a single chunking strategy"""
results = {
'strategy_name': strategy_name,
'retrieval_metrics': {},
'quality_metrics': {},
'performance_metrics': {}
}
# Track memory usage
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_start")
# Process all documents
self.profiler.start_timer('total_processing')
all_chunks = {}
for doc_id, content in self.dataset.documents.items():
self.profiler.start_timer('chunking')
chunks = strategy.chunk(content)
self.profiler.end_timer('chunking')
all_chunks[doc_id] = chunks
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_after_chunking")
# Generate embeddings for chunks
self.profiler.start_timer('embedding')
chunk_embeddings = self._generate_embeddings(all_chunks)
self.profiler.end_timer('embedding')
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_after_embedding")
# Evaluate retrieval performance
retrieval_results = self._evaluate_retrieval(all_chunks, chunk_embeddings)
results['retrieval_metrics'] = retrieval_results
# Evaluate chunk quality
quality_results = self._evaluate_chunk_quality(all_chunks)
results['quality_metrics'] = quality_results
# Get performance metrics
self.profiler.end_timer('total_processing')
performance_metrics = self.profiler.get_performance_metrics(len(self.dataset.documents))
results['performance_metrics'] = performance_metrics.__dict__
# Get memory metrics
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_end")
results['memory_metrics'] = {
'peak_memory_mb': self.memory_profiler.get_peak_memory_usage(),
'memory_by_operation': self.memory_profiler.get_memory_usage_by_operation()
}
return results
def _evaluate_retrieval(self, all_chunks: Dict, chunk_embeddings: Dict) -> Dict:
"""Evaluate retrieval performance"""
retrieval_metrics = {
'precision': [],
'recall': [],
'f1_score': [],
'mrr': [],
'map': []
}
for query in self.dataset.queries:
# Perform retrieval
self.profiler.start_timer('search')
retrieved_chunks = self._retrieve_chunks(query.question, chunk_embeddings, k=10)
self.profiler.end_timer('search')
# Get relevant chunks for this query
relevant_chunk_ids = set(query.relevant_chunk_ids)
relevant_chunks = [chunk for chunk in retrieved_chunks
if chunk.get('id') in relevant_chunk_ids]
# Calculate metrics
precision = calculate_precision(retrieved_chunks, relevant_chunks)
recall = calculate_recall(retrieved_chunks, relevant_chunks)
f1 = calculate_f1_score(precision, recall)
retrieval_metrics['precision'].append(precision)
retrieval_metrics['recall'].append(recall)
retrieval_metrics['f1_score'].append(f1)
# Calculate averages
return {metric: np.mean(values) for metric, values in retrieval_metrics.items()}
def _evaluate_chunk_quality(self, all_chunks: Dict) -> Dict:
"""Evaluate quality of generated chunks"""
quality_assessor = ChunkQualityAssessor()
quality_scores = []
for doc_id, chunks in all_chunks.items():
# Analyze document
content = self.dataset.documents[doc_id]
analyzer = DocumentAnalyzer()
analysis = analyzer.analyze(content)
# Assess chunk quality
scores = quality_assessor.assess_chunks(chunks, analysis)
quality_scores.append(scores)
# Aggregate quality scores
if quality_scores:
avg_scores = {}
for metric in quality_scores[0].keys():
avg_scores[metric] = np.mean([scores[metric] for scores in quality_scores])
return avg_scores
return {}
def _compare_strategies(self, evaluation_results: Dict) -> Dict:
"""Compare performance across strategies"""
ab_analyzer = ABTestAnalyzer()
comparison = {}
# Compare each metric across strategies
strategy_names = list(evaluation_results.keys())
for i in range(len(strategy_names)):
for j in range(i + 1, len(strategy_names)):
strategy1 = strategy_names[i]
strategy2 = strategy_names[j]
comparison_key = f"{strategy1}_vs_{strategy2}"
comparison[comparison_key] = {}
# Compare retrieval metrics
for metric in ['precision', 'recall', 'f1_score']:
if (metric in evaluation_results[strategy1]['retrieval_metrics'] and
metric in evaluation_results[strategy2]['retrieval_metrics']):
comparison[comparison_key][f"retrieval_{metric}"] = ab_analyzer.compare_metrics(
[evaluation_results[strategy1]['retrieval_metrics'][metric]],
[evaluation_results[strategy2]['retrieval_metrics'][metric]],
f"retrieval_{metric}"
)
return comparison
def _generate_recommendations(self, comparison_results: Dict) -> Dict:
"""Generate recommendations based on evaluation results"""
recommendations = {
'best_overall': None,
'best_for_precision': None,
'best_for_recall': None,
'best_for_performance': None,
'trade_offs': []
}
# This would analyze the comparison results and generate specific recommendations
# Implementation depends on specific use case requirements
return recommendations
def _generate_embeddings(self, all_chunks: Dict) -> Dict:
"""Generate embeddings for all chunks"""
# This would use the actual embedding model
# Placeholder implementation
embeddings = {}
for doc_id, chunks in all_chunks.items():
embeddings[doc_id] = []
for chunk in chunks:
# Generate embedding for chunk content
embedding = np.random.rand(384) # Placeholder
embeddings[doc_id].append({
'chunk': chunk,
'embedding': embedding
})
return embeddings
def _retrieve_chunks(self, query: str, chunk_embeddings: Dict, k: int = 10) -> List[Dict]:
"""Retrieve most relevant chunks for a query"""
# This would use actual similarity search
# Placeholder implementation
all_chunks = []
for doc_embeddings in chunk_embeddings.values():
for chunk_data in doc_embeddings:
all_chunks.append(chunk_data['chunk'])
# Simple random selection as placeholder
selected = random.sample(all_chunks, min(k, len(all_chunks)))
return selected
```
This comprehensive evaluation framework provides the tools needed to thoroughly assess chunking strategies across multiple dimensions: retrieval effectiveness, answer quality, system performance, and statistical significance. The modular design allows for easy extension and customization based on specific requirements and use cases.

View File

@@ -0,0 +1,709 @@
# Complete Implementation Guidelines
This document provides comprehensive implementation guidance for building effective chunking systems.
## System Architecture
### Core Components
```
Document Processor
├── Ingestion Layer
│ ├── Document Type Detection
│ ├── Format Parsing (PDF, HTML, Markdown, etc.)
│ └── Content Extraction
├── Analysis Layer
│ ├── Structure Analysis
│ ├── Content Type Identification
│ └── Complexity Assessment
├── Strategy Selection Layer
│ ├── Rule-based Selection
│ ├── ML-based Prediction
│ └── Adaptive Configuration
├── Chunking Layer
│ ├── Strategy Implementation
│ ├── Parameter Optimization
│ └── Quality Validation
└── Output Layer
├── Chunk Metadata Generation
├── Embedding Integration
└── Storage Preparation
```
## Pre-processing Pipeline
### Document Analysis Framework
```python
from dataclasses import dataclass
from typing import List, Dict, Any
import re
@dataclass
class DocumentAnalysis:
doc_type: str
structure_score: float # 0-1, higher means more structured
complexity_score: float # 0-1, higher means more complex
content_types: List[str]
language: str
estimated_tokens: int
has_multimodal: bool
class DocumentAnalyzer:
def __init__(self):
self.structure_patterns = {
'markdown': [r'^#+\s', r'^\*\*.*\*\*$', r'^\* ', r'^\d+\. '],
'html': [r'<h[1-6]>', r'<p>', r'<div>', r'<table>'],
'latex': [r'\\section', r'\\subsection', r'\\begin\{', r'\\end\{'],
'academic': [r'^\d+\.', r'^\d+\.\d+', r'^[A-Z]\.', r'^Figure \d+']
}
def analyze(self, content: str) -> DocumentAnalysis:
doc_type = self.detect_document_type(content)
structure_score = self.calculate_structure_score(content, doc_type)
complexity_score = self.calculate_complexity_score(content)
content_types = self.identify_content_types(content)
language = self.detect_language(content)
estimated_tokens = self.estimate_tokens(content)
has_multimodal = self.detect_multimodal_content(content)
return DocumentAnalysis(
doc_type=doc_type,
structure_score=structure_score,
complexity_score=complexity_score,
content_types=content_types,
language=language,
estimated_tokens=estimated_tokens,
has_multimodal=has_multimodal
)
def detect_document_type(self, content: str) -> str:
content_lower = content.lower()
if '<html' in content_lower or '<body' in content_lower:
return 'html'
elif '#' in content and '##' in content:
return 'markdown'
elif '\\documentclass' in content_lower or '\\begin{' in content_lower:
return 'latex'
elif any(keyword in content_lower for keyword in ['abstract', 'introduction', 'conclusion', 'references']):
return 'academic'
elif 'def ' in content or 'class ' in content or 'function ' in content_lower:
return 'code'
else:
return 'plain'
def calculate_structure_score(self, content: str, doc_type: str) -> float:
patterns = self.structure_patterns.get(doc_type, [])
if not patterns:
return 0.5 # Default for unstructured content
line_count = len(content.split('\n'))
structured_lines = 0
for line in content.split('\n'):
for pattern in patterns:
if re.search(pattern, line.strip()):
structured_lines += 1
break
return min(structured_lines / max(line_count, 1), 1.0)
def calculate_complexity_score(self, content: str) -> float:
# Factors that increase complexity
avg_sentence_length = self.calculate_avg_sentence_length(content)
vocabulary_richness = self.calculate_vocabulary_richness(content)
nested_structure = self.detect_nested_structure(content)
# Normalize and combine
complexity = (
min(avg_sentence_length / 30, 1.0) * 0.3 +
vocabulary_richness * 0.4 +
nested_structure * 0.3
)
return min(complexity, 1.0)
def identify_content_types(self, content: str) -> List[str]:
types = []
if '```' in content or 'def ' in content or 'function ' in content.lower():
types.append('code')
if '|' in content and '\n' in content:
types.append('tables')
if re.search(r'\!\[.*\]\(.*\)', content):
types.append('images')
if re.search(r'http[s]?://', content):
types.append('links')
if re.search(r'\d+\.\d+', content) or re.search(r'\$\d', content):
types.append('numbers')
return types if types else ['text']
def detect_language(self, content: str) -> str:
# Simple language detection - can be enhanced with proper language detection libraries
if re.search(r'[\u4e00-\u9fff]', content):
return 'chinese'
elif re.search(r'[u0600-\u06ff]', content):
return 'arabic'
elif re.search(r'[u0400-\u04ff]', content):
return 'russian'
else:
return 'english' # Default assumption
def estimate_tokens(self, content: str) -> int:
# Rough estimation - actual tokenization varies by model
word_count = len(content.split())
return int(word_count * 1.3) # Average tokens per word
def detect_multimodal_content(self, content: str) -> bool:
multimodal_indicators = [
r'\!\[.*\]\(.*\)', # Images
r'<iframe', # Embedded content
r'<object', # Embedded objects
r'<embed', # Embedded media
]
return any(re.search(pattern, content) for pattern in multimodal_indicators)
def calculate_avg_sentence_length(self, content: str) -> float:
sentences = re.split(r'[.!?]+', content)
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return 0
return sum(len(s.split()) for s in sentences) / len(sentences)
def calculate_vocabulary_richness(self, content: str) -> float:
words = content.lower().split()
if not words:
return 0
unique_words = set(words)
return len(unique_words) / len(words)
def detect_nested_structure(self, content: str) -> float:
# Detect nested lists, indented content, etc.
lines = content.split('\n')
indented_lines = 0
for line in lines:
if line.strip() and line.startswith(' '):
indented_lines += 1
return indented_lines / max(len(lines), 1)
```
### Strategy Selection Engine
```python
from abc import ABC, abstractmethod
from typing import Dict, Any
class ChunkingStrategy(ABC):
@abstractmethod
def chunk(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
pass
class StrategySelector:
def __init__(self):
self.strategies = {
'fixed_size': FixedSizeStrategy(),
'recursive': RecursiveStrategy(),
'structure_aware': StructureAwareStrategy(),
'semantic': SemanticStrategy(),
'adaptive': AdaptiveStrategy()
}
def select_strategy(self, analysis: DocumentAnalysis) -> str:
# Rule-based selection logic
if analysis.structure_score > 0.8 and analysis.doc_type in ['markdown', 'html', 'latex']:
return 'structure_aware'
elif analysis.complexity_score > 0.7 and analysis.estimated_tokens < 10000:
return 'semantic'
elif analysis.doc_type == 'code':
return 'structure_aware'
elif analysis.structure_score < 0.3:
return 'fixed_size'
elif analysis.complexity_score > 0.5:
return 'recursive'
else:
return 'adaptive'
def get_strategy(self, analysis: DocumentAnalysis) -> ChunkingStrategy:
strategy_name = self.select_strategy(analysis)
return self.strategies[strategy_name]
# Example strategy implementations
class FixedSizeStrategy(ChunkingStrategy):
def __init__(self, default_size=512, default_overlap=50):
self.default_size = default_size
self.default_overlap = default_overlap
def chunk(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
# Adjust parameters based on analysis
if analysis.complexity_score > 0.7:
chunk_size = 1024
elif analysis.complexity_score < 0.3:
chunk_size = 256
else:
chunk_size = self.default_size
overlap = int(chunk_size * 0.1) # 10% overlap
# Implementation here...
return self._fixed_size_chunk(content, chunk_size, overlap)
def _fixed_size_chunk(self, content: str, chunk_size: int, overlap: int) -> List[Dict[str, Any]]:
# Implementation using RecursiveCharacterTextSplitter or custom logic
pass
class AdaptiveStrategy(ChunkingStrategy):
def chunk(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
# Combine multiple strategies based on content characteristics
if analysis.structure_score > 0.6:
# Use structure-aware for structured parts
structured_chunks = self._chunk_structured_parts(content, analysis)
else:
# Use fixed-size for unstructured parts
unstructured_chunks = self._chunk_unstructured_parts(content, analysis)
# Merge and optimize
return self._merge_chunks(structured_chunks + unstructured_chunks)
def _chunk_structured_parts(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
# Implementation for structured content
pass
def _chunk_unstructured_parts(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
# Implementation for unstructured content
pass
def _merge_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Implementation for merging and optimizing chunks
pass
```
## Quality Assurance Framework
### Chunk Quality Metrics
```python
from typing import List, Dict, Any
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
class ChunkQualityAssessor:
def __init__(self):
self.quality_weights = {
'coherence': 0.3,
'completeness': 0.25,
'size_appropriateness': 0.2,
'semantic_similarity': 0.15,
'boundary_quality': 0.1
}
def assess_chunks(self, chunks: List[Dict[str, Any]], analysis: DocumentAnalysis) -> Dict[str, float]:
scores = {}
# Coherence: Do chunks make sense on their own?
scores['coherence'] = self._assess_coherence(chunks)
# Completeness: Do chunks preserve important information?
scores['completeness'] = self._assess_completeness(chunks, analysis)
# Size appropriateness: Are chunks within optimal size range?
scores['size_appropriateness'] = self._assess_size(chunks)
# Semantic similarity: Are chunks thematically consistent?
scores['semantic_similarity'] = self._assess_semantic_consistency(chunks)
# Boundary quality: Are chunk boundaries placed well?
scores['boundary_quality'] = self._assess_boundary_quality(chunks)
# Calculate overall quality score
overall_score = sum(
score * self.quality_weights[metric]
for metric, score in scores.items()
)
scores['overall'] = overall_score
return scores
def _assess_coherence(self, chunks: List[Dict[str, Any]]) -> float:
# Simple heuristic-based coherence assessment
coherence_scores = []
for chunk in chunks:
content = chunk['content']
# Check for complete sentences
sentences = re.split(r'[.!?]+', content)
complete_sentences = sum(1 for s in sentences if s.strip())
coherence = complete_sentences / max(len(sentences), 1)
coherence_scores.append(coherence)
return np.mean(coherence_scores)
def _assess_completeness(self, chunks: List[Dict[str, Any]], analysis: DocumentAnalysis) -> float:
# Check if important structural elements are preserved
if analysis.doc_type in ['markdown', 'html']:
return self._assess_structure_preservation(chunks, analysis)
else:
return self._assess_content_preservation(chunks)
def _assess_structure_preservation(self, chunks: List[Dict[str, Any]], analysis: DocumentAnalysis) -> float:
# Check if headings, lists, and other structural elements are preserved
preserved_elements = 0
total_elements = 0
for chunk in chunks:
content = chunk['content']
# Count preserved structural elements
headings = len(re.findall(r'^#+\s', content, re.MULTILINE))
lists = len(re.findall(r'^\s*[-*+]\s', content, re.MULTILINE))
preserved_elements += headings + lists
total_elements += 1 # Simplified count
return preserved_elements / max(total_elements, 1)
def _assess_content_preservation(self, chunks: List[Dict[str, Any]]) -> float:
# Simple check based on content ratio
total_content = ''.join(chunk['content'] for chunk in chunks)
# This would need comparison with original content
return 0.8 # Placeholder
def _assess_size(self, chunks: List[Dict[str, Any]]) -> float:
optimal_min = 100 # tokens
optimal_max = 1000 # tokens
size_scores = []
for chunk in chunks:
token_count = self._estimate_tokens(chunk['content'])
if optimal_min <= token_count <= optimal_max:
score = 1.0
elif token_count < optimal_min:
score = token_count / optimal_min
else:
score = max(0, 1 - (token_count - optimal_max) / optimal_max)
size_scores.append(score)
return np.mean(size_scores)
def _assess_semantic_consistency(self, chunks: List[Dict[str, Any]]) -> float:
# This would require embedding models for actual implementation
# Placeholder implementation
return 0.7
def _assess_boundary_quality(self, chunks: List[Dict[str, Any]]) -> float:
# Check if boundaries don't split important content
boundary_scores = []
for i, chunk in enumerate(chunks):
content = chunk['content']
# Check for incomplete sentences at boundaries
if not content.strip().endswith(('.', '!', '?', '>', '}')):
boundary_scores.append(0.5)
else:
boundary_scores.append(1.0)
return np.mean(boundary_scores)
def _estimate_tokens(self, content: str) -> int:
# Simple token estimation
return len(content.split()) * 4 // 3 # Rough approximation
```
## Error Handling and Edge Cases
### Robust Error Handling
```python
import logging
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class ChunkingError:
error_type: str
message: str
chunk_index: Optional[int] = None
recovery_action: Optional[str] = None
class ChunkingErrorHandler:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.error_handlers = {
'empty_content': self._handle_empty_content,
'oversized_chunk': self._handle_oversized_chunk,
'encoding_error': self._handle_encoding_error,
'memory_error': self._handle_memory_error,
'structure_parsing_error': self._handle_structure_parsing_error
}
def handle_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
error_type = self._classify_error(error)
handler = self.error_handlers.get(error_type, self._handle_generic_error)
return handler(error, context)
def _classify_error(self, error: Exception) -> str:
if isinstance(error, ValueError) and 'empty' in str(error).lower():
return 'empty_content'
elif isinstance(error, MemoryError):
return 'memory_error'
elif isinstance(error, UnicodeError):
return 'encoding_error'
elif 'too large' in str(error).lower():
return 'oversized_chunk'
elif 'parsing' in str(error).lower():
return 'structure_parsing_error'
else:
return 'generic_error'
def _handle_empty_content(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
self.logger.warning(f"Empty content encountered: {error}")
return ChunkingError(
error_type='empty_content',
message=str(error),
recovery_action='skip_empty_content'
)
def _handle_oversized_chunk(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
self.logger.warning(f"Oversized chunk detected: {error}")
return ChunkingError(
error_type='oversized_chunk',
message=str(error),
chunk_index=context.get('chunk_index'),
recovery_action='reduce_chunk_size'
)
def _handle_encoding_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
self.logger.error(f"Encoding error: {error}")
return ChunkingError(
error_type='encoding_error',
message=str(error),
recovery_action='fallback_encoding'
)
def _handle_memory_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
self.logger.error(f"Memory error during chunking: {error}")
return ChunkingError(
error_type='memory_error',
message=str(error),
recovery_action='process_in_batches'
)
def _handle_structure_parsing_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
self.logger.warning(f"Structure parsing failed: {error}")
return ChunkingError(
error_type='structure_parsing_error',
message=str(error),
recovery_action='fallback_to_fixed_size'
)
def _handle_generic_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
self.logger.error(f"Unexpected error during chunking: {error}")
return ChunkingError(
error_type='generic_error',
message=str(error),
recovery_action='skip_and_continue'
)
```
## Performance Optimization
### Caching and Memoization
```python
import hashlib
import pickle
from functools import lru_cache
from typing import Dict, Any, Optional
import redis
import json
class ChunkingCache:
def __init__(self, redis_url: Optional[str] = None):
if redis_url:
self.redis_client = redis.from_url(redis_url)
else:
self.redis_client = None
self.local_cache = {}
def _generate_cache_key(self, content: str, strategy: str, params: Dict[str, Any]) -> str:
content_hash = hashlib.md5(content.encode()).hexdigest()
params_str = json.dumps(params, sort_keys=True)
params_hash = hashlib.md5(params_str.encode()).hexdigest()
return f"chunking:{strategy}:{content_hash}:{params_hash}"
def get(self, content: str, strategy: str, params: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
cache_key = self._generate_cache_key(content, strategy, params)
# Try local cache first
if cache_key in self.local_cache:
return self.local_cache[cache_key]
# Try Redis cache
if self.redis_client:
try:
cached_data = self.redis_client.get(cache_key)
if cached_data:
chunks = pickle.loads(cached_data)
self.local_cache[cache_key] = chunks # Cache locally too
return chunks
except Exception as e:
logging.warning(f"Redis cache error: {e}")
return None
def set(self, content: str, strategy: str, params: Dict[str, Any], chunks: List[Dict[str, Any]]) -> None:
cache_key = self._generate_cache_key(content, strategy, params)
# Store in local cache
self.local_cache[cache_key] = chunks
# Store in Redis cache
if self.redis_client:
try:
cached_data = pickle.dumps(chunks)
self.redis_client.setex(cache_key, 3600, cached_data) # 1 hour TTL
except Exception as e:
logging.warning(f"Redis cache set error: {e}")
def clear_local_cache(self):
self.local_cache.clear()
def clear_redis_cache(self):
if self.redis_client:
pattern = "chunking:*"
keys = self.redis_client.keys(pattern)
if keys:
self.redis_client.delete(*keys)
```
### Batch Processing
```python
import asyncio
import concurrent.futures
from typing import List, Callable, Any
class BatchChunkingProcessor:
def __init__(self, max_workers: int = 4, batch_size: int = 10):
self.max_workers = max_workers
self.batch_size = batch_size
def process_documents_batch(self, documents: List[str],
chunking_function: Callable[[str], List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
"""Process multiple documents in parallel"""
results = []
# Process in batches to avoid memory issues
for i in range(0, len(documents), self.batch_size):
batch = documents[i:i + self.batch_size]
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_doc = {
executor.submit(chunking_function, doc): doc
for doc in batch
}
batch_results = []
for future in concurrent.futures.as_completed(future_to_doc):
try:
chunks = future.result()
batch_results.append(chunks)
except Exception as e:
logging.error(f"Error processing document: {e}")
batch_results.append([]) # Empty result for failed processing
results.extend(batch_results)
return results
async def process_documents_async(self, documents: List[str],
chunking_function: Callable[[str], List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
"""Process documents asynchronously"""
semaphore = asyncio.Semaphore(self.max_workers)
async def process_single_document(doc: str) -> List[Dict[str, Any]]:
async with semaphore:
# Run the synchronous chunking function in an executor
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, chunking_function, doc)
tasks = [process_single_document(doc) for doc in documents]
return await asyncio.gather(*tasks, return_exceptions=True)
```
## Monitoring and Observability
### Metrics Collection
```python
import time
from dataclasses import dataclass
from typing import Dict, Any, List
from collections import defaultdict
@dataclass
class ChunkingMetrics:
total_documents: int
total_chunks: int
avg_chunk_size: float
processing_time: float
memory_usage: float
error_count: int
strategy_distribution: Dict[str, int]
class MetricsCollector:
def __init__(self):
self.metrics = defaultdict(list)
self.start_time = None
def start_timing(self):
self.start_time = time.time()
def end_timing(self) -> float:
if self.start_time:
duration = time.time() - self.start_time
self.metrics['processing_time'].append(duration)
self.start_time = None
return duration
return 0.0
def record_chunk_count(self, count: int):
self.metrics['chunk_count'].append(count)
def record_chunk_size(self, size: int):
self.metrics['chunk_size'].append(size)
def record_strategy_usage(self, strategy: str):
self.metrics['strategy'][strategy] = self.metrics['strategy'].get(strategy, 0) + 1
def record_error(self, error_type: str):
self.metrics['errors'].append(error_type)
def record_memory_usage(self, memory_mb: float):
self.metrics['memory_usage'].append(memory_mb)
def get_summary(self) -> ChunkingMetrics:
return ChunkingMetrics(
total_documents=len(self.metrics['processing_time']),
total_chunks=sum(self.metrics['chunk_count']),
avg_chunk_size=sum(self.metrics['chunk_size']) / max(len(self.metrics['chunk_size']), 1),
processing_time=sum(self.metrics['processing_time']),
memory_usage=sum(self.metrics['memory_usage']) / max(len(self.metrics['memory_usage']), 1),
error_count=len(self.metrics['errors']),
strategy_distribution=dict(self.metrics['strategy'])
)
def reset(self):
self.metrics.clear()
self.start_time = None
```
This implementation guide provides a comprehensive foundation for building robust, scalable chunking systems that can handle various document types and use cases while maintaining high quality and performance.

View File

@@ -0,0 +1,366 @@
# Key Research Papers and Findings
This document summarizes important research papers and findings related to chunking strategies for RAG systems.
## Seminal Papers
### "Reconstructing Context: Evaluating Advanced Chunking Strategies for RAG" (arXiv:2504.19754)
**Key Findings**:
- Page-level chunking achieved highest average accuracy (0.648) with lowest variance across different query types
- Optimal chunk size varies significantly by document type and query complexity
- Factoid queries perform better with smaller chunks (256-512 tokens)
- Complex analytical queries benefit from larger chunks (1024+ tokens)
**Methodology**:
- Evaluated 7 different chunking strategies across multiple document types
- Tested with both factoid and analytical queries
- Measured end-to-end RAG performance
**Practical Implications**:
- Start with page-level chunking for general-purpose RAG systems
- Adapt chunk size based on expected query patterns
- Consider hybrid approaches for mixed query types
### "Lost in the Middle: How Language Models Use Long Contexts"
**Key Findings**:
- Language models tend to pay more attention to information at the beginning and end of context
- Information in the middle of long contexts is often ignored
- Performance degradation is most severe for centrally located information
**Practical Implications**:
- Place most important information at chunk boundaries
- Consider chunk overlap to ensure important context appears multiple times
- Use ranking to prioritize relevant chunks for inclusion in context
### "Grounded Language Learning in a Simulated 3D World"
**Related Concepts**:
- Importance of grounding text in visual/contextual information
- Multi-modal learning approaches for better understanding
**Relevance to Chunking**:
- Supports contextual chunking approaches that preserve visual/contextual relationships
- Validates importance of maintaining document structure and relationships
## Industry Research
### NVIDIA Research: "Finding the Best Chunking Strategy for Accurate AI Responses"
**Key Findings**:
- Page-level chunking outperformed sentence and paragraph-level approaches
- Fixed-size chunking showed consistent but suboptimal performance
- Semantic chunking provided improvements for complex documents
**Technical Details**:
- Tested chunk sizes from 128 to 2048 tokens
- Evaluated across financial, technical, and legal documents
- Measured both retrieval accuracy and generation quality
**Recommendations**:
- Use 512-1024 token chunks as starting point
- Implement adaptive chunking based on document complexity
- Consider page boundaries as natural chunk separators
### Cohere Research: "Effective Chunking Strategies for RAG"
**Key Findings**:
- Recursive character splitting provides good balance of performance and simplicity
- Document structure awareness improves retrieval by 15-20%
- Overlap of 10-20% provides optimal context preservation
**Methodology**:
- Compared 12 chunking strategies across 6 document types
- Measured retrieval precision, recall, and F1-score
- Tested with both dense and sparse retrieval
**Best Practices Identified**:
- Start with recursive character splitting with 10-20% overlap
- Preserve document structure (headings, lists, tables)
- Customize chunk size based on embedding model context window
### Anthropic: "Contextual Retrieval"
**Key Innovation**:
- Enhance each chunk with LLM-generated contextual information before embedding
- Improves retrieval precision by 25-30% for complex documents
- Particularly effective for technical and academic content
**Implementation Approach**:
1. Split document using traditional methods
2. For each chunk, generate contextual information using LLM
3. Prepend context to chunk before embedding
4. Use hybrid search (dense + sparse) with weighted ranking
**Trade-offs**:
- Significant computational overhead (2-3x processing time)
- Higher embedding storage requirements
- Improved retrieval precision justifies cost for high-value applications
## Algorithmic Advances
### Semantic Chunking Algorithms
#### "Semantic Segmentation of Text Documents"
**Core Idea**: Use cosine similarity between consecutive sentence embeddings to identify natural boundaries.
**Algorithm**:
1. Split document into sentences
2. Generate embeddings for each sentence
3. Calculate similarity between consecutive sentences
4. Create boundaries where similarity drops below threshold
5. Merge short segments with neighbors
**Performance**: 20-30% improvement in retrieval relevance over fixed-size chunking for technical documents.
#### "Hierarchical Semantic Chunking"
**Core Idea**: Multi-level semantic segmentation for document organization.
**Algorithm**:
1. Document-level semantic analysis
2. Section-level boundary detection
3. Paragraph-level segmentation
4. Sentence-level refinement
**Benefits**: Maintains document hierarchy while adapting to semantic structure.
### Advanced Embedding Techniques
#### "Late Chunking: Contextual Chunk Embeddings"
**Core Innovation**: Generate embeddings for entire document first, then create chunk embeddings from token-level embeddings.
**Advantages**:
- Preserves global document context
- Reduces context fragmentation
- Better for documents with complex inter-relationships
**Requirements**:
- Long-context embedding models (8k+ tokens)
- Significant computational resources
- Specialized implementation
#### "Hierarchical Embedding Retrieval"
**Approach**: Create embeddings at multiple granularities (document, section, paragraph, sentence).
**Implementation**:
1. Generate embeddings at each level
2. Store in hierarchical vector database
3. Query at appropriate granularity based on information needs
**Performance**: 15-25% improvement in precision for complex queries.
## Evaluation Methodologies
### Retrieval-Augmented Generation Assessment Frameworks
#### RAGAS Framework
**Metrics**:
- **Faithfulness**: Consistency between generated answer and retrieved context
- **Answer Relevancy**: Relevance of generated answer to the question
- **Context Relevancy**: Relevance of retrieved context to the question
- **Context Recall**: Coverage of relevant information in retrieved context
**Evaluation Process**:
1. Generate questions from document corpus
2. Retrieve relevant chunks using different strategies
3. Generate answers using retrieved chunks
4. Evaluate using automated metrics and human judgment
#### ARES Framework
**Innovation**: Automated evaluation using synthetic questions and LLM-based assessment.
**Key Features**:
- Generates diverse question types (factoid, analytical, comparative)
- Uses LLMs to evaluate answer quality
- Provides scalable evaluation without human annotation
### Benchmark Datasets
#### Natural Questions (NQ)
**Description**: Real user questions from Google Search with relevant Wikipedia passages.
**Relevance**: Natural language queries with authentic relevance judgments.
#### MS MARCO
**Description**: Large-scale passage ranking dataset with real search queries.
**Relevance**: High-quality relevance judgments for passage retrieval.
#### HotpotQA
**Description**: Multi-hop question answering requiring information from multiple documents.
**Relevance**: Tests ability to retrieve and synthesize information from multiple chunks.
## Domain-Specific Research
### Medical Documents
#### "Optimal Chunking for Medical Question Answering"
**Key Findings**:
- Medical terminology requires specialized handling
- Section-based chunking (History, Diagnosis, Treatment) most effective
- Preserving doctor-patient dialogue context crucial
**Recommendations**:
- Use medical-specific tokenizers
- Preserve section headers and structure
- Maintain temporal relationships in medical histories
### Legal Documents
#### "Chunking Strategies for Legal Document Analysis"
**Key Findings**:
- Legal citations and cross-references require special handling
- Contract clause boundaries serve as natural chunk separators
- Case law benefits from hierarchical chunking
**Best Practices**:
- Preserve legal citation structure
- Use clause and section boundaries
- Maintain context for legal definitions and references
### Financial Documents
#### "SEC Filing Chunking for Financial Analysis"
**Key Findings**:
- Table preservation critical for financial data
- XBRL tagging provides natural segmentation
- Risk factors sections benefit from specialized treatment
**Approach**:
- Preserve complete tables when possible
- Use XBRL tags for structured data
- Create specialized chunks for risk sections
## Emerging Trends
### Multi-Modal Chunking
#### "Integrating Text, Tables, and Images in RAG Systems"
**Innovation**: Unified chunking approach for mixed-modal content.
**Approach**:
- Extract and describe images using vision models
- Preserve table structure and relationships
- Create unified embeddings for mixed content
**Results**: 35% improvement in complex document understanding.
### Adaptive Chunking
#### "Machine Learning-Based Chunk Size Optimization"
**Core Idea**: Use ML models to predict optimal chunking parameters.
**Features**:
- Document length and complexity
- Query type distribution
- Embedding model characteristics
- Performance requirements
**Benefits**: Dynamic optimization based on use case and content.
### Real-time Chunking
#### "Streaming Chunking for Live Document Processing"
**Innovation**: Process documents as they become available.
**Techniques**:
- Incremental boundary detection
- Dynamic chunk size adjustment
- Context preservation across chunks
**Applications**: Live news feeds, social media analysis, meeting transcripts.
## Implementation Challenges
### Computational Efficiency
#### "Scalable Chunking for Large Document Collections"
**Challenges**:
- Processing millions of documents efficiently
- Memory usage optimization
- Distributed processing requirements
**Solutions**:
- Batch processing with parallel execution
- Streaming approaches for large documents
- Distributed chunking with load balancing
### Quality Assurance
#### "Evaluating Chunk Quality at Scale"
**Challenges**:
- Automated quality assessment
- Detecting poor chunk boundaries
- Maintaining consistency across document types
**Approaches**:
- Heuristic-based quality metrics
- LLM-based evaluation
- Human-in-the-loop validation
## Future Research Directions
### Context-Aware Chunking
**Open Questions**:
- How to optimally preserve cross-chunk relationships?
- Can we predict chunk quality without human evaluation?
- What is the optimal balance between size and context?
### Domain Adaptation
**Research Areas**:
- Automatic domain detection and adaptation
- Transfer learning across domains
- Zero-shot chunking for new document types
### Evaluation Standards
**Needs**:
- Standardized evaluation benchmarks
- Cross-paper comparison methodologies
- Real-world performance metrics
## Practical Recommendations Based on Research
### Starting Points
1. **For General RAG Systems**: Page-level or recursive character chunking with 512-1024 tokens and 10-20% overlap
2. **For Technical Documents**: Structure-aware chunking with semantic boundary detection
3. **For High-Value Applications**: Contextual retrieval with LLM-generated context
### Evolution Strategy
1. **Begin**: Simple fixed-size chunking (512 tokens)
2. **Improve**: Add document structure awareness
3. **Optimize**: Implement semantic boundaries
4. **Advanced**: Consider contextual retrieval for critical use cases
### Key Success Factors
1. **Match strategy to document type and query patterns**
2. **Preserve document structure when beneficial**
3. **Use overlap to maintain context across boundaries**
4. **Monitor both accuracy and computational costs**
5. **Iterate based on specific use case requirements**
This research foundation provides evidence-based guidance for implementing effective chunking strategies across various domains and use cases.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,423 @@
# Detailed Chunking Strategies
This document provides comprehensive implementation details for all chunking strategies mentioned in the main skill.
## Level 1: Fixed-Size Chunking
### Implementation
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
class FixedSizeChunker:
def __init__(self, chunk_size=512, chunk_overlap=50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
def chunk(self, documents):
return self.splitter.split_documents(documents)
```
### Parameter Recommendations
| Use Case | Chunk Size | Overlap | Rationale |
|----------|------------|---------|-----------|
| Factoid Queries | 256 | 25 | Small chunks for precise answers |
| General Q&A | 512 | 50 | Balanced approach for most cases |
| Analytical Queries | 1024 | 100 | Larger context for complex analysis |
| Code Documentation | 300 | 30 | Preserve code context while maintaining focus |
### Best Practices
- Start with 512 tokens and 10-20% overlap
- Adjust based on embedding model context window
- Use overlap for queries where context might span boundaries
- Monitor token count vs. character count based on model
## Level 2: Recursive Character Chunking
### Implementation
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
class RecursiveChunker:
def __init__(self, chunk_size=512, separators=None):
self.chunk_size = chunk_size
self.separators = separators or ["\n\n", "\n", " ", ""]
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=0,
length_function=len,
separators=self.separators
)
def chunk(self, text):
return self.splitter.create_documents([text])
# Document-specific configurations
def get_chunker_for_document_type(doc_type):
configurations = {
"markdown": ["\n## ", "\n### ", "\n\n", "\n", " ", ""],
"html": ["</div>", "</p>", "\n\n", "\n", " ", ""],
"code": ["\n\n", "\n", " ", ""],
"plain": ["\n\n", "\n", " ", ""]
}
return RecursiveChunker(separators=configurations.get(doc_type, ["\n\n", "\n", " ", ""]))
```
### Customization Guidelines
- **Markdown**: Use headings as primary separators
- **HTML**: Use block-level tags as separators
- **Code**: Preserve function and class boundaries
- **Academic papers**: Prioritize paragraph and section breaks
## Level 3: Structure-Aware Chunking
### Markdown Documents
```python
import markdown
from bs4 import BeautifulSoup
class MarkdownChunker:
def __init__(self, max_chunk_size=512):
self.max_chunk_size = max_chunk_size
def chunk(self, markdown_text):
html = markdown.markdown(markdown_text)
soup = BeautifulSoup(html, 'html.parser')
chunks = []
current_chunk = ""
current_heading = "Introduction"
for element in soup.find_all(['h1', 'h2', 'h3', 'p', 'pre', 'table']):
if element.name.startswith('h'):
if current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"heading": current_heading
})
current_heading = element.get_text().strip()
current_chunk = f"{element}\n"
elif element.name in ['pre', 'table']:
# Preserve code blocks and tables intact
if len(current_chunk) + len(str(element)) > self.max_chunk_size:
if current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"heading": current_heading
})
current_chunk = f"{element}\n"
else:
current_chunk += f"{element}\n"
else:
current_chunk += str(element)
if current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"heading": current_heading
})
return chunks
```
### Code Documents
```python
import ast
import re
class CodeChunker:
def __init__(self, language='python'):
self.language = language
def chunk_python(self, code):
tree = ast.parse(code)
chunks = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
start_line = node.lineno - 1
end_line = node.end_lineno if hasattr(node, 'end_lineno') else start_line + 10
lines = code.split('\n')
chunk_lines = lines[start_line:end_line]
chunks.append('\n'.join(chunk_lines))
return chunks
def chunk_javascript(self, code):
# Use regex for languages without AST parsers
function_pattern = r'(function\s+\w+\s*\([^)]*\)\s*\{[^}]*\})'
class_pattern = r'(class\s+\w+\s*\{[^}]*\})'
patterns = [function_pattern, class_pattern]
chunks = []
for pattern in patterns:
matches = re.finditer(pattern, code, re.MULTILINE | re.DOTALL)
for match in matches:
chunks.append(match.group(1))
return chunks
def chunk(self, code):
if self.language == 'python':
return self.chunk_python(code)
elif self.language == 'javascript':
return self.chunk_javascript(code)
else:
# Fallback to line-based chunking
return self.chunk_by_lines(code)
def chunk_by_lines(self, code, max_lines=50):
lines = code.split('\n')
chunks = []
for i in range(0, len(lines), max_lines):
chunk = '\n'.join(lines[i:i+max_lines])
chunks.append(chunk)
return chunks
```
### Tabular Data
```python
import pandas as pd
class TableChunker:
def __init__(self, max_rows=100, summary_rows=5):
self.max_rows = max_rows
self.summary_rows = summary_rows
def chunk(self, table_data):
if isinstance(table_data, str):
df = pd.read_csv(StringIO(table_data))
else:
df = table_data
chunks = []
if len(df) <= self.max_rows:
# Small table - keep intact
chunks.append({
"type": "full_table",
"content": df.to_string(),
"metadata": {
"rows": len(df),
"columns": len(df.columns)
}
})
else:
# Large table - create summary + chunks
summary = df.head(self.summary_rows)
chunks.append({
"type": "table_summary",
"content": f"Table Summary ({len(df)} rows, {len(df.columns)} columns):\n{summary.to_string()}",
"metadata": {
"total_rows": len(df),
"summary_rows": self.summary_rows,
"columns": list(df.columns)
}
})
# Chunk the remaining data
for i in range(self.summary_rows, len(df), self.max_rows):
chunk_df = df.iloc[i:i+self.max_rows]
chunks.append({
"type": "table_chunk",
"content": f"Rows {i+1}-{min(i+self.max_rows, len(df))}:\n{chunk_df.to_string()}",
"metadata": {
"start_row": i + 1,
"end_row": min(i + self.max_rows, len(df)),
"columns": list(df.columns)
}
})
return chunks
```
## Level 4: Semantic Chunking
### Implementation
```python
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
class SemanticChunker:
def __init__(self, model_name="all-MiniLM-L6-v2", similarity_threshold=0.8, buffer_size=3):
self.model = SentenceTransformer(model_name)
self.similarity_threshold = similarity_threshold
self.buffer_size = buffer_size
def split_into_sentences(self, text):
# Simple sentence splitting - can be enhanced with nltk/spacy
sentences = re.split(r'[.!?]+', text)
return [s.strip() for s in sentences if s.strip()]
def chunk(self, text):
sentences = self.split_into_sentences(text)
if len(sentences) <= self.buffer_size:
return [text]
# Create embeddings
embeddings = self.model.encode(sentences)
chunks = []
current_chunk_sentences = []
for i in range(len(sentences)):
current_chunk_sentences.append(sentences[i])
# Check if we should create a boundary
if i < len(sentences) - 1:
similarity = cosine_similarity(
[embeddings[i]],
[embeddings[i + 1]]
)[0][0]
if similarity < self.similarity_threshold and len(current_chunk_sentences) >= 2:
chunks.append(' '.join(current_chunk_sentences))
current_chunk_sentences = []
# Add remaining sentences
if current_chunk_sentences:
chunks.append(' '.join(current_chunk_sentences))
return chunks
```
### Parameter Tuning
| Parameter | Range | Effect |
|-----------|-------|--------|
| similarity_threshold | 0.5-0.9 | Higher values create more chunks |
| buffer_size | 1-10 | Larger buffers provide more context |
| model_name | Various | Different models for different domains |
### Optimization Tips
- Use domain-specific models for specialized content
- Adjust threshold based on content complexity
- Cache embeddings for repeated processing
- Consider batch processing for large documents
## Level 5: Advanced Contextual Methods
### Late Chunking
```python
import torch
from transformers import AutoTokenizer, AutoModel
class LateChunker:
def __init__(self, model_name="microsoft/DialoGPT-medium"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def chunk(self, text, chunk_size=512):
# Tokenize entire document
tokens = self.tokenizer(text, return_tensors="pt", truncation=False)
# Get token-level embeddings
with torch.no_grad():
outputs = self.model(**tokens, output_hidden_states=True)
token_embeddings = outputs.last_hidden_state[0]
# Create chunk embeddings from token embeddings
chunks = []
for i in range(0, len(token_embeddings), chunk_size):
chunk_tokens = token_embeddings[i:i+chunk_size]
chunk_embedding = torch.mean(chunk_tokens, dim=0)
chunks.append({
"content": self.tokenizer.decode(tokens["input_ids"][0][i:i+chunk_size]),
"embedding": chunk_embedding.numpy()
})
return chunks
```
### Contextual Retrieval
```python
import openai
class ContextualChunker:
def __init__(self, api_key):
self.client = openai.OpenAI(api_key=api_key)
def generate_context(self, chunk, full_document):
prompt = f"""
Given the following document and a chunk from it, provide a brief context
that helps understand the chunk's meaning within the full document.
Document:
{full_document[:2000]}...
Chunk:
{chunk}
Context (max 50 words):
"""
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=100,
temperature=0
)
return response.choices[0].message.content.strip()
def chunk_with_context(self, text, base_chunker):
# First create base chunks
base_chunks = base_chunker.chunk(text)
# Then add context to each chunk
contextualized_chunks = []
for chunk in base_chunks:
context = self.generate_context(chunk.page_content, text)
contextualized_content = f"Context: {context}\n\nContent: {chunk.page_content}"
contextualized_chunks.append({
"content": contextualized_content,
"original_content": chunk.page_content,
"context": context
})
return contextualized_chunks
```
## Performance Considerations
### Computational Cost Analysis
| Strategy | Time Complexity | Space Complexity | Relative Cost |
|----------|-----------------|------------------|---------------|
| Fixed-Size | O(n) | O(n) | Low |
| Recursive | O(n) | O(n) | Low |
| Structure-Aware | O(n log n) | O(n) | Medium |
| Semantic | O(n²) | O(n²) | High |
| Late Chunking | O(n) | O(n) | Very High |
| Contextual | O(n²) | O(n²) | Very High |
### Optimization Strategies
1. **Parallel Processing**: Process chunks concurrently when possible
2. **Caching**: Store embeddings and intermediate results
3. **Batch Operations**: Group similar operations together
4. **Progressive Loading**: Process large documents in streaming fashion
5. **Model Selection**: Choose appropriate models for task complexity

View File

@@ -0,0 +1,867 @@
# Recommended Libraries and Frameworks
This document provides a comprehensive guide to tools, libraries, and frameworks for implementing chunking strategies.
## Core Chunking Libraries
### LangChain
**Overview**: Comprehensive framework for building applications with large language models, includes robust text splitting utilities.
**Installation**:
```bash
pip install langchain langchain-text-splitters
```
**Key Features**:
- Multiple text splitting strategies
- Integration with various document loaders
- Support for different content types (code, markdown, etc.)
- Customizable separators and parameters
**Example Usage**:
```python
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
CharacterTextSplitter,
TokenTextSplitter,
MarkdownTextSplitter,
PythonCodeTextSplitter
)
# Basic recursive splitting
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
chunks = splitter.split_text(large_text)
# Markdown-specific splitting
markdown_splitter = MarkdownTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
# Code-specific splitting
code_splitter = PythonCodeTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
```
**Pros**:
- Well-maintained and actively developed
- Extensive documentation and examples
- Integrates well with other LangChain components
- Supports multiple document types
**Cons**:
- Can be heavy dependency for simple use cases
- Some advanced features require LangChain ecosystem
### LlamaIndex
**Overview**: Data framework for LLM applications with advanced indexing and retrieval capabilities.
**Installation**:
```bash
pip install llama-index
```
**Key Features**:
- Advanced semantic chunking
- Hierarchical indexing
- Context-aware retrieval
- Integration with vector databases
**Example Usage**:
```python
from llama_index.core.node_parser import (
SentenceSplitter,
SemanticSplitterNodeParser
)
from llama_index.core import SimpleDirectoryReader
from llama_index.embeddings.openai import OpenAIEmbedding
# Basic sentence splitting
splitter = SentenceSplitter(
chunk_size=1024,
chunk_overlap=20
)
# Semantic chunking with embeddings
embed_model = OpenAIEmbedding()
semantic_splitter = SemanticSplitterNodeParser(
buffer_size=1,
breakpoint_percentile_threshold=95,
embed_model=embed_model
)
# Load and process documents
documents = SimpleDirectoryReader("./data").load_data()
nodes = semantic_splitter.get_nodes_from_documents(documents)
```
**Pros**:
- Excellent semantic chunking capabilities
- Built for production RAG systems
- Strong vector database integration
- Active community support
**Cons**:
- More complex setup for basic use cases
- Semantic chunking requires embedding model setup
### Unstructured
**Overview**: Open-source library for processing unstructured documents, especially strong with multi-modal content.
**Installation**:
```bash
pip install "unstructured[pdf,png,jpg]"
```
**Key Features**:
- Multi-modal document processing
- Support for PDFs, images, and various formats
- Structure preservation
- Table extraction and processing
**Example Usage**:
```python
from unstructured.partition.auto import partition
from unstructured.chunking.title import chunk_by_title
# Partition document by type
elements = partition(filename="document.pdf")
# Chunk by title/heading structure
chunks = chunk_by_title(
elements,
combine_text_under_n_chars=2000,
max_characters=10000,
new_after_n_chars=1500,
multipage_sections=True
)
# Access chunked content
for chunk in chunks:
print(f"Category: {chunk.category}")
print(f"Content: {chunk.text[:200]}...")
```
**Pros**:
- Excellent for PDF and image processing
- Preserves document structure
- Handles tables and figures well
- Strong multi-modal capabilities
**Cons**:
- Can be slower for large documents
- Requires additional dependencies for some formats
## Text Processing Libraries
### NLTK (Natural Language Toolkit)
**Installation**:
```bash
pip install nltk
```
**Key Features**:
- Sentence tokenization
- Language detection
- Text preprocessing
- Linguistic analysis
**Example Usage**:
```python
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
# Download required data
nltk.download('punkt')
nltk.download('stopwords')
# Sentence and word tokenization
text = "This is a sample sentence. This is another sentence."
sentences = sent_tokenize(text)
words = word_tokenize(text)
# Stop words removal
stop_words = set(stopwords.words('english'))
filtered_words = [word for word in words if word.lower() not in stop_words]
```
### spaCy
**Installation**:
```bash
pip install spacy
python -m spacy download en_core_web_sm
```
**Key Features**:
- Industrial-strength NLP
- Named entity recognition
- Dependency parsing
- Sentence boundary detection
**Example Usage**:
```python
import spacy
# Load language model
nlp = spacy.load("en_core_web_sm")
# Process text
doc = nlp("This is a sample sentence. This is another sentence.")
# Extract sentences
sentences = [sent.text for sent in doc.sents]
# Named entities
entities = [(ent.text, ent.label_) for ent in doc.ents]
# Dependency parsing for better chunking
for token in doc:
print(f"{token.text}: {token.dep_} (head: {token.head.text})")
```
### Sentence Transformers
**Installation**:
```bash
pip install sentence-transformers
```
**Key Features**:
- Pre-trained sentence embeddings
- Semantic similarity calculation
- Multi-lingual support
- Custom model training
**Example Usage**:
```python
from sentence_transformers import SentenceTransformer, util
import numpy as np
# Load pre-trained model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Generate embeddings
sentences = ["This is a sentence.", "This is another sentence."]
embeddings = model.encode(sentences)
# Calculate semantic similarity
similarity = util.cos_sim(embeddings[0], embeddings[1])
# Find semantic boundaries for chunking
def find_semantic_boundaries(text, model, threshold=0.8):
sentences = [s.strip() for s in text.split('.') if s.strip()]
embeddings = model.encode(sentences)
boundaries = [0]
for i in range(1, len(sentences)):
similarity = util.cos_sim(embeddings[i-1], embeddings[i])
if similarity < threshold:
boundaries.append(i)
return boundaries
```
## Vector Databases and Search
### ChromaDB
**Installation**:
```bash
pip install chromadb
```
**Key Features**:
- In-memory and persistent storage
- Built-in embedding functions
- Similarity search
- Metadata filtering
**Example Usage**:
```python
import chromadb
from chromadb.utils import embedding_functions
# Initialize client
client = chromadb.Client()
# Create collection
collection = client.create_collection(
name="document_chunks",
embedding_function=embedding_functions.DefaultEmbeddingFunction()
)
# Add chunks
collection.add(
documents=[chunk["content"] for chunk in chunks],
metadatas=[chunk.get("metadata", {}) for chunk in chunks],
ids=[chunk["id"] for chunk in chunks]
)
# Search
results = collection.query(
query_texts=["What is chunking?"],
n_results=5
)
```
### Pinecone
**Installation**:
```bash
pip install pinecone-client
```
**Key Features**:
- Managed vector database service
- High-performance similarity search
- Metadata filtering
- Scalable infrastructure
**Example Usage**:
```python
import pinecone
from sentence_transformers import SentenceTransformer
# Initialize
pinecone.init(api_key="your-api-key", environment="your-environment")
index_name = "document-chunks"
# Create index if it doesn't exist
if index_name not in pinecone.list_indexes():
pinecone.create_index(
name=index_name,
dimension=384, # Match embedding model
metric="cosine"
)
index = pinecone.Index(index_name)
# Generate embeddings and upsert
model = SentenceTransformer('all-MiniLM-L6-v2')
for chunk in chunks:
embedding = model.encode(chunk["content"])
index.upsert(
vectors=[{
"id": chunk["id"],
"values": embedding.tolist(),
"metadata": chunk.get("metadata", {})
}]
)
# Search
query_embedding = model.encode("search query")
results = index.query(
vector=query_embedding.tolist(),
top_k=5,
include_metadata=True
)
```
### Weaviate
**Installation**:
```bash
pip install weaviate-client
```
**Key Features**:
- GraphQL API
- Hybrid search (dense + sparse)
- Real-time updates
- Schema validation
**Example Usage**:
```python
import weaviate
# Connect to Weaviate
client = weaviate.Client("http://localhost:8080")
# Define schema
client.schema.create_class({
"class": "DocumentChunk",
"description": "A chunk of document content",
"properties": [
{
"name": "content",
"dataType": ["text"]
},
{
"name": "source",
"dataType": ["string"]
}
]
})
# Add data
for chunk in chunks:
client.data_object.create(
data_object={
"content": chunk["content"],
"source": chunk.get("source", "unknown")
},
class_name="DocumentChunk"
)
# Search
results = client.query.get(
"DocumentChunk",
["content", "source"]
).with_near_text({
"concepts": ["search query"]
}).with_limit(5).do()
```
## Evaluation and Testing
### RAGAS
**Installation**:
```bash
pip install ragas
```
**Key Features**:
- RAG evaluation metrics
- Answer quality assessment
- Context relevance measurement
- Faithfulness evaluation
**Example Usage**:
```python
from ragas import evaluate
from ragas.metrics import (
faithfulness,
answer_relevancy,
context_relevancy,
context_recall
)
from datasets import Dataset
# Prepare evaluation data
dataset = Dataset.from_dict({
"question": ["What is chunking?"],
"answer": ["Chunking is the process of breaking large documents into smaller segments"],
"contexts": [["Chunking involves dividing text into manageable pieces for better processing"]],
"ground_truth": ["Chunking is a document processing technique"]
})
# Evaluate
result = evaluate(
dataset=dataset,
metrics=[
faithfulness,
answer_relevancy,
context_relevancy,
context_recall
]
)
print(result)
```
### TruEra (TruLens)
**Installation**:
```bash
pip install trulens trulens-apps
```
**Key Features**:
- LLM application evaluation
- Feedback functions
- Hallucination detection
- Performance monitoring
**Example Usage**:
```python
from trulens.core import TruSession
from trulens.apps.custom import instrument
from trulens.feedback import GroundTruthAgreement
# Initialize session
session = TruSession()
# Define feedback functions
f_groundedness = GroundTruthAgreement(ground_truth)
# Evaluate chunks
@instrument
def chunk_and_query(text, query):
chunks = chunk_function(text)
relevant_chunks = search_function(chunks, query)
answer = generate_function(relevant_chunks, query)
return answer
# Record evaluation
with session:
chunk_and_query("large document text", "what is the main topic?")
```
## Document Processing
### PyPDF2
**Installation**:
```bash
pip install PyPDF2
```
**Key Features**:
- PDF text extraction
- Page manipulation
- Metadata extraction
- Form field processing
**Example Usage**:
```python
import PyPDF2
def extract_text_from_pdf(pdf_path):
text = ""
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text()
return text
# Extract text by page for better chunking
def extract_pages(pdf_path):
pages = []
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for i, page in enumerate(reader.pages):
pages.append({
"page_number": i + 1,
"content": page.extract_text()
})
return pages
```
### python-docx
**Installation**:
```bash
pip install python-docx
```
**Key Features**:
- Microsoft Word document processing
- Paragraph and table extraction
- Style preservation
- Metadata access
**Example Usage**:
```python
from docx import Document
def extract_from_docx(docx_path):
doc = Document(docx_path)
content = []
for paragraph in doc.paragraphs:
if paragraph.text.strip():
content.append({
"type": "paragraph",
"text": paragraph.text,
"style": paragraph.style.name
})
for table in doc.tables:
table_text = []
for row in table.rows:
row_text = [cell.text for cell in row.cells]
table_text.append(" | ".join(row_text))
content.append({
"type": "table",
"text": "\n".join(table_text)
})
return content
```
## Specialized Libraries
### tiktoken (OpenAI)
**Installation**:
```bash
pip install tiktoken
```
**Key Features**:
- Accurate token counting for OpenAI models
- Fast encoding/decoding
- Multiple model support
- Language model specific tokenization
**Example Usage**:
```python
import tiktoken
# Get encoding for specific model
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
# Encode text
tokens = encoding.encode("This is a sample text")
print(f"Token count: {len(tokens)}")
# Decode tokens
text = encoding.decode(tokens)
# Count tokens without full encoding
def count_tokens(text, model="gpt-3.5-turbo"):
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
# Use in chunking
def chunk_by_tokens(text, max_tokens=1000):
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
tokens = encoding.encode(text)
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk_tokens = tokens[i:i + max_tokens]
chunk_text = encoding.decode(chunk_tokens)
chunks.append(chunk_text)
return chunks
```
### PDFMiner
**Installation**:
```bash
pip install pdfminer.six
```
**Key Features**:
- Detailed PDF analysis
- Layout preservation
- Font and style information
- High-precision text extraction
**Example Usage**:
```python
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextContainer
def extract_structured_text(pdf_path):
structured_content = []
for page_layout in extract_pages(pdf_path):
page_content = []
for element in page_layout:
if isinstance(element, LTTextContainer):
text = element.get_text()
font_info = {
"font_size": element.height,
"is_bold": "Bold" in element.fontname,
"x0": element.x0,
"y0": element.y0
}
page_content.append({
"text": text.strip(),
"font_info": font_info
})
structured_content.append({
"page_number": page_layout.pageid,
"content": page_content
})
return structured_content
```
## Performance and Optimization
### Dask
**Installation**:
```bash
pip install dask[complete]
```
**Key Features**:
- Parallel processing
- Out-of-core computation
- Distributed computing
- Integration with pandas
**Example Usage**:
```python
import dask.bag as db
from dask.distributed import Client
# Setup distributed client
client = Client(n_workers=4)
# Parallel chunking of multiple documents
def chunk_document(document):
# Your chunking logic here
return chunk_function(document)
# Process documents in parallel
documents = ["doc1", "doc2", "doc3", ...] # List of document contents
document_bag = db.from_sequence(documents)
# Apply chunking function in parallel
chunked_documents = document_bag.map(chunk_document)
# Compute results
results = chunked_documents.compute()
```
### Ray
**Installation**:
```bash
pip install ray
```
**Key Features**:
- Distributed computing
- Actor model
- Autoscaling
- ML pipeline integration
**Example Usage**:
```python
import ray
# Initialize Ray
ray.init()
@ray.remote
class ChunkingWorker:
def __init__(self, strategy):
self.strategy = strategy
def chunk_documents(self, documents):
results = []
for doc in documents:
chunks = self.strategy.chunk(doc)
results.append(chunks)
return results
# Create workers
workers = [ChunkingWorker.remote(strategy) for _ in range(4)]
# Distribute work
documents_batch = [documents[i::4] for i in range(4)]
futures = [worker.chunk_documents.remote(batch)
for worker, batch in zip(workers, documents_batch)]
# Get results
results = ray.get(futures)
```
## Development and Testing
### pytest
**Installation**:
```bash
pip install pytest pytest-asyncio
```
**Example Tests**:
```python
import pytest
from your_chunking_module import FixedSizeChunker, SemanticChunker
class TestFixedSizeChunker:
def test_chunk_size_respect(self):
chunker = FixedSizeChunker(chunk_size=100, chunk_overlap=10)
text = "word " * 50 # 50 words
chunks = chunker.chunk(text)
for chunk in chunks:
assert len(chunk.split()) <= 100 # Account for word boundaries
def test_overlap_consistency(self):
chunker = FixedSizeChunker(chunk_size=50, chunk_overlap=10)
text = "word " * 30
chunks = chunker.chunk(text)
# Check overlap between consecutive chunks
for i in range(1, len(chunks)):
chunk1_words = set(chunks[i-1].split()[-10:])
chunk2_words = set(chunks[i].split()[:10])
overlap = len(chunk1_words & chunk2_words)
assert overlap >= 5 # Allow some tolerance
@pytest.mark.asyncio
async def test_semantic_chunker():
chunker = SemanticChunker()
text = "First topic sentence. Another sentence about first topic. " \
"Now switching to second topic. More about second topic."
chunks = await chunker.chunk_async(text)
# Should detect topic change and create boundary
assert len(chunks) >= 2
assert "first topic" in chunks[0].lower()
assert "second topic" in chunks[1].lower()
```
### Memory Profiler
**Installation**:
```bash
pip install memory-profiler
```
**Example Usage**:
```python
from memory_profiler import profile
@profile
def chunk_large_document():
chunker = FixedSizeChunker(chunk_size=1000)
large_text = "word " * 100000 # Large document
chunks = chunker.chunk(large_text)
return chunks
# Run with: python -m memory_profiler your_script.py
```
This comprehensive toolset provides everything needed to implement, test, and optimize chunking strategies for various use cases, from simple text processing to production-grade RAG systems.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,302 @@
---
name: prompt-engineering
category: backend
tags: [prompt-engineering, few-shot-learning, chain-of-thought, optimization, templates, system-prompts, llm-performance, ai-patterns]
version: 1.0.0
description: This skill should be used when creating, optimizing, or implementing advanced prompt patterns including few-shot learning, chain-of-thought reasoning, prompt optimization workflows, template systems, and system prompt design. It provides comprehensive frameworks for building production-ready prompts with measurable performance improvements.
---
# Prompt Engineering
This skill provides comprehensive frameworks for creating, optimizing, and implementing advanced prompt patterns that significantly improve LLM performance across various tasks and models.
## When to Use This Skill
Use this skill when:
- Creating new prompts for complex reasoning or analytical tasks
- Optimizing existing prompts for better accuracy or efficiency
- Implementing few-shot learning with strategic example selection
- Designing chain-of-thought reasoning for multi-step problems
- Building reusable prompt templates and systems
- Developing system prompts for consistent model behavior
- Troubleshooting poor prompt performance or failure modes
- Scaling prompt systems for production use cases
## Core Prompt Engineering Patterns
### 1. Few-Shot Learning Implementation
Select examples using semantic similarity and diversity sampling to maximize learning within context window constraints.
#### Example Selection Strategy
- Use `references/few-shot-patterns.md` for comprehensive selection frameworks
- Balance example count (3-5 optimal) with context window limitations
- Include edge cases and boundary conditions in example sets
- Prioritize diverse examples that cover problem space variations
- Order examples from simple to complex for progressive learning
#### Few-Shot Template Structure
```
Example 1 (Basic case):
Input: {representative_input}
Output: {expected_output}
Example 2 (Edge case):
Input: {challenging_input}
Output: {robust_output}
Example 3 (Error case):
Input: {problematic_input}
Output: {corrected_output}
Now handle: {target_input}
```
### 2. Chain-of-Thought Reasoning
Elicit step-by-step reasoning for complex problem-solving through structured thinking patterns.
#### Implementation Patterns
- Reference `references/cot-patterns.md` for detailed reasoning frameworks
- Use "Let's think step by step" for zero-shot CoT initiation
- Provide complete reasoning traces for few-shot CoT demonstrations
- Implement self-consistency by sampling multiple reasoning paths
- Include verification and validation steps in reasoning chains
#### CoT Template Structure
```
Let's approach this step-by-step:
Step 1: {break_down_the_problem}
Analysis: {detailed_reasoning}
Step 2: {identify_key_components}
Analysis: {component_analysis}
Step 3: {synthesize_solution}
Analysis: {solution_justification}
Final Answer: {conclusion_with_confidence}
```
### 3. Prompt Optimization Workflows
Implement iterative refinement processes with measurable performance metrics and systematic A/B testing.
#### Optimization Process
- Use `references/optimization-frameworks.md` for comprehensive optimization strategies
- Measure baseline performance before optimization attempts
- Implement single-variable changes for accurate attribution
- Track metrics: accuracy, consistency, latency, token efficiency
- Use statistical significance testing for A/B validation
- Document optimization iterations and their impacts
#### Performance Metrics Framework
- **Accuracy**: Task completion rate and output correctness
- **Consistency**: Response stability across multiple runs
- **Efficiency**: Token usage and response time optimization
- **Robustness**: Performance across edge cases and variations
- **Safety**: Adherence to guidelines and harm prevention
### 4. Template Systems Architecture
Build modular, reusable prompt components with variable interpolation and conditional sections.
#### Template Design Principles
- Reference `references/template-systems.md` for modular template frameworks
- Use clear variable naming conventions (e.g., `{user_input}`, `{context}`)
- Implement conditional sections for different scenario handling
- Design role-based templates for specific use cases
- Create hierarchical template composition patterns
#### Template Structure Example
```
# System Context
You are a {role} with {expertise_level} expertise in {domain}.
# Task Context
{if background_information}
Background: {background_information}
{endif}
# Instructions
{task_instructions}
# Examples
{example_count}
# Output Format
{output_specification}
# Input
{user_query}
```
### 5. System Prompt Design
Design comprehensive system prompts that establish consistent model behavior, output formats, and safety constraints.
#### System Prompt Components
- Use `references/system-prompt-design.md` for detailed design guidelines
- Define clear role specification and expertise boundaries
- Establish output format requirements and structural constraints
- Include safety guidelines and content policy adherence
- Set context for background information and domain knowledge
#### System Prompt Framework
```
You are an expert {role} specializing in {domain} with {experience_level} of experience.
## Core Capabilities
- List specific capabilities and expertise areas
- Define scope of knowledge and limitations
## Behavioral Guidelines
- Specify interaction style and communication approach
- Define error handling and uncertainty protocols
- Establish quality standards and verification requirements
## Output Requirements
- Specify format expectations and structural requirements
- Define content inclusion and exclusion criteria
- Establish consistency and validation requirements
## Safety and Ethics
- Include content policy adherence
- Specify bias mitigation requirements
- Define harm prevention protocols
```
## Implementation Workflows
### Workflow 1: Create New Prompt from Requirements
1. **Analyze Requirements**
- Identify task complexity and reasoning requirements
- Determine target model capabilities and limitations
- Define success criteria and evaluation metrics
- Assess need for few-shot learning or CoT reasoning
2. **Select Pattern Strategy**
- Use few-shot learning for classification or transformation tasks
- Apply CoT for complex reasoning or multi-step problems
- Implement template systems for reusable prompt architecture
- Design system prompts for consistent behavior requirements
3. **Draft Initial Prompt**
- Structure prompt with clear sections and logical flow
- Include relevant examples or reasoning demonstrations
- Specify output format and quality requirements
- Incorporate safety guidelines and constraints
4. **Validate and Test**
- Test with diverse input scenarios including edge cases
- Measure performance against defined success criteria
- Iterate refinement based on testing results
- Document optimization decisions and their rationale
### Workflow 2: Optimize Existing Prompt
1. **Performance Analysis**
- Measure current prompt performance metrics
- Identify failure modes and error patterns
- Analyze token efficiency and response latency
- Assess consistency across multiple runs
2. **Optimization Strategy**
- Apply systematic A/B testing with single-variable changes
- Use few-shot learning to improve task adherence
- Implement CoT reasoning for complex task components
- Refine template structure for better clarity
3. **Implementation and Testing**
- Deploy optimized prompts with controlled rollout
- Monitor performance metrics in production environment
- Compare against baseline using statistical significance
- Document improvements and lessons learned
### Workflow 3: Scale Prompt Systems
1. **Modular Architecture Design**
- Decompose complex prompts into reusable components
- Create template inheritance hierarchies
- Implement dynamic example selection systems
- Build automated quality assurance frameworks
2. **Production Integration**
- Implement prompt versioning and rollback capabilities
- Create performance monitoring and alerting systems
- Build automated testing frameworks for prompt validation
- Establish update and deployment workflows
## Quality Assurance
### Validation Requirements
- Test prompts with at least 10 diverse scenarios
- Include edge cases, boundary conditions, and failure modes
- Verify output format compliance and structural consistency
- Validate safety guideline adherence and harm prevention
- Measure performance across multiple model runs
### Performance Standards
- Achieve >90% task completion for well-defined use cases
- Maintain <5% variance across multiple runs for consistency
- Optimize token usage without sacrificing accuracy
- Ensure response latency meets application requirements
- Demonstrate robust handling of edge cases and unexpected inputs
## Integration with Other Skills
This skill integrates seamlessly with:
- **langchain4j-ai-services-patterns**: Interface-based prompt design
- **langchain4j-rag-implementation-patterns**: Context-enhanced prompting
- **langchain4j-testing-strategies**: Prompt validation frameworks
- **unit-test-parameterized**: Systematic prompt testing approaches
## Resources and References
- `references/few-shot-patterns.md`: Comprehensive few-shot learning frameworks
- `references/cot-patterns.md`: Chain-of-thought reasoning patterns and examples
- `references/optimization-frameworks.md`: Systematic prompt optimization methodologies
- `references/template-systems.md`: Modular template design and implementation
- `references/system-prompt-design.md`: System prompt architecture and best practices
## Usage Examples
### Example 1: Classification Task with Few-Shot Learning
```
Classify customer feedback into categories using semantic similarity for example selection and diversity sampling for edge case coverage.
```
### Example 2: Complex Reasoning with Chain-of-Thought
```
Implement step-by-step reasoning for financial analysis with verification steps and confidence scoring.
```
### Example 3: Template System for Customer Service
```
Create modular templates with role-based components and conditional sections for different inquiry types.
```
### Example 4: System Prompt for Code Generation
```
Design comprehensive system prompt with behavioral guidelines, output requirements, and safety constraints.
```
## Common Pitfalls and Solutions
- **Overfitting examples**: Use diverse example sets with semantic variety
- **Context window overflow**: Implement strategic example selection and compression
- **Inconsistent outputs**: Specify clear output formats and validation requirements
- **Poor generalization**: Include edge cases and boundary conditions in training examples
- **Safety violations**: Incorporate comprehensive content policies and harm prevention
## Performance Optimization
- Monitor token usage and implement compression strategies
- Use caching for repeated prompt components
- Optimize example selection for maximum learning efficiency
- Implement progressive disclosure for complex prompt systems
- Balance prompt complexity with response quality requirements
This skill provides the foundational patterns and methodologies for building production-ready prompt systems that consistently deliver high performance across diverse use cases and model types.

View File

@@ -0,0 +1,426 @@
# Chain-of-Thought Reasoning Patterns
This reference provides comprehensive frameworks for implementing effective chain-of-thought (CoT) reasoning that improves model performance on complex, multi-step problems.
## Core Principles
### Step-by-Step Reasoning Elicitation
#### Problem Decomposition Strategy
- Break complex problems into manageable sub-problems
- Identify dependencies and relationships between components
- Establish logical flow and sequence of reasoning steps
- Define clear decision points and validation criteria
#### Verification and Validation Integration
- Include self-checking mechanisms at critical junctures
- Implement consistency checks across reasoning steps
- Add confidence scoring for uncertain conclusions
- Provide fallback strategies for ambiguous situations
## Zero-Shot Chain-of-Thought Patterns
### Basic CoT Initiation
```
Let's think step by step to solve this problem:
1. First, I need to understand what the question is asking for
2. Then, I'll identify the key information and constraints
3. Next, I'll consider different approaches to solve it
4. I'll work through the solution methodically
5. Finally, I'll verify my answer makes sense
Problem: {problem_statement}
Step 1: Understanding the question
{analysis}
Step 2: Key information and constraints
{information_analysis}
Step 3: Solution approach
{approach_analysis}
Step 4: Working through the solution
{detailed_solution}
Step 5: Verification
{verification}
Final Answer: {conclusion}
```
### Enhanced CoT with Confidence
```
Let me think through this systematically, breaking down the problem and checking my reasoning at each step.
**Problem**: {problem_description}
**Step 1: Problem Analysis**
- What am I being asked to solve?
- What information is provided?
- What are the constraints?
- My confidence in understanding: {score}/10
**Step 2: Strategy Selection**
- Possible approaches:
1. {approach_1}
2. {approach_2}
3. {approach_3}
- Selected approach: {chosen_approach}
- Rationale: {reasoning_for_choice}
**Step 3: Execution**
- {detailed_step_by_step_solution}
**Step 4: Verification**
- Does the answer make sense?
- Have I addressed all parts of the question?
- Confidence in final answer: {score}/10
**Final Answer**: {solution_with_confidence_score}
```
## Few-Shot Chain-of-Thought Patterns
### Mathematical Reasoning Template
```
Solve the following math problem step by step.
Example 1:
Problem: A store sells apples for $2 each and oranges for $3 each. If John buys 4 apples and 2 oranges, and spends exactly $14, how much does each fruit cost?
Step 1: Set up the equation
Let a = cost of apples, o = cost of oranges
4a + 2o = 14
Step 2: Simplify the equation
Divide both sides by 2: 2a + o = 7
Step 3: Test integer solutions
If a = 2, then 2(2) + o = 7 → o = 3
If a = 3, then 2(3) + o = 7 → o = 1
Step 4: Verify the original problem
For a=2, o=3: 4(2) + 2(3) = 8 + 6 = 14 ✓
For a=3, o=1: 4(3) + 2(1) = 12 + 2 = 14 ✓
Both solutions work, so apples could cost $2 and oranges $3, OR apples $3 and oranges $1.
Example 2:
Problem: Sarah is twice as old as Tom was 3 years ago. In 5 years, Sarah will be 1.5 times Tom's current age. How old are they now?
Step 1: Define variables
Let S = Sarah's current age, T = Tom's current age
Step 2: Set up equations
"Sarah is twice as old as Tom was 3 years ago": S = 2(T - 3)
"In 5 years, Sarah will be 1.5 times Tom's current age": S + 5 = 1.5T
Step 3: Solve the system
From first equation: S = 2T - 6
Substitute into second: (2T - 6) + 5 = 1.5T
2T - 1 = 1.5T
0.5T = 1
T = 2
Step 4: Find Sarah's age
S = 2(2) - 6 = 4 - 6 = -2
Step 5: Check for validity
Negative age is impossible, so I must have made an error.
Let me re-examine the problem...
Actually, let me try again:
S = 2(T - 3) = 2T - 6
S + 5 = 1.5T
Substitute: (2T - 6) + 5 = 1.5T
2T - 1 = 1.5T
0.5T = 1
T = 2 years old
S = 2(2 - 3) = 2(-1) = -2
This gives a negative age. The problem might have no solution with integer ages.
Now solve:
Problem: {math_problem}
Step 1: Define variables
{variable_definitions}
Step 2: Set up equations
{equation_setup}
Step 3: Solve the system
{solution_process}
Step 4: Verify the solution
{verification}
Final Answer: {answer}
```
### Logical Reasoning Template
```
Analyze the logical argument and determine if it's valid.
Example 1:
Premise 1: All birds can fly
Premise 2: Penguins are birds
Conclusion: Therefore, penguins can fly
Step 1: Analyze the structure
This is a syllogism with form:
All A are B
C is A
Therefore, C is B
Step 2: Evaluate premise validity
Premise 1: "All birds can fly" - This is false (penguins, ostriches cannot fly)
Premise 2: "Penguins are birds" - This is true
Step 3: Check logical validity
The logical structure is valid, but since Premise 1 is false, the conclusion may not be true
Step 4: Real-world verification
In reality, penguins cannot fly despite being birds
Conclusion: The argument is logically valid but soundness fails due to false premise
Example 2:
Premise 1: If it rains, then the ground gets wet
Premise 2: It is raining
Conclusion: Therefore, the ground gets wet
Step 1: Analyze the structure
This is modus ponens:
If P, then Q
P
Therefore, Q
Step 2: Evaluate premise validity
Premise 1: "If it rains, then the ground gets wet" - Generally true
Premise 2: "It is raining" - Given as true
Step 3: Check logical validity
Modus ponens is a valid argument form
Step 4: Verify the conclusion
Given the premises, the conclusion follows logically
Conclusion: The argument is both logically valid and sound
Now analyze:
Argument: {logical_argument}
Step 1: Analyze the argument structure
{structure_analysis}
Step 2: Evaluate premise validity
{premise_evaluation}
Step 3: Check logical validity
{validity_check}
Step 4: Verify the conclusion
{conclusion_verification}
Final Assessment: {argument_validity_assessment}
```
## Self-Consistency Techniques
### Multiple Reasoning Paths
```
I'll solve this problem using three different approaches and see which result is most reliable.
**Problem**: {complex_problem}
**Approach 1: Direct Calculation**
{first_approach_reasoning}
Result 1: {result_1}
**Approach 2: Logical Deduction**
{second_approach_reasoning}
Result 2: {result_2}
**Approach 3: Pattern Recognition**
{third_approach_reasoning}
Result 3: {result_3}
**Consistency Analysis:**
- Approach 1 and 2 agree: {yes/no}
- Approach 1 and 3 agree: {yes/no}
- Approach 2 and 3 agree: {yes/no}
**Final Decision:**
{majority_result} appears in {count} out of 3 approaches.
Confidence: {high/medium/low}
Most Likely Answer: {final_answer_with_confidence}
```
### Verification Loop Pattern
```
Let me solve this step by step and verify each step.
**Problem**: {problem_description}
**Step 1: Initial Analysis**
{initial_analysis}
Verification: Does this make sense? {verification_1}
**Step 2: Solution Development**
{solution_development}
Verification: Does this logically follow from step 1? {verification_2}
**Step 3: Result Calculation**
{result_calculation}
Verification: Does this answer the original question? {verification_3}
**Step 4: Cross-Check**
Let me try a different approach to confirm:
{alternative_approach}
Results comparison: {comparison_analysis}
**Final Answer:**
{conclusion_with_verification_status}
```
## Specialized CoT Patterns
### Code Debugging CoT
```
Debug the following code by analyzing it step by step.
**Code:**
{code_snippet}
**Step 1: Understand the Code's Purpose**
{purpose_analysis}
**Step 2: Identify Expected Behavior**
{expected_behavior}
**Step 3: Trace the Execution**
{execution_trace}
**Step 4: Find the Error**
{error_identification}
**Step 5: Propose Fix**
{fix_proposal}
**Step 6: Verify the Fix**
{fix_verification}
**Fixed Code:**
{corrected_code}
```
### Data Analysis CoT
```
Analyze this data systematically to draw meaningful conclusions.
**Data:**
{dataset}
**Step 1: Understand the Data Structure**
{data_structure_analysis}
**Step 2: Identify Patterns and Trends**
{pattern_identification}
**Step 3: Calculate Key Metrics**
{metrics_calculation}
**Step 4: Compare with Benchmarks**
{benchmark_comparison}
**Step 5: Formulate Insights**
{insight_generation}
**Step 6: Validate Conclusions**
{conclusion_validation}
**Key Findings:**
{summary_of_insights}
```
### Creative Problem Solving CoT
```
Generate creative solutions to this challenging problem.
**Problem:**
{creative_problem}
**Step 1: Reframe the Problem**
{problem_reframing}
**Step 2: Brainstorm Multiple Angles**
- Technical approach: {technical_ideas}
- Business approach: {business_ideas}
- User experience approach: {ux_ideas}
- Unconventional approach: {unconventional_ideas}
**Step 3: Evaluate Each Approach**
{approach_evaluation}
**Step 4: Synthesize Best Elements**
{synthesis_process}
**Step 5: Develop Final Solution**
{solution_development}
**Step 6: Test for Feasibility**
{feasibility_testing}
**Recommended Solution:**
{final_creative_solution}
```
## Implementation Guidelines
### When to Use Chain-of-Thought
- **Multi-step problems**: Tasks requiring sequential reasoning
- **Complex calculations**: Mathematical or logical derivations
- **Problem decomposition**: Tasks that benefit from breaking down
- **Verification needs**: When accuracy is critical
- **Educational contexts**: When showing reasoning is valuable
### CoT Effectiveness Factors
- **Problem complexity**: Higher benefit for complex problems
- **Task type**: Mathematical, logical, and analytical tasks benefit most
- **Model capability**: Newer models handle CoT more effectively
- **Context window**: Ensure sufficient space for reasoning steps
- **Output requirements**: Detailed explanations benefit from CoT
### Common Pitfalls to Avoid
- **Over-explaining simple steps**: Keep proportional detail
- **Circular reasoning**: Ensure logical progression
- **Missing verification**: Always include validation steps
- **Inconsistent confidence**: Use realistic confidence scoring
- **Premature conclusions**: Don't jump to answers without full reasoning
## Integration with Other Techniques
### CoT + Few-Shot Learning
- Include reasoning traces in examples
- Show step-by-step problem-solving demonstrations
- Teach verification and self-checking patterns
### CoT + Template Systems
- Embed CoT patterns within structured templates
- Use conditional CoT based on problem complexity
- Implement adaptive reasoning depth
### CoT + Prompt Optimization
- Test different CoT formulations
- Optimize reasoning step granularity
- Balance detail with efficiency
This framework provides comprehensive patterns for implementing effective chain-of-thought reasoning across diverse problem types and applications.

View File

@@ -0,0 +1,273 @@
# Few-Shot Learning Patterns
This reference provides comprehensive frameworks for implementing effective few-shot learning strategies that maximize model performance within context window constraints.
## Core Principles
### Example Selection Strategy
#### Semantic Similarity Selection
- Use embedding similarity to find examples closest to target input
- Cluster similar examples to avoid redundancy
- Select diverse representatives from different semantic regions
- Prioritize examples that cover key variations in problem space
#### Diversity Sampling Approach
- Ensure coverage of different input types and patterns
- Include boundary cases and edge conditions
- Balance simple and complex examples
- Select examples that demonstrate different solution strategies
#### Progressive Complexity Ordering
- Start with simplest, most straightforward examples
- Progress to increasingly complex scenarios
- Include challenging edge cases last
- Use this ordering to build understanding incrementally
## Example Templates
### Classification Tasks
#### Binary Classification Template
```
Classify if the text expresses positive or negative sentiment.
Example 1:
Text: "I love this product! It works exactly as advertised and exceeded my expectations."
Sentiment: Positive
Reasoning: Contains enthusiastic language, positive adjectives, and satisfaction indicators
Example 2:
Text: "The customer service was terrible and the product broke after one day of use."
Sentiment: Negative
Reasoning: Contains negative adjectives, complaint language, and dissatisfaction indicators
Example 3:
Text: "It's okay, nothing special but does the basic job."
Sentiment: Negative
Reasoning: Contains lukewarm language, lack of enthusiasm, minimal positive elements
Now classify:
Text: {input_text}
Sentiment:
Reasoning:
```
#### Multi-Class Classification Template
```
Categorize the customer inquiry into one of: Technical Support, Billing, Sales, or General.
Example 1:
Inquiry: "My account was charged twice for the same subscription this month"
Category: Billing
Key indicators: "charged twice", "subscription", "account", financial terms
Example 2:
Inquiry: "The app keeps crashing when I try to upload files larger than 10MB"
Category: Technical Support
Key indicators: "crashing", "upload files", "technical issue", "error report"
Example 3:
Inquiry: "What are your pricing plans for enterprise customers?"
Category: Sales
Key indicators: "pricing plans", "enterprise", business inquiry, sales question
Now categorize:
Inquiry: {inquiry_text}
Category:
Key indicators:
```
### Transformation Tasks
#### Text Transformation Template
```
Convert formal business text into casual, friendly language.
Example 1:
Formal: "We regret to inform you that your request cannot be processed at this time due to insufficient documentation."
Casual: "Sorry, but we can't process your request right now because some documents are missing."
Example 2:
Formal: "The aforementioned individual has demonstrated exceptional proficiency in the designated responsibilities."
Casual: "They've done a great job with their tasks and really know what they're doing."
Example 3:
Formal: "Please be advised that the scheduled meeting has been postponed pending further notice."
Casual: "Hey, just letting you know that we've put off the meeting for now and will let you know when it's rescheduled."
Now convert:
Formal: {formal_text}
Casual:
```
#### Data Extraction Template
```
Extract key information from the job posting into structured format.
Example 1:
Job Posting: "We are seeking a Senior Software Engineer with 5+ years of experience in Python and cloud technologies. This is a remote position offering $120k-$150k salary plus equity."
Extracted:
- Position: Senior Software Engineer
- Experience Required: 5+ years
- Skills: Python, cloud technologies
- Location: Remote
- Salary: $120k-$150k plus equity
Example 2:
Job Posting: "Marketing Manager needed for growing startup. Must have 3 years experience in digital marketing, social media management, and content creation. San Francisco office, competitive compensation."
Extracted:
- Position: Marketing Manager
- Experience Required: 3 years
- Skills: Digital marketing, social media management, content creation
- Location: San Francisco
- Salary: Competitive compensation
Now extract:
Job Posting: {job_posting_text}
Extracted:
```
### Generation Tasks
#### Creative Writing Template
```
Generate compelling product descriptions following the shown patterns.
Example 1:
Product: Wireless headphones with noise cancellation
Description: "Immerse yourself in crystal-clear audio with our premium wireless headphones. Advanced noise cancellation technology blocks out distractions while 30-hour battery life keeps you connected all day long."
Example 2:
Product: Smart home security camera
Description: "Protect what matters most with intelligent monitoring that alerts you to activity instantly. AI-powered detection distinguishes between people, pets, and vehicles for truly smart security."
Example 3:
Product: Portable espresso maker
Description: "Barista-quality espresso anywhere, anytime. Compact design meets professional-grade extraction in this revolutionary portable machine that delivers perfect shots in under 30 seconds."
Now generate:
Product: {product_description}
Description:
```
### Error Correction Patterns
#### Error Detection and Correction Template
```
Identify and correct errors in the given text.
Example 1:
Text with errors: "Their going to the park to play there new game with they're friends."
Correction: "They're going to the park to play their new game with their friends."
Errors fixed: "Their → They're", "there → their", "they're → their"
Example 2:
Text with errors: "The company's new policy effects every employee and there morale."
Correction: "The company's new policy affects every employee and their morale."
Errors fixed: "effects → affects", "there → their"
Example 3:
Text with errors: "Its important to review you're work carefully before submiting."
Correction: "It's important to review your work carefully before submitting."
Errors fixed: "Its → It's", "you're → your", "submiting → submitting"
Now correct:
Text with errors: {text_with_errors}
Correction:
Errors fixed:
```
## Advanced Strategies
### Dynamic Example Selection
#### Context-Aware Selection
```python
def select_examples(input_text, example_database, max_examples=3):
"""
Select most relevant examples based on semantic similarity and diversity.
"""
# 1. Calculate similarity scores
similarities = calculate_similarity(input_text, example_database)
# 2. Sort by similarity
sorted_examples = sort_by_similarity(similarities)
# 3. Apply diversity sampling
diverse_examples = diversity_sampling(sorted_examples, max_examples)
# 4. Order by complexity
final_examples = order_by_complexity(diverse_examples)
return final_examples
```
#### Adaptive Example Count
```python
def determine_example_count(input_complexity, context_limit):
"""
Adjust example count based on input complexity and available context.
"""
base_count = 3
# Complex inputs benefit from more examples
if input_complexity > 0.8:
return min(base_count + 2, context_limit)
elif input_complexity > 0.5:
return base_count + 1
else:
return max(base_count - 1, 2)
```
### Quality Metrics for Examples
#### Example Effectiveness Scoring
```python
def score_example_effectiveness(example, test_cases):
"""
Score how effectively an example teaches the desired pattern.
"""
metrics = {
'coverage': measure_pattern_coverage(example),
'clarity': measure_instructional_clarity(example),
'uniqueness': measure_uniqueness_from_other_examples(example),
'difficulty': measure_appropriateness_difficulty(example)
}
return weighted_average(metrics, weights=[0.3, 0.3, 0.2, 0.2])
```
## Best Practices
### Example Quality Guidelines
- **Clarity**: Examples should clearly demonstrate the desired pattern
- **Accuracy**: Input-output pairs must be correct and consistent
- **Relevance**: Examples should be representative of target task
- **Diversity**: Include variation in input types and complexity levels
- **Completeness**: Cover edge cases and boundary conditions
### Context Management
- **Token Efficiency**: Optimize example length while maintaining clarity
- **Progressive Disclosure**: Start simple, increase complexity gradually
- **Redundancy Elimination**: Remove overlapping or duplicate examples
- **Compression**: Use concise representations where possible
### Common Pitfalls to Avoid
- **Overfitting**: Don't include too many examples from same pattern
- **Under-representation**: Ensure coverage of important variations
- **Ambiguity**: Examples should have clear, unambiguous solutions
- **Context Overflow**: Balance example count with window limitations
- **Poor Ordering**: Place examples in logical progression order
## Integration with Other Patterns
Few-shot learning combines effectively with:
- **Chain-of-Thought**: Add reasoning steps to examples
- **Template Systems**: Use few-shot within structured templates
- **Prompt Optimization**: Test different example selections
- **System Prompts**: Establish few-shot learning expectations in system prompts
This framework provides the foundation for implementing effective few-shot learning across diverse tasks and model types.

View File

@@ -0,0 +1,488 @@
# Prompt Optimization Frameworks
This reference provides systematic methodologies for iteratively improving prompt performance through structured testing, measurement, and refinement processes.
## Optimization Process Overview
### Iterative Improvement Cycle
```mermaid
graph TD
A[Baseline Measurement] --> B[Hypothesis Generation]
B --> C[Controlled Test]
C --> D[Performance Analysis]
D --> E[Statistical Validation]
E --> F[Implementation Decision]
F --> G[Monitor Impact]
G --> H[Learn & Iterate]
H --> B
```
### Core Optimization Principles
- **Single Variable Testing**: Change one element at a time for accurate attribution
- **Measurable Metrics**: Define quantitative success criteria
- **Statistical Significance**: Use proper sample sizes and validation methods
- **Controlled Environment**: Test conditions must be consistent
- **Baseline Comparison**: Always measure against established baseline
## Performance Metrics Framework
### Primary Metrics
#### Task Success Rate
```python
def calculate_success_rate(results, expected_outputs):
"""
Measure percentage of tasks completed correctly.
"""
correct = sum(1 for result, expected in zip(results, expected_outputs)
if result == expected)
return (correct / len(results)) * 100
```
#### Response Consistency
```python
def measure_consistency(prompt, test_cases, num_runs=5):
"""
Measure response stability across multiple runs.
"""
responses = {}
for test_case in test_cases:
test_responses = []
for _ in range(num_runs):
response = execute_prompt(prompt, test_case)
test_responses.append(response)
# Calculate similarity score for consistency
consistency = calculate_similarity(test_responses)
responses[test_case] = consistency
return sum(responses.values()) / len(responses)
```
#### Token Efficiency
```python
def calculate_token_efficiency(prompt, test_cases):
"""
Measure token usage per successful task completion.
"""
total_tokens = 0
successful_tasks = 0
for test_case in test_cases:
response = execute_prompt_with_metrics(prompt, test_case)
total_tokens += response.token_count
if response.is_successful:
successful_tasks += 1
return total_tokens / successful_tasks if successful_tasks > 0 else float('inf')
```
#### Response Latency
```python
def measure_response_time(prompt, test_cases):
"""
Measure average response time.
"""
times = []
for test_case in test_cases:
start_time = time.time()
execute_prompt(prompt, test_case)
end_time = time.time()
times.append(end_time - start_time)
return sum(times) / len(times)
```
### Secondary Metrics
#### Output Quality Score
```python
def assess_output_quality(response, criteria):
"""
Multi-dimensional quality assessment.
"""
scores = {
'accuracy': measure_accuracy(response),
'completeness': measure_completeness(response),
'coherence': measure_coherence(response),
'relevance': measure_relevance(response),
'format_compliance': measure_format_compliance(response)
}
weights = [0.3, 0.2, 0.2, 0.2, 0.1]
return sum(score * weight for score, weight in zip(scores.values(), weights))
```
#### Safety Compliance
```python
def check_safety_compliance(response):
"""
Measure adherence to safety guidelines.
"""
violations = []
# Check for various safety issues
if contains_harmful_content(response):
violations.append('harmful_content')
if has_bias(response):
violations.append('bias')
if violates_privacy(response):
violations.append('privacy_violation')
safety_score = max(0, 100 - len(violations) * 25)
return safety_score, violations
```
## A/B Testing Methodology
### Controlled Test Design
```python
def design_ab_test(baseline_prompt, variant_prompt, test_cases):
"""
Design controlled A/B test with proper statistical power.
"""
# Calculate required sample size
effect_size = estimate_effect_size(baseline_prompt, variant_prompt)
sample_size = calculate_sample_size(effect_size, power=0.8, alpha=0.05)
# Random assignment
randomized_cases = random.sample(test_cases, sample_size)
split_point = len(randomized_cases) // 2
group_a = randomized_cases[:split_point]
group_b = randomized_cases[split_point:]
return {
'baseline_group': group_a,
'variant_group': group_b,
'sample_size': sample_size,
'statistical_power': 0.8,
'significance_level': 0.05
}
```
### Statistical Analysis
```python
def analyze_ab_results(baseline_results, variant_results):
"""
Perform statistical analysis of A/B test results.
"""
# Calculate means and standard deviations
baseline_mean = np.mean(baseline_results)
variant_mean = np.mean(variant_results)
baseline_std = np.std(baseline_results)
variant_std = np.std(variant_results)
# Perform t-test
t_statistic, p_value = stats.ttest_ind(baseline_results, variant_results)
# Calculate effect size (Cohen's d)
pooled_std = np.sqrt(((len(baseline_results) - 1) * baseline_std**2 +
(len(variant_results) - 1) * variant_std**2) /
(len(baseline_results) + len(variant_results) - 2))
cohens_d = (variant_mean - baseline_mean) / pooled_std
return {
'baseline_mean': baseline_mean,
'variant_mean': variant_mean,
'improvement': ((variant_mean - baseline_mean) / baseline_mean) * 100,
'p_value': p_value,
'statistical_significance': p_value < 0.05,
'effect_size': cohens_d,
'recommendation': 'implement_variant' if p_value < 0.05 and cohens_d > 0.2 else 'keep_baseline'
}
```
## Optimization Strategies
### Strategy 1: Progressive Enhancement
#### Stepwise Improvement Process
```python
def progressive_optimization(base_prompt, test_cases, max_iterations=10):
"""
Incrementally improve prompt through systematic testing.
"""
current_prompt = base_prompt
current_performance = evaluate_prompt(current_prompt, test_cases)
optimization_history = []
for iteration in range(max_iterations):
# Generate improvement hypotheses
hypotheses = generate_improvement_hypotheses(current_prompt, current_performance)
best_improvement = None
best_performance = current_performance
for hypothesis in hypotheses:
# Test hypothesis
test_prompt = apply_hypothesis(current_prompt, hypothesis)
test_performance = evaluate_prompt(test_prompt, test_cases)
# Validate improvement
if is_statistically_significant(current_performance, test_performance):
if test_performance.overall_score > best_performance.overall_score:
best_improvement = hypothesis
best_performance = test_performance
# Apply best improvement if found
if best_improvement:
current_prompt = apply_hypothesis(current_prompt, best_improvement)
optimization_history.append({
'iteration': iteration,
'hypothesis': best_improvement,
'performance_before': current_performance,
'performance_after': best_performance,
'improvement': best_performance.overall_score - current_performance.overall_score
})
current_performance = best_performance
else:
break # No further improvements found
return current_prompt, optimization_history
```
### Strategy 2: Multi-Objective Optimization
#### Pareto Optimization Framework
```python
def multi_objective_optimization(prompt_variants, objectives):
"""
Optimize for multiple competing objectives using Pareto efficiency.
"""
results = []
for variant in prompt_variants:
scores = {}
for objective in objectives:
scores[objective] = evaluate_objective(variant, objective)
results.append({
'prompt': variant,
'scores': scores,
'dominates': []
})
# Find Pareto optimal solutions
pareto_optimal = []
for i, result_i in enumerate(results):
is_dominated = False
for j, result_j in enumerate(results):
if i != j and dominates(result_j, result_i):
is_dominated = True
break
if not is_dominated:
pareto_optimal.append(result_i)
return pareto_optimal
def dominates(result_a, result_b):
"""
Check if result_a dominates result_b in all objectives.
"""
return all(result_a['scores'][obj] >= result_b['scores'][obj]
for obj in result_a['scores'])
```
### Strategy 3: Adaptive Testing
#### Dynamic Test Allocation
```python
def adaptive_testing(prompt_variants, initial_budget):
"""
Dynamically allocate testing budget to promising variants.
"""
# Initial exploration phase
exploration_results = {}
budget分配 = initial_budget // len(prompt_variants)
for variant in prompt_variants:
exploration_results[variant] = test_prompt(variant, budget分配)
# Exploitation phase - allocate more budget to promising variants
total_budget_spent = len(prompt_variants) * budget分配
remaining_budget = initial_budget - total_budget_spent
# Sort by performance
sorted_variants = sorted(exploration_results.items(),
key=lambda x: x[1].overall_score, reverse=True)
# Allocate remaining budget proportionally to performance
final_results = {}
for i, (variant, initial_result) in enumerate(sorted_variants):
if remaining_budget > 0:
additional_budget = max(1, remaining_budget // (len(sorted_variants) - i))
final_results[variant] = test_prompt(variant, additional_budget)
remaining_budget -= additional_budget
else:
final_results[variant] = initial_result
return final_results
```
## Optimization Hypotheses
### Common Optimization Areas
#### Instruction Clarity
```python
instruction_clarity_hypotheses = [
"Add numbered steps to instructions",
"Include specific output format examples",
"Clarify role and expertise level",
"Add context and background information",
"Specify constraints and boundaries",
"Include success criteria and evaluation standards"
]
```
#### Example Quality
```python
example_optimization_hypotheses = [
"Increase number of examples from 3 to 5",
"Add edge case examples",
"Reorder examples by complexity",
"Include negative examples",
"Add reasoning traces to examples",
"Improve example diversity and coverage"
]
```
#### Structure Optimization
```python
structure_hypotheses = [
"Add clear section headings",
"Reorganize content flow",
"Include summary at the beginning",
"Add checklist for verification",
"Separate instructions from examples",
"Add troubleshooting section"
]
```
#### Model-Specific Optimization
```python
model_specific_hypotheses = {
'claude': [
"Use XML tags for structure",
"Add <thinking> sections for reasoning",
"Include constitutional AI principles",
"Use system message format",
"Add safety guidelines and constraints"
],
'gpt-4': [
"Use numbered sections with ### headers",
"Include JSON format specifications",
"Add function calling patterns",
"Use bullet points for clarity",
"Include error handling instructions"
],
'gemini': [
"Use bold headers with ** formatting",
"Include step-by-step process descriptions",
"Add validation checkpoints",
"Use conversational tone",
"Include confidence scoring"
]
}
```
## Continuous Monitoring
### Production Performance Tracking
```python
def setup_monitoring(prompt, alert_thresholds):
"""
Set up continuous monitoring for deployed prompts.
"""
monitors = {
'success_rate': MetricMonitor('success_rate', alert_thresholds['success_rate']),
'response_time': MetricMonitor('response_time', alert_thresholds['response_time']),
'token_cost': MetricMonitor('token_cost', alert_thresholds['token_cost']),
'safety_score': MetricMonitor('safety_score', alert_thresholds['safety_score'])
}
def monitor_performance():
recent_data = collect_recent_performance(prompt)
alerts = []
for metric_name, monitor in monitors.items():
if metric_name in recent_data:
alert = monitor.check(recent_data[metric_name])
if alert:
alerts.append(alert)
return alerts
return monitor_performance
```
### Automated Rollback System
```python
def automated_rollback_system(prompts, monitoring_data):
"""
Automatically rollback to previous version if performance degrades.
"""
def check_and_rollback(current_prompt, baseline_prompt):
current_metrics = monitoring_data.get_metrics(current_prompt)
baseline_metrics = monitoring_data.get_metrics(baseline_prompt)
# Check if performance degradation exceeds threshold
degradation_threshold = 0.1 # 10% degradation
for metric in current_metrics:
if current_metrics[metric] < baseline_metrics[metric] * (1 - degradation_threshold):
return True, f"Performance degradation in {metric}"
return False, "Performance acceptable"
return check_and_rollback
```
## Optimization Tools and Utilities
### Prompt Variation Generator
```python
def generate_prompt_variations(base_prompt):
"""
Generate systematic variations for testing.
"""
variations = {}
# Instruction variations
variations['more_detailed'] = add_detailed_instructions(base_prompt)
variations['simplified'] = simplify_instructions(base_prompt)
variations['structured'] = add_structured_format(base_prompt)
# Example variations
variations['more_examples'] = add_examples(base_prompt)
variations['better_examples'] = improve_example_quality(base_prompt)
variations['diverse_examples'] = add_example_diversity(base_prompt)
# Format variations
variations['numbered_steps'] = add_numbered_steps(base_prompt)
variations['bullet_points'] = use_bullet_points(base_prompt)
variations['sections'] = add_section_headers(base_prompt)
return variations
```
### Performance Dashboard
```python
def create_performance_dashboard(optimization_history):
"""
Create visualization of optimization progress.
"""
# Generate performance metrics over time
metrics_over_time = {
'iterations': [h['iteration'] for h in optimization_history],
'success_rates': [h['performance_after'].success_rate for h in optimization_history],
'token_efficiency': [h['performance_after'].token_efficiency for h in optimization_history],
'response_times': [h['performance_after'].response_time for h in optimization_history]
}
return PerformanceDashboard(metrics_over_time)
```
This comprehensive framework provides systematic methodologies for continuous prompt improvement through data-driven optimization and rigorous testing processes.

View File

@@ -0,0 +1,494 @@
# System Prompt Design
This reference provides comprehensive frameworks for designing effective system prompts that establish consistent model behavior, define clear boundaries, and ensure reliable performance across diverse applications.
## System Prompt Architecture
### Core Components Structure
```
1. Role Definition & Expertise
2. Behavioral Guidelines & Constraints
3. Interaction Protocols
4. Output Format Specifications
5. Safety & Ethical Guidelines
6. Context & Background Information
7. Quality Standards & Verification
8. Error Handling & Uncertainty Protocols
```
## Component Design Patterns
### 1. Role Definition Framework
#### Comprehensive Role Specification
```markdown
## Role Definition
You are an expert {role} with {experience_level} of specialized experience in {domain}. Your expertise includes:
### Core Competencies
- {competency_1}
- {competency_2}
- {competency_3}
- {competency_4}
### Knowledge Boundaries
- You have deep knowledge of {strength_area_1} and {strength_area_2}
- Your knowledge is current as of {knowledge_cutoff_date}
- You should acknowledge limitations in {limitation_area}
- When uncertain about recent developments, state this explicitly
### Professional Standards
- Adhere to {industry_standard_1} guidelines
- Follow {industry_standard_2} best practices
- Maintain {professional_attribute} in all interactions
- Ensure compliance with {regulatory_framework}
```
#### Specialized Role Templates
##### Technical Expert Role
```markdown
## Technical Expert Role
You are a Senior {domain} Engineer with {years} years of experience in {specialization}. Your expertise encompasses:
### Technical Proficiency
- Deep understanding of {technology_stack}
- Experience with {specific_frameworks} and {tools}
- Knowledge of {design_patterns} and {architectures}
- Proficiency in {programming_languages} and {development_methodologies}
### Problem-Solving Approach
- Analyze problems systematically using {methodology}
- Consider multiple solution approaches before recommending
- Evaluate trade-offs between {criteria_1}, {criteria_2}, and {criteria_3}
- Provide scalable and maintainable solutions
### Communication Style
- Explain technical concepts clearly to both technical and non-technical audiences
- Use precise terminology when appropriate
- Provide concrete examples and code snippets when helpful
- Structure responses with clear sections and logical flow
```
##### Analyst Role
```markdown
## Analyst Role
You are a professional {analysis_type} Analyst with expertise in {data_domain} and {methodology}. Your analytical approach includes:
### Analytical Framework
- Apply {analytical_methodology} for systematic analysis
- Use {statistical_techniques} for data interpretation
- Consider {contextual_factors} in your analysis
- Validate findings through {verification_methods}
### Critical Thinking Process
- Question assumptions and identify potential biases
- Evaluate evidence quality and source reliability
- Consider alternative explanations and perspectives
- Synthesize information from multiple sources
### Reporting Standards
- Present findings with appropriate confidence levels
- Distinguish between facts, interpretations, and recommendations
- Provide evidence-based conclusions
- Acknowledge limitations and uncertainties
```
### 2. Behavioral Guidelines Design
#### Comprehensive Behavior Framework
```markdown
## Behavioral Guidelines
### Interaction Style
- Maintain {tone} tone throughout all interactions
- Use {communication_approach} when explaining complex concepts
- Be {responsiveness_level} in addressing user questions
- Demonstrate {empathy_level} when dealing with user challenges
### Response Standards
- Provide responses that are {length_preference} and {detail_preference}
- Structure information using {organization_pattern}
- Include {frequency} examples and illustrations
- Use {format_preference} formatting for clarity
### Quality Expectations
- Ensure all information is {accuracy_standard}
- Provide citations for {information_type} when available
- Cross-verify information using {verification_method}
- Update knowledge based on {update_criteria}
```
#### Model-Specific Behavior Patterns
##### Claude 3.5/4 Specific Guidelines
```markdown
## Claude-Specific Behavioral Guidelines
### Constitutional Alignment
- Follow constitutional AI principles in all responses
- Prioritize helpfulness while maintaining safety
- Consider multiple perspectives before concluding
- Avoid harmful content while remaining useful
### Output Formatting
- Use XML tags for structured information: <tag>content</tag>
- Include thinking blocks for complex reasoning: <thinking>...</thinking>
- Provide clear section headers with proper hierarchy
- Use markdown formatting for improved readability
### Safety Protocols
- Apply content policies consistently
- Identify and flag potentially harmful requests
- Provide safe alternatives when appropriate
- Maintain transparency about limitations
```
##### GPT-4 Specific Guidelines
```markdown
## GPT-4 Specific Behavioral Guidelines
### Structured Response Patterns
- Use numbered lists for step-by-step processes
- Implement clear section boundaries with ### headers
- Provide JSON formatted outputs when specified
- Use consistent indentation and formatting
### Function Calling Integration
- Recognize when function calling would be appropriate
- Structure responses to facilitate tool usage
- Provide clear parameter specifications
- Handle function results systematically
### Optimization Behaviors
- Balance conciseness with comprehensiveness
- Prioritize information relevance and importance
- Use efficient language patterns
- Minimize redundancy while maintaining clarity
```
### 3. Output Format Specifications
#### Comprehensive Format Framework
```markdown
## Output Format Requirements
### Structure Standards
- Begin responses with {opening_pattern}
- Use {section_pattern} for major sections
- Implement {hierarchy_pattern} for information organization
- Include {closing_pattern} for response completion
### Content Organization
- Present information in {presentation_order}
- Group related information using {grouping_method}
- Use {transition_pattern} between sections
- Include {summary_element} for complex responses
### Format Specifications
{if json_format_required}
- Provide responses in valid JSON format
- Use consistent key naming conventions
- Include all required fields
- Validate JSON syntax before output
{endif}
{if markdown_format_required}
- Use markdown for formatting and emphasis
- Include appropriate heading levels
- Use code blocks for technical content
- Implement tables for structured data
{endif}
```
### 4. Safety and Ethical Guidelines
#### Comprehensive Safety Framework
```markdown
## Safety and Ethical Guidelines
### Content Policies
- Avoid generating {prohibited_content_1}
- Do not provide {prohibited_content_2}
- Flag {sensitive_topics} for human review
- Provide {safe_alternatives} when appropriate
### Ethical Considerations
- Consider {ethical_principle_1} in all responses
- Evaluate potential {ethical_impact} of provided information
- Balance helpfulness with {safety_consideration}
- Maintain {transparency_standard} about limitations
### Bias Mitigation
- Actively identify and mitigate {bias_type_1}
- Present information {neutrality_standard}
- Include {diverse_perspectives} when appropriate
- Avoid {stereotype_patterns}
### Harm Prevention
- Identify potential {harm_type_1} in responses
- Implement {prevention_mechanism} for harmful content
- Provide {warning_system} for sensitive topics
- Include {escalation_protocol} for concerning requests
```
### 5. Error Handling and Uncertainty
#### Comprehensive Error Management
```markdown
## Error Handling and Uncertainty Protocols
### Uncertainty Management
- Explicitly state confidence levels for uncertain information
- Use phrases like "I believe," "It appears that," "Based on available information"
- Acknowledge when information may be {uncertainty_type}
- Provide {verification_method} for uncertain claims
### Error Recognition
- Identify when {error_pattern} might have occurred
- Implement {self_checking_mechanism} for accuracy
- Use {validation_process} for important information
- Provide {correction_protocol} when errors are identified
### Limitation Acknowledgment
- Clearly state {knowledge_limitation} when relevant
- Explain {limitation_reason} when unable to provide complete information
- Suggest {alternative_approach} when direct assistance isn't possible
- Provide {escalation_option} for complex scenarios
### Correction Procedures
- Implement {correction_workflow} for identified errors
- Provide {explanation_format} for corrections
- Use {acknowledgment_pattern} for mistakes
- Include {improvement_commitment} for future accuracy
```
## Specialized System Prompt Templates
### 1. Educational Assistant System Prompt
```markdown
# Educational Assistant System Prompt
## Role Definition
You are an expert educational assistant specializing in {subject_area} with {experience_level} of teaching experience. Your pedagogical approach emphasizes {teaching_philosophy} and adapts to different learning styles.
## Educational Philosophy
- Create inclusive and supportive learning environments
- Adapt explanations to match learner's comprehension level
- Use scaffolding techniques to build understanding progressively
- Encourage critical thinking and independent learning
## Teaching Standards
- Provide accurate, up-to-date information verified through {verification_sources}
- Use clear, accessible language appropriate for the target audience
- Include relevant examples and analogies to enhance understanding
- Structure learning objectives with clear progression
## Interaction Protocols
- Assess learner's current understanding before providing explanations
- Ask clarifying questions to tailor responses appropriately
- Provide opportunities for learner questions and feedback
- Offer additional resources for extended learning
## Output Format
- Begin with brief assessment of learner's needs
- Use clear headings and organized structure
- Include summary points for key takeaways
- Provide practice exercises when appropriate
- End with suggestions for further learning
## Safety Guidelines
- Create psychologically safe learning environments
- Avoid language that might discourage or intimidate learners
- Be patient and supportive when learners struggle with concepts
- Respect diverse backgrounds and learning abilities
## Uncertainty Handling
- Acknowledge when topics are beyond current expertise
- Suggest reliable resources for additional information
- Be transparent about the limits of available knowledge
- Encourage critical thinking and independent verification
```
### 2. Technical Documentation Generator System Prompt
```markdown
# Technical Documentation System Prompt
## Role Definition
You are a Senior Technical Writer with {years} of experience creating documentation for {technology_domain}. Your expertise encompasses {documentation_types} and you follow {industry_standards} for technical communication.
## Documentation Standards
- Follow {style_guide} for consistent formatting and terminology
- Ensure clarity and accuracy in all technical explanations
- Include practical examples and code snippets when helpful
- Structure content with clear hierarchy and logical flow
## Quality Requirements
- Maintain technical accuracy verified through {review_process}
- Use consistent terminology throughout documentation
- Provide comprehensive coverage of topics without overwhelming detail
- Include troubleshooting information for common issues
## Audience Considerations
- Target documentation at {audience_level} technical proficiency
- Define technical terms and concepts appropriately
- Provide progressive disclosure of complex information
- Include context and motivation for technical decisions
## Format Specifications
- Use markdown formatting for clear structure and readability
- Include code blocks with syntax highlighting
- Implement consistent section headings and numbering
- Provide navigation aids and cross-references
## Review Process
- Verify technical accuracy through {verification_method}
- Test all code examples and procedures
- Ensure completeness of coverage for documented features
- Validate clarity and comprehensibility with target audience
## Safety and Compliance
- Include security considerations where relevant
- Document potential risks and mitigation strategies
- Follow industry compliance requirements
- Maintain confidentiality for sensitive information
```
### 3. Data Analysis System Prompt
```markdown
# Data Analysis System Prompt
## Role Definition
You are an expert Data Analyst specializing in {data_domain} with {years} of experience in {analysis_methodologies}. Your analytical approach combines {technical_skills} with {business_acumen} to deliver actionable insights.
## Analytical Framework
- Apply {statistical_methods} for rigorous data analysis
- Use {visualization_techniques} for effective data communication
- Implement {quality_assurance} processes for data validation
- Follow {ethical_guidelines} for responsible data handling
## Analysis Standards
- Ensure methodological soundness in all analyses
- Provide clear documentation of analytical processes
- Include appropriate statistical measures and confidence intervals
- Validate findings through {validation_methods}
## Communication Requirements
- Present findings with appropriate technical depth for the audience
- Use clear visualizations and narrative explanations
- Highlight actionable insights and recommendations
- Acknowledge limitations and uncertainties in analyses
## Output Structure
```json
{
"executive_summary": "High-level overview of key findings",
"methodology": "Description of analytical approach and methods used",
"data_overview": "Summary of data sources, quality, and limitations",
"key_findings": [
{
"finding": "Specific discovery or insight",
"evidence": "Supporting data and statistical measures",
"confidence": "Confidence level in the finding",
"implications": "Business or operational implications"
}
],
"recommendations": [
{
"action": "Recommended action",
"priority": "High/Medium/Low",
"expected_impact": "Anticipated outcome",
"implementation_considerations": "Factors to consider"
}
],
"limitations": "Constraints and limitations of the analysis",
"next_steps": "Suggested follow-up analyses or actions"
}
```
## Ethical Considerations
- Protect privacy and confidentiality of data subjects
- Ensure unbiased analysis and interpretation
- Consider potential impact of findings on stakeholders
- Maintain transparency about analytical limitations
```
## System Prompt Testing and Validation
### Validation Framework
```python
class SystemPromptValidator:
def __init__(self):
self.validation_criteria = {
'role_clarity': 0.2,
'instruction_specificity': 0.2,
'safety_completeness': 0.15,
'output_format_clarity': 0.15,
'error_handling_coverage': 0.1,
'behavioral_consistency': 0.1,
'ethical_considerations': 0.1
}
def validate_prompt(self, system_prompt):
"""Validate system prompt against quality criteria."""
scores = {}
scores['role_clarity'] = self.assess_role_clarity(system_prompt)
scores['instruction_specificity'] = self.assess_instruction_specificity(system_prompt)
scores['safety_completeness'] = self.assess_safety_completeness(system_prompt)
scores['output_format_clarity'] = self.assess_output_format_clarity(system_prompt)
scores['error_handling_coverage'] = self.assess_error_handling(system_prompt)
scores['behavioral_consistency'] = self.assess_behavioral_consistency(system_prompt)
scores['ethical_considerations'] = self.assess_ethical_considerations(system_prompt)
# Calculate overall score
overall_score = sum(score * weight for score, weight in
zip(scores.values(), self.validation_criteria.values()))
return {
'overall_score': overall_score,
'individual_scores': scores,
'recommendations': self.generate_recommendations(scores)
}
def test_prompt_consistency(self, system_prompt, test_scenarios):
"""Test prompt behavior consistency across different scenarios."""
results = []
for scenario in test_scenarios:
response = execute_with_system_prompt(system_prompt, scenario)
# Analyze response consistency
consistency_score = self.analyze_response_consistency(response, system_prompt)
results.append({
'scenario': scenario,
'response': response,
'consistency_score': consistency_score
})
average_consistency = sum(r['consistency_score'] for r in results) / len(results)
return {
'average_consistency': average_consistency,
'scenario_results': results,
'recommendations': self.generate_consistency_recommendations(results)
}
```
## Best Practices Summary
### Design Principles
- **Clarity First**: Ensure role and instructions are unambiguous
- **Comprehensive Coverage**: Address all aspects of model behavior
- **Consistency Focus**: Maintain consistent behavior across scenarios
- **Safety Priority**: Include robust safety guidelines and constraints
- **Flexibility Built-in**: Allow for adaptation to different contexts
### Common Pitfalls to Avoid
- **Vague Instructions**: Be specific about expected behaviors
- **Over-constraining**: Allow room for intelligent adaptation
- **Missing Safety Guidelines**: Always include comprehensive safety measures
- **Inconsistent Formatting**: Use consistent structure throughout
- **Ignoring Model Capabilities**: Design prompts that leverage model strengths
This comprehensive system prompt design framework provides the foundation for creating effective, reliable, and safe AI system behaviors across diverse applications and use cases.

View File

@@ -0,0 +1,599 @@
# Template Systems Architecture
This reference provides comprehensive frameworks for building modular, reusable prompt templates with variable interpolation, conditional sections, and hierarchical composition.
## Template Design Principles
### Modularity and Reusability
- **Single Responsibility**: Each template handles one specific type of task
- **Composability**: Templates can be combined to create complex prompts
- **Parameterization**: Variables allow customization without core changes
- **Inheritance**: Base templates can be extended for specific use cases
### Clear Variable Naming Conventions
```
{user_input} - Direct input from user
{context} - Background information
{examples} - Few-shot learning examples
{constraints} - Task limitations and requirements
{output_format} - Desired output structure
{role} - AI role or persona
{expertise_level} - Level of expertise for the role
{domain} - Specific domain or field
{difficulty} - Task complexity level
{language} - Output language specification
```
## Core Template Components
### 1. Base Template Structure
```
# Template: Universal Task Framework
# Purpose: Base template for most task types
# Variables: {role}, {task_description}, {context}, {examples}, {output_format}
## System Instructions
You are a {role} with {expertise_level} expertise in {domain}.
## Context Information
{if context}
Background and relevant context:
{context}
{endif}
## Task Description
{task_description}
## Examples
{if examples}
Here are some examples to guide your response:
{examples}
{endif}
## Output Requirements
{output_format}
## Constraints and Guidelines
{constraints}
## User Input
{user_input}
```
### 2. Conditional Sections Framework
```python
def process_conditional_template(template, variables):
"""
Process template with conditional sections.
"""
# Process if/endif blocks
while '{if ' in template:
start = template.find('{if ')
end_condition = template.find('}', start)
condition = template[start+4:end_condition].strip()
start_endif = template.find('{endif}', end_condition)
if_content = template[end_condition+1:start_endif].strip()
# Evaluate condition
if evaluate_condition(condition, variables):
template = template[:start] + if_content + template[start_endif+6:]
else:
template = template[:start] + template[start_endif+6:]
# Replace variables
for key, value in variables.items():
template = template.replace(f'{{{key}}}', str(value))
return template
```
### 3. Variable Interpolation System
```python
class TemplateEngine:
def __init__(self):
self.variables = {}
self.functions = {
'upper': str.upper,
'lower': str.lower,
'capitalize': str.capitalize,
'pluralize': self.pluralize,
'format_date': self.format_date,
'truncate': self.truncate
}
def set_variable(self, name, value):
"""Set a template variable."""
self.variables[name] = value
def render(self, template):
"""Render template with variable substitution."""
# Process function calls {variable|function}
template = self.process_functions(template)
# Replace variables
for key, value in self.variables.items():
template = template.replace(f'{{{key}}}', str(value))
return template
def process_functions(self, template):
"""Process template functions."""
import re
pattern = r'\{(\w+)\|(\w+)\}'
def replace_function(match):
var_name, func_name = match.groups()
value = self.variables.get(var_name, '')
if func_name in self.functions:
return self.functions[func_name](str(value))
return value
return re.sub(pattern, replace_function, template)
```
## Specialized Template Types
### 1. Classification Template
```
# Template: Multi-Class Classification
# Purpose: Classify inputs into predefined categories
# Required Variables: {input_text}, {categories}, {role}
## Classification Framework
You are a {role} specializing in accurate text classification.
## Classification Categories
{categories}
## Classification Process
1. Analyze the input text carefully
2. Identify key indicators and features
3. Match against category definitions
4. Select the most appropriate category
5. Provide confidence score
## Input to Classify
{input_text}
## Output Format
```json
{{
"category": "selected_category",
"confidence": 0.95,
"reasoning": "Brief explanation of classification logic",
"key_indicators": ["indicator1", "indicator2"]
}}
```
```
### 2. Transformation Template
```
# Template: Text Transformation
# Purpose: Transform text from one format/style to another
# Required Variables: {source_text}, {target_format}, {transformation_rules}
## Transformation Task
Transform the given {source_format} text into {target_format} following these rules:
{transformation_rules}
## Source Text
{source_text}
## Transformation Process
1. Analyze the structure and content of the source text
2. Apply the specified transformation rules
3. Maintain the core meaning and intent
4. Ensure proper {target_format} formatting
5. Verify completeness and accuracy
## Transformed Output
```
### 3. Generation Template
```
# Template: Creative Generation
# Purpose: Generate creative content based on specifications
# Required Variables: {content_type}, {specifications}, {style_guidelines}
## Creative Generation Task
Generate {content_type} that meets the following specifications:
## Content Specifications
{specifications}
## Style Guidelines
{style_guidelines}
## Quality Requirements
- Originality and creativity
- Adherence to specifications
- Appropriate tone and style
- Clear structure and coherence
- Audience-appropriate language
## Generated Content
```
### 4. Analysis Template
```
# Template: Comprehensive Analysis
# Purpose: Perform detailed analysis of given input
# Required Variables: {input_data}, {analysis_framework}, {focus_areas}
## Analysis Framework
You are an expert analyst with deep expertise in {domain}.
## Analysis Scope
Focus on these key areas:
{focus_areas}
## Analysis Methodology
{analysis_framework}
## Input Data for Analysis
{input_data}
## Analysis Process
1. Initial assessment and context understanding
2. Detailed examination of each focus area
3. Pattern and trend identification
4. Comparative analysis with benchmarks
5. Insight generation and recommendation formulation
## Analysis Output Structure
```yaml
executive_summary:
key_findings: []
overall_assessment: ""
detailed_analysis:
{focus_area_1}:
observations: []
patterns: []
insights: []
{focus_area_2}:
observations: []
patterns: []
insights: []
recommendations:
immediate: []
short_term: []
long_term: []
```
## Advanced Template Patterns
### 1. Hierarchical Template Composition
```python
class HierarchicalTemplate:
def __init__(self, name, content, parent=None):
self.name = name
self.content = content
self.parent = parent
self.children = []
self.variables = {}
def add_child(self, child_template):
"""Add a child template."""
child_template.parent = self
self.children.append(child_template)
def render(self, variables=None):
"""Render template with inherited variables."""
# Combine variables from parent hierarchy
combined_vars = {}
# Collect variables from parents
current = self.parent
while current:
combined_vars.update(current.variables)
current = current.parent
# Add current variables
combined_vars.update(self.variables)
# Override with provided variables
if variables:
combined_vars.update(variables)
# Render content
rendered_content = self.render_content(self.content, combined_vars)
# Render children
for child in self.children:
child_rendered = child.render(combined_vars)
rendered_content = rendered_content.replace(
f'{{child:{child.name}}}', child_rendered
)
return rendered_content
```
### 2. Role-Based Template System
```python
class RoleBasedTemplate:
def __init__(self):
self.roles = {
'analyst': {
'persona': 'You are a professional analyst with expertise in data interpretation and pattern recognition.',
'approach': 'systematic',
'output_style': 'detailed and evidence-based',
'verification': 'Always cross-check findings and cite sources'
},
'creative_writer': {
'persona': 'You are a creative writer with a talent for engaging storytelling and vivid descriptions.',
'approach': 'imaginative',
'output_style': 'descriptive and engaging',
'verification': 'Ensure narrative consistency and flow'
},
'technical_expert': {
'persona': 'You are a technical expert with deep knowledge of {domain} and practical implementation experience.',
'approach': 'methodical',
'output_style': 'precise and technical',
'verification': 'Include technical accuracy and best practices'
}
}
def create_prompt(self, role, task, domain=None):
"""Create role-specific prompt template."""
role_config = self.roles.get(role, self.roles['analyst'])
template = f"""
## Role Definition
{role_config['persona']}
## Approach
Use a {role_config['approach']} approach to this task.
## Task
{task}
## Output Style
{role_config['output_style']}
## Verification
{role_config['verification']}
"""
if domain and '{domain}' in role_config['persona']:
template = template.replace('{domain}', domain)
return template
```
### 3. Dynamic Template Selection
```python
class DynamicTemplateSelector:
def __init__(self):
self.templates = {}
self.selection_rules = {}
def register_template(self, name, template, selection_criteria):
"""Register a template with selection criteria."""
self.templates[name] = template
self.selection_rules[name] = selection_criteria
def select_template(self, task_characteristics):
"""Select the most appropriate template based on task characteristics."""
best_template = None
best_score = 0
for name, criteria in self.selection_rules.items():
score = self.calculate_match_score(task_characteristics, criteria)
if score > best_score:
best_score = score
best_template = name
return self.templates[best_template] if best_template else None
def calculate_match_score(self, task_characteristics, criteria):
"""Calculate how well task matches template criteria."""
score = 0
total_weight = 0
for characteristic, weight in criteria.items():
if characteristic in task_characteristics:
if task_characteristics[characteristic] == weight['value']:
score += weight['weight']
total_weight += weight['weight']
return score / total_weight if total_weight > 0 else 0
```
## Template Implementation Examples
### Example 1: Customer Service Template
```python
customer_service_template = """
# Customer Service Response Template
## Role Definition
You are a {customer_service_role} with {experience_level} of customer service experience in {industry}.
## Context
{if customer_history}
Customer History:
{customer_history}
{endif}
{if issue_context}
Issue Context:
{issue_context}
{endif}
## Response Guidelines
- Maintain {tone} tone throughout
- Address all aspects of the customer's inquiry
- Provide {level_of_detail} explanation
- Include {additional_elements}
- Follow company {communication_style} style
## Customer Inquiry
{customer_inquiry}
## Response Structure
1. Greeting and acknowledgment
2. Understanding and empathy
3. Solution or explanation
4. Additional assistance offered
5. Professional closing
## Response
"""
```
### Example 2: Technical Documentation Template
```python
documentation_template = """
# Technical Documentation Generator
## Role Definition
You are a {technical_writer_role} specializing in {technology} documentation with {experience_level} of experience.
## Documentation Standards
- Target audience: {audience_level}
- Technical depth: {technical_depth}
- Include examples: {include_examples}
- Add troubleshooting: {add_troubleshooting}
- Version: {version}
## Content to Document
{content_to_document}
## Documentation Structure
```markdown
# {title}
## Overview
{overview}
## Prerequisites
{prerequisites}
## {main_sections}
## Examples
{if include_examples}
{examples}
{endif}
## Troubleshooting
{if add_troubleshooting}
{troubleshooting}
{endif}
## Additional Resources
{additional_resources}
```
## Generated Documentation
"""
```
## Template Management System
### Version Control Integration
```python
class TemplateVersionManager:
def __init__(self):
self.versions = {}
self.current_versions = {}
def create_version(self, template_name, template_content, author, description):
"""Create a new version of a template."""
import datetime
import hashlib
version_id = hashlib.md5(template_content.encode()).hexdigest()[:8]
timestamp = datetime.datetime.now().isoformat()
version_info = {
'version_id': version_id,
'content': template_content,
'author': author,
'description': description,
'timestamp': timestamp,
'parent_version': self.current_versions.get(template_name)
}
if template_name not in self.versions:
self.versions[template_name] = []
self.versions[template_name].append(version_info)
self.current_versions[template_name] = version_id
return version_id
def rollback(self, template_name, version_id):
"""Rollback to a specific version."""
if template_name in self.versions:
for version in self.versions[template_name]:
if version['version_id'] == version_id:
self.current_versions[template_name] = version_id
return version['content']
return None
```
### Performance Monitoring
```python
class TemplatePerformanceMonitor:
def __init__(self):
self.usage_stats = {}
self.performance_metrics = {}
def track_usage(self, template_name, execution_time, success):
"""Track template usage and performance."""
if template_name not in self.usage_stats:
self.usage_stats[template_name] = {
'usage_count': 0,
'total_time': 0,
'success_count': 0,
'failure_count': 0
}
stats = self.usage_stats[template_name]
stats['usage_count'] += 1
stats['total_time'] += execution_time
if success:
stats['success_count'] += 1
else:
stats['failure_count'] += 1
def get_performance_report(self, template_name):
"""Generate performance report for a template."""
if template_name not in self.usage_stats:
return None
stats = self.usage_stats[template_name]
avg_time = stats['total_time'] / stats['usage_count']
success_rate = stats['success_count'] / stats['usage_count']
return {
'template_name': template_name,
'total_usage': stats['usage_count'],
'average_execution_time': avg_time,
'success_rate': success_rate,
'failure_rate': 1 - success_rate
}
```
## Best Practices
### Template Quality Guidelines
- **Clear Documentation**: Include purpose, variables, and usage examples
- **Consistent Naming**: Use standardized variable naming conventions
- **Error Handling**: Include fallback mechanisms for missing variables
- **Performance Optimization**: Minimize template complexity and rendering time
- **Testing**: Implement comprehensive template testing frameworks
### Security Considerations
- **Input Validation**: Sanitize all template variables
- **Injection Prevention**: Prevent code injection in template rendering
- **Access Control**: Implement proper authorization for template modifications
- **Audit Trail**: Track template changes and usage
This comprehensive template system architecture provides the foundation for building scalable, maintainable prompt templates that can be efficiently managed and optimized across diverse use cases.

286
skills/ai/rag/SKILL.md Normal file
View File

@@ -0,0 +1,286 @@
---
name: rag-implementation
description: Build Retrieval-Augmented Generation (RAG) systems for AI applications with vector databases and semantic search. Use when implementing knowledge-grounded AI, building document Q&A systems, or integrating LLMs with external knowledge bases.
allowed-tools: Read, Write, Bash
category: ai-engineering
tags: [rag, vector-databases, embeddings, retrieval, semantic-search]
version: 1.0.0
---
# RAG Implementation
Build Retrieval-Augmented Generation systems that extend AI capabilities with external knowledge sources.
## Overview
RAG (Retrieval-Augmented Generation) enhances AI applications by retrieving relevant information from knowledge bases and incorporating it into AI responses, reducing hallucinations and providing accurate, grounded answers.
## When to Use
Use this skill when:
- Building Q&A systems over proprietary documents
- Creating chatbots with current, factual information
- Implementing semantic search with natural language queries
- Reducing hallucinations with grounded responses
- Enabling AI systems to access domain-specific knowledge
- Building documentation assistants
- Creating research tools with source citation
- Developing knowledge management systems
## Core Components
### Vector Databases
Store and efficiently retrieve document embeddings for semantic search.
**Key Options:**
- **Pinecone**: Managed, scalable, production-ready
- **Weaviate**: Open-source, hybrid search capabilities
- **Milvus**: High performance, on-premise deployment
- **Chroma**: Lightweight, easy local development
- **Qdrant**: Fast, advanced filtering
- **FAISS**: Meta's library, full control
### Embedding Models
Convert text to numerical vectors for similarity search.
**Popular Models:**
- **text-embedding-ada-002** (OpenAI): General purpose, 1536 dimensions
- **all-MiniLM-L6-v2**: Fast, lightweight, 384 dimensions
- **e5-large-v2**: High quality, multilingual
- **bge-large-en-v1.5**: State-of-the-art performance
### Retrieval Strategies
Find relevant content based on user queries.
**Approaches:**
- **Dense Retrieval**: Semantic similarity via embeddings
- **Sparse Retrieval**: Keyword matching (BM25, TF-IDF)
- **Hybrid Search**: Combine dense + sparse for best results
- **Multi-Query**: Generate multiple query variations
- **Contextual Compression**: Extract only relevant parts
## Quick Implementation
### Basic RAG Setup
```java
// Load documents from file system
List<Document> documents = FileSystemDocumentLoader.loadDocuments("/path/to/docs");
// Create embedding store
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
// Ingest documents into the store
EmbeddingStoreIngestor.ingest(documents, embeddingStore);
// Create AI service with RAG capability
Assistant assistant = AiServices.builder(Assistant.class)
.chatModel(chatModel)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.contentRetriever(EmbeddingStoreContentRetriever.from(embeddingStore))
.build();
```
### Document Processing Pipeline
```java
// Split documents into chunks
DocumentSplitter splitter = new RecursiveCharacterTextSplitter(
500, // chunk size
100 // overlap
);
// Create embedding model
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey("your-api-key")
.build();
// Create embedding store
EmbeddingStore<TextSegment> embeddingStore = PgVectorEmbeddingStore.builder()
.host("localhost")
.database("postgres")
.user("postgres")
.password("password")
.table("embeddings")
.dimension(1536)
.build();
// Process and store documents
for (Document document : documents) {
List<TextSegment> segments = splitter.split(document);
for (TextSegment segment : segments) {
Embedding embedding = embeddingModel.embed(segment).content();
embeddingStore.add(embedding, segment);
}
}
```
## Implementation Patterns
### Pattern 1: Simple Document Q&A
Create a basic Q&A system over your documents.
```java
public interface DocumentAssistant {
String answer(String question);
}
DocumentAssistant assistant = AiServices.builder(DocumentAssistant.class)
.chatModel(chatModel)
.contentRetriever(retriever)
.build();
```
### Pattern 2: Metadata-Filtered Retrieval
Filter results based on document metadata.
```java
// Add metadata during document loading
Document document = Document.builder()
.text("Content here")
.metadata("source", "technical-manual.pdf")
.metadata("category", "technical")
.metadata("date", "2024-01-15")
.build();
// Filter during retrieval
EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.minScore(0.7)
.filter(metadataKey("category").isEqualTo("technical"))
.build();
```
### Pattern 3: Multi-Source Retrieval
Combine results from multiple knowledge sources.
```java
ContentRetriever webRetriever = EmbeddingStoreContentRetriever.from(webStore);
ContentRetriever documentRetriever = EmbeddingStoreContentRetriever.from(documentStore);
ContentRetriever databaseRetriever = EmbeddingStoreContentRetriever.from(databaseStore);
// Combine results
List<Content> allResults = new ArrayList<>();
allResults.addAll(webRetriever.retrieve(query));
allResults.addAll(documentRetriever.retrieve(query));
allResults.addAll(databaseRetriever.retrieve(query));
// Rerank combined results
List<Content> rerankedResults = reranker.reorder(query, allResults);
```
## Best Practices
### Document Preparation
- Clean and preprocess documents before ingestion
- Remove irrelevant content and formatting artifacts
- Standardize document structure for consistent processing
- Add relevant metadata for filtering and context
### Chunking Strategy
- Use 500-1000 tokens per chunk for optimal balance
- Include 10-20% overlap to preserve context at boundaries
- Consider document structure when determining chunk boundaries
- Test different chunk sizes for your specific use case
### Retrieval Optimization
- Start with high k values (10-20) then filter/rerank
- Use metadata filtering to improve relevance
- Combine multiple retrieval strategies for better coverage
- Monitor retrieval quality and user feedback
### Performance Considerations
- Cache embeddings for frequently accessed content
- Use batch processing for document ingestion
- Optimize vector store configuration for your scale
- Monitor query performance and system resources
## Common Issues and Solutions
### Poor Retrieval Quality
**Problem**: Retrieved documents don't match user queries
**Solutions**:
- Improve document preprocessing and cleaning
- Adjust chunk size and overlap parameters
- Try different embedding models
- Use hybrid search combining semantic and keyword matching
### Irrelevant Results
**Problem**: Retrieved documents contain relevant information but are not specific enough
**Solutions**:
- Add metadata filtering for domain-specific constraints
- Implement reranking with cross-encoder models
- Use contextual compression to extract relevant parts
- Fine-tune retrieval parameters (k values, similarity thresholds)
### Performance Issues
**Problem**: Slow response times during retrieval
**Solutions**:
- Optimize vector store configuration and indexing
- Implement caching for frequently retrieved content
- Use smaller embedding models for faster inference
- Consider approximate nearest neighbor algorithms
### Hallucination Prevention
**Problem**: AI generates information not present in retrieved documents
**Solutions**:
- Improve prompt engineering to emphasize grounding
- Add verification steps to check answer alignment
- Include confidence scoring for responses
- Implement fact-checking mechanisms
## Evaluation Framework
### Retrieval Metrics
- **Precision@k**: Percentage of relevant documents in top-k results
- **Recall@k**: Percentage of all relevant documents found in top-k results
- **Mean Reciprocal Rank (MRR)**: Average rank of first relevant result
- **Normalized Discounted Cumulative Gain (nDCG)**: Ranking quality metric
### Answer Quality Metrics
- **Faithfulness**: Degree to which answers are grounded in retrieved documents
- **Answer Relevance**: How well answers address user questions
- **Context Recall**: Percentage of relevant context used in answers
- **Context Precision**: Percentage of retrieved context that is relevant
### User Experience Metrics
- **Response Time**: Time from query to answer
- **User Satisfaction**: Feedback ratings on answer quality
- **Task Completion**: Rate of successful task completion
- **Engagement**: User interaction patterns with the system
## Resources
### Reference Documentation
- [Vector Database Comparison](references/vector-databases.md) - Detailed comparison of vector database options
- [Embedding Models Guide](references/embedding-models.md) - Model selection and optimization
- [Retrieval Strategies](references/retrieval-strategies.md) - Advanced retrieval techniques
- [Document Chunking](references/document-chunking.md) - Chunking strategies and best practices
- [LangChain4j RAG Guide](references/langchain4j-rag-guide.md) - Official implementation patterns
### Assets
- `assets/vector-store-config.yaml` - Configuration templates for different vector stores
- `assets/retriever-pipeline.java` - Complete RAG pipeline implementation
- `assets/evaluation-metrics.java` - Evaluation framework code
## Constraints and Limitations
1. **Token Limits**: Respect model context window limitations
2. **API Rate Limits**: Manage external API rate limits and costs
3. **Data Privacy**: Ensure compliance with data protection regulations
4. **Resource Requirements**: Consider memory and computational requirements
5. **Maintenance**: Plan for regular updates and system monitoring
## Security Considerations
- Secure access to vector databases and embedding services
- Implement proper authentication and authorization
- Validate and sanitize user inputs
- Monitor for abuse and unusual usage patterns
- Regular security audits and penetration testing

View File

@@ -0,0 +1,307 @@
package com.example.rag;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.splitter.RecursiveCharacterTextSplitter;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import dev.langchain4j.store.embedding.pinecone.PineconeEmbeddingStore;
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
/**
* Complete RAG Pipeline Implementation
*
* This class provides a comprehensive implementation of a RAG (Retrieval-Augmented Generation)
* system with support for multiple vector stores and advanced retrieval strategies.
*/
public class RAGPipeline {
private final EmbeddingModel embeddingModel;
private final EmbeddingStore<TextSegment> embeddingStore;
private final DocumentSplitter documentSplitter;
private final RAGConfig config;
/**
* Configuration class for RAG pipeline
*/
public static class RAGConfig {
private String vectorStoreType = "chroma";
private String openAiApiKey;
private String pineconeApiKey;
private String pineconeEnvironment;
private String pineconeIndex = "rag-documents";
private String chromaCollection = "rag-documents";
private String chromaPersistPath = "./chroma_db";
private String qdrantHost = "localhost";
private int qdrantPort = 6333;
private String qdrantCollection = "rag-documents";
private int chunkSize = 1000;
private int chunkOverlap = 200;
private int embeddingDimension = 1536;
// Getters and setters
public String getVectorStoreType() { return vectorStoreType; }
public void setVectorStoreType(String vectorStoreType) { this.vectorStoreType = vectorStoreType; }
public String getOpenAiApiKey() { return openAiApiKey; }
public void setOpenAiApiKey(String openAiApiKey) { this.openAiApiKey = openAiApiKey; }
public String getPineconeApiKey() { return pineconeApiKey; }
public void setPineconeApiKey(String pineconeApiKey) { this.pineconeApiKey = pineconeApiKey; }
public String getPineconeEnvironment() { return pineconeEnvironment; }
public void setPineconeEnvironment(String pineconeEnvironment) { this.pineconeEnvironment = pineconeEnvironment; }
public String getPineconeIndex() { return pineconeIndex; }
public void setPineconeIndex(String pineconeIndex) { this.pineconeIndex = pineconeIndex; }
public String getChromaCollection() { return chromaCollection; }
public void setChromaCollection(String chromaCollection) { this.chromaCollection = chromaCollection; }
public String getChromaPersistPath() { return chromaPersistPath; }
public void setChromaPersistPath(String chromaPersistPath) { this.chromaPersistPath = chromaPersistPath; }
public String getQdrantHost() { return qdrantHost; }
public void setQdrantHost(String qdrantHost) { this.qdrantHost = qdrantHost; }
public int getQdrantPort() { return qdrantPort; }
public void setQdrantPort(int qdrantPort) { this.qdrantPort = qdrantPort; }
public String getQdrantCollection() { return qdrantCollection; }
public void setQdrantCollection(String qdrantCollection) { this.qdrantCollection = qdrantCollection; }
public int getChunkSize() { return chunkSize; }
public void setChunkSize(int chunkSize) { this.chunkSize = chunkSize; }
public int getChunkOverlap() { return chunkOverlap; }
public void setChunkOverlap(int chunkOverlap) { this.chunkOverlap = chunkOverlap; }
public int getEmbeddingDimension() { return embeddingDimension; }
public void setEmbeddingDimension(int embeddingDimension) { this.embeddingDimension = embeddingDimension; }
}
/**
* Constructor
*/
public RAGPipeline(RAGConfig config) {
this.config = config;
this.embeddingModel = createEmbeddingModel();
this.embeddingStore = createEmbeddingStore();
this.documentSplitter = createDocumentSplitter();
}
/**
* Create embedding model based on configuration
*/
private EmbeddingModel createEmbeddingModel() {
return OpenAiEmbeddingModel.builder()
.apiKey(config.getOpenAiApiKey())
.modelName("text-embedding-ada-002")
.build();
}
/**
* Create embedding store based on configuration
*/
private EmbeddingStore<TextSegment> createEmbeddingStore() {
switch (config.getVectorStoreType().toLowerCase()) {
case "pinecone":
return PineconeEmbeddingStore.builder()
.apiKey(config.getPineconeApiKey())
.environment(config.getPineconeEnvironment())
.index(config.getPineconeIndex())
.dimension(config.getEmbeddingDimension())
.build();
case "chroma":
return ChromaEmbeddingStore.builder()
.collectionName(config.getChromaCollection())
.persistDirectory(config.getChromaPersistPath())
.build();
case "qdrant":
return QdrantEmbeddingStore.builder()
.host(config.getQdrantHost())
.port(config.getQdrantPort())
.collectionName(config.getQdrantCollection())
.dimension(config.getEmbeddingDimension())
.build();
case "memory":
default:
return new InMemoryEmbeddingStore<>();
}
}
/**
* Create document splitter
*/
private DocumentSplitter createDocumentSplitter() {
return new RecursiveCharacterTextSplitter(
config.getChunkSize(),
config.getChunkOverlap()
);
}
/**
* Load documents from directory
*/
public List<Document> loadDocuments(String directoryPath) {
try {
Path directory = Paths.get(directoryPath);
List<Document> documents = FileSystemDocumentLoader.loadDocuments(directory);
// Add metadata to documents
for (Document document : documents) {
Map<String, Object> metadata = new HashMap<>(document.metadata().toMap());
metadata.put("loaded_at", System.currentTimeMillis());
metadata.put("source_directory", directoryPath);
// Update document metadata
document = Document.from(document.text(), metadata);
}
return documents;
} catch (Exception e) {
throw new RuntimeException("Failed to load documents from " + directoryPath, e);
}
}
/**
* Process and ingest documents
*/
public void ingestDocuments(List<Document> documents) {
// Split documents into segments
List<TextSegment> segments = documentSplitter.split(documents);
// Add additional metadata to segments
for (int i = 0; i < segments.size(); i++) {
TextSegment segment = segments.get(i);
Map<String, Object> metadata = new HashMap<>(segment.metadata().toMap());
metadata.put("segment_index", i);
metadata.put("total_segments", segments.size());
metadata.put("processed_at", System.currentTimeMillis());
segments.set(i, TextSegment.from(segment.text(), metadata));
}
// Ingest into embedding store
EmbeddingStoreIngestor.ingest(segments, embeddingStore);
System.out.println("Ingested " + documents.size() + " documents into " +
segments.size() + " segments");
}
/**
* Search documents with optional filtering
*/
public List<TextSegment> search(String query, int maxResults, Filter filter) {
Embedding queryEmbedding = embeddingModel.embed(query).content();
return embeddingStore.findRelevant(queryEmbedding, maxResults, filter);
}
/**
* Search documents with metadata filtering
*/
public List<TextSegment> searchWithMetadataFilter(String query, int maxResults,
Map<String, Object> metadataFilters) {
Filter filter = null;
if (metadataFilters != null && !metadataFilters.isEmpty()) {
MetadataFilterBuilder filterBuilder = new MetadataFilterBuilder();
for (Map.Entry<String, Object> entry : metadataFilters.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
if (value instanceof String) {
filterBuilder = filterBuilder.metadata(key).isEqualTo((String) value);
} else if (value instanceof Number) {
filterBuilder = filterBuilder.metadata(key).isEqualTo(((Number) value).doubleValue());
}
// Add more type handling as needed
}
filter = filterBuilder.build();
}
return search(query, maxResults, filter);
}
/**
* Get statistics about the stored documents
*/
public RAGStatistics getStatistics() {
// This is a simplified implementation
// In practice, you might want to track more detailed statistics
return new RAGStatistics(
embeddingStore.getClass().getSimpleName(),
config.getVectorStoreType()
);
}
/**
* Statistics holder class
*/
public static class RAGStatistics {
private final String storeType;
private final String implementation;
public RAGStatistics(String storeType, String implementation) {
this.storeType = storeType;
this.implementation = implementation;
}
public String getStoreType() { return storeType; }
public String getImplementation() { return implementation; }
@Override
public String toString() {
return "RAGStatistics{" +
"storeType='" + storeType + '\'' +
", implementation='" + implementation + '\'' +
'}';
}
}
/**
* Example usage
*/
public static void main(String[] args) {
// Configure the pipeline
RAGConfig config = new RAGConfig();
config.setVectorStoreType("chroma"); // or "pinecone", "qdrant", "memory"
config.setOpenAiApiKey("your-openai-api-key");
config.setChunkSize(1000);
config.setChunkOverlap(200);
// Create pipeline
RAGPipeline pipeline = new RAGPipeline(config);
// Load documents
List<Document> documents = pipeline.loadDocuments("./documents");
// Ingest documents
pipeline.ingestDocuments(documents);
// Search for relevant content
List<TextSegment> results = pipeline.search("What is machine learning?", 5, null);
// Print results
for (int i = 0; i < results.size(); i++) {
TextSegment segment = results.get(i);
System.out.println("Result " + (i + 1) + ":");
System.out.println("Content: " + segment.text().substring(0, Math.min(200, segment.text().length())) + "...");
System.out.println("Metadata: " + segment.metadata());
System.out.println();
}
// Print statistics
System.out.println("Pipeline Statistics: " + pipeline.getStatistics());
}
}

View File

@@ -0,0 +1,127 @@
# Vector Store Configuration Templates
# This file contains configuration templates for different vector databases
# Chroma (Local/Development)
chroma:
type: chroma
settings:
persist_directory: "./chroma_db"
collection_name: "rag_documents"
host: "localhost"
port: 8000
# Recommended for: Development, small-scale applications
# Pros: Easy setup, local deployment, free
# Cons: Limited scalability, single-node only
# Pinecone (Cloud/Production)
pinecone:
type: pinecone
settings:
api_key: "${PINECONE_API_KEY}"
environment: "us-west1-gcp"
index_name: "rag-documents"
dimension: 1536
metric: "cosine"
pods: 1
pod_type: "p1.x1"
# Recommended for: Production applications, large-scale
# Pros: Managed service, scalable, fast
# Cons: Cost, requires internet connection
# Weaviate (Open-source/Cloud)
weaviate:
type: weaviate
settings:
url: "http://localhost:8080"
api_key: "${WEAVIATE_API_KEY}"
class_name: "Document"
text_key: "content"
vectorizer: "text2vec-openai"
module_config:
text2vec-openai:
model: "ada"
modelVersion: "002"
type: "text"
baseUrl: "https://api.openai.com/v1"
# Recommended for: Hybrid search, GraphQL API
# Pros: Open-source, hybrid search, flexible
# Cons: More complex setup
# Qdrant (Performance-focused)
qdrant:
type: qdrant
settings:
host: "localhost"
port: 6333
collection_name: "rag_documents"
vector_size: 1536
distance: "Cosine"
api_key: "${QDRANT_API_KEY}"
# Recommended for: Performance, advanced filtering
# Pros: Fast, good filtering, open-source
# Cons: Newer project, smaller community
# Milvus (Enterprise/Scale)
milvus:
type: milvus
settings:
host: "localhost"
port: 19530
collection_name: "rag_documents"
dimension: 1536
index_type: "IVF_FLAT"
metric_type: "COSINE"
nlist: 1024
# Recommended for: Enterprise, large-scale deployments
# Pros: High performance, distributed
# Cons: Complex setup, resource intensive
# FAISS (Local/Research)
faiss:
type: faiss
settings:
index_type: "IndexFlatL2"
dimension: 1536
save_path: "./faiss_index"
# Recommended for: Research, local processing
# Pros: Fast, local, no dependencies
# Cons: No persistence, limited features
# Common Configuration Parameters
common:
chunking:
chunk_size: 1000
chunk_overlap: 200
separators: ["\n\n", "\n", " ", ""]
embedding:
model: "text-embedding-ada-002"
batch_size: 100
max_retries: 3
timeout: 30
retrieval:
default_k: 5
similarity_threshold: 0.7
max_results: 20
performance:
cache_embeddings: true
cache_size: 1000
parallel_processing: true
batch_size: 50
# Environment Variables Template
# Copy this to .env file and fill in your values
environment:
OPENAI_API_KEY: "your-openai-api-key-here"
PINECONE_API_KEY: "your-pinecone-api-key-here"
PINECONE_ENVIRONMENT: "us-west1-gcp"
WEAVIATE_API_KEY: "your-weaviate-api-key-here"
QDRANT_API_KEY: "your-qdrant-api-key-here"

View File

@@ -0,0 +1,137 @@
# Document Chunking Strategies
## Overview
Document chunking is the process of breaking large documents into smaller, manageable pieces that can be effectively embedded and retrieved.
## Chunking Strategies
### 1. Recursive Character Text Splitter
**Method**: Split text based on character count, trying separators in order
**Use Case**: General purpose text splitting
**Advantages**: Preserves sentence and paragraph boundaries when possible
```python
from langchain.text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", " ", ""] # Try these in order
)
chunks = splitter.split_documents(documents)
```
### 2. Token-Based Splitting
**Method**: Split based on token count rather than characters
**Use Case**: When working with token limits of language models
**Advantages**: Better control over context window usage
```python
from langchain.text_splitters import TokenTextSplitter
splitter = TokenTextSplitter(
chunk_size=512,
chunk_overlap=50
)
chunks = splitter.split_documents(documents)
```
### 3. Semantic Chunking
**Method**: Split based on semantic similarity
**Use Case**: When maintaining semantic coherence is important
**Advantages**: Chunks are more semantically meaningful
```python
from langchain.text_splitters import SemanticChunker
splitter = SemanticChunker(
embeddings=OpenAIEmbeddings(),
breakpoint_threshold_type="percentile"
)
chunks = splitter.split_documents(documents)
```
### 4. Markdown Header Splitter
**Method**: Split based on markdown headers
**Use Case**: Structured documents with clear hierarchical organization
**Advantages**: Maintains document structure and context
```python
from langchain.text_splitters import MarkdownHeaderTextSplitter
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
chunks = splitter.split_documents(documents)
```
### 5. HTML Splitter
**Method**: Split based on HTML tags
**Use Case**: Web pages and HTML documents
**Advantages**: Preserves HTML structure and metadata
```python
from langchain.text_splitters import HTMLHeaderTextSplitter
headers_to_split_on = [
("h1", "Header 1"),
("h2", "Header 2"),
("h3", "Header 3"),
]
splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
chunks = splitter.split_documents(documents)
```
## Parameter Tuning
### Chunk Size
- **Small chunks (200-400 tokens)**: More precise retrieval, but may lose context
- **Medium chunks (500-1000 tokens)**: Good balance of precision and context
- **Large chunks (1000-2000 tokens)**: More context, but less precise retrieval
### Chunk Overlap
- **Purpose**: Preserve context at chunk boundaries
- **Typical range**: 10-20% of chunk size
- **Higher overlap**: Better context preservation, but more redundancy
- **Lower overlap**: Less redundancy, but may lose important context
### Separators
- **Hierarchical separators**: Start with larger boundaries (paragraphs), then smaller (sentences)
- **Custom separators**: Add domain-specific separators for better results
- **Language-specific**: Adjust for different languages and writing styles
## Best Practices
1. **Preserve Context**: Ensure chunks contain enough surrounding context
2. **Maintain Coherence**: Keep semantically related content together
3. **Respect Boundaries**: Avoid breaking sentences or important phrases
4. **Consider Query Types**: Adapt chunking strategy to typical user queries
5. **Test and Iterate**: Evaluate different chunking strategies for your specific use case
## Evaluation Metrics
1. **Retrieval Quality**: How well chunks answer user queries
2. **Context Preservation**: Whether important context is maintained
3. **Chunk Distribution**: Evenness of chunk sizes
4. **Boundary Quality**: How natural chunk boundaries are
5. **Retrieval Efficiency**: Impact on retrieval speed and accuracy
## Advanced Techniques
### Adaptive Chunking
Adjust chunk size based on document structure and content density.
### Hierarchical Chunking
Create multiple levels of chunks for different retrieval scenarios.
### Query-Aware Chunking
Optimize chunk boundaries based on typical query patterns.
### Domain-Specific Splitting
Use specialized splitters for specific document types (legal, medical, technical).

View File

@@ -0,0 +1,88 @@
# Embedding Models Guide
## Overview
Embedding models convert text into numerical vectors that capture semantic meaning for similarity search in RAG systems.
## Popular Embedding Models
### 1. text-embedding-ada-002 (OpenAI)
- **Dimensions**: 1536
- **Type**: General purpose
- **Use Case**: Most applications requiring high quality embeddings
- **Performance**: Excellent balance of quality and speed
### 2. all-MiniLM-L6-v2 (Sentence Transformers)
- **Dimensions**: 384
- **Type**: Lightweight
- **Use Case**: Applications requiring fast inference
- **Performance**: Good quality, very fast
### 3. e5-large-v2
- **Dimensions**: 1024
- **Type**: High quality
- **Use Case**: Applications needing superior performance
- **Performance**: Excellent quality, multilingual support
### 4. Instructor
- **Dimensions**: Variable (768)
- **Type**: Task-specific
- **Use Case**: Domain-specific applications
- **Performance**: Can be fine-tuned for specific tasks
### 5. bge-large-en-v1.5
- **Dimensions**: 1024
- **Type**: State-of-the-art
- **Use Case**: Applications requiring best possible quality
- **Performance**: SOTA performance on benchmarks
## Selection Criteria
1. **Quality vs Speed**: Balance between embedding quality and inference speed
2. **Dimension Size**: Impact on storage and retrieval performance
3. **Domain**: Specific language or domain requirements
4. **Cost**: API costs vs local deployment
5. **Batch Size**: Throughput requirements
6. **Language**: Multilingual support needs
## Usage Examples
### OpenAI Embeddings
```python
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vector = embeddings.embed_query("Your text here")
```
### Sentence Transformers
```python
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
vector = model.encode("Your text here")
```
### Hugging Face Models
```python
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
```
## Optimization Tips
1. **Batch Processing**: Process multiple texts together for efficiency
2. **Model Quantization**: Reduce model size for faster inference
3. **Caching**: Cache embeddings for frequently used texts
4. **GPU Acceleration**: Use GPU for faster processing when available
5. **Model Selection**: Choose appropriate model size for your use case
## Evaluation Metrics
1. **Semantic Similarity**: How well embeddings capture meaning
2. **Retrieval Performance**: Quality of retrieved documents
3. **Speed**: Inference time per document
4. **Memory Usage**: RAM requirements for the model
5. **Cost**: API costs or infrastructure requirements

View File

@@ -0,0 +1,94 @@
# LangChain4j RAG Implementation Guide
## Overview
RAG (Retrieval-Augmented Generation) extends LLM knowledge by finding and injecting relevant information from your data into prompts before sending to the LLM.
## What is RAG?
RAG helps LLMs answer questions using domain-specific knowledge by retrieving relevant information to reduce hallucinations.
## RAG Flavors in LangChain4j
### 1. Easy RAG
Simplest way to start with minimal setup. Handles document loading, splitting, and embedding automatically.
### 2. Core RAG APIs
Modular components including:
- Document
- TextSegment
- EmbeddingModel
- EmbeddingStore
- DocumentSplitter
### 3. Advanced RAG
Complex pipelines supporting:
- Query transformation
- Multi-source retrieval
- Re-ranking with components like QueryTransformer and ContentRetriever
## RAG Stages
### 1. Indexing
Pre-process documents for efficient search
### 2. Retrieval
Find relevant content based on user queries
## Core Components
### Documents with metadata
Structured representation of your content with associated metadata for filtering and context.
### Text segments (chunks)
Smaller, manageable pieces of documents that are embedded and stored in vector databases.
### Embedding models
Convert text segments into numerical vectors for similarity search.
### Embedding stores (vector databases)
Store and efficiently retrieve embedded text segments.
### Content retrievers
Find relevant content based on user queries.
### Query transformers
Transform and optimize user queries for better retrieval.
### Content aggregators
Combine and rank retrieved content.
## Advanced Features
- Query transformation and routing
- Multiple retrievers for different data sources
- Re-ranking models for improved relevance
- Metadata filtering for targeted retrieval
- Parallel processing for performance
## Implementation Example (Easy RAG)
```java
// Load documents
List<Document> documents = FileSystemDocumentLoader.loadDocuments("/path/to/docs");
// Create embedding store
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
// Ingest documents
EmbeddingStoreIngestor.ingest(documents, embeddingStore);
// Create AI service
Assistant assistant = AiServices.builder(Assistant.class)
.chatModel(chatModel)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.contentRetriever(EmbeddingStoreContentRetriever.from(embeddingStore))
.build();
```
## Best Practices
1. **Document Preparation**: Clean and structure documents before ingestion
2. **Chunk Size**: Balance between context preservation and retrieval precision
3. **Metadata Strategy**: Include relevant metadata for filtering and context
4. **Embedding Model Selection**: Choose models appropriate for your domain
5. **Retrieval Strategy**: Select appropriate k values and filtering criteria
6. **Evaluation**: Continuously evaluate retrieval quality and answer accuracy

View File

@@ -0,0 +1,161 @@
# Advanced Retrieval Strategies
## Overview
Different retrieval approaches for finding relevant documents in RAG systems, each with specific strengths and use cases.
## Retrieval Approaches
### 1. Dense Retrieval
**Method**: Semantic similarity via embeddings
**Use Case**: Understanding meaning and context
**Example**: Finding documents about "machine learning" when query is "AI algorithms"
```python
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_documents(chunks, embeddings)
results = vectorstore.similarity_search("query", k=5)
```
### 2. Sparse Retrieval
**Method**: Keyword matching (BM25, TF-IDF)
**Use Case**: Exact term matching and keyword-specific queries
**Example**: Finding documents containing specific technical terms
```python
from langchain.retrievers import BM25Retriever
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 5
results = bm25_retriever.get_relevant_documents("query")
```
### 3. Hybrid Search
**Method**: Combine dense + sparse retrieval
**Use Case**: Balance between semantic understanding and keyword matching
```python
from langchain.retrievers import BM25Retriever, EnsembleRetriever
# Sparse retriever (BM25)
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 5
# Dense retriever (embeddings)
embedding_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
# Combine with weights
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, embedding_retriever],
weights=[0.3, 0.7]
)
```
### 4. Multi-Query Retrieval
**Method**: Generate multiple query variations
**Use Case**: Complex queries that can be interpreted in multiple ways
```python
from langchain.retrievers.multi_query import MultiQueryRetriever
# Generate multiple query perspectives
retriever = MultiQueryRetriever.from_llm(
retriever=vectorstore.as_retriever(),
llm=OpenAI()
)
# Single query → multiple variations → combined results
results = retriever.get_relevant_documents("What is the main topic?")
```
### 5. HyDE (Hypothetical Document Embeddings)
**Method**: Generate hypothetical documents for better retrieval
**Use Case**: When queries are very different from document style
```python
# Generate hypothetical document based on query
hypothetical_doc = llm.generate(f"Write a document about: {query}")
# Use hypothetical doc for retrieval
results = vectorstore.similarity_search(hypothetical_doc, k=5)
```
## Advanced Retrieval Patterns
### Contextual Compression
Compress retrieved documents to only include relevant parts
```python
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=vectorstore.as_retriever()
)
```
### Parent Document Retriever
Store small chunks for retrieval, return larger chunks for context
```python
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
store = InMemoryStore()
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter
)
```
## Retrieval Optimization Techniques
### 1. Metadata Filtering
Filter results based on document metadata
```python
results = vectorstore.similarity_search(
"query",
filter={"category": "technical", "date": {"$gte": "2023-01-01"}},
k=5
)
```
### 2. Maximal Marginal Relevance (MMR)
Balance relevance with diversity
```python
results = vectorstore.max_marginal_relevance_search(
"query",
k=5,
fetch_k=20,
lambda_mult=0.5 # 0=max diversity, 1=max relevance
)
```
### 3. Reranking
Improve top results with cross-encoder
```python
from sentence_transformers import CrossEncoder
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
candidates = vectorstore.similarity_search("query", k=20)
pairs = [[query, doc.page_content] for doc in candidates]
scores = reranker.predict(pairs)
reranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)[:5]
```
## Selection Guidelines
1. **Query Type**: Choose strategy based on typical query patterns
2. **Document Type**: Consider document structure and content
3. **Performance Requirements**: Balance quality vs speed
4. **Domain Knowledge**: Leverage domain-specific patterns
5. **User Expectations**: Match retrieval behavior to user expectations

View File

@@ -0,0 +1,86 @@
# Vector Database Comparison and Configuration
## Overview
Vector databases store and efficiently retrieve document embeddings for semantic search in RAG systems.
## Popular Vector Database Options
### 1. Pinecone
- **Type**: Managed cloud service
- **Features**: Scalable, fast queries, managed infrastructure
- **Use Case**: Production applications requiring high availability
### 2. Weaviate
- **Type**: Open-source, hybrid search
- **Features**: Combines vector and keyword search, GraphQL API
- **Use Case**: Applications needing both semantic and traditional search
### 3. Milvus
- **Type**: High performance, on-premise
- **Features**: Distributed architecture, GPU acceleration
- **Use Case**: Large-scale deployments with custom infrastructure
### 4. Chroma
- **Type**: Lightweight, easy to use
- **Features**: Local deployment, simple API
- **Use Case**: Development and small-scale applications
### 5. Qdrant
- **Type**: Fast, filtered search
- **Features**: Advanced filtering, payload support
- **Use Case**: Applications requiring complex metadata filtering
### 6. FAISS
- **Type**: Meta's library, local deployment
- **Features**: High performance, CPU/GPU optimized
- **Use Case**: Research and applications needing full control
## Configuration Examples
### Pinecone Setup
```python
import pinecone
from langchain.vectorstores import Pinecone
pinecone.init(api_key="your-api-key", environment="us-west1-gcp")
index = pinecone.Index("your-index-name")
vectorstore = Pinecone(index, embeddings.embed_query, "text")
```
### Weaviate Setup
```python
import weaviate
from langchain.vectorstores import Weaviate
client = weaviate.Client("http://localhost:8080")
vectorstore = Weaviate(client, "Document", "content", embeddings)
```
### Chroma Local Setup
```python
from langchain.vectorstores import Chroma
vectorstore = Chroma(
collection_name="my_collection",
embedding_function=embeddings,
persist_directory="./chroma_db"
)
```
## Selection Criteria
1. **Scale**: Number of documents and expected query volume
2. **Performance**: Latency requirements and throughput needs
3. **Deployment**: Cloud vs on-premise preferences
4. **Features**: Filtering, hybrid search, metadata support
5. **Cost**: Budget constraints and operational overhead
6. **Maintenance**: Team expertise and available resources
## Best Practices
1. **Indexing Strategy**: Choose appropriate distance metrics (cosine, euclidean)
2. **Sharding**: Distribute data for large-scale deployments
3. **Monitoring**: Track query performance and system health
4. **Backups**: Implement regular backup procedures
5. **Security**: Secure access to sensitive data
6. **Optimization**: Tune parameters for your specific use case

View File

@@ -0,0 +1,396 @@
---
name: aws-rds-spring-boot-integration
description: Configure AWS RDS (Aurora, MySQL, PostgreSQL) with Spring Boot applications. Use when setting up datasources, connection pooling, security, and production-ready database configuration.
category: aws
tags: [aws, rds, aurora, spring-boot, spring-data-jpa, datasource, configuration, hikari, mysql, postgresql]
version: 1.1.0
allowed-tools: Read, Write, Bash, Glob
---
# AWS RDS Spring Boot Integration
Configure AWS RDS databases (Aurora, MySQL, PostgreSQL) with Spring Boot applications for production-ready connectivity.
## When to Use This Skill
Use this skill when:
- Setting up AWS RDS Aurora with Spring Data JPA
- Configuring datasource properties for Aurora, MySQL, or PostgreSQL endpoints
- Implementing HikariCP connection pooling for RDS
- Setting up environment-specific configurations (dev/prod)
- Configuring SSL connections to AWS RDS
- Troubleshooting RDS connection issues
- Setting up database migrations with Flyway
- Integrating with AWS Secrets Manager for credential management
- Optimizing connection pool settings for RDS workloads
- Implementing read/write split with Aurora
## Prerequisites
Before starting AWS RDS Spring Boot integration:
1. AWS account with RDS access
2. Spring Boot project (3.x)
3. RDS instance created and running (Aurora/MySQL/PostgreSQL)
4. Security group configured for database access
5. Database endpoint information available
6. Database credentials secured (environment variables or Secrets Manager)
## Quick Start
### Step 1: Add Dependencies
**Maven (pom.xml):**
```xml
<dependencies>
<!-- Spring Data JPA -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<!-- Aurora MySQL Driver -->
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.2.0</version>
<scope>runtime</scope>
</dependency>
<!-- Aurora PostgreSQL Driver (alternative) -->
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<scope>runtime</scope>
</dependency>
<!-- Flyway for database migrations -->
<dependency>
<groupId>org.flywaydb</groupId>
<artifactId>flyway-core</artifactId>
</dependency>
<!-- Validation -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-validation</artifactId>
</dependency>
</dependencies>
```
**Gradle (build.gradle):**
```gradle
dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-validation'
// Aurora MySQL
runtimeOnly 'com.mysql:mysql-connector-j:8.2.0'
// Aurora PostgreSQL (alternative)
runtimeOnly 'org.postgresql:postgresql'
// Flyway
implementation 'org.flywaydb:flyway-core'
}
```
### Step 2: Basic Datasource Configuration
**application.properties (Aurora MySQL):**
```properties
# Aurora MySQL Datasource - Cluster Endpoint
spring.datasource.url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
spring.datasource.username=admin
spring.datasource.password=${DB_PASSWORD}
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver
# JPA/Hibernate Configuration
spring.jpa.hibernate.ddl-auto=validate
spring.jpa.show-sql=false
spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.MySQL8Dialect
spring.jpa.properties.hibernate.format_sql=true
spring.jpa.open-in-view=false
# HikariCP Connection Pool
spring.datasource.hikari.maximum-pool-size=20
spring.datasource.hikari.minimum-idle=5
spring.datasource.hikari.connection-timeout=20000
spring.datasource.hikari.idle-timeout=300000
spring.datasource.hikari.max-lifetime=1200000
# Flyway Configuration
spring.flyway.enabled=true
spring.flyway.baseline-on-migrate=true
spring.flyway.locations=classpath:db/migration
```
**application.properties (Aurora PostgreSQL):**
```properties
# Aurora PostgreSQL Datasource
spring.datasource.url=jdbc:postgresql://myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:5432/devops
spring.datasource.username=admin
spring.datasource.password=${DB_PASSWORD}
spring.datasource.driver-class-name=org.postgresql.Driver
# JPA/Hibernate Configuration
spring.jpa.hibernate.ddl-auto=validate
spring.jpa.show-sql=false
spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.PostgreSQLDialect
spring.jpa.properties.hibernate.jdbc.lob.non_contextual_creation=true
spring.jpa.open-in-view=false
```
### Step 3: Set Up Environment Variables
```bash
# Production environment variables
export DB_PASSWORD=YourStrongPassword123!
export SPRING_PROFILES_ACTIVE=prod
# For development
export SPRING_PROFILES_ACTIVE=dev
```
## Configuration Examples
### Simple Aurora Cluster (MySQL)
**application.yml:**
```yaml
spring:
application:
name: DevOps
datasource:
url: jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
username: admin
password: ${DB_PASSWORD}
driver-class-name: com.mysql.cj.jdbc.Driver
hikari:
pool-name: AuroraHikariPool
maximum-pool-size: 20
minimum-idle: 5
connection-timeout: 20000
idle-timeout: 300000
max-lifetime: 1200000
leak-detection-threshold: 60000
connection-test-query: SELECT 1
jpa:
hibernate:
ddl-auto: validate
show-sql: false
open-in-view: false
properties:
hibernate:
dialect: org.hibernate.dialect.MySQL8Dialect
format_sql: true
jdbc:
batch_size: 20
order_inserts: true
order_updates: true
flyway:
enabled: true
baseline-on-migrate: true
locations: classpath:db/migration
validate-on-migrate: true
logging:
level:
org.hibernate.SQL: WARN
com.zaxxer.hikari: INFO
```
### Read/Write Split Configuration
For read-heavy workloads, use separate writer and reader datasources:
**application.properties:**
```properties
# Aurora MySQL - Writer Endpoint
spring.datasource.writer.jdbc-url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
spring.datasource.writer.username=admin
spring.datasource.writer.password=${DB_PASSWORD}
spring.datasource.writer.driver-class-name=com.mysql.cj.jdbc.Driver
# Aurora MySQL - Reader Endpoint (Read Replicas)
spring.datasource.reader.jdbc-url=jdbc:mysql://myapp-aurora-cluster.cluster-ro-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
spring.datasource.reader.username=admin
spring.datasource.reader.password=${DB_PASSWORD}
spring.datasource.reader.driver-class-name=com.mysql.cj.jdbc.Driver
# HikariCP for Writer
spring.datasource.writer.hikari.maximum-pool-size=15
spring.datasource.writer.hikari.minimum-idle=5
# HikariCP for Reader
spring.datasource.reader.hikari.maximum-pool-size=25
spring.datasource.reader.hikari.minimum-idle=10
```
### SSL Configuration
**Aurora MySQL with SSL:**
```properties
spring.datasource.url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops?useSSL=true&requireSSL=true&verifyServerCertificate=true
```
**Aurora PostgreSQL with SSL:**
```properties
spring.datasource.url=jdbc:postgresql://myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:5432/devops?ssl=true&sslmode=require
```
## Environment-Specific Configuration
### Development Profile
**application-dev.properties:**
```properties
# Local MySQL for development
spring.datasource.url=jdbc:mysql://localhost:3306/devops_dev
spring.datasource.username=root
spring.datasource.password=root
# Enable DDL auto-update in development
spring.jpa.hibernate.ddl-auto=update
spring.jpa.show-sql=true
# Smaller connection pool for local dev
spring.datasource.hikari.maximum-pool-size=5
spring.datasource.hikari.minimum-idle=2
```
### Production Profile
**application-prod.properties:**
```properties
# Aurora Cluster Endpoint (Production)
spring.datasource.url=jdbc:mysql://${AURORA_ENDPOINT}:3306/${DB_NAME}
spring.datasource.username=${DB_USERNAME}
spring.datasource.password=${DB_PASSWORD}
# Validate schema only in production
spring.jpa.hibernate.ddl-auto=validate
spring.jpa.show-sql=false
spring.jpa.open-in-view=false
# Production-optimized connection pool
spring.datasource.hikari.maximum-pool-size=30
spring.datasource.hikari.minimum-idle=10
spring.datasource.hikari.connection-timeout=20000
spring.datasource.hikari.idle-timeout=300000
spring.datasource.hikari.max-lifetime=1200000
# Enable Flyway migrations
spring.flyway.enabled=true
spring.flyway.validate-on-migrate=true
```
## Database Migration Setup
Create migration files for Flyway:
```
src/main/resources/db/migration/
├── V1__create_users_table.sql
├── V2__add_phone_column.sql
└── V3__create_orders_table.sql
```
**V1__create_users_table.sql:**
```sql
CREATE TABLE users (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(255) NOT NULL UNIQUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_email (email)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
```
## Advanced Features
For advanced configuration, see the reference documents:
- [Multi-datasource, SSL, Secrets Manager integration](references/advanced-configuration.md)
- [Common issues and solutions](references/troubleshooting.md)
## Best Practices
### Connection Pool Optimization
- Use HikariCP with Aurora-optimized settings
- Set appropriate pool sizes based on Aurora instance capacity
- Configure connection timeouts for failover handling
- Enable leak detection
### Security Best Practices
- Never hardcode credentials in configuration files
- Use environment variables or AWS Secrets Manager
- Enable SSL/TLS connections
- Configure proper security group rules
- Use IAM Database Authentication when possible
### Performance Optimization
- Enable batch operations for bulk data operations
- Disable open-in-view pattern to prevent lazy loading issues
- Use appropriate indexing for Aurora queries
- Configure connection pooling for high availability
### Monitoring
- Enable Spring Boot Actuator for database metrics
- Monitor connection pool metrics
- Set up proper logging for debugging
- Configure health checks for database connectivity
## Testing
Create a health check endpoint to test database connectivity:
```java
@RestController
@RequestMapping("/api/health")
public class DatabaseHealthController {
@Autowired
private DataSource dataSource;
@GetMapping("/db-connection")
public ResponseEntity<Map<String, Object>> testDatabaseConnection() {
Map<String, Object> response = new HashMap<>();
try (Connection connection = dataSource.getConnection()) {
response.put("status", "success");
response.put("database", connection.getCatalog());
response.put("url", connection.getMetaData().getURL());
response.put("connected", true);
return ResponseEntity.ok(response);
} catch (Exception e) {
response.put("status", "failed");
response.put("error", e.getMessage());
response.put("connected", false);
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE).body(response);
}
}
}
```
**Test with cURL:**
```bash
curl http://localhost:8080/api/health/db-connection
```
## Support
For detailed troubleshooting and advanced configuration, refer to:
- [AWS RDS Aurora Advanced Configuration](references/advanced-configuration.md)
- [AWS RDS Aurora Troubleshooting Guide](references/troubleshooting.md)
- [AWS RDS Aurora documentation](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/java_aurora_code_examples.html)
- [Spring Boot Data RDS Aurora documentation](https://www.baeldung.com/aws-aurora-rds-java)

View File

@@ -0,0 +1,279 @@
# AWS RDS Aurora Advanced Configuration
## Read/Write Split Configuration
For applications with heavy read operations, configure separate datasources:
**Multi-Datasource Configuration Class:**
```java
@Configuration
public class AuroraDataSourceConfig {
@Primary
@Bean(name = "writerDataSource")
@ConfigurationProperties("spring.datasource.writer")
public DataSource writerDataSource() {
return DataSourceBuilder.create().build();
}
@Bean(name = "readerDataSource")
@ConfigurationProperties("spring.datasource.reader")
public DataSource readerDataSource() {
return DataSourceBuilder.create().build();
}
@Primary
@Bean(name = "writerEntityManagerFactory")
public LocalContainerEntityManagerFactoryBean writerEntityManagerFactory(
EntityManagerFactoryBuilder builder,
@Qualifier("writerDataSource") DataSource dataSource) {
return builder
.dataSource(dataSource)
.packages("com.example.domain")
.persistenceUnit("writer")
.build();
}
@Bean(name = "readerEntityManagerFactory")
public LocalContainerEntityManagerFactoryBean readerEntityManagerFactory(
EntityManagerFactoryBuilder builder,
@Qualifier("readerDataSource") DataSource dataSource) {
return builder
.dataSource(dataSource)
.packages("com.example.domain")
.persistenceUnit("reader")
.build();
}
@Primary
@Bean(name = "writerTransactionManager")
public PlatformTransactionManager writerTransactionManager(
@Qualifier("writerEntityManagerFactory") EntityManagerFactory entityManagerFactory) {
return new JpaTransactionManager(entityManagerFactory);
}
@Bean(name = "readerTransactionManager")
public PlatformTransactionManager readerTransactionManager(
@Qualifier("readerEntityManagerFactory") EntityManagerFactory entityManagerFactory) {
return new JpaTransactionManager(entityManagerFactory);
}
}
```
**Usage in Repository:**
```java
@Repository
public interface UserReadRepository extends JpaRepository<User, Long> {
// Read operations automatically use reader endpoint
}
@Repository
public interface UserWriteRepository extends JpaRepository<User, Long> {
// Write operations use writer endpoint
}
```
## SSL/TLS Configuration
Enable SSL for secure connections to Aurora:
**Aurora MySQL with SSL:**
```properties
spring.datasource.url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops?useSSL=true&requireSSL=true&verifyServerCertificate=true
```
**Aurora PostgreSQL with SSL:**
```properties
spring.datasource.url=jdbc:postgresql://myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:5432/devops?ssl=true&sslmode=require
```
**Download RDS Certificate:**
```bash
# Download RDS CA certificate
wget https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem
# Configure in application
spring.datasource.url=jdbc:mysql://...?useSSL=true&trustCertificateKeyStoreUrl=file:///path/to/global-bundle.pem
```
## AWS Secrets Manager Integration
**Add AWS SDK Dependency:**
```xml
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>secretsmanager</artifactId>
<version>2.20.0</version>
</dependency>
```
**Secrets Manager Configuration:**
```java
@Configuration
public class AuroraDataSourceConfig {
@Value("${aws.secretsmanager.secret-name}")
private String secretName;
@Value("${aws.region}")
private String region;
@Bean
public DataSource dataSource() {
Map<String, String> credentials = getAuroraCredentials();
HikariConfig config = new HikariConfig();
config.setJdbcUrl(credentials.get("url"));
config.setUsername(credentials.get("username"));
config.setPassword(credentials.get("password"));
config.setMaximumPoolSize(20);
config.setMinimumIdle(5);
config.setConnectionTimeout(20000);
return new HikariDataSource(config);
}
private Map<String, String> getAuroraCredentials() {
SecretsManagerClient client = SecretsManagerClient.builder()
.region(Region.of(region))
.build();
GetSecretValueRequest request = GetSecretValueRequest.builder()
.secretId(secretName)
.build();
GetSecretValueResponse response = client.getSecretValue(request);
String secretString = response.secretString();
// Parse JSON secret
ObjectMapper mapper = new ObjectMapper();
try {
return mapper.readValue(secretString, Map.class);
} catch (Exception e) {
throw new RuntimeException("Failed to parse secret", e);
}
}
}
```
**application.properties (Secrets Manager):**
```properties
aws.secretsmanager.secret-name=prod/aurora/credentials
aws.region=us-east-1
```
## Database Migration with Flyway
### Setup Flyway
**Create Migration Directory:**
```
src/main/resources/db/migration/
├── V1__create_users_table.sql
├── V2__add_phone_column.sql
└── V3__create_orders_table.sql
```
**V1__create_users_table.sql:**
```sql
CREATE TABLE users (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(255) NOT NULL UNIQUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_email (email)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
```
**V2__add_phone_column.sql:**
```sql
ALTER TABLE users ADD COLUMN phone VARCHAR(20);
```
**Flyway Configuration:**
```properties
spring.jpa.hibernate.ddl-auto=validate
spring.flyway.enabled=true
spring.flyway.baseline-on-migrate=true
spring.flyway.locations=classpath:db/migration
spring.flyway.validate-on-migrate=true
```
## Connection Pool Optimization for Aurora
**Recommended HikariCP Settings:**
```properties
# Aurora-optimized connection pool
spring.datasource.hikari.maximum-pool-size=20
spring.datasource.hikari.minimum-idle=5
spring.datasource.hikari.connection-timeout=20000
spring.datasource.hikari.idle-timeout=300000
spring.datasource.hikari.max-lifetime=1200000
spring.datasource.hikari.leak-detection-threshold=60000
spring.datasource.hikari.connection-test-query=SELECT 1
```
**Formula for Pool Size:**
```
connections = ((core_count * 2) + effective_spindle_count)
For Aurora: Use 20-30 connections per application instance
```
## Failover Handling
Aurora automatically handles failover between instances. Configure connection retry:
```properties
# Connection retry configuration
spring.datasource.hikari.connection-timeout=30000
spring.datasource.url=jdbc:mysql://cluster-endpoint:3306/db?failOverReadOnly=false&maxReconnects=3&connectTimeout=30000
```
## Read Replica Load Balancing
Use reader endpoint for distributing read traffic across replicas:
```properties
# Reader endpoint for read-heavy workloads
spring.datasource.reader.url=jdbc:mysql://cluster-ro-endpoint:3306/db
```
## Performance Optimization
**Enable batch operations:**
```properties
spring.jpa.properties.hibernate.jdbc.batch_size=20
spring.jpa.properties.hibernate.order_inserts=true
spring.jpa.properties.hibernate.order_updates=true
spring.jpa.properties.hibernate.batch_versioned_data=true
```
**Disable open-in-view pattern:**
```properties
spring.jpa.open-in-view=false
```
**Production logging configuration:**
```properties
# Disable SQL logging in production
logging.level.org.hibernate.SQL=WARN
logging.level.org.springframework.jdbc=WARN
# Enable HikariCP metrics
logging.level.com.zaxxer.hikari=INFO
logging.level.com.zaxxer.hikari.pool=DEBUG
```
**Enable Spring Boot Actuator for metrics:**
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
```
```properties
management.endpoints.web.exposure.include=health,metrics,info
management.endpoint.health.show-details=always
```

View File

@@ -0,0 +1,180 @@
# AWS RDS Aurora Troubleshooting Guide
## Common Issues and Solutions
### Connection Timeout to Aurora Cluster
**Error:** `Communications link failure` or `Connection timed out`
**Solutions:**
- Verify security group inbound rules allow traffic on port 3306 (MySQL) or 5432 (PostgreSQL)
- Check Aurora cluster endpoint is correct (cluster vs instance endpoint)
- Ensure your IP/CIDR is whitelisted in security group
- Verify VPC and subnet configuration
- Check if Aurora cluster is in the same VPC or VPC peering is configured
```bash
# Test connection from EC2/local machine
telnet myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com 3306
```
### Access Denied for User
**Error:** `Access denied for user 'admin'@'...'`
**Solutions:**
- Verify master username and password are correct
- Check if IAM authentication is required but not configured
- Reset master password in Aurora console if needed
- Verify user permissions in database
```sql
-- Check user permissions
SHOW GRANTS FOR 'admin'@'%';
```
### Database Not Found
**Error:** `Unknown database 'devops'`
**Solutions:**
- Verify initial database name was created with cluster
- Create database manually using MySQL/PostgreSQL client
- Check database name in JDBC URL matches existing database
```sql
-- Connect to Aurora and create database
CREATE DATABASE devops CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
```
### SSL Connection Issues
**Error:** `SSL connection error` or `Certificate validation failed`
**Solutions:**
```properties
# Option 1: Disable SSL verification (NOT recommended for production)
spring.datasource.url=jdbc:mysql://...?useSSL=false
# Option 2: Properly configure SSL with RDS certificate
spring.datasource.url=jdbc:mysql://...?useSSL=true&requireSSL=true&verifyServerCertificate=true&trustCertificateKeyStoreUrl=file:///path/to/global-bundle.pem
# Option 3: Trust all certificates (NOT recommended for production)
spring.datasource.url=jdbc:mysql://...?useSSL=true&requireSSL=true&verifyServerCertificate=false
```
### Too Many Connections
**Error:** `Too many connections` or `Connection pool exhausted`
**Solutions:**
- Review Aurora instance max_connections parameter
- Optimize HikariCP pool size
- Check for connection leaks in application code
```properties
# Reduce pool size
spring.datasource.hikari.maximum-pool-size=15
spring.datasource.hikari.minimum-idle=5
# Enable leak detection
spring.datasource.hikari.leak-detection-threshold=60000
```
**Check Aurora max_connections:**
```sql
SHOW VARIABLES LIKE 'max_connections';
-- Default for Aurora: depends on instance class
-- db.r6g.large: ~1000 connections
```
### Slow Query Performance
**Error:** Queries taking longer than expected
**Solutions:**
- Enable slow query log in Aurora parameter group
- Review connection pool settings
- Check Aurora instance metrics in CloudWatch
- Optimize queries and add indexes
```properties
# Enable query logging (development only)
logging.level.org.hibernate.SQL=DEBUG
logging.level.org.hibernate.type.descriptor.sql.BasicBinder=TRACE
```
### Failover Delays
**Error:** Application freezes during Aurora failover
**Solutions:**
- Configure connection timeout appropriately
- Use cluster endpoint (not instance endpoint)
- Implement connection retry logic
```properties
spring.datasource.hikari.connection-timeout=20000
spring.datasource.hikari.validation-timeout=5000
spring.datasource.url=jdbc:mysql://...?failOverReadOnly=false&maxReconnects=3
```
## Testing Aurora Connection
### Connection Test with Spring Boot Application
**Create a Simple Test Endpoint:**
```java
@RestController
@RequestMapping("/api/health")
public class DatabaseHealthController {
@Autowired
private DataSource dataSource;
@GetMapping("/db-connection")
public ResponseEntity<Map<String, Object>> testDatabaseConnection() {
Map<String, Object> response = new HashMap<>();
try (Connection connection = dataSource.getConnection()) {
response.put("status", "success");
response.put("database", connection.getCatalog());
response.put("url", connection.getMetaData().getURL());
response.put("connected", true);
return ResponseEntity.ok(response);
} catch (Exception e) {
response.put("status", "failed");
response.put("error", e.getMessage());
response.put("connected", false);
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE).body(response);
}
}
}
```
**Test with cURL:**
```bash
curl http://localhost:8080/api/health/db-connection
```
### Verify Aurora Connection with MySQL/PostgreSQL Client
**MySQL Client Connection:**
```bash
# Connect to Aurora MySQL cluster
mysql -h myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com \
-P 3306 \
-u admin \
-p devops
# Verify connection
SHOW DATABASES;
SELECT @@version;
SHOW VARIABLES LIKE 'aurora_version';
```
**PostgreSQL Client Connection:**
```bash
# Connect to Aurora PostgreSQL
psql -h myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com \
-p 5432 \
-U admin \
-d devops
# Verify connection
\l
SELECT version();
```

View File

@@ -0,0 +1,377 @@
---
name: aws-sdk-java-v2-bedrock
description: Amazon Bedrock patterns using AWS SDK for Java 2.x. Use when working with foundation models (listing, invoking), text generation, image generation, embeddings, streaming responses, or integrating generative AI with Spring Boot applications.
category: aws
tags: [aws, bedrock, java, sdk, generative-ai, foundation-models]
version: 2.0.0
allowed-tools: Read, Write, Bash
---
# AWS SDK for Java 2.x - Amazon Bedrock
## When to Use
Use this skill when:
- Listing and inspecting foundation models on Amazon Bedrock
- Invoking foundation models for text generation (Claude, Llama, Titan)
- Generating images with AI models (Stable Diffusion)
- Creating text embeddings for RAG applications
- Implementing streaming responses for real-time generation
- Working with multiple AI providers through unified API
- Integrating generative AI into Spring Boot applications
- Building AI-powered chatbots and assistants
## Overview
Amazon Bedrock provides access to foundation models from leading AI providers through a unified API. This skill covers patterns for working with various models including Claude, Llama, Titan, and Stability Diffusion using AWS SDK for Java 2.x.
## Quick Start
### Dependencies
```xml
<!-- Bedrock (model management) -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrock</artifactId>
</dependency>
<!-- Bedrock Runtime (model invocation) -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrockruntime</artifactId>
</dependency>
<!-- For JSON processing -->
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20231013</version>
</dependency>
```
### Basic Client Setup
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrock.BedrockClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
// Model management client
BedrockClient bedrockClient = BedrockClient.builder()
.region(Region.US_EAST_1)
.build();
// Model invocation client
BedrockRuntimeClient bedrockRuntimeClient = BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.build();
```
## Core Patterns
### Model Discovery
```java
import software.amazon.awssdk.services.bedrock.model.*;
import java.util.List;
public List<FoundationModelSummary> listFoundationModels(BedrockClient bedrockClient) {
return bedrockClient.listFoundationModels().modelSummaries();
}
```
### Model Invocation
```java
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.bedrockruntime.model.*;
import org.json.JSONObject;
public String invokeModel(BedrockRuntimeClient client, String modelId, String prompt) {
JSONObject payload = createPayload(modelId, prompt);
InvokeModelResponse response = client.invokeModel(request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload.toString())));
return extractTextFromResponse(modelId, response.body().asUtf8String());
}
private JSONObject createPayload(String modelId, String prompt) {
if (modelId.startsWith("anthropic.claude")) {
return new JSONObject()
.put("anthropic_version", "bedrock-2023-05-31")
.put("max_tokens", 1000)
.put("messages", new JSONObject[]{
new JSONObject().put("role", "user").put("content", prompt)
});
} else if (modelId.startsWith("amazon.titan")) {
return new JSONObject()
.put("inputText", prompt)
.put("textGenerationConfig", new JSONObject()
.put("maxTokenCount", 512)
.put("temperature", 0.7));
} else if (modelId.startsWith("meta.llama")) {
return new JSONObject()
.put("prompt", "[INST] " + prompt + " [/INST]")
.put("max_gen_len", 512)
.put("temperature", 0.7);
}
throw new IllegalArgumentException("Unsupported model: " + modelId);
}
```
### Streaming Responses
```java
public void streamResponse(BedrockRuntimeClient client, String modelId, String prompt) {
JSONObject payload = createPayload(modelId, prompt);
InvokeModelWithResponseStreamRequest streamRequest =
InvokeModelWithResponseStreamRequest.builder()
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload.toString()))
.build();
client.invokeModelWithResponseStream(streamRequest,
InvokeModelWithResponseStreamResponseHandler.builder()
.onEventStream(stream -> {
stream.forEach(event -> {
if (event instanceof PayloadPart) {
PayloadPart payloadPart = (PayloadPart) event;
String chunk = payloadPart.bytes().asUtf8String();
processChunk(modelId, chunk);
}
});
})
.build());
}
```
### Text Embeddings
```java
public double[] createEmbeddings(BedrockRuntimeClient client, String text) {
String modelId = "amazon.titan-embed-text-v1";
JSONObject payload = new JSONObject().put("inputText", text);
InvokeModelResponse response = client.invokeModel(request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload.toString())));
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
JSONArray embeddingArray = responseBody.getJSONArray("embedding");
double[] embeddings = new double[embeddingArray.length()];
for (int i = 0; i < embeddingArray.length(); i++) {
embeddings[i] = embeddingArray.getDouble(i);
}
return embeddings;
}
```
### Spring Boot Integration
```java
@Configuration
public class BedrockConfiguration {
@Bean
public BedrockClient bedrockClient() {
return BedrockClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public BedrockRuntimeClient bedrockRuntimeClient() {
return BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
@Service
public class BedrockAIService {
private final BedrockRuntimeClient bedrockRuntimeClient;
@Value("${bedrock.default-model-id:anthropic.claude-sonnet-4-5-20250929-v1:0}")
private String defaultModelId;
public BedrockAIService(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
}
public String generateText(String prompt) {
return generateText(prompt, defaultModelId);
}
public String generateText(String prompt, String modelId) {
Map<String, Object> payload = createPayload(modelId, prompt);
String payloadJson = new ObjectMapper().writeValueAsString(payload);
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(
request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payloadJson)));
return extractTextFromResponse(modelId, response.body().asUtf8String());
}
}
```
## Basic Usage Example
```java
BedrockRuntimeClient client = BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.build();
String prompt = "Explain quantum computing in simple terms";
String response = invokeModel(client, "anthropic.claude-sonnet-4-5-20250929-v1:0", prompt);
System.out.println(response);
```
## Best Practices
### Model Selection
- **Claude 4.5 Sonnet**: Best for complex reasoning, analysis, and creative tasks
- **Claude 4.5 Haiku**: Fast and affordable for real-time applications
- **Claude 3.7 Sonnet**: Most advanced reasoning capabilities
- **Llama 3.1**: Latest generation open-source alternative, good for general tasks
- **Titan**: AWS native, cost-effective for simple text generation
### Performance Optimization
- Reuse client instances (don't create new clients for each request)
- Use async clients for I/O operations
- Implement streaming for long responses
- Cache foundation model lists
### Security
- Never log sensitive prompt data
- Use IAM roles for authentication (never access keys)
- Implement rate limiting for public applications
- Sanitize user inputs to prevent prompt injection
### Error Handling
- Implement retry logic for throttling (exponential backoff)
- Handle model-specific validation errors
- Validate responses before processing
- Use proper exception handling for different error types
### Cost Optimization
- Use appropriate max_tokens limits
- Choose cost-effective models for simple tasks
- Cache embeddings when possible
- Monitor usage and set budget alerts
## Common Model IDs
```java
// Claude Models
public static final String CLAUDE_SONNET_4_5 = "anthropic.claude-sonnet-4-5-20250929-v1:0";
public static final String CLAUDE_HAIKU_4_5 = "anthropic.claude-haiku-4-5-20251001-v1:0";
public static final String CLAUDE_OPUS_4_1 = "anthropic.claude-opus-4-1-20250805-v1:0";
public static final String CLAUDE_3_7_SONNET = "anthropic.claude-3-7-sonnet-20250219-v1:0";
public static final String CLAUDE_OPUS_4 = "anthropic.claude-opus-4-20250514-v1:0";
public static final String CLAUDE_SONNET_4 = "anthropic.claude-sonnet-4-20250514-v1:0";
public static final String CLAUDE_3_5_SONNET_V2 = "anthropic.claude-3-5-sonnet-20241022-v2:0";
public static final String CLAUDE_3_5_HAIKU = "anthropic.claude-3-5-haiku-20241022-v1:0";
public static final String CLAUDE_3_OPUS = "anthropic.claude-3-opus-20240229-v1:0";
// Llama Models
public static final String LLAMA_3_3_70B = "meta.llama3-3-70b-instruct-v1:0";
public static final String LLAMA_3_2_90B = "meta.llama3-2-90b-instruct-v1:0";
public static final String LLAMA_3_2_11B = "meta.llama3-2-11b-instruct-v1:0";
public static final String LLAMA_3_2_3B = "meta.llama3-2-3b-instruct-v1:0";
public static final String LLAMA_3_2_1B = "meta.llama3-2-1b-instruct-v1:0";
public static final String LLAMA_4_MAV_17B = "meta.llama4-maverick-17b-instruct-v1:0";
public static final String LLAMA_4_SCOUT_17B = "meta.llama4-scout-17b-instruct-v1:0";
public static final String LLAMA_3_1_405B = "meta.llama3-1-405b-instruct-v1:0";
public static final String LLAMA_3_1_70B = "meta.llama3-1-70b-instruct-v1:0";
public static final String LLAMA_3_1_8B = "meta.llama3-1-8b-instruct-v1:0";
public static final String LLAMA_3_70B = "meta.llama3-70b-instruct-v1:0";
public static final String LLAMA_3_8B = "meta.llama3-8b-instruct-v1:0";
// Amazon Titan Models
public static final String TITAN_TEXT_EXPRESS = "amazon.titan-text-express-v1";
public static final String TITAN_TEXT_LITE = "amazon.titan-text-lite-v1";
public static final String TITAN_EMBEDDINGS = "amazon.titan-embed-text-v1";
public static final String TITAN_IMAGE_GENERATOR = "amazon.titan-image-generator-v1";
// Stable Diffusion
public static final String STABLE_DIFFUSION_XL = "stability.stable-diffusion-xl-v1";
// Mistral AI Models
public static final String MISTRAL_LARGE_2407 = "mistral.mistral-large-2407-v1:0";
public static final String MISTRAL_LARGE_2402 = "mistral.mistral-large-2402-v1:0";
public static final String MISTRAL_SMALL_2402 = "mistral.mistral-small-2402-v1:0";
public static final String MISTRAL_PIXTRAL_2502 = "mistral.pixtral-large-2502-v1:0";
public static final String MISTRAL_MIXTRAL_8X7B = "mistral.mixtral-8x7b-instruct-v0:1";
public static final String MISTRAL_7B = "mistral.mistral-7b-instruct-v0:2";
// Amazon Nova Models
public static final String NOVA_PREMIER = "amazon.nova-premier-v1:0";
public static final String NOVA_PRO = "amazon.nova-pro-v1:0";
public static final String NOVA_LITE = "amazon.nova-lite-v1:0";
public static final String NOVA_MICRO = "amazon.nova-micro-v1:0";
public static final String NOVA_CANVAS = "amazon.nova-canvas-v1:0";
public static final String NOVA_REEL = "amazon.nova-reel-v1:1";
// Other Models
public static final String COHERE_COMMAND = "cohere.command-text-v14";
public static final String DEEPSEEK_R1 = "deepseek.r1-v1:0";
public static final String DEEPSEEK_V3_1 = "deepseek.v3-v1:0";
```
## Examples
See the [examples directory](examples/) for comprehensive usage patterns.
## Advanced Topics
See the [Advanced Topics](references/advanced-topics.md) for:
- Multi-model service patterns
- Advanced error handling with retries
- Batch processing strategies
- Performance optimization techniques
- Custom response parsing
## Model Reference
See the [Model Reference](references/model-reference.md) for:
- Detailed model specifications
- Payload/response formats for each provider
- Performance characteristics
- Model selection guidelines
- Configuration templates
## Testing Strategies
See the [Testing Strategies](references/testing-strategies.md) for:
- Unit testing with mocked clients
- Integration testing with LocalStack
- Performance testing
- Streaming response testing
- Test data management
## Related Skills
- `aws-sdk-java-v2-core` - Core AWS SDK patterns
- `langchain4j-ai-services-patterns` - LangChain4j integration
- `spring-boot-dependency-injection` - Spring DI patterns
- `spring-boot-test-patterns` - Spring testing patterns
## References
- [AWS Bedrock User Guide](references/aws-bedrock-user-guide.md)
- [AWS SDK for Java 2.x Documentation](references/aws-sdk-java-bedrock-api.md)
- [Bedrock API Reference](references/aws-bedrock-api-reference.md)
- [AWS SDK Examples](references/aws-sdk-examples.md)
- [Official AWS Examples](bedrock_code_examples.md)
- [Supported Models](bedrock_models_supported.md)
- [Runtime Examples](bedrock_runtime_code_examples.md)

View File

@@ -0,0 +1,249 @@
Amazon Bedrock examples using SDK for Java 2.x - AWS SDK for Java 2.x
Amazon Bedrock examples using SDK for Java 2.x - AWS SDK for Java 2.x
[Open PDF](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/pdfs/sdk-for-java/latest/developer-guide/aws-sdk-java-dg-v2.pdf#java_bedrock_code_examples "Open PDF")
[Documentation](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/index.html) [AWS SDK for Java](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/sdk-for-java/index.html) [Developer Guide for version 2.x](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/home.html)
[Actions](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html#actions)
# Amazon Bedrock examples using SDK for Java 2.x
The following code examples show you how to perform actions and implement common scenarios by using
the AWS SDK for Java 2.x with Amazon Bedrock.
_Actions_ are code excerpts from larger programs and must be run in context. While actions show you how to call individual service functions, you can see actions in context in their related scenarios.
Each example includes a link to the complete source code, where you can find
instructions on how to set up and run the code in context.
###### Topics
- [Actions](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html#actions)
## Actions
The following code example shows how to use `GetFoundationModel`.
**SDK for Java 2.x**
###### Note
There's more on GitHub. Find the complete example and learn how to set up and run in the
[AWS Code\
Examples Repository](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/bedrock#code-examples).
Get details about a foundation model using the synchronous Amazon Bedrock client.
```java
/**
* Get details about an Amazon Bedrock foundation model.
*
* @param bedrockClient The service client for accessing Amazon Bedrock.
* @param modelIdentifier The model identifier.
* @return An object containing the foundation model's details.
*/
public static FoundationModelDetails getFoundationModel(BedrockClient bedrockClient, String modelIdentifier) {
try {
GetFoundationModelResponse response = bedrockClient.getFoundationModel(
r -> r.modelIdentifier(modelIdentifier)
);
FoundationModelDetails model = response.modelDetails();
System.out.println(" Model ID: " + model.modelId());
System.out.println(" Model ARN: " + model.modelArn());
System.out.println(" Model Name: " + model.modelName());
System.out.println(" Provider Name: " + model.providerName());
System.out.println(" Lifecycle status: " + model.modelLifecycle().statusAsString());
System.out.println(" Input modalities: " + model.inputModalities());
System.out.println(" Output modalities: " + model.outputModalities());
System.out.println(" Supported customizations: " + model.customizationsSupported());
System.out.println(" Supported inference types: " + model.inferenceTypesSupported());
System.out.println(" Response streaming supported: " + model.responseStreamingSupported());
return model;
} catch (ValidationException e) {
throw new IllegalArgumentException(e.getMessage());
} catch (SdkException e) {
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
}
```
Get details about a foundation model using the asynchronous Amazon Bedrock client.
```java
/**
* Get details about an Amazon Bedrock foundation model.
*
* @param bedrockClient The async service client for accessing Amazon Bedrock.
* @param modelIdentifier The model identifier.
* @return An object containing the foundation model's details.
*/
public static FoundationModelDetails getFoundationModel(BedrockAsyncClient bedrockClient, String modelIdentifier) {
try {
CompletableFuture<GetFoundationModelResponse> future = bedrockClient.getFoundationModel(
r -> r.modelIdentifier(modelIdentifier)
);
FoundationModelDetails model = future.get().modelDetails();
System.out.println(" Model ID: " + model.modelId());
System.out.println(" Model ARN: " + model.modelArn());
System.out.println(" Model Name: " + model.modelName());
System.out.println(" Provider Name: " + model.providerName());
System.out.println(" Lifecycle status: " + model.modelLifecycle().statusAsString());
System.out.println(" Input modalities: " + model.inputModalities());
System.out.println(" Output modalities: " + model.outputModalities());
System.out.println(" Supported customizations: " + model.customizationsSupported());
System.out.println(" Supported inference types: " + model.inferenceTypesSupported());
System.out.println(" Response streaming supported: " + model.responseStreamingSupported());
return model;
} catch (ExecutionException e) {
if (e.getMessage().contains("ValidationException")) {
throw new IllegalArgumentException(e.getMessage());
} else {
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
}
```
- For API details, see
[GetFoundationModel](https://docs.aws.amazon.com/goto/SdkForJavaV2/bedrock-2023-04-20/GetFoundationModel)
in _AWS SDK for Java 2.x API Reference_.
The following code example shows how to use `ListFoundationModels`.
**SDK for Java 2.x**
###### Note
There's more on GitHub. Find the complete example and learn how to set up and run in the
[AWS Code\
Examples Repository](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/bedrock#code-examples).
List the available Amazon Bedrock foundation models using the synchronous Amazon Bedrock client.
```java
/**
* Lists Amazon Bedrock foundation models that you can use.
* You can filter the results with the request parameters.
*
* @param bedrockClient The service client for accessing Amazon Bedrock.
* @return A list of objects containing the foundation models' details
*/
public static List<FoundationModelSummary> listFoundationModels(BedrockClient bedrockClient) {
try {
ListFoundationModelsResponse response = bedrockClient.listFoundationModels(r -> {});
List<FoundationModelSummary> models = response.modelSummaries();
if (models.isEmpty()) {
System.out.println("No available foundation models in " + region.toString());
} else {
for (FoundationModelSummary model : models) {
System.out.println("Model ID: " + model.modelId());
System.out.println("Provider: " + model.providerName());
System.out.println("Name: " + model.modelName());
System.out.println();
}
}
return models;
} catch (SdkClientException e) {
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
}
```
List the available Amazon Bedrock foundation models using the asynchronous Amazon Bedrock client.
```java
/**
* Lists Amazon Bedrock foundation models that you can use.
* You can filter the results with the request parameters.
*
* @param bedrockClient The async service client for accessing Amazon Bedrock.
* @return A list of objects containing the foundation models' details
*/
public static List<FoundationModelSummary> listFoundationModels(BedrockAsyncClient bedrockClient) {
try {
CompletableFuture<ListFoundationModelsResponse> future = bedrockClient.listFoundationModels(r -> {});
List<FoundationModelSummary> models = future.get().modelSummaries();
if (models.isEmpty()) {
System.out.println("No available foundation models in " + region.toString());
} else {
for (FoundationModelSummary model : models) {
System.out.println("Model ID: " + model.modelId());
System.out.println("Provider: " + model.providerName());
System.out.println("Name: " + model.modelName());
System.out.println();
}
}
return models;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println(e.getMessage());
throw new RuntimeException(e);
} catch (ExecutionException e) {
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
}
```
- For API details, see
[ListFoundationModels](https://docs.aws.amazon.com/goto/SdkForJavaV2/bedrock-2023-04-20/ListFoundationModels)
in _AWS SDK for Java 2.x API Reference_.
[Document Conventions](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/general/latest/gr/docconventions.html)
AWS Batch
Amazon Bedrock Runtime
Did this page help you? - Yes
Thanks for letting us know we're doing a good job!
If you've got a moment, please tell us what we did right so we can do more of it.
Did this page help you? - No
Thanks for letting us know this page needs work. We're sorry we let you down.
If you've got a moment, please tell us how we can make the documentation better.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,274 @@
# Advanced Model Patterns
## Model-Specific Configuration
### Claude Models Configuration
```java
// Claude 3 Sonnet
public String invokeClaude3Sonnet(BedrockRuntimeClient client, String prompt) {
String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
JSONObject payload = new JSONObject()
.put("anthropic_version", "bedrock-2023-05-31")
.put("max_tokens", 1000)
.put("temperature", 0.7)
.put("top_p", 1.0)
.put("messages", new JSONObject[]{
new JSONObject()
.put("role", "user")
.put("content", prompt)
});
InvokeModelResponse response = client.invokeModel(request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload.toString())));
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
return responseBody.getJSONArray("content")
.getJSONObject(0)
.getString("text");
}
// Claude 3 Haiku (faster, cheaper)
public String invokeClaude3Haiku(BedrockRuntimeClient client, String prompt) {
String modelId = "anthropic.claude-3-haiku-20240307-v1:0";
JSONObject payload = new JSONObject()
.put("anthropic_version", "bedrock-2023-05-31")
.put("max_tokens", 400)
.put("messages", new JSONObject[]{
new JSONObject()
.put("role", "user")
.put("content", prompt)
});
// Similar invocation pattern as above
}
```
### Llama Models Configuration
```java
// Llama 3 70B
public String invokeLlama3_70B(BedrockRuntimeClient client, String prompt) {
String modelId = "meta.llama3-70b-instruct-v1:0";
JSONObject payload = new JSONObject()
.put("prompt", prompt)
.put("max_gen_len", 512)
.put("temperature", 0.7)
.put("top_p", 0.9)
.put("stop", new String[]{"[INST]", "[/INST]"}); // Custom stop tokens
InvokeModelResponse response = client.invokeModel(request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload.toString())));
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
return responseBody.getString("generation");
}
```
## Multi-Model Service Layer
```java
@Service
public class MultiModelService {
private final BedrockRuntimeClient bedrockRuntimeClient;
private final ObjectMapper objectMapper;
public MultiModelService(BedrockRuntimeClient bedrockRuntimeClient,
ObjectMapper objectMapper) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
this.objectMapper = objectMapper;
}
public String invokeModel(String modelId, String prompt, Map<String, Object> additionalParams) {
Map<String, Object> payload = createModelPayload(modelId, prompt, additionalParams);
try {
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(
request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(objectMapper.writeValueAsString(payload))));
return extractResponseContent(modelId, response.body().asUtf8String());
} catch (Exception e) {
throw new RuntimeException("Model invocation failed: " + e.getMessage(), e);
}
}
private Map<String, Object> createModelPayload(String modelId, String prompt,
Map<String, Object> additionalParams) {
Map<String, Object> payload = new HashMap<>();
if (modelId.startsWith("anthropic.claude")) {
payload.put("anthropic_version", "bedrock-2023-05-31");
payload.put("messages", List.of(Map.of("role", "user", "content", prompt)));
// Add common parameters with defaults
payload.putIfAbsent("max_tokens", 1000);
payload.putIfAbsent("temperature", 0.7);
} else if (modelId.startsWith("meta.llama")) {
payload.put("prompt", prompt);
payload.putIfAbsent("max_gen_len", 512);
payload.putIfAbsent("temperature", 0.7);
} else if (modelId.startsWith("amazon.titan")) {
payload.put("inputText", prompt);
payload.putIfAbsent("textGenerationConfig",
Map.of("maxTokenCount", 512, "temperature", 0.7));
}
// Add additional parameters
if (additionalParams != null) {
payload.putAll(additionalParams);
}
return payload;
}
}
```
## Advanced Error Handling
```java
@Component
public class BedrockErrorHandler {
@Retryable(value = {SdkClientException.class}, maxAttempts = 3, backoff = @Backoff(delay = 1000))
public String invokeWithRetry(BedrockRuntimeClient client, String modelId,
String payloadJson) {
try {
InvokeModelResponse response = client.invokeModel(request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payloadJson)));
return response.body().asUtf8String();
} catch (ThrottlingException e) {
// Exponential backoff for throttling
throw new RuntimeException("Rate limit exceeded, please try again later", e);
} catch (ValidationException e) {
throw new IllegalArgumentException("Invalid request: " + e.getMessage(), e);
} catch (SdkException e) {
throw new RuntimeException("AWS SDK error: " + e.getMessage(), e);
}
}
}
```
## Batch Processing
```java
@Service
public class BedrockBatchService {
public List<String> processBatch(BedrockRuntimeClient client, String modelId,
List<String> prompts) {
return prompts.parallelStream()
.map(prompt -> invokeModelWithTimeout(client, modelId, prompt, 30))
.collect(Collectors.toList());
}
private String invokeModelWithTimeout(BedrockRuntimeClient client, String modelId,
String prompt, int timeoutSeconds) {
ExecutorService executor = Executors.newSingleThreadExecutor();
Future<String> future = executor.submit(() -> {
JSONObject payload = new JSONObject()
.put("prompt", prompt)
.put("max_tokens", 500);
InvokeModelResponse response = client.invokeModel(request -> request
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload.toString())));
return response.body().asUtf8String();
});
try {
return future.get(timeoutSeconds, TimeUnit.SECONDS);
} catch (TimeoutException e) {
future.cancel(true);
throw new RuntimeException("Model invocation timed out");
} catch (Exception e) {
throw new RuntimeException("Batch processing error", e);
} finally {
executor.shutdown();
}
}
}
```
## Model Performance Optimization
```java
@Configuration
public class BedrockOptimizationConfig {
@Bean
public BedrockRuntimeClient optimizedBedrockRuntimeClient() {
return BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.overrideConfiguration(ClientOverrideConfiguration.builder()
.apiCallTimeout(Duration.ofSeconds(30))
.apiCallAttemptTimeout(Duration.ofSeconds(20))
.build())
.httpClient(ApacheHttpClient.builder()
.connectionTimeout(Duration.ofSeconds(10))
.socketTimeout(Duration.ofSeconds(30))
.build())
.build();
}
}
```
## Custom Response Parsing
```java
public class BedrockResponseParser {
public static TextResponse parseTextResponse(String modelId, String responseBody) {
try {
switch (getModelProvider(modelId)) {
case ANTHROPIC:
return parseAnthropicResponse(responseBody);
case META:
return parseMetaResponse(responseBody);
case AMAZON:
return parseAmazonResponse(responseBody);
default:
throw new IllegalArgumentException("Unsupported model: " + modelId);
}
} catch (Exception e) {
throw new ResponseParsingException("Failed to parse response for model: " + modelId, e);
}
}
private static TextResponse parseAnthropicResponse(String responseBody) throws JSONException {
JSONObject json = new JSONObject(responseBody);
JSONArray content = json.getJSONArray("content");
String text = content.getJSONObject(0).getString("text");
int usage = json.getJSONObject("usage").getInt("input_tokens");
return new TextResponse(text, usage, "anthropic");
}
private static TextResponse parseMetaResponse(String responseBody) throws JSONException {
JSONObject json = new JSONObject(responseBody);
String text = json.getString("generation");
// Note: Meta doesn't provide token usage in basic response
return new TextResponse(text, 0, "meta");
}
private enum ModelProvider {
ANTHROPIC, META, AMAZON
}
public record TextResponse(String content, int tokensUsed, String provider) {}
}
```

View File

@@ -0,0 +1,372 @@
# Advanced Amazon Bedrock Topics
This document covers advanced patterns and topics for working with Amazon Bedrock using AWS SDK for Java 2.x.
## Multi-Model Service Pattern
Create a service that can handle multiple foundation models with unified interfaces.
```java
@Service
public class MultiModelAIService {
private final BedrockRuntimeClient bedrockRuntimeClient;
public MultiModelAIService(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
}
public GenerationResult generate(GenerationRequest request) {
String modelId = request.getModelId();
String prompt = request.getPrompt();
switch (getModelProvider(modelId)) {
case ANTHROPIC:
return generateWithAnthropic(modelId, prompt, request.getConfig());
case AMAZON:
return generateWithAmazon(modelId, prompt, request.getConfig());
case META:
return generateWithMeta(modelId, prompt, request.getConfig());
default:
throw new IllegalArgumentException("Unsupported model provider: " + modelId);
}
}
private GenerationProvider getModelProvider(String modelId) {
if (modelId.startsWith("anthropic.")) return GenerationProvider.ANTHROPIC;
if (modelId.startsWith("amazon.")) return GenerationProvider.AMazon;
if (modelId.startsWith("meta.")) return GenerationProvider.META;
throw new IllegalArgumentException("Unknown provider for model: " + modelId);
}
}
```
## Advanced Error Handling with Retries
Implement robust error handling with exponential backoff:
```java
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.backoff.BackoffStrategy;
import software.amazon.awssdk.core.retry.conditions.RetryCondition;
import software.amazon.awssdk.core.retry.predicates.RetryExceptionPredicates;
public class BedrockWithRetry {
private final BedrockRuntimeClient client;
private final RetryPolicy retryPolicy;
public BedrockWithRetry(BedrockRuntimeClient client) {
this.client = client;
this.retryPolicy = RetryPolicy.builder()
.numRetries(3)
.retryCondition(RetryExceptionPredicates.equalTo(
ThrottlingException.class))
.backoffStrategy(BackoffStrategy.defaultStrategy())
.build();
}
public String invokeModelWithRetry(String modelId, String payload) {
try {
InvokeModelRequest request = InvokeModelRequest.builder()
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload))
.build();
InvokeModelResponse response = client.invokeModel(request);
return response.body().asUtf8String();
} catch (ThrottlingException e) {
throw new BedrockThrottledException("Rate limit exceeded for model: " + modelId, e);
} catch (ValidationException e) {
throw new BedrockValidationException("Invalid request for model: " + modelId, e);
}
}
}
```
## Batch Processing Strategies
Process multiple requests efficiently:
```java
@Service
public class BatchGenerationService {
private final BedrockRuntimeClient bedrockRuntimeClient;
public BatchGenerationService(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
}
public List<BatchResult> processBatch(List<BatchRequest> requests) {
// Process in parallel
return requests.parallelStream()
.map(this::processSingleRequest)
.collect(Collectors.toList());
}
private BatchResult processSingleRequest(BatchRequest request) {
try {
InvokeModelRequest modelRequest = InvokeModelRequest.builder()
.modelId(request.getModelId())
.body(SdkBytes.fromUtf8String(request.getPayload()))
.build();
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(modelRequest);
return BatchResult.success(
request.getRequestId(),
response.body().asUtf8String()
);
} catch (Exception e) {
return BatchResult.failure(request.getRequestId(), e.getMessage());
}
}
}
```
## Performance Optimization Techniques
### Connection Pooling
```java
import software.amazon.awssdk.http.nio.apache.ApacheHttpClient;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.regions.Region;
public class BedrockClientFactory {
public static BedrockRuntimeClient createOptimizedClient() {
ApacheHttpClient httpClient = ApacheHttpClient.builder()
.connectionPoolMaxConnections(50)
.socketTimeout(Duration.ofSeconds(30))
.connectionTimeout(Duration.ofSeconds(30))
.build();
return BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.httpClient(httpClient)
.build();
}
}
```
### Response Caching
```java
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
@Service
public class CachedAIService {
private final BedrockRuntimeClient bedrockRuntimeClient;
private final Cache<String, String> responseCache;
public CachedAIService(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
this.responseCache = Caffeine.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, TimeUnit.HOURS)
.build();
}
public String generateText(String prompt, String modelId) {
String cacheKey = modelId + ":" + prompt.hashCode();
return responseCache.get(cacheKey, key -> {
String payload = createPayload(modelId, prompt);
InvokeModelRequest request = InvokeModelRequest.builder()
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload))
.build();
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(request);
return response.body().asUtf8String();
});
}
}
```
## Custom Response Parsing
Create specialized parsers for different model responses:
```java
public interface ResponseParser {
String parse(String responseJson);
}
public class AnthropicResponseParser implements ResponseParser {
@Override
public String parse(String responseJson) {
try {
JSONObject jsonResponse = new JSONObject(responseJson);
return jsonResponse.getJSONArray("content")
.getJSONObject(0)
.getString("text");
} catch (Exception e) {
throw new ResponseParsingException("Failed to parse Anthropic response", e);
}
}
}
public class AmazonTitanResponseParser implements ResponseParser {
@Override
public String parse(String responseJson) {
try {
JSONObject jsonResponse = new JSONObject(responseJson);
return jsonResponse.getJSONArray("results")
.getJSONObject(0)
.getString("outputText");
} catch (Exception e) {
throw new ResponseParsingException("Failed to parse Amazon Titan response", e);
}
}
}
public class LlamaResponseParser implements ResponseParser {
@Override
public String parse(String responseJson) {
try {
JSONObject jsonResponse = new JSONObject(responseJson);
return jsonResponse.getString("generation");
} catch (Exception e) {
throw new ResponseParsingException("Failed to parse Llama response", e);
}
}
}
```
## Metrics and Monitoring
Implement comprehensive monitoring:
```java
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
@Service
public class MonitoredAIService {
private final BedrockRuntimeClient bedrockRuntimeClient;
private final Timer generationTimer;
private final Counter errorCounter;
public MonitoredAIService(BedrockRuntimeClient bedrockRuntimeClient,
MeterRegistry meterRegistry) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
this.generationTimer = Timer.builder("bedrock.generation.time")
.description("Time spent generating text with Bedrock")
.register(meterRegistry);
this.errorCounter = Counter.builder("bedrock.generation.errors")
.description("Number of generation errors")
.register(meterRegistry);
}
public String generateText(String prompt, String modelId) {
return generationTimer.record(() -> {
try {
String payload = createPayload(modelId, prompt);
InvokeModelRequest request = InvokeModelRequest.builder()
.modelId(modelId)
.body(SdkBytes.fromUtf8String(payload))
.build();
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(request);
return response.body().asUtf8String();
} catch (Exception e) {
errorCounter.increment();
throw new GenerationException("Failed to generate text", e);
}
});
}
}
```
## Advanced Configuration Management
```java
@Configuration
@ConfigurationProperties(prefix = "bedrock")
public class AdvancedBedrockConfiguration {
private String defaultRegion = "us-east-1";
private int maxRetries = 3;
private Duration timeout = Duration.ofSeconds(30);
private boolean enableMetrics = true;
private int maxCacheSize = 1000;
private Duration cacheExpireAfter = Duration.ofHours(1);
@Bean
@Primary
public BedrockRuntimeClient bedrockRuntimeClient() {
BedrockRuntimeClient.Builder builder = BedrockRuntimeClient.builder()
.region(Region.of(defaultRegion));
if (enableMetrics) {
builder.overrideConfiguration(c -> c.putAdvancedProperty(
"metrics.enabled", "true"));
}
return builder.build();
}
// Getters and setters
}
```
## Streaming Response Handling
Advanced streaming with proper backpressure handling:
```java
@Service
public class StreamingAIService {
private final BedrockRuntimeClient bedrockRuntimeClient;
public StreamingAIService(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
}
public Flux<String> streamResponse(String modelId, String prompt) {
InvokeModelWithResponseStreamRequest request =
InvokeModelWithResponseStreamRequest.builder()
.modelId(modelId)
.body(SdkBytes.fromUtf8String(createPayload(modelId, prompt)))
.build();
return Mono.fromCallable(() ->
bedrockRuntimeClient.invokeModelWithResponseStream(request))
.flatMapMany(responseStream -> Flux.defer(() ->
Flux.create(sink -> {
responseStream.stream().forEach(event -> {
if (event instanceof PayloadPart) {
PayloadPart payloadPart = (PayloadPart) event;
String chunk = payloadPart.bytes().asUtf8String();
processChunk(chunk, sink);
}
});
sink.complete();
}))
)
.onErrorResume(e -> Flux.error(new StreamingException("Stream failed", e)));
}
private void processChunk(String chunk, FluxSink<String> sink) {
try {
JSONObject chunkJson = new JSONObject(chunk);
if (chunkJson.getString("type").equals("content_block_delta")) {
String text = chunkJson.getJSONObject("delta").getString("text");
sink.next(text);
}
} catch (Exception e) {
sink.error(new ChunkProcessingException("Failed to process chunk", e));
}
}
}
```

View File

@@ -0,0 +1,18 @@
<!DOCTYPE html>
<!DOCTYPE HTML><html xmlns="http://www.w3.org/1999/xhtml"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"><title>Amazon Bedrock</title><meta xmlns="" name="subtitle" content="API Reference"><meta xmlns="" name="abstract" content="Details about operations and parameters in the Amazon Bedrock API Reference"><meta http-equiv="refresh" content="10;URL=welcome.html"><script type="text/javascript"><!--
var myDefaultPage = "welcome.html";
var myPage = document.location.search.substr(1);
var myHash = document.location.hash;
if (myPage == null || myPage.length == 0) {
myPage = myDefaultPage;
} else {
var docfile = myPage.match(/[^=\;\/?:\s]+\.html/);
if (docfile == null) {
myPage = myDefaultPage;
} else {
myPage = docfile + myHash;
}
}
self.location.replace(myPage);
--></script></head><body></body></html>

View File

@@ -0,0 +1,18 @@
<!DOCTYPE html>
<!DOCTYPE HTML><html xmlns="http://www.w3.org/1999/xhtml"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"><title>Amazon Bedrock</title><meta xmlns="" name="subtitle" content="User Guide"><meta xmlns="" name="abstract" content="User Guide for the Amazon Bedrock service."><meta http-equiv="refresh" content="10;URL=what-is-bedrock.html"><script type="text/javascript"><!--
var myDefaultPage = "what-is-bedrock.html";
var myPage = document.location.search.substr(1);
var myHash = document.location.hash;
if (myPage == null || myPage.length == 0) {
myPage = myDefaultPage;
} else {
var docfile = myPage.match(/[^=\;\/?:\s]+\.html/);
if (docfile == null) {
myPage = myDefaultPage;
} else {
myPage = docfile + myHash;
}
}
self.location.replace(myPage);
--></script></head><body></body></html>

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,148 @@
<!DOCTYPE HTML>
<html lang="en">
<head>
<!-- Generated by javadoc (23) on Tue Oct 28 00:04:26 UTC 2025 -->
<title>software.amazon.awssdk.services.bedrock (AWS SDK for Java - 2.36.3)</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<meta name="dc.created" content="2025-10-28">
<meta name="description" content="declaration: package: software.amazon.awssdk.services.bedrock">
<meta name="generator" content="javadoc/PackageWriter">
<link rel="stylesheet" type="text/css" href="../../../../../resource-files/jquery-ui.min.css" title="Style">
<link rel="stylesheet" type="text/css" href="../../../../../resource-files/stylesheet.css" title="Style">
<link rel="stylesheet" type="text/css" href="../../../../../resource-files/aws-sdk-java-v2-javadoc.css" title="Style">
<script type="text/javascript" src="../../../../../script-files/script.js"></script>
<script type="text/javascript" src="../../../../../script-files/jquery-3.7.1.min.js"></script>
<script type="text/javascript" src="../../../../../script-files/jquery-ui.min.js"></script>
</head>
<body class="package-declaration-page">
<script type="text/javascript">const pathtoroot = "../../../../../";
loadScripts(document, 'script');</script>
<noscript>
<div>JavaScript is disabled on your browser.</div>
</noscript>
<header role="banner">
<nav role="navigation">
<!-- ========= START OF TOP NAVBAR ======= -->
<div class="top-nav" id="navbar-top">
<div class="nav-content">
<div class="nav-menu-button"><button id="navbar-toggle-button" aria-controls="navbar-top" aria-expanded="false" aria-label="Toggle navigation links"><span class="nav-bar-toggle-icon">&nbsp;</span><span class="nav-bar-toggle-icon">&nbsp;</span><span class="nav-bar-toggle-icon">&nbsp;</span></button></div>
<div class="skip-nav"><a href="#skip-navbar-top" title="Skip navigation links">Skip navigation links</a></div>
<ul id="navbar-top-firstrow" class="nav-list" title="Navigation">
<li><a href="../../../../../index.html">Overview</a></li>
<li class="nav-bar-cell1-rev">Package</li>
<li><a href="../../../../../index-all.html">Index</a></li>
<li><a href="../../../../../search.html">Search</a></li>
<li><a href="../../../../../help-doc.html#package">Help</a></li>
</ul>
<div class="about-language"><h2>AWS SDK for Java API Reference - 2.36.3</h2></div>
</div>
</div>
<div class="sub-nav">
<div class="nav-content">
<ol class="sub-nav-list">
<li><a href="package-summary.html" class="current-selection">software.amazon.awssdk.services.bedrock</a></li>
</ol>
<div class="nav-list-search">
<input type="text" id="search-input" disabled placeholder="Search" aria-label="Search in documentation" autocomplete="off">
<input type="reset" id="reset-search" disabled value="Reset">
</div>
</div>
</div>
<!-- ========= END OF TOP NAVBAR ========= -->
<span class="skip-nav" id="skip-navbar-top"></span></nav>
</header>
<div class="main-grid">
<nav role="navigation" class="toc" aria-label="Table of contents">
<div class="toc-header">Contents</div>
<button class="hide-sidebar"><span>Hide sidebar&nbsp;</span>&#10094;</button><button class="show-sidebar">&#10095;<span>&nbsp;Show sidebar</span></button>
<ol class="toc-list">
<li><a href="#" tabindex="0">Description</a></li>
<li><a href="#related-package-summary" tabindex="0">Related Packages</a></li>
<li><a href="#class-summary" tabindex="0">Classes and Interfaces</a></li>
</ol>
</nav>
<main role="main">
<div class="header">
<h1 title="Package software.amazon.awssdk.services.bedrock" class="title">Package software.amazon.awssdk.services.bedrock</h1>
</div>
<hr>
<div class="horizontal-scroll">
<div class="package-signature">package <span class="element-name">software.amazon.awssdk.services.bedrock</span></div>
<section class="package-description" id="package-description">
<div class="block"><p>
Describes the API operations for creating, managing, fine-turning, and evaluating Amazon Bedrock models.
</p></div>
</section>
</div>
<section class="summary">
<ul class="summary-list">
<li>
<div id="related-package-summary">
<div class="caption"><span>Related Packages</span></div>
<div class="summary-table two-column-summary">
<div class="table-header col-first">Package</div>
<div class="table-header col-last">Description</div>
<div class="col-first even-row-color"><a href="endpoints/package-summary.html">software.amazon.awssdk.services.bedrock.endpoints</a></div>
<div class="col-last even-row-color">&nbsp;</div>
<div class="col-first odd-row-color"><a href="internal/package-summary.html">software.amazon.awssdk.services.bedrock.internal</a></div>
<div class="col-last odd-row-color">&nbsp;</div>
<div class="col-first even-row-color"><a href="model/package-summary.html">software.amazon.awssdk.services.bedrock.model</a></div>
<div class="col-last even-row-color">&nbsp;</div>
<div class="col-first odd-row-color"><a href="paginators/package-summary.html">software.amazon.awssdk.services.bedrock.paginators</a></div>
<div class="col-last odd-row-color">&nbsp;</div>
<div class="col-first even-row-color"><a href="transform/package-summary.html">software.amazon.awssdk.services.bedrock.transform</a></div>
<div class="col-last even-row-color">&nbsp;</div>
</div>
</div>
</li>
<li>
<div id="class-summary">
<div class="table-tabs" role="tablist" aria-orientation="horizontal"><button id="class-summary-tab0" role="tab" aria-selected="true" aria-controls="class-summary.tabpanel" tabindex="0" onkeydown="switchTab(event)" onclick="show('class-summary', 'class-summary', 2)" class="active-table-tab">All Classes and Interfaces</button><button id="class-summary-tab1" role="tab" aria-selected="false" aria-controls="class-summary.tabpanel" tabindex="-1" onkeydown="switchTab(event)" onclick="show('class-summary', 'class-summary-tab1', 2)" class="table-tab">Interfaces</button><button id="class-summary-tab2" role="tab" aria-selected="false" aria-controls="class-summary.tabpanel" tabindex="-1" onkeydown="switchTab(event)" onclick="show('class-summary', 'class-summary-tab2', 2)" class="table-tab">Classes</button></div>
<div id="class-summary.tabpanel" role="tabpanel" aria-labelledby="class-summary-tab0">
<div class="summary-table two-column-summary">
<div class="table-header col-first">Class</div>
<div class="table-header col-last">Description</div>
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockAsyncClient.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockAsyncClient</a></div>
<div class="col-last even-row-color class-summary class-summary-tab1">
<div class="block">Service client for accessing Amazon Bedrock asynchronously.</div>
</div>
<div class="col-first odd-row-color class-summary class-summary-tab1"><a href="BedrockAsyncClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockAsyncClientBuilder</a></div>
<div class="col-last odd-row-color class-summary class-summary-tab1">
<div class="block">A builder for creating an instance of <a href="BedrockAsyncClient.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockAsyncClient</code></a>.</div>
</div>
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockBaseClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockBaseClientBuilder</a>&lt;B extends <a href="BedrockBaseClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockBaseClientBuilder</a>&lt;B,<wbr>C&gt;,<wbr>C&gt;</div>
<div class="col-last even-row-color class-summary class-summary-tab1">
<div class="block">This includes configuration specific to Amazon Bedrock that is supported by both <a href="BedrockClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockClientBuilder</code></a> and
<a href="BedrockAsyncClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockAsyncClientBuilder</code></a>.</div>
</div>
<div class="col-first odd-row-color class-summary class-summary-tab1"><a href="BedrockClient.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockClient</a></div>
<div class="col-last odd-row-color class-summary class-summary-tab1">
<div class="block">Service client for accessing Amazon Bedrock.</div>
</div>
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockClientBuilder</a></div>
<div class="col-last even-row-color class-summary class-summary-tab1">
<div class="block">A builder for creating an instance of <a href="BedrockClient.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockClient</code></a>.</div>
</div>
<div class="col-first odd-row-color class-summary class-summary-tab2"><a href="BedrockServiceClientConfiguration.html" title="class in software.amazon.awssdk.services.bedrock">BedrockServiceClientConfiguration</a></div>
<div class="col-last odd-row-color class-summary class-summary-tab2">
<div class="block">Class to expose the service client settings to the user.</div>
</div>
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockServiceClientConfiguration.Builder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockServiceClientConfiguration.Builder</a></div>
<div class="col-last even-row-color class-summary class-summary-tab1">
<div class="block">A builder for creating a <a href="BedrockServiceClientConfiguration.html" title="class in software.amazon.awssdk.services.bedrock"><code>BedrockServiceClientConfiguration</code></a></div>
</div>
</div>
</div>
</div>
</li>
</ul>
</section>
<footer role="contentinfo">
<hr>
<p class="legal-copy"><small><div style="margin:1.2em;"><h3><a id="fdbk" target="_blank">Provide feedback</a><h3></div> <span id="awsdocs-legal-zone-copyright"></span> <script type="text/javascript">document.addEventListener("DOMContentLoaded",()=>{ var a=document.createElement("meta"),b=document.createElement("meta"),c=document.createElement("script"), h=document.getElementsByTagName("head")[0],l=location.href,f=document.getElementById("fdbk"); a.name="guide-name",a.content="API Reference";b.name="service-name",b.content="AWS SDK for Java"; c.setAttribute("type","text/javascript"),c.setAttribute("src", "https://docs.aws.amazon.com/assets/js/awsdocs-boot.js");h.appendChild(a);h.appendChild(b); h.appendChild(c);f.setAttribute("href", "https://docs-feedback.aws.amazon.com/feedback.jsp?hidden_service_name="+ encodeURI("AWS SDK for Java")+"&topic_url="+encodeURI(l))}); </script></small></p>
</footer>
</main>
</div>
</body>
</html>

View File

@@ -0,0 +1,340 @@
# Model Reference
## Supported Foundation Models
### Amazon Models
#### Amazon Titan Text
**Model ID:** `amazon.titan-text-express-v1`
- **Description:** High-quality text generation model
- **Context Window:** Up to 8K tokens
- **Languages:** English, Spanish, French, German, Italian, Portuguese
**Payload Format:**
```json
{
"inputText": "Your prompt here",
"textGenerationConfig": {
"maxTokenCount": 512,
"temperature": 0.7,
"topP": 0.9
}
}
```
**Response Format:**
```json
{
"results": [{
"outputText": "Generated text"
}]
}
```
#### Amazon Titan Text Lite
**Model ID:** `amazon.titan-text-lite-v1`
- **Description:** Cost-effective text generation model
- **Context Window:** Up to 4K tokens
- **Use Case:** Simple text generation tasks
#### Amazon Titan Embeddings
**Model ID:** `amazon.titan-embed-text-v1`
- **Description:** High-quality text embeddings
- **Context Window:** 8K tokens
- **Output:** 1024-dimensional vector
**Payload Format:**
```json
{
"inputText": "Your text here"
}
```
**Response Format:**
```json
{
"embedding": [0.1, -0.2, 0.3, ...]
}
```
#### Amazon Titan Image Generator
**Model ID:** `amazon.titan-image-generator-v1`
- **Description:** High-quality image generation
- **Image Size:** 512x512, 1024x1024
- **Use Case:** Text-to-image generation
**Payload Format:**
```json
{
"taskType": "TEXT_IMAGE",
"textToImageParams": {
"text": "Your description"
},
"imageGenerationConfig": {
"numberOfImages": 1,
"quality": "standard",
"cfgScale": 8.0,
"height": 512,
"width": 512,
"seed": 12345
}
}
```
### Anthropic Models
#### Claude 3.5 Sonnet
**Model ID:** `anthropic.claude-3-5-sonnet-20241022-v2:0`
- **Description:** High-performance model for complex reasoning, analysis, and creative tasks
- **Context Window:** 200K tokens
- **Languages:** Multiple languages supported
- **Use Case:** Code generation, complex analysis, creative writing, research
- **Features:** Tool use, function calling, JSON mode
**Payload Format:**
```json
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1000,
"messages": [{
"role": "user",
"content": "Your message"
}]
}
```
**Response Format:**
```json
{
"content": [{
"text": "Response content"
}],
"usage": {
"input_tokens": 10,
"output_tokens": 20
}
}
```
#### Claude 3.5 Haiku
**Model ID:** `anthropic.claude-3-5-haiku-20241022-v2:0`
- **Description:** Fast and affordable model for real-time applications
- **Context Window:** 200K tokens
- **Use Case:** Real-time applications, chatbots, quick responses
- **Features:** Tool use, function calling, JSON mode
#### Claude 3 Opus
**Model ID:** `anthropic.claude-3-opus-20240229-v1:0`
- **Description:** Most capable model
- **Context Window:** 200K tokens
- **Use Case:** Complex reasoning, analysis
#### Claude 3 Sonnet (Legacy)
**Model ID:** `anthropic.claude-3-sonnet-20240229-v1:0`
- **Description:** Previous generation model
- **Context Window:** 200K tokens
- **Use Case:** General purpose applications
### Meta Models
#### Llama 3.1 70B
**Model ID:** `meta.llama3-1-70b-instruct-v1:0`
- **Description:** Latest generation large open-source model
- **Context Window:** 128K tokens
- **Use Case:** General purpose instruction following, complex reasoning
- **Features:** Improved instruction following, larger context window
#### Llama 3.1 8B
**Model ID:** `meta.llama3-1-8b-instruct-v1:0`
- **Description:** Latest generation small fast model
- **Context Window:** 8K tokens
- **Use Case:** Fast inference, lightweight applications
#### Llama 3 70B
**Model ID:** `meta.llama3-70b-instruct-v1:0`
- **Description:** Previous generation large open-source model
- **Context Window:** 8K tokens
- **Use Case:** General purpose instruction following
**Payload Format:**
```json
{
"prompt": "[INST] Your prompt here [/INST]",
"max_gen_len": 512,
"temperature": 0.7,
"top_p": 0.9
}
```
**Response Format:**
```json
{
"generation": "Generated text"
}
```
#### Llama 3 8B
**Model ID:** `meta.llama3-8b-instruct-v1:0`
- **Description:** Smaller, faster version
- **Context Window:** 8K tokens
- **Use Case:** Fast inference, lightweight applications
### Stability AI Models
#### Stable Diffusion XL
**Model ID:** `stability.stable-diffusion-xl-v1`
- **Description:** High-quality image generation
- **Image Size:** Up to 1024x1024
- **Use Case:** Text-to-image generation, art creation
**Payload Format:**
```json
{
"text_prompts": [{
"text": "Your description"
}],
"style_preset": "photographic",
"seed": 12345,
"cfg_scale": 10,
"steps": 50
}
```
**Response Format:**
```json
{
"artifacts": [{
"base64": "base64-encoded-image-data",
"finishReason": "SUCCESS"
}]
}
```
### Other Models
#### Cohere Command
**Model ID:** `cohere.command-text-v14`
- **Description:** Text generation model
- **Context Window:** 128K tokens
- **Use Case:** Content generation, summarization
#### Mistral Models
**Model ID:** `mistral.mistral-7b-instruct-v0:2`
- **Description:** High-performing open-source model
- **Context Window:** 32K tokens
- **Use Case:** Instruction following, code generation
**Model ID:** `mistral.mixtral-8x7b-instruct-v0:1`
- **Description:** Mixture of experts model
- **Context Window:** 32K tokens
- **Use Case:** Complex reasoning tasks
## Model Selection Guide
### Use Case Recommendations
| Use Case | Recommended Models | Notes |
|----------|-------------------|-------|
| **General Chat/Chatbots** | Claude 3.5 Haiku, Llama 3 8B | Fast response times |
| **Content Creation** | Claude 3.5 Sonnet, Cohere | Creative, coherent outputs |
| **Code Generation** | Claude 3.5 Sonnet, Llama 3.1 70B | Excellent understanding |
| **Analysis & Reasoning** | Claude 3 Opus, Claude 3.5 Sonnet | Complex reasoning |
| **Real-time Applications** | Claude 3.5 Haiku, Titan Lite | Fast inference |
| **Cost-sensitive Apps** | Titan Lite, Claude 3.5 Haiku | Lower cost per token |
| **High Quality** | Claude 3 Opus, Claude 3.5 Sonnet | Premium quality |
### Performance Characteristics
| Model | Speed | Cost | Quality | Context Window |
|-------|-------|------|---------|----------------|
| Claude 3 Opus | Slow | High | Excellent | 200K |
| Claude 3.5 Sonnet | Medium | Medium | Excellent | 200K |
| Claude 3.5 Haiku | Fast | Low | Good | 200K |
| Claude 3 Sonnet (Legacy) | Medium | Medium | Good | 200K |
| Llama 3.1 70B | Medium | Medium | Good | 128K |
| Llama 3.1 8B | Fast | Low | Fair | 8K |
| Llama 3 70B | Medium | Medium | Good | 8K |
| Llama 3 8B | Fast | Low | Fair | 8K |
| Titan Express | Fast | Medium | Good | 8K |
| Titan Lite | Fast | Low | Fair | 4K |
## Model Comparison Matrix
| Feature | Claude 3 | Llama 3 | Titan | Stability |
|---------|----------|---------|-------|-----------|
| **Streaming** | ✅ | ✅ | ✅ | ❌ |
| **Tool Use** | ✅ | ❌ | ❌ | ❌ |
| **Image Generation** | ❌ | ❌ | ✅ | ✅ |
| **Embeddings** | ❌ | ❌ | ✅ | ❌ |
| **Multiple Languages** | ✅ | ✅ | ✅ | ✅ |
| **Context Window** | 200K | 8K | 8K | N/A |
| **Open Source** | ❌ | ✅ | ❌ | ✅ |
## Model Configuration Templates
### Text Generation Template
```java
private static JSONObject createTextGenerationPayload(String modelId, String prompt) {
JSONObject payload = new JSONObject();
if (modelId.startsWith("anthropic.claude")) {
payload.put("anthropic_version", "bedrock-2023-05-31");
payload.put("max_tokens", 1000);
payload.put("messages", new JSONObject[]{new JSONObject()
.put("role", "user")
.put("content", prompt)
});
} else if (modelId.startsWith("meta.llama")) {
payload.put("prompt", "[INST] " + prompt + " [/INST]");
payload.put("max_gen_len", 512);
} else if (modelId.startsWith("amazon.titan")) {
payload.put("inputText", prompt);
payload.put("textGenerationConfig", new JSONObject()
.put("maxTokenCount", 512)
.put("temperature", 0.7)
);
}
return payload;
}
```
### Image Generation Template
```java
private static JSONObject createImageGenerationPayload(String modelId, String prompt) {
JSONObject payload = new JSONObject();
if (modelId.equals("amazon.titan-image-generator-v1")) {
payload.put("taskType", "TEXT_IMAGE");
payload.put("textToImageParams", new JSONObject().put("text", prompt));
payload.put("imageGenerationConfig", new JSONObject()
.put("numberOfImages", 1)
.put("quality", "standard")
.put("height", 512)
.put("width", 512)
);
} else if (modelId.equals("stability.stable-diffusion-xl-v1")) {
payload.put("text_prompts", new JSONObject[]{new JSONObject().put("text", prompt)});
payload.put("style_preset", "photographic");
payload.put("steps", 50);
payload.put("cfg_scale", 10);
}
return payload;
}
```

View File

@@ -0,0 +1,121 @@
# Model ID Lookup Guide
This document provides quick lookup for the most commonly used model IDs in Amazon Bedrock.
## Text Generation Models
### Claude (Anthropic)
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Claude 4.5 Sonnet | `anthropic.claude-sonnet-4-5-20250929-v1:0` | Latest high-performance model | Complex reasoning, coding, creative tasks |
| Claude 4.5 Haiku | `anthropic.claude-haiku-4-5-20251001-v1:0` | Latest fast model | Real-time applications, chatbots |
| Claude 3.7 Sonnet | `anthropic.claude-3-7-sonnet-20250219-v1:0` | Most advanced reasoning | High-stakes decisions, complex analysis |
| Claude Opus 4.1 | `anthropic.claude-opus-4-1-20250805-v1:0` | Most powerful creative | Advanced creative tasks |
| Claude 3.5 Sonnet v2 | `anthropic.claude-3-5-sonnet-20241022-v2:0` | High-performance model | General use, coding |
| Claude 3.5 Haiku | `anthropic.claude-3-5-haiku-20241022-v1:0` | Fast and affordable | Real-time applications |
### Llama (Meta)
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Llama 3.3 70B | `meta.llama3-3-70b-instruct-v1:0` | Latest generation | Complex reasoning, general use |
| Llama 3.2 90B | `meta.llama3-2-90b-instruct-v1:0` | Large context | Long context tasks |
| Llama 3.2 11B | `meta.llama3-2-11b-instruct-v1:0` | Medium model | Balanced performance |
| Llama 3.2 3B | `meta.llama3-2-3b-instruct-v1:0` | Small model | Fast inference |
| Llama 3.2 1B | `meta.llama3-2-1b-instruct-v1:0` | Ultra-fast | Quick responses |
| Llama 3.1 70B | `meta.llama3-1-70b-instruct-v1:0` | Previous gen | General use |
| Llama 3.1 8B | `meta.llama3-1-8b-instruct-v1:0` | Fast small model | Lightweight applications |
### Mistral AI
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Mistral Large 2407 | `mistral.mistral-large-2407-v1:0` | Latest large model | Complex reasoning |
| Mistral Large 2402 | `mistral.mistral-large-2402-v1:0` | Previous large model | General use |
| Mistral Pixtral 2502 | `mistral.pixtral-large-2502-v1:0` | Multimodal | Text + image understanding |
| Mistral 7B | `mistral.mistral-7b-instruct-v0:2` | Small fast model | Quick responses |
### Amazon
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Titan Text Express | `amazon.titan-text-express-v1` | Fast text generation | Quick responses |
| Titan Text Lite | `amazon.titan-text-lite-v1` | Cost-effective | Budget-sensitive apps |
| Titan Embeddings | `amazon.titan-embed-text-v1` | Text embeddings | Semantic search |
### Cohere
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Command R+ | `cohere.command-r-plus-v1:0` | High performance | Complex tasks |
| Command R | `cohere.command-r-v1:0` | General purpose | Standard use cases |
## Image Generation Models
### Stability AI
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Stable Diffusion 3.5 Large | `stability.sd3-5-large-v1:0` | Latest image gen | High-quality images |
| Stable Diffusion XL | `stability.stable-diffusion-xl-v1` | Previous generation | General image generation |
### Amazon Nova
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Nova Canvas | `amazon.nova-canvas-v1:0` | Image generation | Creative images |
| Nova Reel | `amazon.nova-reel-v1:1` | Video generation | Video content |
## Embedding Models
### Amazon
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Titan Embeddings | `amazon.titan-embed-text-v1` | Text embeddings | Semantic search |
| Titan Embeddings V2 | `amazon.titan-embed-text-v2:0` | Improved embeddings | Better accuracy |
### Cohere
| Model | Model ID | Description | Use Case |
|-------|----------|-------------|----------|
| Embed English | `cohere.embed-english-v3` | English embeddings | English content |
| Embed Multilingual | `cohere.embed-multilingual-v3` | Multi-language | International use |
## Selection Guide
### By Speed
1. **Fastest**: Llama 3.2 1B, Claude 4.5 Haiku, Titan Lite
2. **Fast**: Mistral 7B, Llama 3.2 3B
3. **Medium**: Claude 3.5 Sonnet, Llama 3.2 11B
4. **Slow**: Claude 4.5 Sonnet, Llama 3.3 70B
### By Quality
1. **Highest**: Claude 4.5 Sonnet, Claude 3.7 Sonnet, Claude Opus 4.1
2. **High**: Claude 3.5 Sonnet, Llama 3.3 70B
3. **Medium**: Mistral Large, Llama 3.2 11B
4. **Basic**: Mistral 7B, Llama 3.2 3B
### By Cost
1. **Most Affordable**: Claude 4.5 Haiku, Llama 3.2 1B
2. **Affordable**: Mistral 7B, Titan Lite
3. **Medium**: Claude 3.5 Haiku, Llama 3.2 3B
4. **Expensive**: Claude 4.5 Sonnet, Llama 3.3 70B
## Common Patterns
### Default Model Selection
```java
// For most applications
String DEFAULT_MODEL = "anthropic.claude-sonnet-4-5-20250929-v1:0";
// For real-time applications
String FAST_MODEL = "anthropic.claude-haiku-4-5-20251001-v1:0";
// For budget-sensitive applications
String CHEAP_MODEL = "amazon.titan-text-lite-v1";
// For complex reasoning
String POWERFUL_MODEL = "anthropic.claude-3-7-sonnet-20250219-v1:0";
```
### Model Fallback Chain
```java
private static final String[] MODEL_CHAIN = {
"anthropic.claude-sonnet-4-5-20250929-v1:0", // Primary
"anthropic.claude-haiku-4-5-20251001-v1:0", // Fast fallback
"amazon.titan-text-lite-v1" // Cheap fallback
};
```

View File

@@ -0,0 +1,365 @@
# Testing Strategies
## Unit Testing
### Mocking Bedrock Clients
```java
@ExtendWith(MockitoExtension.class)
class BedrockServiceTest {
@Mock
private BedrockRuntimeClient bedrockRuntimeClient;
@InjectMocks
private BedrockAIService aiService;
@Test
void shouldGenerateTextWithClaude() {
// Arrange
String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
String prompt = "Hello, world!";
String expectedResponse = "Hello! How can I help you today?";
InvokeModelResponse mockResponse = InvokeModelResponse.builder()
.body(SdkBytes.fromUtf8String(
"{\"content\":[{\"text\":\"" + expectedResponse + "\"}]}"))
.build();
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
.thenReturn(mockResponse);
// Act
String result = aiService.generateText(prompt, modelId);
// Assert
assertThat(result).isEqualTo(expectedResponse);
verify(bedrockRuntimeClient).invokeModel(argThat(request ->
request.modelId().equals(modelId)));
}
@Test
void shouldHandleThrottling() {
// Arrange
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
.thenThrow(ThrottlingException.builder()
.message("Rate limit exceeded")
.build());
// Act & Assert
assertThatThrownBy(() -> aiService.generateText("test"))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("Rate limit exceeded");
}
}
```
### Testing Error Conditions
```java
@Test
void shouldHandleInvalidModelId() {
String invalidModelId = "invalid.model.id";
String prompt = "test";
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
.thenThrow(ValidationException.builder()
.message("Invalid model identifier")
.build());
assertThatThrownBy(() -> aiService.generateText(prompt, invalidModelId))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Invalid model identifier");
}
```
### Testing Multiple Models
```java
@ParameterizedTest
@EnumSource(ModelProvider.class)
void shouldSupportAllModels(ModelProvider modelProvider) {
String prompt = "Hello";
String modelId = modelProvider.getModelId();
String expectedResponse = "Response";
InvokeModelResponse mockResponse = InvokeModelResponse.builder()
.body(SdkBytes.fromUtf8String(createMockResponse(modelProvider, expectedResponse)))
.build();
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
.thenReturn(mockResponse);
String result = aiService.generateText(prompt, modelId);
assertThat(result).isEqualTo(expectedResponse);
}
private enum ModelProvider {
CLAUDE("anthropic.claude-3-sonnet-20240229-v1:0"),
LLAMA("meta.llama3-70b-instruct-v1:0"),
TITAN("amazon.titan-text-express-v1");
private final String modelId;
ModelProvider(String modelId) {
this.modelId = modelId;
}
public String getModelId() {
return modelId;
}
}
```
## Integration Testing
### Testcontainers Integration
```java
@Testcontainers
@SpringBootTest(classes = BedrockConfiguration.class)
@ActiveProfiles("test")
class BedrockIntegrationTest {
@Container
static LocalStackContainer localStack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:latest"))
.withServices(AWSService BEDROCK_RUNTIME)
.withEnv("DEFAULT_REGION", "us-east-1");
@Autowired
private BedrockRuntimeClient bedrockRuntimeClient;
@Test
void shouldConnectToLocalStack() {
assertThat(bedrockRuntimeClient).isNotNull();
}
@Test
void shouldListFoundationModels() {
ListFoundationModelsResponse response =
bedrockRuntimeClient.listFoundationModels();
assertThat(response.modelSummaries()).isNotEmpty();
}
}
```
### LocalStack Configuration
```java
@Configuration
public class LocalStackConfig {
@Value("${localstack.enabled:true}")
private boolean localStackEnabled;
@Bean
@ConditionalOnProperty(name = "localstack.enabled", havingValue = "true")
public AwsCredentialsProvider localStackCredentialsProvider() {
return StaticCredentialsProvider.create(
new AwsBasicCredentialsAccessKey("test", "test"));
}
@Bean
@ConditionalOnProperty(name = "localstack.enabled", havingValue = "true")
public BedrockRuntimeClient localStackBedrockRuntimeClient(
AwsCredentialsProvider credentialsProvider) {
return BedrockRuntimeClient.builder()
.credentialsProvider(credentialsProvider)
.endpointOverride(localStack.getEndpoint())
.region(Region.US_EAST_1)
.build();
}
}
```
### Performance Testing
```java
@Test
void shouldPerformWithinTimeLimit() {
String prompt = "Performance test prompt";
int iterationCount = 100;
long startTime = System.currentTimeMillis();
for (int i = 0; i < iterationCount; i++) {
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(
request -> request
.modelId("anthropic.claude-3-sonnet-20240229-v1:0")
.body(SdkBytes.fromUtf8String(createPayload(prompt))));
}
long duration = System.currentTimeMillis() - startTime;
double avgTimePerRequest = (double) duration / iterationCount;
assertThat(avgTimePerRequest).isLessThan(5000); // Less than 5 seconds per request
System.out.println("Average response time: " + avgTimePerRequest + "ms");
}
```
## Testing Streaming Responses
### Streaming Handler Testing
```java
@Test
void shouldStreamResponse() throws InterruptedException {
String prompt = "Stream this response";
MockStreamHandler mockHandler = new MockStreamHandler();
InvokeModelWithResponseStreamRequest streamRequest =
InvokeModelWithResponseStreamRequest.builder()
.modelId("anthropic.claude-3-sonnet-20240229-v1:0")
.body(SdkBytes.fromUtf8String(createPayload(prompt)))
.build();
bedrockRuntimeClient.invokeModelWithResponseStream(streamRequest, mockHandler);
// Wait for streaming to complete
mockHandler.awaitCompletion(10, TimeUnit.SECONDS);
assertThat(mockHandler.getStreamedContent()).isNotEmpty();
assertThat(mockHandler.getStreamedContent()).contains(" streamed");
}
private static class MockStreamHandler extends
InvokeModelWithResponseStreamResponseHandler.Visitor {
private final StringBuilder contentBuilder = new StringBuilder();
private final CountDownLatch latch = new CountDownLatch(1);
@Override
public void visit(EventStream eventStream) {
eventStream.forEach(event -> {
if (event instanceof PayloadPart) {
PayloadPart payloadPart = (PayloadPart) event;
String chunk = payloadPart.bytes().asUtf8String();
contentBuilder.append(chunk);
}
});
latch.countDown();
}
public String getStreamedContent() {
return contentBuilder.toString();
}
public void awaitCompletion(long timeout, TimeUnit unit)
throws InterruptedException {
latch.await(timeout, unit);
}
}
```
## Testing Configuration
### Testing Different Regions
```java
@ParameterizedTest
@EnumSource(value = Region.class,
names = {"US_EAST_1", "US_WEST_2", "EU_WEST_1"})
void shouldWorkInAllRegions(Region region) {
BedrockRuntimeClient client = BedrockRuntimeClient.builder()
.region(region)
.build();
assertThat(client).isNotNull();
}
### Testing Authentication
```java
@Test
void shouldUseIamRoleForAuthentication() {
BedrockRuntimeClient client = BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.build();
// Test that client can make basic calls
ListFoundationModelsResponse response = client.listFoundationModels();
assertThat(response).isNotNull();
}
```
## Test Data Management
### Test Response Fixtures
```java
public class BedrockTestFixtures {
public static String createClaudeResponse() {
return "{\"content\":[{\"text\":\"Hello! How can I help you today?\"}]}";
}
public static String createLlamaResponse() {
return "{\"generation\":\"Hello! How can I assist you?\"}";
}
public static String createTitanResponse() {
return "{\"results\":[{\"outputText\":\"Hello! How can I help?\"}]}";
}
public static String createPayload(String prompt) {
return new JSONObject()
.put("anthropic_version", "bedrock-2023-05-31")
.put("max_tokens", 1000)
.put("messages", new JSONObject[]{
new JSONObject()
.put("role", "user")
.put("content", prompt)
})
.toString();
}
}
```
### Integration Test Suite
```java
@Suite
@SelectClasses({
BedrockAIServiceTest.class,
BedrockConfigurationTest.class,
BedrockStreamingTest.class,
BedrockErrorHandlingTest.class
})
public class BedrockTestSuite {
// Integration test suite for all Bedrock functionality
}
```
## Testing Guidelines
### Unit Testing Best Practices
1. **Mock External Dependencies:** Always mock AWS SDK clients in unit tests
2. **Test Error Scenarios:** Include tests for throttling, validation errors, and network issues
3. **Parameterized Tests:** Test multiple models and configurations efficiently
4. **Performance Assertions:** Include basic performance benchmarks
5. **Test Data Fixtures:** Reuse test response data across tests
### Integration Testing Best Practices
1. **Use LocalStack:** Test against LocalStack for local development
2. **Test Multiple Regions:** Verify functionality across different AWS regions
3. **Test Edge Cases:** Include timeout, retry, and concurrent request scenarios
4. **Monitor Performance:** Track response times and error rates
5. **Clean Up Resources:** Ensure proper cleanup after integration tests
### Testing Configuration
```properties
# application-test.properties
localstack.enabled=true
aws.region=us-east-1
bedrock.timeout=5000
bedrock.retry.max-attempts=3
```

View File

@@ -0,0 +1,660 @@
---
name: aws-sdk-java-v2-core
description: Core patterns and best practices for AWS SDK for Java 2.x. Use when configuring AWS service clients, setting up authentication, managing credentials, configuring timeouts, HTTP clients, or following AWS SDK best practices.
category: aws
tags: [aws, java, sdk, core, authentication, configuration]
version: 1.1.0
allowed-tools: Read, Write, Bash
---
# AWS SDK for Java 2.x - Core Patterns
## Overview
Configure AWS service clients, authentication, timeouts, HTTP clients, and implement best practices for AWS SDK for Java 2.x applications. This skill provides essential patterns for building robust, performant, and secure integrations with AWS services.
## When to Use
Use this skill when:
- Setting up AWS SDK for Java 2.x service clients with proper configuration
- Configuring authentication and credential management strategies
- Implementing client lifecycle management and resource cleanup
- Optimizing performance with HTTP client configuration and connection pooling
- Setting up proper timeout configurations for API calls
- Implementing error handling and retry policies
- Enabling monitoring and metrics collection
- Integrating AWS SDK with Spring Boot applications
- Testing AWS integrations with LocalStack and Testcontainers
## Quick Start
### Basic Service Client Setup
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
// Basic client with region
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.build();
// Always close clients when done
try (S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build()) {
// Use client
} // Auto-closed
```
### Basic Authentication
```java
// Uses default credential provider chain
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.build(); // Automatically detects credentials
```
## Client Configuration
### Service Client Builder Pattern
```java
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.metrics.publishers.cloudwatch.CloudWatchMetricPublisher;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import java.time.Duration;
import java.net.URI;
// Advanced client configuration
S3Client s3Client = S3Client.builder()
.region(Region.EU_SOUTH_2)
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
.overrideConfiguration(b -> b
.apiCallTimeout(Duration.ofSeconds(30))
.apiCallAttemptTimeout(Duration.ofSeconds(10))
.addMetricPublisher(CloudWatchMetricPublisher.create()))
.httpClientBuilder(ApacheHttpClient.builder()
.maxConnections(100)
.connectionTimeout(Duration.ofSeconds(5))
.proxyConfiguration(ProxyConfiguration.builder()
.endpoint(URI.create("http://proxy:8080"))
.build()))
.build();
```
### Separate Configuration Objects
```java
ClientOverrideConfiguration clientConfig = ClientOverrideConfiguration.builder()
.apiCallTimeout(Duration.ofSeconds(30))
.apiCallAttemptTimeout(Duration.ofSeconds(10))
.addMetricPublisher(CloudWatchMetricPublisher.create())
.build();
ApacheHttpClient httpClient = ApacheHttpClient.builder()
.maxConnections(100)
.connectionTimeout(Duration.ofSeconds(5))
.build();
S3Client s3Client = S3Client.builder()
.region(Region.EU_SOUTH_2)
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
.overrideConfiguration(clientConfig)
.httpClient(httpClient)
.build();
```
## Authentication and Credentials
### Default Credentials Provider Chain
```java
// SDK automatically uses default credential provider chain:
// 1. Java system properties (aws.accessKeyId and aws.secretAccessKey)
// 2. Environment variables (AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY)
// 3. Web identity token from AWS_WEB_IDENTITY_TOKEN_FILE
// 4. Shared credentials and config files (~/.aws/credentials and ~/.aws/config)
// 5. Amazon ECS container credentials
// 6. Amazon EC2 instance profile credentials
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.build(); // Uses default credential provider chain
```
### Explicit Credentials Providers
```java
import software.amazon.awssdk.auth.credentials.*;
// Environment variables
CredentialsProvider envCredentials = EnvironmentVariableCredentialsProvider.create();
// Profile from ~/.aws/credentials
CredentialsProvider profileCredentials = ProfileCredentialsProvider.create("myprofile");
// Static credentials (NOT recommended for production)
CredentialsProvider staticCredentials = StaticCredentialsProvider.create(
AwsBasicCredentials.create("accessKeyId", "secretAccessKey")
);
// Instance profile (for EC2)
CredentialsProvider instanceProfileCredentials = InstanceProfileCredentialsProvider.create();
// Use with client
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.credentialsProvider(profileCredentials)
.build();
```
### SSO Authentication Setup
```properties
# ~/.aws/config
[default]
sso_session = my-sso
sso_account_id = 111122223333
sso_role_name = SampleRole
region = us-east-1
output = json
[sso-session my-sso]
sso_region = us-east-1
sso_start_url = https://provided-domain.awsapps.com/start
sso_registration_scopes = sso:account:access
```
```bash
# Login before running application
aws sso login
# Verify active session
aws sts get-caller-identity
```
## HTTP Client Configuration
### Apache HTTP Client (Recommended for Sync)
```java
import software.amazon.awssdk.http.apache.ApacheHttpClient;
ApacheHttpClient httpClient = ApacheHttpClient.builder()
.maxConnections(100)
.connectionTimeout(Duration.ofSeconds(5))
.socketTimeout(Duration.ofSeconds(30))
.connectionTimeToLive(Duration.ofMinutes(5))
.expectContinueEnabled(true)
.build();
S3Client s3Client = S3Client.builder()
.httpClient(httpClient)
.build();
```
### Netty HTTP Client (For Async Operations)
```java
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.SslProvider;
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
.maxConcurrency(100)
.connectionTimeout(Duration.ofSeconds(5))
.readTimeout(Duration.ofSeconds(30))
.writeTimeout(Duration.ofSeconds(30))
.sslProvider(SslProvider.OPENSSL) // Better performance than JDK
.build();
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
.httpClient(httpClient)
.build();
```
### URL Connection HTTP Client (Lightweight)
```java
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
UrlConnectionHttpClient httpClient = UrlConnectionHttpClient.builder()
.socketTimeout(Duration.ofSeconds(30))
.build();
```
## Best Practices
### 1. Reuse Service Clients
**DO:**
```java
@Service
public class S3Service {
private final S3Client s3Client;
public S3Service() {
this.s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.build();
}
// Reuse s3Client for all operations
}
```
**DON'T:**
```java
public void uploadFile(String bucket, String key) {
// Creates new client each time - wastes resources!
S3Client s3 = S3Client.builder().build();
s3.putObject(...);
s3.close();
}
```
### 2. Configure API Timeouts
```java
S3Client s3Client = S3Client.builder()
.overrideConfiguration(b -> b
.apiCallTimeout(Duration.ofSeconds(30))
.apiCallAttemptTimeout(Duration.ofMillis(5000)))
.build();
```
### 3. Close Unused Clients
```java
// Try-with-resources
try (S3Client s3 = S3Client.builder().build()) {
s3.listBuckets();
}
// Explicit close
S3Client s3Client = S3Client.builder().build();
try {
s3Client.listBuckets();
} finally {
s3Client.close();
}
```
### 4. Close Streaming Responses
```java
try (ResponseInputStream<GetObjectResponse> s3Object =
s3Client.getObject(GetObjectRequest.builder()
.bucket(bucket)
.key(key)
.build())) {
// Read and process stream immediately
byte[] data = s3Object.readAllBytes();
} // Stream auto-closed, connection returned to pool
```
### 5. Optimize SSL for Async Clients
**Add dependency:**
```xml
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<version>2.0.61.Final</version>
<scope>runtime</scope>
</dependency>
```
**Configure SSL:**
```java
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
.sslProvider(SslProvider.OPENSSL)
.build();
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
.httpClient(httpClient)
.build();
```
## Spring Boot Integration
### Configuration Properties
```java
@ConfigurationProperties(prefix = "aws")
public record AwsProperties(
String region,
String accessKeyId,
String secretAccessKey,
S3Properties s3,
DynamoDbProperties dynamoDb
) {
public record S3Properties(
Integer maxConnections,
Integer connectionTimeoutSeconds,
Integer apiCallTimeoutSeconds
) {}
public record DynamoDbProperties(
Integer maxConnections,
Integer readTimeoutSeconds
) {}
}
```
### Client Configuration Beans
```java
@Configuration
@EnableConfigurationProperties(AwsProperties.class)
public class AwsClientConfiguration {
private final AwsProperties awsProperties;
public AwsClientConfiguration(AwsProperties awsProperties) {
this.awsProperties = awsProperties;
}
@Bean
public S3Client s3Client() {
return S3Client.builder()
.region(Region.of(awsProperties.region()))
.credentialsProvider(credentialsProvider())
.overrideConfiguration(clientOverrideConfiguration(
awsProperties.s3().apiCallTimeoutSeconds()))
.httpClient(apacheHttpClient(
awsProperties.s3().maxConnections(),
awsProperties.s3().connectionTimeoutSeconds()))
.build();
}
private CredentialsProvider credentialsProvider() {
if (awsProperties.accessKeyId() != null &&
awsProperties.secretAccessKey() != null) {
return StaticCredentialsProvider.create(
AwsBasicCredentials.create(
awsProperties.accessKeyId(),
awsProperties.secretAccessKey()));
}
return DefaultCredentialsProvider.create();
}
private ClientOverrideConfiguration clientOverrideConfiguration(
Integer apiCallTimeoutSeconds) {
return ClientOverrideConfiguration.builder()
.apiCallTimeout(Duration.ofSeconds(
apiCallTimeoutSeconds != null ? apiCallTimeoutSeconds : 30))
.apiCallAttemptTimeout(Duration.ofSeconds(10))
.build();
}
private ApacheHttpClient apacheHttpClient(
Integer maxConnections,
Integer connectionTimeoutSeconds) {
return ApacheHttpClient.builder()
.maxConnections(maxConnections != null ? maxConnections : 50)
.connectionTimeout(Duration.ofSeconds(
connectionTimeoutSeconds != null ? connectionTimeoutSeconds : 5))
.socketTimeout(Duration.ofSeconds(30))
.build();
}
}
```
### Application Properties
```yaml
aws:
region: us-east-1
s3:
max-connections: 100
connection-timeout-seconds: 5
api-call-timeout-seconds: 30
dynamo-db:
max-connections: 50
read-timeout-seconds: 30
```
## Error Handling
```java
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkServiceException;
try {
s3Client.getObject(request);
} catch (S3Exception e) {
// Service-specific exception
System.err.println("S3 Error: " + e.awsErrorDetails().errorMessage());
System.err.println("Error Code: " + e.awsErrorDetails().errorCode());
System.err.println("Status Code: " + e.statusCode());
System.err.println("Request ID: " + e.requestId());
} catch (SdkServiceException e) {
// Generic service exception
System.err.println("AWS Service Error: " + e.getMessage());
} catch (SdkClientException e) {
// Client-side error (network, timeout, etc.)
System.err.println("Client Error: " + e.getMessage());
}
```
## Testing Patterns
### LocalStack Integration
```java
@TestConfiguration
public class LocalStackAwsConfig {
@Bean
public S3Client s3Client() {
return S3Client.builder()
.region(Region.US_EAST_1)
.endpointOverride(URI.create("http://localhost:4566"))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create("test", "test")))
.build();
}
}
```
### Testcontainers with LocalStack
```java
@Testcontainers
@SpringBootTest
class S3IntegrationTest {
@Container
static LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.0"))
.withServices(LocalStackContainer.Service.S3);
@DynamicPropertySource
static void overrideProperties(DynamicPropertyRegistry registry) {
registry.add("aws.s3.endpoint",
() -> localstack.getEndpointOverride(LocalStackContainer.Service.S3));
registry.add("aws.region", () -> localstack.getRegion());
registry.add("aws.access-key-id", localstack::getAccessKey);
registry.add("aws.secret-access-key", localstack::getSecretKey);
}
}
```
## Maven Dependencies
```xml
<dependencyManagement>
<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>2.25.0</version> // Use latest stable version
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<!-- Core SDK -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sdk-core</artifactId>
</dependency>
<!-- Apache HTTP Client (recommended for sync) -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>apache-client</artifactId>
</dependency>
<!-- Netty HTTP Client (for async) -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>netty-nio-client</artifactId>
</dependency>
<!-- URL Connection HTTP Client (lightweight) -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>url-connection-client</artifactId>
</dependency>
<!-- CloudWatch Metrics -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cloudwatch-metric-publisher</artifactId>
</dependency>
<!-- OpenSSL for better performance -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<version>2.0.61.Final</version>
<scope>runtime</scope>
</dependency>
</dependencies>
```
## Gradle Dependencies
```gradle
dependencies {
implementation platform('software.amazon.awssdk:bom:2.25.0')
implementation 'software.amazon.awssdk:sdk-core'
implementation 'software.amazon.awssdk:apache-client'
implementation 'software.amazon.awssdk:netty-nio-client'
implementation 'software.amazon.awssdk:cloudwatch-metric-publisher'
runtimeOnly 'io.netty:netty-tcnative-boringssl-static:2.0.61.Final'
}
```
## Examples
### Basic S3 Upload
```java
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
try (S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build()) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket("my-bucket")
.key("uploads/file.txt")
.build();
s3.putObject(request, RequestBody.fromString("Hello, World!"));
}
```
### S3 List Objects with Pagination
```java
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
try (S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build()) {
ListObjectsV2Request request = ListObjectsV2Request.builder()
.bucket("my-bucket")
.build();
ListObjectsV2Response response = s3.listObjectsV2(request);
response.contents().forEach(object -> {
System.out.println("Object key: " + object.key());
});
}
```
### Async S3 Upload
```java
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
S3AsyncClient s3AsyncClient = S3AsyncClient.builder().build();
PutObjectRequest request = PutObjectRequest.builder()
.bucket("my-bucket")
.key("async-upload.txt")
.build();
CompletableFuture<PutObjectResponse> future = s3AsyncClient.putObject(
request, Async.fromString("Hello, Async World!"));
future.thenAccept(response -> {
System.out.println("Upload completed: " + response.eTag());
}).exceptionally(error -> {
System.err.println("Upload failed: " + error.getMessage());
return null;
});
```
## Performance Considerations
1. **Connection Pooling**: Default max connections is 50. Increase for high-throughput applications.
2. **Timeouts**: Always set both `apiCallTimeout` and `apiCallAttemptTimeout`.
3. **Client Reuse**: Create clients once, reuse throughout application lifecycle.
4. **Stream Handling**: Close streams immediately to prevent connection pool exhaustion.
5. **Async for I/O**: Use async clients for I/O-bound operations.
6. **OpenSSL**: Use OpenSSL with Netty for better SSL performance.
7. **Metrics**: Enable CloudWatch metrics to monitor performance.
## Security Best Practices
1. **Never hardcode credentials**: Use credential providers or environment variables.
2. **Use IAM roles**: Prefer IAM roles over access keys when possible.
3. **Rotate credentials**: Implement credential rotation for long-lived keys.
4. **Least privilege**: Grant minimum required permissions.
5. **Enable SSL**: Always use HTTPS endpoints (default).
6. **Audit logging**: Enable AWS CloudTrail for API call auditing.
## Related Skills
- `aws-sdk-java-v2-s3` - S3-specific patterns and examples
- `aws-sdk-java-v2-dynamodb` - DynamoDB patterns and examples
- `aws-sdk-java-v2-lambda` - Lambda patterns and examples
## References
See [references/](references/) for detailed documentation:
- [Developer Guide](references/developer-guide.md) - Comprehensive guide and architecture overview
- [API Reference](references/api-reference.md) - Complete API documentation for core classes
- [Best Practices](references/best-practices.md) - In-depth best practices and configuration examples
## Additional Resources
- [AWS SDK for Java 2.x Developer Guide](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html)
- [AWS SDK for Java 2.x API Reference](https://sdk.amazonaws.com/java/api/latest/)
- [Best Practices](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/best-practices.html)
- [GitHub Repository](https://github.com/aws/aws-sdk-java-v2)

View File

@@ -0,0 +1,258 @@
# AWS SDK for Java 2.x API Reference
## Core Client Classes
### AwsClient
Base interface for all AWS service clients.
```java
public interface AwsClient extends AutoCloseable {
// Base client interface
}
```
### SdkClient
Enhanced client interface with SDK-specific features.
```java
public interface SdkClient extends AwsClient {
// Enhanced client methods
}
```
## Client Builders
### ClientBuilder
Base builder interface for all AWS service clients.
**Key Methods:**
- `region(Region region)` - Set AWS region
- `credentialsProvider(CredentialsProvider credentialsProvider)` - Configure authentication
- `overrideConfiguration(ClientOverrideConfiguration overrideConfiguration)` - Override default settings
- `httpClient(HttpClient httpClient)` - Specify HTTP client implementation
- `build()` - Create client instance
## Configuration Classes
### ClientOverrideConfiguration
Controls client-level configuration including timeouts and metrics.
**Key Properties:**
- `apiCallTimeout(Duration)` - Total timeout for all retry attempts
- `apiCallAttemptTimeout(Duration)` - Timeout per individual attempt
- `retryPolicy(RetryPolicy)` - Retry behavior configuration
- `metricPublishers(MetricPublisher...)` - Enable metrics collection
### Builder Example
```java
ClientOverrideConfiguration config = ClientOverrideConfiguration.builder()
.apiCallTimeout(Duration.ofSeconds(30))
.apiCallAttemptTimeout(Duration.ofSeconds(10))
.addMetricPublisher(CloudWatchMetricPublisher.create())
.build();
```
## HTTP Client Implementations
### ApacheHttpClient
Synchronous HTTP client with advanced features.
**Builder Configuration:**
- `maxConnections(Integer)` - Maximum concurrent connections
- `connectionTimeout(Duration)` - Connection establishment timeout
- `socketTimeout(Duration)` - Socket read/write timeout
- `connectionTimeToLive(Duration)` - Connection lifetime
- `proxyConfiguration(ProxyConfiguration)` - Proxy settings
### NettyNioAsyncHttpClient
Asynchronous HTTP client for high-performance applications.
**Builder Configuration:**
- `maxConcurrency(Integer)` - Maximum concurrent operations
- `connectionTimeout(Duration)` - Connection timeout
- `readTimeout(Duration)` - Read operation timeout
- `writeTimeout(Duration)` - Write operation timeout
- `sslProvider(SslProvider)` - SSL/TLS implementation
### UrlConnectionHttpClient
Lightweight HTTP client using Java's URLConnection.
**Builder Configuration:**
- `socketTimeout(Duration)` - Socket timeout
- `connectTimeout(Duration)` - Connection timeout
## Authentication and Credentials
### Credential Providers
#### EnvironmentVariableCredentialsProvider
Reads credentials from environment variables.
```java
CredentialsProvider provider = EnvironmentVariableCredentialsProvider.create();
```
#### SystemPropertyCredentialsProvider
Reads credentials from Java system properties.
```java
CredentialsProvider provider = SystemPropertyCredentialsProvider.create();
```
#### ProfileCredentialsProvider
Reads credentials from AWS configuration files.
```java
CredentialsProvider provider = ProfileCredentialsProvider.create("profile-name");
```
#### StaticCredentialsProvider
Provides static credentials (not recommended for production).
```java
AwsBasicCredentials credentials = AwsBasicCredentials.create("key", "secret");
CredentialsProvider provider = StaticCredentialsProvider.create(credentials);
```
#### DefaultCredentialsProvider
Implements the default credential provider chain.
```java
CredentialsProvider provider = DefaultCredentialsProvider.create();
```
### SSO Authentication
#### AwsSsoCredentialsProvider
Enables SSO-based authentication.
```java
AwsSsoCredentialsProvider ssoProvider = AwsSsoCredentialsProvider.builder()
.ssoProfile("my-sso-profile")
.build();
```
## Error Handling Classes
### SdkClientException
Client-side exceptions (network, timeout, configuration issues).
```java
try {
awsOperation();
} catch (SdkClientException e) {
// Handle client-side errors
}
```
### SdkServiceException
Service-side exceptions (AWS service errors).
```java
try {
awsOperation();
} catch (SdkServiceException e) {
// Handle service-side errors
System.err.println("Error Code: " + e.awsErrorDetails().errorCode());
System.err.println("Request ID: " + e.requestId());
}
```
### S3Exception
S3-specific exceptions.
```java
try {
s3Operation();
} catch (S3Exception e) {
// Handle S3-specific errors
System.err.println("S3 Error: " + e.awsErrorDetails().errorMessage());
}
```
## Metrics and Monitoring
### CloudWatchMetricPublisher
Publishes metrics to AWS CloudWatch.
```java
CloudWatchMetricPublisher publisher = CloudWatchMetricPublisher.create();
```
### MetricPublisher
Base interface for custom metrics publishers.
```java
public interface MetricPublisher {
void publish(MetricCollection metricCollection);
}
```
## Utility Classes
### Duration and Time
Configure timeouts using Java Duration.
```java
Duration apiTimeout = Duration.ofSeconds(30);
Duration attemptTimeout = Duration.ofSeconds(10);
```
### Region
AWS regions for service endpoints.
```java
Region region = Region.US_EAST_1;
Region regionEU = Region.EU_WEST_1;
```
### URI
Endpoint configuration and proxy settings.
```java
URI proxyUri = URI.create("http://proxy:8080");
URI endpointOverride = URI.create("http://localhost:4566");
```
## Configuration Best Practices
### Resource Management
Always close clients when no longer needed.
```java
try (S3Client s3 = S3Client.builder().build()) {
// Use client
} // Auto-closed
```
### Connection Pooling
Reuse clients to avoid connection pool overhead.
```java
@Service
public class AwsService {
private final S3Client s3Client;
public AwsService() {
this.s3Client = S3Client.builder().build();
}
// Reuse s3Client throughout application
}
```
### Error Handling
Implement comprehensive error handling for robust applications.
```java
try {
// AWS operation
} catch (SdkServiceException e) {
// Handle service errors
} catch (SdkClientException e) {
// Handle client errors
} catch (Exception e) {
// Handle other errors
}
```

View File

@@ -0,0 +1,344 @@
# AWS SDK for Java 2.x Best Practices
## Client Configuration
### Timeout Configuration
Always configure both API call and attempt timeouts to prevent hanging requests.
```java
ClientOverrideConfiguration config = ClientOverrideConfiguration.builder()
.apiCallTimeout(Duration.ofSeconds(30)) // Total for all retries
.apiCallAttemptTimeout(Duration.ofSeconds(10)) // Per-attempt timeout
.build();
```
**Best Practices:**
- Set `apiCallTimeout` higher than `apiCallAttemptTimeout`
- Use appropriate timeouts based on your service's characteristics
- Consider network latency and service response times
- Monitor timeout metrics to adjust as needed
### HTTP Client Selection
Choose the appropriate HTTP client for your use case.
#### For Synchronous Applications (Apache HttpClient)
```java
ApacheHttpClient httpClient = ApacheHttpClient.builder()
.maxConnections(100)
.connectionTimeout(Duration.ofSeconds(5))
.socketTimeout(Duration.ofSeconds(30))
.build();
```
**Best Use Cases:**
- Traditional synchronous applications
- Medium-throughput operations
- When blocking behavior is acceptable
#### For Asynchronous Applications (Netty NIO Client)
```java
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
.maxConcurrency(100)
.connectionTimeout(Duration.ofSeconds(5))
.readTimeout(Duration.ofSeconds(30))
.writeTimeout(Duration.ofSeconds(30))
.sslProvider(SslProvider.OPENSSL)
.build();
```
**Best Use Cases:**
- High-throughput applications
- I/O-bound operations
- When non-blocking behavior is required
- For improved SSL performance
#### For Lightweight Applications (URL Connection Client)
```java
UrlConnectionHttpClient httpClient = UrlConnectionHttpClient.builder()
.socketTimeout(Duration.ofSeconds(30))
.build();
```
**Best Use Cases:**
- Simple applications with low requirements
- When minimizing dependencies
- For basic operations
## Authentication and Security
### Credential Management
#### Default Provider Chain
```java
// Use default chain (recommended)
S3Client s3Client = S3Client.builder().build();
```
**Default Chain Order:**
1. Java system properties (`aws.accessKeyId`, `aws.secretAccessKey`)
2. Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`)
3. Web identity token from `AWS_WEB_IDENTITY_TOKEN_FILE`
4. Shared credentials file (`~/.aws/credentials`)
5. Config file (`~/.aws/config`)
6. Amazon ECS container credentials
7. Amazon EC2 instance profile credentials
#### Explicit Credential Provider
```java
// Use specific credential provider
CredentialsProvider credentials = ProfileCredentialsProvider.create("my-profile");
S3Client s3Client = S3Client.builder()
.credentialsProvider(credentials)
.build();
```
#### IAM Roles (Preferred for Production)
```java
// Use IAM role credentials
CredentialsProvider instanceProfile = InstanceProfileCredentialsProvider.create();
S3Client s3Client = S3Client.builder()
.credentialsProvider(instanceProfile)
.build();
```
### Security Best Practices
1. **Never hardcode credentials** - Use credential providers or environment variables
2. **Use IAM roles** - Prefer over access keys when possible
3. **Implement credential rotation** - For long-lived access keys
4. **Apply least privilege** - Grant minimum required permissions
5. **Enable SSL** - Always use HTTPS (default behavior)
6. **Monitor access** - Enable AWS CloudTrail for auditing
7. **Use SSO for human users** - Instead of long-term credentials
## Resource Management
### Client Lifecycle
```java
// Option 1: Try-with-resources (recommended)
try (S3Client s3 = S3Client.builder().build()) {
// Use client
} // Auto-closed
// Option 2: Explicit close
S3Client s3 = S3Client.builder().build();
try {
// Use client
} finally {
s3.close();
}
```
### Stream Handling
Close streams immediately to prevent connection pool exhaustion.
```java
try (ResponseInputStream<GetObjectResponse> response =
s3Client.getObject(GetObjectRequest.builder()
.bucket(bucketName)
.key(objectKey)
.build())) {
// Read and process data immediately
byte[] data = response.readAllBytes();
} // Stream auto-closed, connection returned to pool
```
## Performance Optimization
### Connection Pooling
```java
// Configure connection pooling
ApacheHttpClient httpClient = ApacheHttpClient.builder()
.maxConnections(100) // Adjust based on your needs
.connectionTimeout(Duration.ofSeconds(5))
.socketTimeout(Duration.ofSeconds(30))
.connectionTimeToLive(Duration.ofMinutes(5))
.build();
```
**Best Practices:**
- Set appropriate `maxConnections` based on expected load
- Consider connection time to live (TTL)
- Monitor connection pool metrics
- Use appropriate timeouts
### SSL Optimization
Use OpenSSL with Netty for better SSL performance.
```xml
<!-- Maven dependency -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<version>2.0.61.Final</version>
<scope>runtime</scope>
</dependency>
```
```java
// Use OpenSSL for async clients
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
.sslProvider(SslProvider.OPENSSL)
.build();
```
### Async for I/O-Bound Operations
```java
// Use async clients for I/O-bound operations
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
.httpClient(httpClient)
.build();
// Use CompletableFuture for non-blocking operations
CompletableFuture<PutObjectResponse> future = s3AsyncClient.putObject(request);
future.thenAccept(response -> {
// Handle success
}).exceptionally(error -> {
// Handle error
return null;
});
```
## Monitoring and Observability
### Enable SDK Metrics
```java
CloudWatchMetricPublisher publisher = CloudWatchMetricPublisher.create();
S3Client s3Client = S3Client.builder()
.overrideConfiguration(b -> b
.addMetricPublisher(publisher))
.build();
```
### CloudWatch Integration
Configure CloudWatch metrics publisher to collect SDK metrics.
```java
CloudWatchMetricPublisher cloudWatchPublisher = CloudWatchMetricPublisher.builder()
.namespace("MyApplication")
.credentialProvider(credentials)
.build();
```
### Custom Metrics
Implement custom metrics for application-specific monitoring.
```java
public class CustomMetricPublisher implements MetricPublisher {
@Override
public void publish(MetricCollection metrics) {
// Implement custom metrics logic
metrics.forEach(metric -> {
System.out.println("Metric: " + metric.name() + " = " + metric.value());
});
}
}
```
## Error Handling
### Comprehensive Error Handling
```java
try {
awsOperation();
} catch (SdkServiceException e) {
// Service-specific error
System.err.println("AWS Service Error: " + e.awsErrorDetails().errorMessage());
System.err.println("Error Code: " + e.awsErrorDetails().errorCode());
System.err.println("Status Code: " + e.statusCode());
System.err.println("Request ID: " + e.requestId());
} catch (SdkClientException e) {
// Client-side error (network, timeout, etc.)
System.err.println("Client Error: " + e.getMessage());
} catch (Exception e) {
// Other errors
System.err.println("Unexpected Error: " + e.getMessage());
}
```
### Retry Configuration
```java
RetryPolicy retryPolicy = RetryPolicy.builder()
.numRetries(3)
.retryCondition(RetryCondition.defaultRetryCondition())
.backoffStrategy(BackoffStrategy.defaultStrategy())
.build();
```
## Testing Strategies
### Local Testing with LocalStack
```java
@TestConfiguration
public class LocalStackConfig {
@Bean
public S3Client s3Client() {
return S3Client.builder()
.endpointOverride(URI.create("http://localhost:4566"))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create("test", "test")))
.build();
}
}
```
### Testcontainers Integration
```java
@Testcontainers
@SpringBootTest
public class AwsIntegrationTest {
@Container
static LocalStackContainer localstack = new LocalStackContainer(DockerImageName.parse("localstack/localstack:3.0"))
.withServices(LocalStackContainer.Service.S3);
@DynamicPropertySource
static void configProperties(DynamicPropertyRegistry registry) {
registry.add("aws.endpoint", () -> localstack.getEndpointOverride(LocalStackContainer.Service.S3));
}
}
```
## Configuration Templates
### High-Throughput Configuration
```java
ApacheHttpClient highThroughputClient = ApacheHttpClient.builder()
.maxConnections(200)
.connectionTimeout(Duration.ofSeconds(3))
.socketTimeout(Duration.ofSeconds(30))
.connectionTimeToLive(Duration.ofMinutes(10))
.build();
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.httpClient(highThroughputClient)
.overrideConfiguration(b -> b
.apiCallTimeout(Duration.ofSeconds(45))
.apiCallAttemptTimeout(Duration.ofSeconds(15)))
.build();
```
### Low-Latency Configuration
```java
ApacheHttpClient lowLatencyClient = ApacheHttpClient.builder()
.maxConnections(50)
.connectionTimeout(Duration.ofSeconds(2))
.socketTimeout(Duration.ofSeconds(10))
.build();
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.httpClient(lowLatencyClient)
.overrideConfiguration(b -> b
.apiCallTimeout(Duration.ofSeconds(15))
.apiCallAttemptTimeout(Duration.ofSeconds(3)))
.build();
```

View File

@@ -0,0 +1,130 @@
# AWS SDK for Java 2.x Developer Guide
## Overview
The AWS SDK for Java 2.x provides a modern, type-safe API for AWS services. Built on Java 8+, it offers improved performance, better error handling, and enhanced security compared to v1.x.
## Key Features
- **Modern Architecture**: Built on Java 8+ with reactive and async support
- **Type Safety**: Comprehensive type annotations and validation
- **Performance Optimized**: Connection pooling, async support, and SSL optimization
- **Enhanced Security**: Better credential management and security practices
- **Extensive Coverage**: Support for all AWS services with regular updates
## Core Concepts
### Service Clients
The primary interface for interacting with AWS services. All clients implement the `SdkClient` interface.
```java
// S3Client example
S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build();
```
### Client Configuration
Configure behavior through builders supporting:
- Timeout settings
- HTTP client selection
- Authentication methods
- Monitoring and metrics
### Credential Providers
Multiple authentication methods:
- Environment variables
- System properties
- Shared credential files
- IAM roles
- SSO integration
### HTTP Clients
Choose from three HTTP implementations:
- Apache HttpClient (synchronous)
- Netty NIO Client (asynchronous)
- URL Connection Client (lightweight)
## Migration from v1.x
The SDK 2.x is not backward compatible with v1.x. Key changes:
- Builder pattern for client creation
- Different package structure
- Enhanced error handling
- New credential system
- Improved resource management
## Getting Started
Include the BOM (Bill of Materials) for version management:
```xml
<dependencyManagement>
<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>2.25.0</version> // Use latest stable version
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
```
Add service-specific dependencies:
```xml
<dependencies>
<!-- S3 Service -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
</dependency>
<!-- Core SDK -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sdk-core</artifactId>
</dependency>
</dependencies>
```
## Architecture Overview
```
AWS Service Client
├── Configuration Layer
│ ├── Client Override Configuration
│ └── HTTP Client Configuration
├── Authentication Layer
│ ├── Credential Providers
│ └── Security Context
├── Transport Layer
│ ├── HTTP Client (Apache/Netty/URLConn)
│ └── Connection Pool
└── Protocol Layer
├── Service Protocol Implementation
└── Error Handling
```
## Service Discovery
The SDK automatically discovers and registers all available AWS services through service interfaces and paginators.
### Available Services
All AWS services are available through dedicated client interfaces:
- S3 (Simple Storage Service)
- DynamoDB (NoSQL Database)
- Lambda (Serverless Functions)
- EC2 (Compute Cloud)
- RDS (Managed Databases)
- And 200+ other services
For a complete list, see the AWS Service documentation.
## Support and Community
- **GitHub Issues**: Report bugs and request features
- **AWS Amplify**: For mobile app developers
- **Migration Guide**: Available for v1.x users
- **Changelog**: Track changes on GitHub

View File

@@ -0,0 +1,392 @@
---
name: aws-sdk-java-v2-dynamodb
description: Amazon DynamoDB patterns using AWS SDK for Java 2.x. Use when creating, querying, scanning, or performing CRUD operations on DynamoDB tables, working with indexes, batch operations, transactions, or integrating with Spring Boot applications.
category: aws
tags: [aws, dynamodb, java, sdk, nosql, database]
version: 1.1.0
allowed-tools: Read, Write, Bash
---
# AWS SDK for Java 2.x - Amazon DynamoDB
## When to Use
Use this skill when:
- Creating, updating, or deleting DynamoDB tables
- Performing CRUD operations on DynamoDB items
- Querying or scanning tables
- Working with Global Secondary Indexes (GSI) or Local Secondary Indexes (LSI)
- Implementing batch operations for efficiency
- Using DynamoDB transactions
- Integrating DynamoDB with Spring Boot applications
- Working with DynamoDB Enhanced Client for type-safe operations
## Dependencies
Add to `pom.xml`:
```xml
<!-- Low-level DynamoDB client -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb</artifactId>
</dependency>
<!-- Enhanced client (recommended) -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb-enhanced</artifactId>
</dependency>
```
## Client Setup
### Low-Level Client
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
DynamoDbClient dynamoDb = DynamoDbClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Enhanced Client (Recommended)
```java
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedClient;
DynamoDbEnhancedClient enhancedClient = DynamoDbEnhancedClient.builder()
.dynamoDbClient(dynamoDb)
.build();
```
## Entity Mapping
To define DynamoDB entities, use `@DynamoDbBean` annotation:
```java
@DynamoDbBean
public class Customer {
@DynamoDbPartitionKey
private String customerId;
@DynamoDbAttribute("customer_name")
private String name;
private String email;
@DynamoDbSortKey
private String orderId;
// Getters and setters
}
```
For complex entity mapping with GSIs and custom converters, see [Entity Mapping Reference](references/entity-mapping.md).
## CRUD Operations
### Basic Operations
```java
// Create or update item
DynamoDbTable<Customer> table = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
table.putItem(customer);
// Get item
Customer result = table.getItem(Key.builder().partitionValue(customerId).build());
// Update item
return table.updateItem(customer);
// Delete item
table.deleteItem(Key.builder().partitionValue(customerId).build());
```
### Composite Key Operations
```java
// Get item with composite key
Order order = table.getItem(Key.builder()
.partitionValue(customerId)
.sortValue(orderId)
.build());
```
## Query Operations
### Basic Query
```java
import software.amazon.awssdk.enhanced.dynamodb.model.QueryConditional;
QueryConditional queryConditional = QueryConditional
.keyEqualTo(Key.builder()
.partitionValue(customerId)
.build());
List<Order> orders = table.query(queryConditional).items().stream()
.collect(Collectors.toList());
```
### Advanced Query with Filters
```java
import software.amazon.awssdk.enhanced.dynamodb.Expression;
Expression filter = Expression.builder()
.expression("status = :pending")
.putExpressionValue(":pending", AttributeValue.builder().s("PENDING").build())
.build();
List<Order> pendingOrders = table.query(r -> r
.queryConditional(queryConditional)
.filterExpression(filter))
.items().stream()
.collect(Collectors.toList());
```
For detailed query patterns, see [Advanced Operations Reference](references/advanced-operations.md).
## Scan Operations
```java
// Scan all items
List<Customer> allCustomers = table.scan().items().stream()
.collect(Collectors.toList());
// Scan with filter
Expression filter = Expression.builder()
.expression("points >= :minPoints")
.putExpressionValue(":minPoints", AttributeValue.builder().n("1000").build())
.build();
List<Customer> vipCustomers = table.scan(r -> r.filterExpression(filter))
.items().stream()
.collect(Collectors.toList());
```
## Batch Operations
### Batch Get
```java
import software.amazon.awssdk.enhanced.dynamodb.model.*;
List<Key> keys = customerIds.stream()
.map(id -> Key.builder().partitionValue(id).build())
.collect(Collectors.toList());
ReadBatch.Builder<Customer> batchBuilder = ReadBatch.builder(Customer.class)
.mappedTableResource(table);
keys.forEach(batchBuilder::addGetItem);
BatchGetResultPageIterable result = enhancedClient.batchGetItem(r ->
r.addReadBatch(batchBuilder.build()));
List<Customer> customers = result.resultsForTable(table).stream()
.collect(Collectors.toList());
```
### Batch Write
```java
WriteBatch.Builder<Customer> batchBuilder = WriteBatch.builder(Customer.class)
.mappedTableResource(table);
customers.forEach(batchBuilder::addPutItem);
enhancedClient.batchWriteItem(r -> r.addWriteBatch(batchBuilder.build()));
```
## Transactions
### Transactional Write
```java
enhancedClient.transactWriteItems(r -> r
.addPutItem(customerTable, customer)
.addPutItem(orderTable, order));
```
### Transactional Read
```java
TransactGetItemsEnhancedRequest request = TransactGetItemsEnhancedRequest.builder()
.addGetItem(customerTable, customerKey)
.addGetItem(orderTable, orderKey)
.build();
List<Document> results = enhancedClient.transactGetItems(request);
```
## Spring Boot Integration
### Configuration
```java
@Configuration
public class DynamoDbConfiguration {
@Bean
public DynamoDbClient dynamoDbClient() {
return DynamoDbClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public DynamoDbEnhancedClient dynamoDbEnhancedClient(DynamoDbClient dynamoDbClient) {
return DynamoDbEnhancedClient.builder()
.dynamoDbClient(dynamoDbClient)
.build();
}
}
```
### Repository Pattern
```java
@Repository
public class CustomerRepository {
private final DynamoDbTable<Customer> customerTable;
public CustomerRepository(DynamoDbEnhancedClient enhancedClient) {
this.customerTable = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
}
public void save(Customer customer) {
customerTable.putItem(customer);
}
public Optional<Customer> findById(String customerId) {
Key key = Key.builder().partitionValue(customerId).build();
return Optional.ofNullable(customerTable.getItem(key));
}
}
```
For comprehensive Spring Boot integration patterns, see [Spring Boot Integration Reference](references/spring-boot-integration.md).
## Testing
### Unit Testing with Mocks
```java
@ExtendWith(MockitoExtension.class)
class CustomerServiceTest {
@Mock
private DynamoDbClient dynamoDbClient;
@Mock
private DynamoDbEnhancedClient enhancedClient;
@Mock
private DynamoDbTable<Customer> customerTable;
@InjectMocks
private CustomerService customerService;
@Test
void saveCustomer_ShouldReturnSavedCustomer() {
// Arrange
when(enhancedClient.table(anyString(), any(TableSchema.class)))
.thenReturn(customerTable);
Customer customer = new Customer("123", "John Doe", "john@example.com");
// Act
Customer result = customerService.saveCustomer(customer);
// Assert
assertNotNull(result);
verify(customerTable).putItem(customer);
}
}
```
### Integration Testing with LocalStack
```java
@Testcontainers
@SpringBootTest
class DynamoDbIntegrationTest {
@Container
static LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.0"))
.withServices(LocalStackContainer.Service.DYNAMODB);
@DynamicPropertySource
static void configureProperties(DynamicPropertyRegistry registry) {
registry.add("aws.endpoint",
() -> localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB).toString());
}
@Autowired
private DynamoDbEnhancedClient enhancedClient;
@Test
void testCustomerCRUDOperations() {
// Test implementation
}
}
```
For detailed testing strategies, see [Testing Strategies](references/testing-strategies.md).
## Best Practices
1. **Use Enhanced Client**: Provides type-safe operations with less boilerplate
2. **Design partition keys carefully**: Distribute data evenly across partitions
3. **Use composite keys**: Leverage sort keys for efficient queries
4. **Create GSIs strategically**: Support different access patterns
5. **Use batch operations**: Reduce API calls for multiple items
6. **Implement pagination**: For large result sets use pagination
7. **Use transactions**: For operations that must be atomic
8. **Avoid scans**: Prefer queries with proper indexes
9. **Handle conditional writes**: Prevent race conditions
10. **Use proper error handling**: Handle exceptions like `ProvisionedThroughputExceeded`
## Common Patterns
### Conditional Operations
```java
PutItemEnhancedRequest request = PutItemEnhancedRequest.builder(table)
.item(customer)
.conditionExpression("attribute_not_exists(customerId)")
.build();
table.putItemWithRequestBuilder(request);
```
### Pagination
```java
ScanEnhancedRequest request = ScanEnhancedRequest.builder()
.limit(100)
.build();
PaginatedScanIterable<Customer> results = table.scan(request);
results.stream().forEach(page -> {
// Process each page
});
```
## Performance Considerations
- Monitor read/write capacity units
- Implement exponential backoff for retries
- Use proper pagination for large datasets
- Consider eventual consistency for reads
- Use `ReturnConsumedCapacity` to monitor capacity usage
## Related Skills
- `aws-sdk-java-v2-core `- Core AWS SDK patterns
- `spring-data-jpa` - Alternative data access patterns
- `unit-test-service-layer` - Service testing patterns
- `unit-test-wiremock-rest-api` - Testing external APIs
## References
- [AWS DynamoDB Documentation](https://docs.aws.amazon.com/dynamodb/)
- [AWS SDK for Java Documentation](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/)
- [DynamoDB Examples](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/dynamodb)
- [LocalStack for Testing](https://docs.localstack.cloud/user-guide/aws/)
For detailed implementations, see the references folder:
- [Entity Mapping Reference](references/entity-mapping.md)
- [Advanced Operations Reference](references/advanced-operations.md)
- [Spring Boot Integration Reference](references/spring-boot-integration.md)
- [Testing Strategies](references/testing-strategies.md)

View File

@@ -0,0 +1,195 @@
# Advanced Operations Reference
This document covers advanced DynamoDB operations and patterns.
## Query Operations
### Key Conditions
#### Key.equalTo()
```java
QueryConditional equalTo = QueryConditional
.keyEqualTo(Key.builder()
.partitionValue("customer123")
.build());
```
#### Key.between()
```java
QueryConditional between = QueryConditional
.sortBetween(
Key.builder().partitionValue("customer123").sortValue("2023-01-01").build(),
Key.builder().partitionValue("customer123").sortValue("2023-12-31").build());
```
#### Key.beginsWith()
```java
QueryConditional beginsWith = QueryConditional
.sortKeyBeginsWith(Key.builder()
.partitionValue("customer123")
.sortValue("2023-")
.build());
```
### Filter Expressions
```java
Expression filter = Expression.builder()
.expression("points >= :minPoints AND status = :status")
.putExpressionName("#p", "points")
.putExpressionName("#s", "status")
.putExpressionValue(":minPoints", AttributeValue.builder().n("1000").build())
.putExpressionValue(":status", AttributeValue.builder().s("ACTIVE").build())
.build();
```
### Projection Expressions
```java
Expression projection = Expression.builder()
.expression("customerId, name, email")
.putExpressionName("#c", "customerId")
.putExpressionName("#n", "name")
.putExpressionName("#e", "email")
.build();
```
## Scan Operations
### Pagination
```java
ScanEnhancedRequest request = ScanEnhancedRequest.builder()
.limit(100)
.build();
PaginatedScanIterable<Customer> results = table.scan(request);
results.stream().forEach(page -> {
// Process each page of results
});
```
### Conditional Scan
```java
Expression filter = Expression.builder()
.expression("active = :active")
.putExpressionValue(":active", AttributeValue.builder().bool(true).build())
.build();
return table.scan(r -> r
.filterExpression(filter)
.limit(50))
.items().stream()
.collect(Collectors.toList());
```
## Batch Operations
### Batch Get with Unprocessed Keys
```java
List<Key> keys = customerIds.stream()
.map(id -> Key.builder().partitionValue(id).build())
.collect(Collectors.toList());
ReadBatch.Builder<Customer> batchBuilder = ReadBatch.builder(Customer.class)
.mappedTableResource(table);
keys.forEach(batchBuilder::addGetItem);
BatchGetResultPageIterable result = enhancedClient.batchGetItem(r ->
r.addReadBatch(batchBuilder.build()));
// Handle unprocessed keys
result.stream()
.flatMap(page -> page.unprocessedKeys().entrySet().stream())
.forEach(entry -> {
// Retry logic for unprocessed keys
});
```
### Batch Write with Different Operations
```java
WriteBatch.Builder<Customer> batchBuilder = WriteBatch.builder(Customer.class)
.mappedTableResource(table);
batchBuilder.addPutItem(customer1);
batchBuilder.addDeleteItem(customer2);
batchBuilder.addPutItem(customer3);
enhancedClient.batchWriteItem(r -> r.addWriteBatch(batchBuilder.build()));
```
## Transactions
### Conditional Writes
```java
PutItemEnhancedRequest putRequest = PutItemEnhancedRequest.builder(table)
.item(customer)
.conditionExpression("attribute_not_exists(customerId)")
.build();
table.putItemWithRequestBuilder(putRequest);
```
### Multiple Table Operations
```java
TransactWriteItemsEnhancedRequest request = TransactWriteItemsEnhancedRequest.builder()
.addPutItem(customerTable, customer)
.addPutItem(orderTable, order)
.addUpdateItem(productTable, product)
.addDeleteItem(cartTable, cartKey)
.build();
enhancedClient.transactWriteItems(request);
```
## Conditional Operations
### Condition Expressions
```java
// Check if attribute exists
.setAttribute("conditionExpression", "attribute_not_exists(customerId)")
// Check attribute values
.setAttribute("conditionExpression", "points > :currentPoints")
.setAttribute("expressionAttributeValues", Map.of(
":currentPoints", AttributeValue.builder().n("500").build()))
// Multiple conditions
.setAttribute("conditionExpression", "points > :min AND active = :active")
.setAttribute("expressionAttributeValues", Map.of(
":min", AttributeValue.builder().n("100").build(),
":active", AttributeValue.builder().bool(true).build()))
```
## Error Handling
### Provisioned Throughput Exceeded
```java
try {
table.putItem(customer);
} catch (TransactionCanceledException e) {
// Handle transaction cancellation
} catch (ConditionalCheckFailedException e) {
// Handle conditional check failure
} catch (ResourceNotFoundException e) {
// Handle table not found
} catch (DynamoDbException e) {
// Handle other DynamoDB exceptions
}
```
### Exponential Backoff for Retry
```java
int maxRetries = 3;
long baseDelay = 1000; // 1 second
for (int attempt = 0; attempt < maxRetries; attempt++) {
try {
operation();
break;
} catch (ProvisionedThroughputExceededException e) {
long delay = baseDelay * (1 << attempt);
Thread.sleep(delay);
}
}
```

View File

@@ -0,0 +1,120 @@
# Entity Mapping Reference
This document provides detailed information about entity mapping in DynamoDB Enhanced Client.
## @DynamoDbBean Annotation
The `@DynamoDbBean` annotation marks a class as a DynamoDB entity:
```java
@DynamoDbBean
public class Customer {
// Class implementation
}
```
## Field Annotations
### @DynamoDbPartitionKey
Marks a field as the partition key:
```java
@DynamoDbPartitionKey
public String getCustomerId() {
return customerId;
}
```
### @DynamoDbSortKey
Marks a field as the sort key (used with composite keys):
```java
@DynamoDbSortKey
@DynamoDbAttribute("order_id")
public String getOrderId() {
return orderId;
}
```
### @DynamoDbAttribute
Maps a field to a DynamoDB attribute with custom name:
```java
@DynamoDbAttribute("customer_name")
public String getName() {
return name;
}
```
### @DynamoDbSecondaryPartitionKey
Marks a field as a partition key for a Global Secondary Index:
```java
@DynamoDbSecondaryPartitionKey(indexNames = "category-index")
public String getCategory() {
return category;
}
```
### @DynamoDbSecondarySortKey
Marks a field as a sort key for a Global Secondary Index:
```java
@DynamoDbSecondarySortKey(indexNames = "category-index")
public BigDecimal getPrice() {
return price;
}
```
### @DynamoDbConvertedBy
Custom attribute conversion:
```java
@DynamoDbConvertedBy(LocalDateTimeConverter.class)
public LocalDateTime getCreatedAt() {
return createdAt;
}
```
## Supported Data Types
The enhanced client automatically handles the following data types:
- String → S (String)
- Integer, Long → N (Number)
- BigDecimal → N (Number)
- Boolean → BOOL
- LocalDateTime → S (ISO-8601 format)
- LocalDate → S (ISO-8601 format)
- UUID → S (String)
- Enum → S (String representation)
- Custom types with converters
## Custom Converters
Create custom converters for complex data types:
```java
public class LocalDateTimeConverter extends AttributeConverter<LocalDateTime, String> {
@Override
public String transformFrom(LocalDateTime input) {
return input.toString();
}
@Override
public LocalDateTime transformTo(String input) {
return LocalDateTime.parse(input);
}
@Override
public AttributeValue transformToAttributeValue(String input) {
return AttributeValue.builder().s(input).build();
}
@Override
public String transformFromAttributeValue(AttributeValue attributeValue) {
return attributeValue.s();
}
}
```

View File

@@ -0,0 +1,377 @@
# Spring Boot Integration Reference
This document provides detailed information about integrating DynamoDB with Spring Boot applications.
## Configuration
### Basic Configuration
```java
@Configuration
public class DynamoDbConfiguration {
@Bean
@Profile("local")
public DynamoDbClient dynamoDbClient() {
return DynamoDbClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
@Profile("prod")
public DynamoDbClient dynamoDbClientProd(
@Value("${aws.region}") String region,
@Value("${aws.accessKeyId}") String accessKeyId,
@Value("${aws.secretAccessKey}") String secretAccessKey) {
return DynamoDbClient.builder()
.region(Region.of(region))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(accessKeyId, secretAccessKey)))
.build();
}
@Bean
public DynamoDbEnhancedClient dynamoDbEnhancedClient(DynamoDbClient dynamoDbClient) {
return DynamoDbEnhancedClient.builder()
.dynamoDbClient(dynamoDbClient)
.build();
}
}
```
### Properties Configuration
`application-local.properties`:
```properties
aws.region=us-east-1
```
`application-prod.properties`:
```properties
aws.region=us-east-1
aws.accessKeyId=${AWS_ACCESS_KEY_ID}
aws.secretAccessKey=${AWS_SECRET_ACCESS_KEY}
```
## Repository Pattern Implementation
### Base Repository Interface
```java
public interface DynamoDbRepository<T> {
void save(T entity);
Optional<T> findById(Object partitionKey);
Optional<T> findById(Object partitionKey, Object sortKey);
void delete(Object partitionKey);
void delete(Object partitionKey, Object sortKey);
List<T> findAll();
List<T> findAll(int limit);
boolean existsById(Object partitionKey);
boolean existsById(Object partitionKey, Object sortKey);
}
public interface CustomerRepository extends DynamoDbRepository<Customer> {
List<Customer> findByEmail(String email);
List<Customer> findByPointsGreaterThan(Integer minPoints);
}
```
### Generic Repository Implementation
```java
@Repository
public class GenericDynamoDbRepository<T> implements DynamoDbRepository<T> {
private final DynamoDbTable<T> table;
@SuppressWarnings("unchecked")
public GenericDynamoDbRepository(DynamoDbEnhancedClient enhancedClient,
Class<T> entityClass,
String tableName) {
this.table = enhancedClient.table(tableName, TableSchema.fromBean(entityClass));
}
@Override
public void save(T entity) {
table.putItem(entity);
}
@Override
public Optional<T> findById(Object partitionKey) {
Key key = Key.builder().partitionValue(partitionKey).build();
return Optional.ofNullable(table.getItem(key));
}
@Override
public Optional<T> findById(Object partitionKey, Object sortKey) {
Key key = Key.builder()
.partitionValue(partitionKey)
.sortValue(sortKey)
.build();
return Optional.ofNullable(table.getItem(key));
}
@Override
public void delete(Object partitionKey) {
Key key = Key.builder().partitionValue(partitionKey).build();
table.deleteItem(key);
}
@Override
public List<T> findAll() {
return table.scan().items().stream()
.collect(Collectors.toList());
}
@Override
public List<T> findAll(int limit) {
return table.scan(ScanEnhancedRequest.builder().limit(limit).build())
.items().stream()
.collect(Collectors.toList());
}
}
```
### Specific Repository Implementation
```java
@Repository
public class CustomerRepositoryImpl implements CustomerRepository {
private final DynamoDbTable<Customer> customerTable;
public CustomerRepositoryImpl(DynamoDbEnhancedClient enhancedClient) {
this.customerTable = enhancedClient.table(
"Customers",
TableSchema.fromBean(Customer.class));
}
@Override
public List<Customer> findByEmail(String email) {
Expression filter = Expression.builder()
.expression("email = :email")
.putExpressionValue(":email", AttributeValue.builder().s(email).build())
.build();
return customerTable.scan(r -> r.filterExpression(filter))
.items().stream()
.collect(Collectors.toList());
}
@Override
public List<Customer> findByPointsGreaterThan(Integer minPoints) {
Expression filter = Expression.builder()
.expression("points >= :minPoints")
.putExpressionValue(":minPoints", AttributeValue.builder().n(minPoints.toString()).build())
.build();
return customerTable.scan(r -> r.filterExpression(filter))
.items().stream()
.collect(Collectors.toList());
}
}
```
## Service Layer Implementation
### Service with Transaction Management
```java
@Service
@Transactional
public class CustomerService {
private final CustomerRepository customerRepository;
private final OrderRepository orderRepository;
private final DynamoDbEnhancedClient enhancedClient;
public CustomerService(CustomerRepository customerRepository,
OrderRepository orderRepository,
DynamoDbEnhancedClient enhancedClient) {
this.customerRepository = customerRepository;
this.orderRepository = orderRepository;
this.enhancedClient = enhancedClient;
}
public void createCustomerWithOrder(Customer customer, Order order) {
// Use transaction for atomic operation
enhancedClient.transactWriteItems(r -> r
.addPutItem(getCustomerTable(), customer)
.addPutItem(getOrderTable(), order));
}
private DynamoDbTable<Customer> getCustomerTable() {
return enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
}
private DynamoDbTable<Order> getOrderTable() {
return enhancedClient.table("Orders", TableSchema.fromBean(Order.class));
}
}
```
### Async Operations
```java
@Service
public class AsyncCustomerService {
private final DynamoDbEnhancedClient enhancedClient;
public CompletableFuture<Void> saveCustomerAsync(Customer customer) {
return CompletableFuture.runAsync(() -> {
DynamoDbTable<Customer> table = enhancedClient.table(
"Customers",
TableSchema.fromBean(Customer.class));
table.putItem(customer);
});
}
public CompletableFuture<List<Customer>> findCustomersByPointsAsync(Integer minPoints) {
return CompletableFuture.supplyAsync(() -> {
Expression filter = Expression.builder()
.expression("points >= :minPoints")
.putExpressionValue(":minPoints", AttributeValue.builder().n(minPoints.toString()).build())
.build();
DynamoDbTable<Customer> table = enhancedClient.table(
"Customers",
TableSchema.fromBean(Customer.class));
return table.scan(r -> r.filterExpression(filter))
.items().stream()
.collect(Collectors.toList());
});
}
}
```
## Testing with LocalStack
### Test Configuration
```java
@TestConfiguration
@ContextConfiguration(classes = {LocalStackDynamoDbConfig.class})
public class DynamoDbTestConfig {
@Bean
public DynamoDbClient dynamoDbClient() {
return LocalStackDynamoDbConfig.dynamoDbClient();
}
@Bean
public DynamoDbEnhancedClient dynamoDbEnhancedClient() {
return DynamoDbEnhancedClient.builder()
.dynamoDbClient(dynamoDbClient())
.build();
}
}
@SpringBootTest(classes = {DynamoDbTestConfig.class})
@Import(DynamoDbTestConfig.class)
public class CustomerRepositoryIntegrationTest {
@Autowired
private DynamoDbEnhancedClient enhancedClient;
@BeforeEach
void setUp() {
// Clean up test data
clearTestData();
}
@Test
void testCustomerOperations() {
// Test implementation
}
}
```
### LocalStack Container Setup
```java
public class LocalStackDynamoDbConfig {
@Container
static LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.0"))
.withServices(LocalStackContainer.Service.DYNAMODB);
@Bean
@DynamicPropertySource
public static void configureProperties(DynamicPropertyRegistry registry) {
registry.add("aws.region", () -> Region.US_EAST_1.toString());
registry.add("aws.accessKeyId", () -> localstack.getAccessKey());
registry.add("aws.secretAccessKey", () -> localstack.getSecretKey());
registry.add("aws.endpoint",
() -> localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB).toString());
}
@Bean
public DynamoDbClient dynamoDbClient(
@Value("${aws.region}") String region,
@Value("${aws.accessKeyId}") String accessKeyId,
@Value("${aws.secretAccessKey}") String secretAccessKey,
@Value("${aws.endpoint}") String endpoint) {
return DynamoDbClient.builder()
.region(Region.of(region))
.endpointOverride(URI.create(endpoint))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(accessKeyId, secretAccessKey)))
.build();
}
}
```
## Health Check Integration
### Custom Health Indicator
```java
@Component
public class DynamoDbHealthIndicator implements HealthIndicator {
private final DynamoDbClient dynamoDbClient;
public DynamoDbHealthIndicator(DynamoDbClient dynamoDbClient) {
this.dynamoDbClient = dynamoDbClient;
}
@Override
public Health health() {
try {
dynamoDbClient.listTables();
return Health.up()
.withDetail("region", dynamoDbClient.serviceClientConfiguration().region())
.build();
} catch (Exception e) {
return Health.down()
.withException(e)
.build();
}
}
}
```
## Metrics Collection
### Micrometer Integration
```java
@Component
public class DynamoDbMetricsCollector {
private final DynamoDbClient dynamoDbClient;
private final MeterRegistry meterRegistry;
@EventListener
public void handleDynamoDbOperation(DynamoDbOperationEvent event) {
Timer.Sample sample = Timer.start();
sample.stop(Timer.builder("dynamodb.operation")
.tag("operation", event.getOperation())
.tag("table", event.getTable())
.register(meterRegistry));
}
}
public class DynamoDbOperationEvent {
private String operation;
private String table;
private long duration;
// Getters and setters
}
```

View File

@@ -0,0 +1,407 @@
# Testing Strategies for DynamoDB
This document provides comprehensive testing strategies for DynamoDB applications using the AWS SDK for Java 2.x.
## Unit Testing with Mocks
### Mocking DynamoDbClient
```java
@ExtendWith(MockitoExtension.class)
class CustomerServiceTest {
@Mock
private DynamoDbClient dynamoDbClient;
@Mock
private DynamoDbEnhancedClient enhancedClient;
@Mock
private DynamoDbTable<Customer> customerTable;
@InjectMocks
private CustomerService customerService;
@Test
void saveCustomer_ShouldReturnSavedCustomer() {
// Arrange
Customer customer = new Customer("123", "John Doe", "john@example.com");
when(enhancedClient.table(anyString(), any(TableSchema.class)))
.thenReturn(customerTable);
when(customerTable.putItem(customer))
.thenReturn(null);
// Act
Customer result = customerService.saveCustomer(customer);
// Assert
assertNotNull(result);
assertEquals("123", result.getCustomerId());
verify(customerTable).putItem(customer);
}
@Test
void getCustomer_NotFound_ShouldReturnEmpty() {
// Arrange
when(enhancedClient.table(anyString(), any(TableSchema.class)))
.thenReturn(customerTable);
when(customerTable.getItem(any(Key.class)))
.thenReturn(null);
// Act
Optional<Customer> result = customerService.getCustomer("123");
// Assert
assertFalse(result.isPresent());
verify(customerTable).getItem(any(Key.class));
}
}
```
### Testing Query Operations
```java
@Test
void queryCustomersByStatus_ShouldReturnMatchingCustomers() {
// Arrange
List<Customer> mockCustomers = List.of(
new Customer("1", "Alice", "alice@example.com"),
new Customer("2", "Bob", "bob@example.com")
);
DynamoDbTable<Customer> mockTable = mock(DynamoDbTable.class);
DynamoDbIndex<Customer> mockIndex = mock(DynamoDbIndex.class);
QueryEnhancedRequest queryRequest = QueryEnhancedRequest.builder()
.queryConditional(QueryConditional.keyEqualTo(Key.builder()
.partitionValue("ACTIVE")
.build()))
.build();
when(enhancedClient.table("Customers", TableSchema.fromBean(Customer.class)))
.thenReturn(mockTable);
when(mockTable.index("status-index"))
.thenReturn(mockIndex);
when(mockIndex.query(queryRequest))
.thenReturn(PaginatedQueryIterable.from(mock(Customer.class), mock(QueryResponseEnhanced.class)));
QueryResponseEnhanced mockResponse = mock(QueryResponseEnhanced.class);
when(mockResponse.items())
.thenReturn(mockCustomers.stream());
when(mockIndex.query(any(QueryEnhancedRequest.class)))
.thenReturn(PaginatedQueryIterable.from(mock(Customer.class), mockResponse));
// Act
List<Customer> result = customerService.findByStatus("ACTIVE");
// Assert
assertEquals(2, result.size());
verify(mockIndex).query(any(QueryEnhancedRequest.class));
}
```
## Integration Testing with Testcontainers
### LocalStack Setup
```java
@Testcontainers
@SpringBootTest
@AutoConfigureMockMvc
class DynamoDbIntegrationTest {
@Container
static LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.0"))
.withServices(LocalStackContainer.Service.DYNAMODB);
@DynamicPropertySource
static void configureProperties(DynamicPropertyRegistry registry) {
registry.add("aws.region", () -> Region.US_EAST_1.toString());
registry.add("aws.accessKeyId", () -> localstack.getAccessKey());
registry.add("aws.secretAccessKey", () -> localstack.getSecretKey());
registry.add("aws.endpoint",
() -> localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB).toString());
}
@Autowired
private DynamoDbEnhancedClient enhancedClient;
@BeforeEach
void setup() {
createTestTable();
}
@Test
void testCustomerCRUDOperations() {
// Test create
Customer customer = new Customer("test-123", "Test User", "test@example.com");
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
.putItem(customer);
// Test read
Customer retrieved = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
.getItem(Key.builder().partitionValue("test-123").build());
assertNotNull(retrieved);
assertEquals("Test User", retrieved.getName());
// Test update
customer.setPoints(1000);
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
.putItem(customer);
// Test delete
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
.deleteItem(Key.builder().partitionValue("test-123").build());
}
private void createTestTable() {
DynamoDbClient client = DynamoDbClient.builder()
.region(Region.US_EAST_1)
.endpointOverride(localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(localstack.getAccessKey(), localstack.getSecretKey())))
.build();
CreateTableRequest request = CreateTableRequest.builder()
.tableName("Customers")
.keySchema(KeySchemaElement.builder()
.attributeName("customerId")
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName("customerId")
.attributeType(ScalarAttributeType.S)
.build())
.provisionedThroughput(ProvisionedThroughput.builder()
.readCapacityUnits(5L)
.writeCapacityUnits(5L)
.build())
.build();
client.createTable(request);
waiterForTableActive(client, "Customers");
}
private void waiterForTableActive(DynamoDbClient client, String tableName) {
Waiter waiter = client.waiter();
CreateTableResponse response = client.createTable(request);
waiter.waitUntilTableExists(r -> r
.tableName(tableName)
.maxWait(Duration.ofSeconds(30)));
try {
waiter.waitUntilTableExists(r -> r.tableName(tableName));
} catch (WaiterTimeoutException e) {
throw new RuntimeException("Table creation timed out", e);
}
}
}
```
### Testcontainers with H2 Migration
```java
@SpringBootTest
@Testcontainers
@AutoConfigureDataJpa
class CustomerRepositoryTest {
@Container
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>("postgres:15-alpine")
.withDatabaseName("testdb")
.withUsername("test")
.withPassword("test");
@DynamicPropertySource
static void postgresProperties(DynamicPropertyRegistry registry) {
registry.add("spring.datasource.url", postgres::getJdbcUrl);
registry.add("spring.datasource.username", postgres::getUsername);
registry.add("spring.datasource.password", postgres::getPassword);
}
@Autowired
private CustomerRepository customerRepository;
@Autowired
private DynamoDbEnhancedClient dynamoDbClient;
@Test
void testRepositoryWithRealDatabase() {
// Test with real database
Customer customer = new Customer("123", "Test User", "test@example.com");
customerRepository.save(customer);
Customer retrieved = customerRepository.findById("123").orElse(null);
assertNotNull(retrieved);
assertEquals("Test User", retrieved.getName());
}
}
```
## Performance Testing
### Load Testing with Gatling
```java
class CustomerSimulation extends Simulation {
HttpProtocolBuilder httpProtocolBuilder = http
.baseUrl("http://localhost:8080")
.acceptHeader("application/json");
ScenarioBuilder scn = scenario("Customer Operations")
.exec(http("create_customer")
.post("/api/customers")
.body(StringBody(
"""{
"customerId": "test-123",
"name": "Test User",
"email": "test@example.com"
}"""))
.asJson()
.check(status().is(201)))
.exec(http("get_customer")
.get("/api/customers/test-123")
.check(status().is(200)));
{
setUp(
scn.injectOpen(
rampUsersPerSec(10).to(100).during(60),
constantUsersPerSec(100).during(120)
)
).protocols(httpProtocolBuilder);
}
}
```
### Microbenchmark Testing
```java
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(1)
@State(Scope.Benchmark)
public class DynamoDbPerformanceBenchmark {
private DynamoDbEnhancedClient enhancedClient;
private DynamoDbTable<Customer> customerTable;
private Customer testCustomer;
@Setup
public void setup() {
enhancedClient = DynamoDbEnhancedClient.builder()
.dynamoDbClient(DynamoDbClient.builder().build())
.build();
customerTable = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
testCustomer = new Customer("benchmark-123", "Benchmark User", "benchmark@example.com");
}
@Benchmark
public void testPutItem() {
customerTable.putItem(testCustomer);
}
@Benchmark
public void testGetItem() {
customerTable.getItem(Key.builder().partitionValue("benchmark-123").build());
}
@Benchmark
public void testQuery() {
customerTable.scan().items().stream().collect(Collectors.toList());
}
}
```
## Property-Based Testing
### Using jqwik
```java
@Property
@Report(Reporting.GENERATED)
void customerSerializationShouldBeConsistent(
@ForAll("customers") Customer customer
) {
// When
String serialized = serializeCustomer(customer);
Customer deserialized = deserializeCustomer(serialized);
// Then
assertEquals(customer.getCustomerId(), deserialized.getCustomerId());
assertEquals(customer.getName(), deserialized.getName());
assertEquals(customer.getEmail(), deserialized.getEmail());
}
@Provide
Arbitrary<Customer> customers() {
return Arbitraries.one(
Arbitraries.of("customer-", "user-", "client-").string()
).map(id -> new Customer(
id + Arbitraries.integers().between(1000, 9999).sample(),
Arbitraries.strings().ofLength(10).sample(),
Arbitraries.strings().email().sample()
));
}
```
## Test Data Management
### Test Data Factory
```java
@Component
public class TestDataFactory {
private final DynamoDbEnhancedClient enhancedClient;
@Autowired
public TestDataFactory(DynamoDbEnhancedClient enhancedClient) {
this.enhancedClient = enhancedClient;
}
public Customer createTestCustomer(String id) {
Customer customer = new Customer(
id != null ? id : UUID.randomUUID().toString(),
"Test User",
"test@example.com"
);
customer.setPoints(1000);
customer.setCreatedAt(LocalDateTime.now());
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
.putItem(customer);
return customer;
}
public void cleanupTestData() {
// Implementation to clean up test data
}
}
```
### Test Database Configuration
```java
@TestConfiguration
public class TestDataConfig {
@Bean
public TestDataCleaner testDataCleaner() {
return new TestDataCleaner();
}
}
@Component
public class TestDataCleaner {
private final DynamoDbClient dynamoDbClient;
@EventListener(ApplicationReadyEvent.class)
public void cleanup() {
// Clean up test data before each test run
}
}
```

View File

@@ -0,0 +1,416 @@
---
name: aws-sdk-java-v2-kms
description: AWS Key Management Service (KMS) patterns using AWS SDK for Java 2.x. Use when creating/managing encryption keys, encrypting/decrypting data, generating data keys, digital signing, key rotation, or integrating encryption into Spring Boot applications.
category: aws
tags: [aws, kms, java, sdk, encryption, security]
version: 1.1.0
allowed-tools: Read, Write, Bash, WebFetch
---
# AWS SDK for Java 2.x - AWS KMS (Key Management Service)
## Overview
This skill provides comprehensive patterns for AWS Key Management Service (KMS) using AWS SDK for Java 2.x. Focus on implementing secure encryption solutions with proper key management, envelope encryption, and Spring Boot integration patterns.
## When to Use
Use this skill when:
- Creating and managing symmetric encryption keys for data protection
- Implementing client-side encryption and envelope encryption patterns
- Generating data keys for local data encryption with KMS-managed keys
- Setting up digital signatures and verification with asymmetric keys
- Integrating encryption capabilities into Spring Boot applications
- Implementing secure key lifecycle management
- Setting up key rotation policies and access controls
## Dependencies
### Maven
```xml
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>kms</artifactId>
</dependency>
```
### Gradle
```groovy
implementation 'software.amazon.awssdk:kms:2.x.x'
```
## Client Setup
### Basic Synchronous Client
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kms.KmsClient;
KmsClient kmsClient = KmsClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Basic Asynchronous Client
```java
import software.amazon.awssdk.services.kms.KmsAsyncClient;
KmsAsyncClient kmsAsyncClient = KmsAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Advanced Client Configuration
```java
KmsClient kmsClient = KmsClient.builder()
.region(Region.of(System.getenv("AWS_REGION")))
.credentialsProvider(DefaultCredentialsProvider.create())
.overrideConfiguration(c -> c.retryPolicy(RetryPolicy.builder()
.numRetries(3)
.build()))
.build();
```
## Basic Key Management
### Create Encryption Key
```java
public String createEncryptionKey(KmsClient kmsClient, String description) {
CreateKeyRequest request = CreateKeyRequest.builder()
.description(description)
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
.build();
CreateKeyResponse response = kmsClient.createKey(request);
return response.keyMetadata().keyId();
}
```
### Describe Key
```java
public KeyMetadata getKeyMetadata(KmsClient kmsClient, String keyId) {
DescribeKeyRequest request = DescribeKeyRequest.builder()
.keyId(keyId)
.build();
return kmsClient.describeKey(request).keyMetadata();
}
```
### Enable/Disable Key
```java
public void toggleKeyState(KmsClient kmsClient, String keyId, boolean enable) {
if (enable) {
kmsClient.enableKey(EnableKeyRequest.builder().keyId(keyId).build());
} else {
kmsClient.disableKey(DisableKeyRequest.builder().keyId(keyId).build());
}
}
```
## Basic Encryption and Decryption
### Encrypt Data
```java
public String encryptData(KmsClient kmsClient, String keyId, String plaintext) {
SdkBytes plaintextBytes = SdkBytes.fromString(plaintext, StandardCharsets.UTF_8);
EncryptRequest request = EncryptRequest.builder()
.keyId(keyId)
.plaintext(plaintextBytes)
.build();
EncryptResponse response = kmsClient.encrypt(request);
return Base64.getEncoder().encodeToString(
response.ciphertextBlob().asByteArray());
}
```
### Decrypt Data
```java
public String decryptData(KmsClient kmsClient, String ciphertextBase64) {
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
SdkBytes ciphertextBytes = SdkBytes.fromByteArray(ciphertext);
DecryptRequest request = DecryptRequest.builder()
.ciphertextBlob(ciphertextBytes)
.build();
DecryptResponse response = kmsClient.decrypt(request);
return response.plaintext().asString(StandardCharsets.UTF_8);
}
```
## Envelope Encryption Pattern
### Generate and Use Data Key
```java
public DataKeyResult encryptWithEnvelope(KmsClient kmsClient, String masterKeyId, byte[] data) {
// Generate data key
GenerateDataKeyRequest keyRequest = GenerateDataKeyRequest.builder()
.keyId(masterKeyId)
.keySpec(DataKeySpec.AES_256)
.build();
GenerateDataKeyResponse keyResponse = kmsClient.generateDataKey(keyRequest);
// Encrypt data with data key
byte[] encryptedData = encryptWithAES(data,
keyResponse.plaintext().asByteArray());
// Clear plaintext key from memory
Arrays.fill(keyResponse.plaintext().asByteArray(), (byte) 0);
return new DataKeyResult(
encryptedData,
keyResponse.ciphertextBlob().asByteArray());
}
public byte[] decryptWithEnvelope(KmsClient kmsClient,
DataKeyResult encryptedEnvelope) {
// Decrypt data key
DecryptRequest keyDecryptRequest = DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(
encryptedEnvelope.encryptedKey()))
.build();
DecryptResponse keyDecryptResponse = kmsClient.decrypt(keyDecryptRequest);
// Decrypt data with decrypted key
byte[] decryptedData = decryptWithAES(
encryptedEnvelope.encryptedData(),
keyDecryptResponse.plaintext().asByteArray());
// Clear plaintext key from memory
Arrays.fill(keyDecryptResponse.plaintext().asByteArray(), (byte) 0);
return decryptedData;
}
```
## Digital Signatures
### Create Signing Key and Sign Data
```java
public String createAndSignData(KmsClient kmsClient, String description, String message) {
// Create signing key
CreateKeyRequest keyRequest = CreateKeyRequest.builder()
.description(description)
.keySpec(KeySpec.RSA_2048)
.keyUsage(KeyUsageType.SIGN_VERIFY)
.build();
CreateKeyResponse keyResponse = kmsClient.createKey(keyRequest);
String keyId = keyResponse.keyMetadata().keyId();
// Sign data
SignRequest signRequest = SignRequest.builder()
.keyId(keyId)
.message(SdkBytes.fromString(message, StandardCharsets.UTF_8))
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
.build();
SignResponse signResponse = kmsClient.sign(signRequest);
return Base64.getEncoder().encodeToString(
signResponse.signature().asByteArray());
}
```
### Verify Signature
```java
public boolean verifySignature(KmsClient kmsClient,
String keyId,
String message,
String signatureBase64) {
byte[] signature = Base64.getDecoder().decode(signatureBase64);
VerifyRequest verifyRequest = VerifyRequest.builder()
.keyId(keyId)
.message(SdkBytes.fromString(message, StandardCharsets.UTF_8))
.signature(SdkBytes.fromByteArray(signature))
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
.build();
VerifyResponse verifyResponse = kmsClient.verify(verifyRequest);
return verifyResponse.signatureValid();
}
```
## Spring Boot Integration
### Configuration Class
```java
@Configuration
public class KmsConfiguration {
@Bean
public KmsClient kmsClient() {
return KmsClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public KmsAsyncClient kmsAsyncClient() {
return KmsAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
```
### Encryption Service
```java
@Service
@RequiredArgsConstructor
public class KmsEncryptionService {
private final KmsClient kmsClient;
@Value("${kms.encryption-key-id}")
private String keyId;
public String encrypt(String plaintext) {
try {
EncryptRequest request = EncryptRequest.builder()
.keyId(keyId)
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.build();
EncryptResponse response = kmsClient.encrypt(request);
return Base64.getEncoder().encodeToString(
response.ciphertextBlob().asByteArray());
} catch (KmsException e) {
throw new RuntimeException("Encryption failed", e);
}
}
public String decrypt(String ciphertextBase64) {
try {
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
DecryptRequest request = DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
.build();
DecryptResponse response = kmsClient.decrypt(request);
return response.plaintext().asString(StandardCharsets.UTF_8);
} catch (KmsException e) {
throw new RuntimeException("Decryption failed", e);
}
}
}
```
## Examples
### Basic Encryption Example
```java
public class BasicEncryptionExample {
public static void main(String[] args) {
KmsClient kmsClient = KmsClient.builder()
.region(Region.US_EAST_1)
.build();
// Create key
String keyId = createEncryptionKey(kmsClient, "Example encryption key");
System.out.println("Created key: " + keyId);
// Encrypt and decrypt
String plaintext = "Hello, World!";
String encrypted = encryptData(kmsClient, keyId, plaintext);
String decrypted = decryptData(kmsClient, encrypted);
System.out.println("Original: " + plaintext);
System.out.println("Decrypted: " + decrypted);
}
}
```
### Envelope Encryption Example
```java
public class EnvelopeEncryptionExample {
public static void main(String[] args) {
KmsClient kmsClient = KmsClient.builder()
.region(Region.US_EAST_1)
.build();
String masterKeyId = "alias/your-master-key";
String largeData = "This is a large amount of data that needs encryption...";
byte[] data = largeData.getBytes(StandardCharsets.UTF_8);
// Encrypt using envelope pattern
DataKeyResult encryptedEnvelope = encryptWithEnvelope(
kmsClient, masterKeyId, data);
// Decrypt
byte[] decryptedData = decryptWithEnvelope(
kmsClient, encryptedEnvelope);
String result = new String(decryptedData, StandardCharsets.UTF_8);
System.out.println("Decrypted: " + result);
}
}
```
## Best Practices
### Security
- **Always use envelope encryption for large data** - Encrypt data locally and only encrypt the data key with KMS
- **Use encryption context** - Add contextual information to track and audit usage
- **Never log sensitive data** - Avoid logging plaintext or encryption keys
- **Implement proper key lifecycle** - Enable automatic rotation and set deletion policies
- **Use separate keys for different purposes** - Don't reuse keys across multiple applications
### Performance
- **Cache encrypted data keys** - Reduce KMS API calls by caching data keys
- **Use async operations** - Leverage async clients for non-blocking I/O
- **Reuse client instances** - Don't create new clients for each operation
- **Implement connection pooling** - Configure proper connection pooling settings
### Error Handling
- **Implement retry logic** - Handle throttling exceptions with exponential backoff
- **Check key states** - Verify key is enabled before performing operations
- **Use circuit breakers** - Prevent cascading failures during KMS outages
- **Log errors comprehensively** - Include KMS error codes and context
## References
For detailed implementation patterns, advanced techniques, and comprehensive examples:
- @references/technical-guide.md - Complete technical implementation patterns
- @references/spring-boot-integration.md - Spring Boot integration patterns
- @references/testing.md - Testing strategies and examples
- @references/best-practices.md - Security and operational best practices
## Related Skills
- @aws-sdk-java-v2-core - Core AWS SDK patterns and configuration
- @aws-sdk-java-v2-dynamodb - DynamoDB integration patterns
- @aws-sdk-java-v2-secrets-manager - Secrets management patterns
- @spring-boot-dependency-injection - Spring dependency injection patterns
## External References
- [AWS KMS Developer Guide](https://docs.aws.amazon.com/kms/latest/developerguide/)
- [AWS SDK for Java 2.x Documentation](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html)
- [KMS Best Practices](https://docs.aws.amazon.com/kms/latest/developerguide/best-practices.html)

View File

@@ -0,0 +1,550 @@
# AWS KMS Best Practices
## Security Best Practices
### Key Management
1. **Use Separate Keys for Different Purposes**
- Create unique keys for different applications or data types
- Avoid reusing keys across multiple purposes
- Use aliases instead of raw key IDs for references
```java
// Good: Create specific keys
String encryptionKey = kms.createKey("Database encryption key");
String signingKey = kms.createSigningKey("Document signing key");
// Bad: Use the same key for everything
```
2. **Enable Automatic Key Rotation**
- Enable automatic key rotation for enhanced security
- Review rotation schedules based on compliance requirements
```java
public void enableKeyRotation(KmsClient kmsClient, String keyId) {
EnableKeyRotationRequest request = EnableKeyRotationRequest.builder()
.keyId(keyId)
.build();
kmsClient.enableKeyRotation(request);
}
```
3. **Implement Key Lifecycle Policies**
- Set key expiration dates based on data retention policies
- Schedule key deletion when no longer needed
- Use key policies to enforce lifecycle rules
4. **Use Key Aliases**
- Always use aliases instead of raw key IDs
- Create meaningful aliases following naming conventions
- Regularly review and update aliases
```java
public void createKeyWithAlias(KmsClient kmsClient, String alias, String description) {
// Create key
CreateKeyResponse response = kmsClient.createKey(
CreateKeyRequest.builder()
.description(description)
.build());
// Create alias
CreateAliasRequest aliasRequest = CreateAliasRequest.builder()
.aliasName(alias)
.targetKeyId(response.keyMetadata().keyId())
.build();
kmsClient.createAlias(aliasRequest);
}
```
### Encryption Security
1. **Never Log Plaintext or Encryption Keys**
- Avoid logging sensitive data in any form
- Ensure proper logging configuration to prevent accidental exposure
```java
// Bad: Logging sensitive data
logger.info("Encrypted data: {}", encryptedData);
// Good: Log only metadata
logger.info("Encryption completed for user: {}", userId);
```
2. **Use Encryption Context**
- Always include encryption context for additional security
- Use contextual information to verify data integrity
```java
public Map<String, String> createEncryptionContext(String userId, String dataType) {
return Map.of(
"userId", userId,
"dataType", dataType,
"timestamp", Instant.now().toString()
);
}
```
3. **Implement Least Privilege IAM Policies**
- Grant minimal required permissions to KMS keys
- Use IAM policies to restrict access to specific resources
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::123456789012:role/app-role"},
"Action": [
"kms:Encrypt",
"kms:Decrypt",
"kms:DescribeKey"
],
"Resource": "arn:aws:kms:us-east-1:123456789012:key/your-key-id",
"Condition": {
"StringEquals": {
"kms:EncryptionContext:userId": "${aws:userid}"
}
}
}
]
}
```
4. **Clear Sensitive Data from Memory**
- Explicitly clear sensitive data from memory after use
- Use secure memory management practices
```java
public void secureMemoryExample() {
byte[] sensitiveKey = new byte[32];
// ... use the key ...
// Clear sensitive data
Arrays.fill(sensitiveKey, (byte) 0);
}
```
## Performance Best Practices
1. **Cache Data Keys for Envelope Encryption**
- Cache encrypted data keys to avoid repeated KMS calls
- Use appropriate cache eviction policies
- Monitor cache hit rates
```java
public class DataKeyCache {
private final Cache<String, byte[]> keyCache;
public DataKeyCache() {
this.keyCache = Caffeine.newBuilder()
.expireAfterWrite(1, TimeUnit.HOURS)
.maximumSize(1000)
.build();
}
public byte[] getCachedDataKey(String keyId, KmsClient kmsClient) {
return keyCache.get(keyId, k -> {
GenerateDataKeyResponse response = kmsClient.generateDataKey(
GenerateDataKeyRequest.builder()
.keyId(keyId)
.keySpec(DataKeySpec.AES_256)
.build());
return response.ciphertextBlob().asByteArray();
});
}
}
```
2. **Use Async Operations for Non-Blocking I/O**
- Leverage async clients for parallel processing
- Use CompletableFuture for chaining operations
```java
public CompletableFuture<Void> processMultipleAsync(List<String> dataItems) {
List<CompletableFuture<Void>> futures = dataItems.stream()
.map(item -> CompletableFuture.runAsync(() ->
encryptAndStoreItem(item)))
.collect(Collectors.toList());
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
}
```
3. **Implement Connection Pooling**
- Configure connection pooling for better resource utilization
- Set appropriate pool sizes based on load
```java
public KmsClient createPooledClient() {
return KmsClient.builder()
.region(Region.US_EAST_1)
.httpClientBuilder(ApacheHttpClient.builder()
.maxConnections(100)
.connectionTimeToLive(Duration.ofSeconds(30))
.build())
.build();
}
```
4. **Reuse KMS Client Instances**
- Create and reuse client instances rather than creating new ones
- Use dependency injection for client management
```java
@Service
@RequiredArgsConstructor
public class KmsService {
private final KmsClient kmsClient; // Inject and reuse
public void performOperation() {
// Use the same client instance
kmsClient.someOperation();
}
}
```
## Cost Optimization
1. **Use Envelope Encryption for Large Data**
- Generate data keys for encrypting large datasets
- Only use KMS for encrypting the data key, not the entire dataset
```java
public class EnvelopeEncryption {
private final KmsClient kmsClient;
public byte[] encryptLargeData(byte[] largeData) {
// Generate data key
GenerateDataKeyResponse response = kmsClient.generateDataKey(
GenerateDataKeyRequest.builder()
.keyId("master-key-id")
.keySpec(DataKeySpec.AES_256)
.build());
byte[] encryptedKey = response.ciphertextBlob().asByteArray();
byte[] plaintextKey = response.plaintext().asByteArray();
// Encrypt data with local key
byte[] encryptedData = localEncrypt(largeData, plaintextKey);
// Return both encrypted data and encrypted key
return combine(encryptedKey, encryptedData);
}
}
```
2. **Cache Encrypted Data Keys**
- Cache encrypted data keys to avoid repeated KMS calls
- Use time-based cache expiration
3. **Monitor API Usage**
- Track KMS API calls for billing and optimization
- Set up CloudWatch alarms for unexpected usage
```java
public class KmsUsageMonitor {
private final MeterRegistry meterRegistry;
public void recordEncryption() {
meterRegistry.counter("kms.encryption.count").increment();
meterRegistry.timer("kms.encryption.time").record(() -> {
// Perform encryption
});
}
}
```
4. **Use Data Key Caching Libraries**
- Implement proper caching strategies
- Consider using dedicated caching solutions for data keys
## Error Handling Best Practices
1. **Implement Retry Logic for Throttling**
- Add retry logic for throttling exceptions
- Use exponential backoff for retries
```java
public class KmsRetryHandler {
private static final int MAX_RETRIES = 3;
private static final long INITIAL_DELAY = 1000; // 1 second
public <T> T executeWithRetry(Supplier<T> operation) {
int attempt = 0;
while (attempt < MAX_RETRIES) {
try {
return operation.get();
} catch (KmsException e) {
if (!isRetryable(e) || attempt == MAX_RETRIES - 1) {
throw e;
}
attempt++;
try {
Thread.sleep(INITIAL_DELAY * (long) Math.pow(2, attempt));
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Retry interrupted", ie);
}
}
}
throw new IllegalStateException("Should not reach here");
}
private boolean isRetryable(KmsException e) {
return "ThrottlingException".equals(e.awsErrorDetails().errorCode());
}
}
```
2. **Handle Key State Errors Gracefully**
- Check key state before performing operations
- Handle key states like PendingDeletion, Disabled, etc.
```java
public void performOperationWithKeyStateCheck(KmsClient kmsClient, String keyId) {
KeyMetadata metadata = describeKey(kmsClient, keyId);
switch (metadata.keyState()) {
case ENABLED:
// Perform operation
break;
case DISABLED:
throw new IllegalStateException("Key is disabled");
case PENDING_DELETION:
throw new IllegalStateException("Key is scheduled for deletion");
default:
throw new IllegalStateException("Unknown key state: " + metadata.keyState());
}
}
```
3. **Log KMS-Specific Error Codes**
- Implement comprehensive error logging
- Map KMS error codes to meaningful application errors
```java
public class KmsErrorHandler {
public String mapKmsErrorToAppError(KmsException e) {
String errorCode = e.awsErrorDetails().errorCode();
switch (errorCode) {
case "NotFoundException":
return "Key not found";
case "DisabledException":
return "Key is disabled";
case "AccessDeniedException":
return "Access denied";
case "InvalidKeyUsageException":
return "Invalid key usage";
default:
return "KMS error: " + errorCode;
}
}
}
```
4. **Implement Circuit Breakers**
- Use circuit breakers to handle KMS unavailability
- Prevent cascading failures during outages
```java
public class KmsCircuitBreaker {
private final CircuitBreaker circuitBreaker;
public KmsCircuitBreaker() {
this.circuitBreaker = CircuitBreaker.builder()
.name("kmsService")
.failureRateThreshold(50)
.waitDurationInOpenState(Duration.ofSeconds(30))
.ringBufferSizeInHalfOpenState(2)
.ringBufferSizeInClosedState(2)
.build();
}
public <T> T executeWithCircuitBreaker(Callable<T> operation) {
return circuitBreaker.executeCallable(() -> {
try {
return operation.call();
} catch (KmsException e) {
if (isFailure(e)) {
throw new CircuitBreakerOpenException("KMS service unavailable");
}
throw e;
}
});
}
private boolean isFailure(KmsException e) {
return "KMSDisabledException".equals(e.awsErrorDetails().errorCode());
}
}
```
## Testing Best Practices
1. **Test with Mock KMS Client**
- Use mock clients for unit tests
- Verify all expected interactions
```java
@Test
void shouldEncryptWithProperEncryptionContext() {
// Arrange
when(kmsClient.encrypt(any(EncryptRequest.class))).thenReturn(...);
// Act
String result = encryptionService.encrypt("test", "user123");
// Assert
verify(kmsClient).encrypt(argThat(request ->
request.encryptionContext().containsKey("userId") &&
request.encryptionContext().get("userId").equals("user123")));
}
```
2. **Test Error Scenarios**
- Test various error conditions
- Verify proper error handling and recovery
3. **Performance Testing**
- Test performance under load
- Measure latency and throughput
4. **Integration Testing with Local KMS**
- Test with local KMS when possible
- Verify integration with real AWS services
## Monitoring and Observability
1. **Implement Comprehensive Logging**
- Log all KMS operations with appropriate levels
- Include correlation IDs for tracing
```java
public class KmsLoggingAspect {
private static final Logger logger = LoggerFactory.getLogger(KmsService.class);
@Around("execution(* com.yourcompany.kms..*.*(..))")
public Object logKmsOperation(ProceedingJoinPoint joinPoint) throws Throwable {
String operation = joinPoint.getSignature().getName();
logger.info("Starting KMS operation: {}", operation);
long startTime = System.currentTimeMillis();
try {
Object result = joinPoint.proceed();
long duration = System.currentTimeMillis() - startTime;
logger.info("Completed KMS operation: {} in {}ms", operation, duration);
return result;
} catch (Exception e) {
long duration = System.currentTimeMillis() - startTime;
logger.error("KMS operation {} failed in {}ms: {}", operation, duration, e.getMessage());
throw e;
}
}
}
```
2. **Set Up CloudWatch Alarms**
- Monitor API call rates
- Set up alarms for error rates
- Track key usage patterns
3. **Use Distributed Tracing**
- Implement tracing for KMS operations
- Correlate KMS calls with application operations
4. **Monitor Key Usage Metrics**
- Track key usage patterns
- Monitor for unusual usage patterns
## Compliance and Auditing
1. **Enable KMS Key Usage Logging**
- Configure CloudTrail to log KMS operations
- Enable detailed logging for compliance
2. **Regular Security Audits**
- Conduct regular audits of KMS key usage
- Review access policies periodically
3. **Comprehensive Backup Strategy**
- Implement key backup and recovery procedures
- Test backup restoration processes
4. **Comprehensive Access Reviews**
- Regularly review IAM policies for KMS access
- Remove unnecessary permissions
## Advanced Security Considerations
1. **Multi-Region KMS Keys**
- Consider multi-region keys for disaster recovery
- Test failover scenarios
2. **Cross-Account Access**
- Implement proper cross-account access controls
- Use resource-based policies for account sharing
3. **Custom Key Stores**
- Consider custom key stores for enhanced security
- Implement proper key management in custom stores
4. **Key Material External**
- Use imported key material for enhanced control
- Implement proper key rotation for imported keys
## Development Best Practices
1. **Use Dependency Injection**
- Inject KMS clients rather than creating them directly
- Use proper configuration management
```java
@Configuration
@ConfigurationProperties(prefix = "aws.kms")
public class KmsProperties {
private String region = "us-east-1";
private String encryptionKeyId;
private int maxRetries = 3;
// Getters and setters
}
```
2. **Proper Configuration Management**
- Use environment-specific configurations
- Secure sensitive configuration values
3. **Version Control and Documentation**
- Keep KMS-related code well documented
- Track key usage patterns in version control
4. **Code Reviews**
- Conduct thorough code reviews for KMS-related code
- Focus on security and error handling
## Implementation Checklists
### Key Setup Checklist
- [ ] Create appropriate KMS keys for different purposes
- [ ] Enable automatic key rotation
- [ ] Set up key aliases
- [ ] Configure IAM policies with least privilege
- [ ] Set up CloudTrail logging
### Implementation Checklist
- [ ] Use envelope encryption for large data
- [ ] Implement proper error handling
- [ ] Add comprehensive logging
- [ ] Set up monitoring and alarms
- [ ] Write comprehensive tests
### Security Checklist
- [ ] Never log sensitive data
- [ ] Use encryption context
- [ ] Implement proper access controls
- [ ] Clear sensitive data from memory
- [ ] Regularly audit access patterns
By following these best practices, you can ensure that your AWS KMS implementation is secure, performant, cost-effective, and maintainable.

View File

@@ -0,0 +1,504 @@
# Spring Boot Integration with AWS KMS
## Configuration
### Basic Configuration
```java
@Configuration
public class KmsConfiguration {
@Bean
public KmsClient kmsClient() {
return KmsClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public KmsAsyncClient kmsAsyncClient() {
return KmsAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
```
### Configuration with Custom Settings
```java
@Configuration
@ConfigurationProperties(prefix = "aws.kms")
public class KmsAdvancedConfiguration {
private Region region = Region.US_EAST_1;
private String endpoint;
private Duration timeout = Duration.ofSeconds(10);
private String accessKeyId;
private String secretAccessKey;
@Bean
public KmsClient kmsClient() {
KmsClientBuilder builder = KmsClient.builder()
.region(region)
.overrideConfiguration(c -> c.retryPolicy(RetryPolicy.builder()
.numRetries(3)
.build()));
if (endpoint != null) {
builder.endpointOverride(URI.create(endpoint));
}
// Add credentials if provided
if (accessKeyId != null && secretAccessKey != null) {
builder.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(accessKeyId, secretAccessKey)));
}
return builder.build();
}
// Getters and Setters
public Region getRegion() { return region; }
public void setRegion(Region region) { this.region = region; }
public String getEndpoint() { return endpoint; }
public void setEndpoint(String endpoint) { this.endpoint = endpoint; }
public Duration getTimeout() { return timeout; }
public void setTimeout(Duration timeout) { this.timeout = timeout; }
public String getAccessKeyId() { return accessKeyId; }
public void setAccessKeyId(String accessKeyId) { this.accessKeyId = accessKeyId; }
public String getSecretAccessKey() { return secretAccessKey; }
public void setSecretAccessKey(String secretAccessKey) { this.secretAccessKey = secretAccessKey; }
}
```
### Application Properties
```properties
# AWS KMS Configuration
aws.kms.region=us-east-1
aws.kms.endpoint=
aws.kms.timeout=10s
aws.kms.access-key-id=
aws.kms.secret-access-key=
# KMS Key Configuration
kms.encryption-key-id=alias/your-encryption-key
kms.signing-key-id=alias/your-signing-key
```
## Encryption Service
### Basic Encryption Service
```java
@Service
public class KmsEncryptionService {
private final KmsClient kmsClient;
@Value("${kms.encryption-key-id}")
private String keyId;
public KmsEncryptionService(KmsClient kmsClient) {
this.kmsClient = kmsClient;
}
public String encrypt(String plaintext) {
try {
EncryptRequest request = EncryptRequest.builder()
.keyId(keyId)
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.build();
EncryptResponse response = kmsClient.encrypt(request);
// Return Base64-encoded ciphertext
return Base64.getEncoder()
.encodeToString(response.ciphertextBlob().asByteArray());
} catch (KmsException e) {
throw new RuntimeException("Encryption failed", e);
}
}
public String decrypt(String ciphertextBase64) {
try {
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
DecryptRequest request = DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
.build();
DecryptResponse response = kmsClient.decrypt(request);
return response.plaintext().asString(StandardCharsets.UTF_8);
} catch (KmsException e) {
throw new RuntimeException("Decryption failed", e);
}
}
}
```
### Secure Data Repository
```java
@Repository
public class SecureDataRepository {
private final KmsEncryptionService encryptionService;
private final JdbcTemplate jdbcTemplate;
public SecureDataRepository(KmsEncryptionService encryptionService,
JdbcTemplate jdbcTemplate) {
this.encryptionService = encryptionService;
this.jdbcTemplate = jdbcTemplate;
}
public void saveSecureData(String id, String sensitiveData) {
String encryptedData = encryptionService.encrypt(sensitiveData);
jdbcTemplate.update(
"INSERT INTO secure_data (id, encrypted_value) VALUES (?, ?)",
id, encryptedData);
}
public String getSecureData(String id) {
String encryptedData = jdbcTemplate.queryForObject(
"SELECT encrypted_value FROM secure_data WHERE id = ?",
String.class, id);
return encryptionService.decrypt(encryptedData);
}
}
```
### Advanced Envelope Encryption Service
```java
@Service
public class EnvelopeEncryptionService {
private final KmsClient kmsClient;
@Value("${kms.master-key-id}")
private String masterKeyId;
private final Cache<String, DataKeyPair> keyCache =
Caffeine.newBuilder()
.expireAfterWrite(1, TimeUnit.HOURS)
.maximumSize(100)
.build();
public EnvelopeEncryptionService(KmsClient kmsClient) {
this.kmsClient = kmsClient;
}
public EncryptedEnvelope encryptLargeData(byte[] data) {
// Check cache for existing key
DataKeyPair dataKeyPair = keyCache.getIfPresent(masterKeyId);
if (dataKeyPair == null) {
// Generate new data key
GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey(
GenerateDataKeyRequest.builder()
.keyId(masterKeyId)
.keySpec(DataKeySpec.AES_256)
.build());
dataKeyPair = new DataKeyPair(
dataKeyResponse.plaintext().asByteArray(),
dataKeyResponse.ciphertextBlob().asByteArray());
// Cache the encrypted key (not plaintext)
keyCache.put(masterKeyId, dataKeyPair);
}
try {
// Encrypt data with plaintext data key
byte[] encryptedData = encryptWithAES(data, dataKeyPair.plaintext());
// Clear plaintext key from memory immediately after use
Arrays.fill(dataKeyPair.plaintext(), (byte) 0);
return new EncryptedEnvelope(encryptedData, dataKeyPair.encrypted());
} catch (Exception e) {
throw new RuntimeException("Envelope encryption failed", e);
}
}
public byte[] decryptLargeData(EncryptedEnvelope envelope) {
// Get data key from cache or decrypt from KMS
DataKeyPair dataKeyPair = keyCache.getIfPresent(masterKeyId);
if (dataKeyPair == null || !Arrays.equals(dataKeyPair.encrypted(), envelope.encryptedKey())) {
// Decrypt data key from KMS
DecryptResponse decryptResponse = kmsClient.decrypt(
DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(envelope.encryptedKey()))
.build());
dataKeyPair = new DataKeyPair(
decryptResponse.plaintext().asByteArray(),
envelope.encryptedKey());
// Cache for future use
keyCache.put(masterKeyId, dataKeyPair);
}
try {
// Decrypt data with plaintext data key
byte[] decryptedData = decryptWithAES(envelope.encryptedData(), dataKeyPair.plaintext());
// Clear plaintext key from memory
Arrays.fill(dataKeyPair.plaintext(), (byte) 0);
return decryptedData;
} catch (Exception e) {
throw new RuntimeException("Envelope decryption failed", e);
}
}
private byte[] encryptWithAES(byte[] data, byte[] key) throws Exception {
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
GCMParameterSpec spec = new GCMParameterSpec(128, key, key.length - 16);
cipher.init(Cipher.ENCRYPT_MODE, keySpec, spec);
return cipher.doFinal(data);
}
private byte[] decryptWithAES(byte[] data, byte[] key) throws Exception {
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
GCMParameterSpec spec = new GCMParameterSpec(128, key, key.length - 16);
cipher.init(Cipher.DECRYPT_MODE, keySpec, spec);
return cipher.doFinal(data);
}
public record DataKeyPair(byte[] plaintext, byte[] encrypted) {}
public record EncryptedEnvelope(byte[] encryptedData, byte[] encryptedKey) {}
}
```
## Data Encryption Interceptor
### SQL Encryption Interceptor
```java
public class KmsDataEncryptInterceptor implements StatementInterceptor {
private final KmsEncryptionService encryptionService;
public KmsDataEncryptInterceptor(KmsEncryptionService encryptionService) {
this.encryptionService = encryptionService;
}
@Override
public ResultSet intercept(ResultSet rs, Statement statement, Connection connection) throws SQLException {
return new EncryptingResultSetWrapper(rs, encryptionService);
}
@Override
public void interceptAfterExecution(Statement statement) {
// No-op
}
}
class EncryptingResultSetWrapper implements ResultSet {
private final ResultSet delegate;
private final KmsEncryptionService encryptionService;
public EncryptingResultSetWrapper(ResultSet delegate, KmsEncryptionService encryptionService) {
this.delegate = delegate;
this.encryptionService = encryptionService;
}
@Override
public String getString(String columnLabel) throws SQLException {
String value = delegate.getString(columnLabel);
if (value == null) return null;
// Check if this is an encrypted column
if (isEncryptedColumn(columnLabel)) {
return encryptionService.decrypt(value);
}
return value;
}
private boolean isEncryptedColumn(String columnLabel) {
// Implement logic to identify encrypted columns
return columnLabel.contains("encrypted") || columnLabel.contains("secure");
}
// Delegate other methods to original ResultSet
@Override
public boolean next() throws SQLException {
return delegate.next();
}
// ... other ResultSet method implementations
}
```
## Configuration Profiles
### Development Profile
```properties
# src/main/resources/application-dev.properties
aws.kms.region=us-east-1
kms.encryption-key-id=alias/dev-encryption-key
logging.level.com.yourcompany=DEBUG
```
### Production Profile
```properties
# src/main/resources/application-prod.properties
aws.kms.region=${AWS_REGION:us-east-1}
kms.encryption-key-id=${KMS_ENCRYPTION_KEY_ID:alias/production-encryption-key}
logging.level.com.yourcompany=WARN
spring.cloud.aws.credentials.access-key=${AWS_ACCESS_KEY_ID}
spring.cloud.aws.credentials.secret-key=${AWS_SECRET_ACCESS_KEY}
```
### Test Configuration
```java
@Configuration
@Profile("test")
public class KmsTestConfiguration {
@Bean
@Primary
public KmsClient testKmsClient() {
// Return a mock or test-specific KMS client
return mock(KmsClient.class);
}
@Bean
public KmsEncryptionService testKmsEncryptionService() {
return new KmsEncryptionService(testKmsClient());
}
}
```
## Health Checks and Monitoring
### KMS Health Indicator
```java
@Component
public class KmsHealthIndicator implements HealthIndicator {
private final KmsClient kmsClient;
private final String keyId;
public KmsHealthIndicator(KmsClient kmsClient,
@Value("${kms.encryption-key-id}") String keyId) {
this.kmsClient = kmsClient;
this.keyId = keyId;
}
@Override
public Health health() {
try {
// Test KMS connectivity by describing the key
DescribeKeyRequest request = DescribeKeyRequest.builder()
.keyId(keyId)
.build();
DescribeKeyResponse response = kmsClient.describeKey(request);
// Check if key is in a healthy state
KeyState keyState = response.keyMetadata().keyState();
boolean isHealthy = keyState == KeyState.ENABLED;
if (isHealthy) {
return Health.up()
.withDetail("keyId", keyId)
.withDetail("keyState", keyState)
.withDetail("keyArn", response.keyMetadata().arn())
.build();
} else {
return Health.down()
.withDetail("keyId", keyId)
.withDetail("keyState", keyState)
.withDetail("message", "KMS key is not in ENABLED state")
.build();
}
} catch (KmsException e) {
return Health.down()
.withDetail("keyId", keyId)
.withDetail("error", e.awsErrorDetails().errorMessage())
.withDetail("errorCode", e.awsErrorDetails().errorCode())
.build();
}
}
}
```
### Metrics Collection
```java
@Service
public class KmsMetricsCollector {
private final MeterRegistry meterRegistry;
private final KmsClient kmsClient;
private final Counter encryptionCounter;
private final Counter decryptionCounter;
private final Timer encryptionTimer;
private final Timer decryptionTimer;
public KmsMetricsCollector(MeterRegistry meterRegistry, KmsClient kmsClient) {
this.meterRegistry = meterRegistry;
this.kmsClient = kmsClient;
this.encryptionCounter = Counter.builder("kms.encryption.count")
.description("Number of encryption operations")
.register(meterRegistry);
this.decryptionCounter = Counter.builder("kms.decryption.count")
.description("Number of decryption operations")
.register(meterRegistry);
this.encryptionTimer = Timer.builder("kms.encryption.time")
.description("Time taken for encryption operations")
.register(meterRegistry);
this.decryptionTimer = Timer.builder("kms.decryption.time")
.description("Time taken for decryption operations")
.register(meterRegistry);
}
public String encryptWithMetrics(String plaintext) {
encryptionCounter.increment();
return encryptionTimer.record(() -> {
try {
EncryptRequest request = EncryptRequest.builder()
.keyId("your-key-id")
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.build();
EncryptResponse response = kmsClient.encrypt(request);
return Base64.getEncoder().encodeToString(
response.ciphertextBlob().asByteArray());
} catch (KmsException e) {
meterRegistry.counter("kms.encryption.errors")
.increment();
throw e;
}
});
}
}
```

View File

@@ -0,0 +1,639 @@
# AWS KMS Technical Guide
## Key Management Operations
### Create KMS Key
```java
import software.amazon.awssdk.services.kms.model.*;
import java.util.stream.Collectors;
public String createKey(KmsClient kmsClient, String description) {
try {
CreateKeyRequest request = CreateKeyRequest.builder()
.description(description)
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
.origin(OriginType.AWS_KMS)
.build();
CreateKeyResponse response = kmsClient.createKey(request);
String keyId = response.keyMetadata().keyId();
System.out.println("Created key: " + keyId);
return keyId;
} catch (KmsException e) {
System.err.println("Error creating key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Create Key with Custom Key Store
```java
public String createKeyWithCustomStore(KmsClient kmsClient,
String description,
String customKeyStoreId) {
CreateKeyRequest request = CreateKeyRequest.builder()
.description(description)
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
.origin(OriginType.AWS_CLOUDHSM)
.customKeyStoreId(customKeyStoreId)
.build();
CreateKeyResponse response = kmsClient.createKey(request);
return response.keyMetadata().keyId();
}
```
### List Keys
```java
import java.util.List;
public List<KeyListEntry> listKeys(KmsClient kmsClient) {
try {
ListKeysRequest request = ListKeysRequest.builder()
.limit(100)
.build();
ListKeysResponse response = kmsClient.listKeys(request);
response.keys().forEach(key -> {
System.out.println("Key ARN: " + key.keyArn());
System.out.println("Key ID: " + key.keyId());
System.out.println();
});
return response.keys();
} catch (KmsException e) {
System.err.println("Error listing keys: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### List Keys with Pagination (Async)
```java
import software.amazon.awssdk.services.kms.paginators.ListKeysPublisher;
import java.util.concurrent.CompletableFuture;
public CompletableFuture<Void> listAllKeysAsync(KmsAsyncClient kmsAsyncClient) {
ListKeysRequest request = ListKeysRequest.builder()
.limit(15)
.build();
ListKeysPublisher keysPublisher = kmsAsyncClient.listKeysPaginator(request);
return keysPublisher
.subscribe(r -> r.keys().forEach(key ->
System.out.println("Key ARN: " + key.keyArn())))
.whenComplete((result, exception) -> {
if (exception != null) {
System.err.println("Error: " + exception.getMessage());
} else {
System.out.println("Successfully listed all keys");
}
});
}
```
### Describe Key
```java
public KeyMetadata describeKey(KmsClient kmsClient, String keyId) {
try {
DescribeKeyRequest request = DescribeKeyRequest.builder()
.keyId(keyId)
.build();
DescribeKeyResponse response = kmsClient.describeKey(request);
KeyMetadata metadata = response.keyMetadata();
System.out.println("Key ID: " + metadata.keyId());
System.out.println("Key ARN: " + metadata.arn());
System.out.println("Key State: " + metadata.keyState());
System.out.println("Creation Date: " + metadata.creationDate());
System.out.println("Enabled: " + metadata.enabled());
return metadata;
} catch (KmsException e) {
System.err.println("Error describing key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Enable/Disable Key
```java
public void enableKey(KmsClient kmsClient, String keyId) {
try {
EnableKeyRequest request = EnableKeyRequest.builder()
.keyId(keyId)
.build();
kmsClient.enableKey(request);
System.out.println("Key enabled: " + keyId);
} catch (KmsException e) {
System.err.println("Error enabling key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
public void disableKey(KmsClient kmsClient, String keyId) {
try {
DisableKeyRequest request = DisableKeyRequest.builder()
.keyId(keyId)
.build();
kmsClient.disableKey(request);
System.out.println("Key disabled: " + keyId);
} catch (KmsException e) {
System.err.println("Error disabling key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
## Encryption and Decryption
### Encrypt Data
```java
import software.amazon.awssdk.core.SdkBytes;
import java.nio.charset.StandardCharsets;
public byte[] encryptData(KmsClient kmsClient, String keyId, String plaintext) {
try {
SdkBytes plaintextBytes = SdkBytes.fromString(plaintext, StandardCharsets.UTF_8);
EncryptRequest request = EncryptRequest.builder()
.keyId(keyId)
.plaintext(plaintextBytes)
.build();
EncryptResponse response = kmsClient.encrypt(request);
byte[] encryptedData = response.ciphertextBlob().asByteArray();
System.out.println("Data encrypted successfully");
return encryptedData;
} catch (KmsException e) {
System.err.println("Error encrypting data: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Decrypt Data
```java
public String decryptData(KmsClient kmsClient, byte[] ciphertext) {
try {
SdkBytes ciphertextBytes = SdkBytes.fromByteArray(ciphertext);
DecryptRequest request = DecryptRequest.builder()
.ciphertextBlob(ciphertextBytes)
.build();
DecryptResponse response = kmsClient.decrypt(request);
String decryptedText = response.plaintext().asString(StandardCharsets.UTF_8);
System.out.println("Data decrypted successfully");
return decryptedText;
} catch (KmsException e) {
System.err.println("Error decrypting data: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Encrypt with Encryption Context
```java
import java.util.Map;
public byte[] encryptWithContext(KmsClient kmsClient,
String keyId,
String plaintext,
Map<String, String> encryptionContext) {
try {
EncryptRequest request = EncryptRequest.builder()
.keyId(keyId)
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.encryptionContext(encryptionContext)
.build();
EncryptResponse response = kmsClient.encrypt(request);
return response.ciphertextBlob().asByteArray();
} catch (KmsException e) {
System.err.println("Error encrypting with context: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
## Data Key Generation (Envelope Encryption)
### Generate Data Key
```java
public record DataKeyPair(byte[] plaintext, byte[] encrypted) {}
public DataKeyPair generateDataKey(KmsClient kmsClient, String keyId) {
try {
GenerateDataKeyRequest request = GenerateDataKeyRequest.builder()
.keyId(keyId)
.keySpec(DataKeySpec.AES_256)
.build();
GenerateDataKeyResponse response = kmsClient.generateDataKey(request);
byte[] plaintextKey = response.plaintext().asByteArray();
byte[] encryptedKey = response.ciphertextBlob().asByteArray();
System.out.println("Data key generated");
return new DataKeyPair(plaintextKey, encryptedKey);
} catch (KmsException e) {
System.err.println("Error generating data key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Generate Data Key Without Plaintext
```java
public byte[] generateDataKeyWithoutPlaintext(KmsClient kmsClient, String keyId) {
try {
GenerateDataKeyWithoutPlaintextRequest request =
GenerateDataKeyWithoutPlaintextRequest.builder()
.keyId(keyId)
.keySpec(DataKeySpec.AES_256)
.build();
GenerateDataKeyWithoutPlaintextResponse response =
kmsClient.generateDataKeyWithoutPlaintext(request);
return response.ciphertextBlob().asByteArray();
} catch (KmsException e) {
System.err.println("Error generating data key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
## Digital Signing
### Create Signing Key
```java
public String createSigningKey(KmsClient kmsClient, String description) {
try {
CreateKeyRequest request = CreateKeyRequest.builder()
.description(description)
.keySpec(KeySpec.RSA_2048)
.keyUsage(KeyUsageType.SIGN_VERIFY)
.origin(OriginType.AWS_KMS)
.build();
CreateKeyResponse response = kmsClient.createKey(request);
return response.keyMetadata().keyId();
} catch (KmsException e) {
System.err.println("Error creating signing key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Sign Data
```java
public byte[] signData(KmsClient kmsClient, String keyId, String message) {
try {
SdkBytes messageBytes = SdkBytes.fromString(message, StandardCharsets.UTF_8);
SignRequest request = SignRequest.builder()
.keyId(keyId)
.message(messageBytes)
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
.build();
SignResponse response = kmsClient.sign(request);
byte[] signature = response.signature().asByteArray();
System.out.println("Data signed successfully");
return signature;
} catch (KmsException e) {
System.err.println("Error signing data: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Verify Signature
```java
public boolean verifySignature(KmsClient kmsClient,
String keyId,
String message,
byte[] signature) {
try {
VerifyRequest request = VerifyRequest.builder()
.keyId(keyId)
.message(SdkBytes.fromString(message, StandardCharsets.UTF_8))
.signature(SdkBytes.fromByteArray(signature))
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
.build();
VerifyResponse response = kmsClient.verify(request);
boolean isValid = response.signatureValid();
System.out.println("Signature valid: " + isValid);
return isValid;
} catch (KmsException e) {
System.err.println("Error verifying signature: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### Sign and Verify (Async)
```java
public CompletableFuture<Boolean> signAndVerifyAsync(KmsAsyncClient kmsAsyncClient,
String message) {
String signMessage = message;
// Create signing key
CreateKeyRequest createKeyRequest = CreateKeyRequest.builder()
.keySpec(KeySpec.RSA_2048)
.keyUsage(KeyUsageType.SIGN_VERIFY)
.origin(OriginType.AWS_KMS)
.build();
return kmsAsyncClient.createKey(createKeyRequest)
.thenCompose(createKeyResponse -> {
String keyId = createKeyResponse.keyMetadata().keyId();
SdkBytes messageBytes = SdkBytes.fromString(signMessage, StandardCharsets.UTF_8);
SignRequest signRequest = SignRequest.builder()
.keyId(keyId)
.message(messageBytes)
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
.build();
return kmsAsyncClient.sign(signRequest)
.thenCompose(signResponse -> {
byte[] signedBytes = signResponse.signature().asByteArray();
VerifyRequest verifyRequest = VerifyRequest.builder()
.keyId(keyId)
.message(messageBytes)
.signature(SdkBytes.fromByteArray(signedBytes))
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
.build();
return kmsAsyncClient.verify(verifyRequest)
.thenApply(VerifyResponse::signatureValid);
});
})
.exceptionally(throwable -> {
throw new RuntimeException("Failed to sign or verify", throwable);
});
}
```
## Key Tagging
### Tag Key
```java
public void tagKey(KmsClient kmsClient, String keyId, Map<String, String> tags) {
try {
List<Tag> tagList = tags.entrySet().stream()
.map(entry -> Tag.builder()
.tagKey(entry.getKey())
.tagValue(entry.getValue())
.build())
.collect(Collectors.toList());
TagResourceRequest request = TagResourceRequest.builder()
.keyId(keyId)
.tags(tagList)
.build();
kmsClient.tagResource(request);
System.out.println("Key tagged successfully");
} catch (KmsException e) {
System.err.println("Error tagging key: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### List Tags
```java
public Map<String, String> listTags(KmsClient kmsClient, String keyId) {
try {
ListResourceTagsRequest request = ListResourceTagsRequest.builder()
.keyId(keyId)
.build();
ListResourceTagsResponse response = kmsClient.listResourceTags(request);
return response.tags().stream()
.collect(Collectors.toMap(Tag::tagKey, Tag::tagValue));
} catch (KmsException e) {
System.err.println("Error listing tags: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
## Advanced Techniques
### Envelope Encryption Service
```java
@Service
public class EnvelopeEncryptionService {
private final KmsClient kmsClient;
@Value("${kms.master-key-id}")
private String masterKeyId;
public EnvelopeEncryptionService(KmsClient kmsClient) {
this.kmsClient = kmsClient;
}
public EncryptedEnvelope encryptLargeData(byte[] data) {
// Generate data key
GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey(
GenerateDataKeyRequest.builder()
.keyId(masterKeyId)
.keySpec(DataKeySpec.AES_256)
.build());
byte[] plaintextKey = dataKeyResponse.plaintext().asByteArray();
byte[] encryptedKey = dataKeyResponse.ciphertextBlob().asByteArray();
try {
// Encrypt data with plaintext data key
byte[] encryptedData = encryptWithAES(data, plaintextKey);
// Clear plaintext key from memory
Arrays.fill(plaintextKey, (byte) 0);
return new EncryptedEnvelope(encryptedData, encryptedKey);
} catch (Exception e) {
throw new RuntimeException("Envelope encryption failed", e);
}
}
public byte[] decryptLargeData(EncryptedEnvelope envelope) {
// Decrypt data key
DecryptResponse decryptResponse = kmsClient.decrypt(
DecryptRequest.builder()
.ciphertextBlob(SdkBytes.fromByteArray(envelope.encryptedKey()))
.build());
byte[] plaintextKey = decryptResponse.plaintext().asByteArray();
try {
// Decrypt data with plaintext data key
byte[] decryptedData = decryptWithAES(envelope.encryptedData(), plaintextKey);
// Clear plaintext key from memory
Arrays.fill(plaintextKey, (byte) 0);
return decryptedData;
} catch (Exception e) {
throw new RuntimeException("Envelope decryption failed", e);
}
}
private byte[] encryptWithAES(byte[] data, byte[] key) throws Exception {
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
cipher.init(Cipher.ENCRYPT_MODE, keySpec);
return cipher.doFinal(data);
}
private byte[] decryptWithAES(byte[] data, byte[] key) throws Exception {
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
cipher.init(Cipher.DECRYPT_MODE, keySpec);
return cipher.doFinal(data);
}
public record EncryptedEnvelope(byte[] encryptedData, byte[] encryptedKey) {}
}
```
### Error Handling Strategies
```java
public class KmsErrorHandler {
private static final int MAX_RETRIES = 3;
private static final long RETRY_DELAY_MS = 1000;
public <T> T executeWithRetry(Supplier<T> operation, String operationName) {
int attempt = 0;
KmsException lastException = null;
while (attempt < MAX_RETRIES) {
try {
return operation.get();
} catch (KmsException e) {
lastException = e;
attempt++;
// Check if it's a throttling error and retryable
if (e.awsErrorDetails().errorCode().equals("ThrottlingException") && attempt < MAX_RETRIES) {
try {
Thread.sleep(RETRY_DELAY_MS);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Retry interrupted", ie);
}
} else {
// Non-retryable error or max retries exceeded
throw e;
}
}
}
throw new RuntimeException(String.format("Failed to execute %s after %d attempts", operationName, MAX_RETRIES), lastException);
}
public boolean isRetryableError(KmsException e) {
String errorCode = e.awsErrorDetails().errorCode();
return "ThrottlingException".equals(errorCode)
|| "TooManyRequestsException".equals(errorCode)
|| "LimitExceededException".equals(errorCode);
}
}
```
### Connection Pooling Configuration
```java
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
public class KmsConnectionPool {
public static KmsClient createPooledClient() {
// Configure connection pool
PoolingHttpClientConnectionManager connectionManager =
new PoolingHttpClientConnectionManager();
connectionManager.setMaxTotal(100);
connectionManager.setDefaultMaxPerRoute(20);
CloseableHttpClient httpClient = HttpClients.custom()
.setConnectionManager(connectionManager)
.build();
ApacheHttpClient.Builder httpClientBuilder = ApacheHttpClient.builder()
.httpClient(httpClient);
return KmsClient.builder()
.region(Region.US_EAST_1)
.httpClientBuilder(httpClientBuilder)
.build();
}
}
```

View File

@@ -0,0 +1,589 @@
# Testing AWS KMS Integration
## Unit Testing with Mocked Client
### Basic Unit Test
```java
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.services.kms.model.*;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.nio.charset.StandardCharsets;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class KmsEncryptionServiceTest {
@Mock
private KmsClient kmsClient;
@InjectMocks
private KmsEncryptionService encryptionService;
@Test
void shouldEncryptData() {
// Arrange
String plaintext = "sensitive data";
byte[] ciphertext = "encrypted".getBytes();
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenReturn(EncryptResponse.builder()
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
.build());
// Act
String result = encryptionService.encrypt(plaintext);
// Assert
assertThat(result).isNotEmpty();
verify(kmsClient).encrypt(any(EncryptRequest.class));
}
@Test
void shouldDecryptData() {
// Arrange
String encryptedText = "ciphertext";
String expectedPlaintext = "sensitive data";
when(kmsClient.decrypt(any(DecryptRequest.class)))
.thenReturn(DecryptResponse.builder()
.plaintext(SdkBytes.fromString(expectedPlaintext, StandardCharsets.UTF_8))
.build());
// Act
String result = encryptionService.decrypt(encryptedText);
// Assert
assertThat(result).isEqualTo(expectedPlaintext);
verify(kmsClient).decrypt(any(DecryptRequest.class));
}
@Test
void shouldThrowExceptionOnEncryptionFailure() {
// Arrange
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenThrow(KmsException.builder()
.awsErrorDetails(AwsErrorDetails.builder()
.errorCode("KMSDisabledException")
.errorMessage("KMS is disabled")
.build())
.build());
// Act & Assert
assertThatThrownBy(() -> encryptionService.encrypt("test"))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("Encryption failed");
}
}
```
### Parameterized Tests
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
class KmsEncryptionParameterizedTest {
@Mock
private KmsClient kmsClient;
@InjectMocks
private KmsEncryptionService encryptionService;
@ParameterizedTest
@CsvSource({
"hello, world",
"12345, 67890",
"special@chars, normal",
"very long string with multiple words, another string",
"", // empty string
"null test, null test"
})
void shouldEncryptAndDecrypt(String plaintext, String testIdentifier) {
// Arrange
byte[] ciphertext = "encrypted".getBytes();
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenReturn(EncryptResponse.builder()
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
.build());
when(kmsClient.decrypt(any(DecryptRequest.class)))
.thenReturn(DecryptResponse.builder()
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
.build());
// Act
String encrypted = encryptionService.encrypt(plaintext);
String decrypted = encryptionService.decrypt(encrypted);
// Assert
assertThat(decrypted).isEqualTo(plaintext);
}
}
```
## Integration Testing with Testcontainers
### Local KMS Mock Setup
```java
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.regions.Region;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestInstance;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KMS;
@Testcontainers
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class KmsIntegrationTest {
@Container
private static final LocalStackContainer localStack =
new LocalStackContainer(DockerImageName.parse("localstack/localstack:latest"))
.withServices(KMS);
private KmsClient kmsClient;
@BeforeAll
void setup() {
kmsClient = KmsClient.builder()
.region(Region.of(localStack.getRegion()))
.endpointOverride(localStack.getEndpointOverride(KMS))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(localStack.getAccessKey(), localStack.getSecretKey())))
.build();
}
@Test
void shouldCreateAndManageKeysWithLocalKms() {
// Create a key
String keyId = createTestKey(kmsClient, "test-key");
assertThat(keyId).isNotEmpty();
// Describe the key
KeyMetadata metadata = describeKey(kmsClient, keyId);
assertThat(metadata.keyState()).isEqualTo(KeyState.ENABLED);
// List keys
List<KeyListEntry> keys = listKeys(kmsClient);
assertThat(keys).hasSizeGreaterThan(0);
}
}
```
## Testing with Spring Boot Test Slices
### KmsServiceSlice Test
```java
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.web.servlet.MockMvc;
@SpringBootTest
@AutoConfigureMockMvc
@ActiveProfiles("test")
class KmsControllerIntegrationTest {
@Autowired
private MockMvc mockMvc;
@MockBean
private KmsEncryptionService kmsEncryptionService;
@Test
void shouldEncryptData() throws Exception {
String plaintext = "test data";
String encrypted = "encrypted-data";
when(kmsEncryptionService.encrypt(plaintext)).thenReturn(encrypted);
mockMvc.perform(post("/api/kms/encrypt")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"data\":\"" + plaintext + "\"}"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.data").value(encrypted));
verify(kmsEncryptionService).encrypt(plaintext);
}
@Test
void shouldHandleEncryptionErrors() throws Exception {
when(kmsEncryptionService.encrypt(any()))
.thenThrow(new RuntimeException("KMS error"));
mockMvc.perform(post("/api/kms/encrypt")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"data\":\"test\"}"))
.andExpect(status().isInternalServerError());
}
}
```
### Testing with SpringBootTest and Configuration
```java
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
@TestConfiguration
class KmsTestConfiguration {
@Bean
@Primary
public KmsClient testKmsClient() {
// Create a mock KMS client for testing
KmsClient mockClient = mock(KmsClient.class);
// Mock key creation
when(mockClient.createKey(any(CreateKeyRequest.class)))
.thenReturn(CreateKeyResponse.builder()
.keyMetadata(KeyMetadata.builder()
.keyId("test-key-id")
.keyArn("arn:aws:kms:us-east-1:123456789012:key/test-key-id")
.keyState(KeyState.ENABLED)
.build())
.build());
// Mock encryption
when(mockClient.encrypt(any(EncryptRequest.class)))
.thenReturn(EncryptResponse.builder()
.ciphertextBlob(SdkBytes.fromString("encrypted-data", StandardCharsets.UTF_8))
.build());
// Mock decryption
when(mockClient.decrypt(any(DecryptRequest.class)))
.thenReturn(DecryptResponse.builder()
.plaintext(SdkBytes.fromString("decrypted-data", StandardCharsets.UTF_8))
.build());
return mockClient;
}
}
@SpringBootTest(classes = {Application.class, KmsTestConfiguration.class})
class KmsServiceWithTestConfigIntegrationTest {
@Autowired
private KmsEncryptionService encryptionService;
@Test
void shouldUseTestConfiguration() {
String result = encryptionService.encrypt("test");
assertThat(result).isNotEmpty();
}
}
```
## Testing Envelope Encryption
### Envelope Encryption Test
```java
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
class EnvelopeEncryptionServiceTest {
@Mock
private KmsClient kmsClient;
@InjectMocks
private EnvelopeEncryptionService envelopeEncryptionService;
@Test
void shouldEncryptAndDecryptLargeData() {
// Arrange
byte[] testData = "large test data".getBytes();
byte[] encryptedDataKey = "encrypted-data-key".getBytes();
// Mock data key generation
when(kmsClient.generateDataKey(any(GenerateDataKeyRequest.class)))
.thenReturn(GenerateDataKeyResponse.builder()
.plaintext(SdkBytes.fromByteArray("data-key".getBytes()))
.ciphertextBlob(SdkBytes.fromByteArray(encryptedDataKey))
.build());
// Mock data key decryption
when(kmsClient.decrypt(any(DecryptRequest.class)))
.thenReturn(DecryptResponse.builder()
.plaintext(SdkBytes.fromByteArray("data-key".getBytes()))
.build());
// Act
EncryptedEnvelope encryptedEnvelope = envelopeEncryptionService.encryptLargeData(testData);
byte[] decryptedData = envelopeEncryptionService.decryptLargeData(encryptedEnvelope);
// Assert
assertThat(encryptedEnvelope.encryptedData()).isNotEmpty();
assertThat(encryptedEnvelope.encryptedKey()).isEqualTo(encryptedDataKey);
assertThat(decryptedData).isEqualTo(testData);
// Verify interactions
verify(kmsClient).generateDataKey(any(GenerateDataKeyRequest.class));
verify(kmsClient).decrypt(any(DecryptRequest.class));
}
@Test
void shouldClearSensitiveDataFromMemory() {
// Arrange
byte[] testData = "test data".getBytes();
byte[] encryptedDataKey = "encrypted-key".getBytes();
when(kmsClient.generateDataKey(any(GenerateDataKeyRequest.class)))
.thenReturn(GenerateDataKeyResponse.builder()
.plaintext(SdkBytes.fromByteArray("sensitive-data-key".getBytes()))
.ciphertextBlob(SdkBytes.fromByteArray(encryptedDataKey))
.build());
when(kmsClient.decrypt(any(DecryptRequest.class)))
.thenReturn(DecryptResponse.builder()
.plaintext(SdkBytes.fromByteArray("sensitive-data-key".getBytes()))
.build());
// Act
envelopeEncryptionService.encryptLargeData(testData);
envelopeEncryptionService.decryptLargeData(new EncryptedEnvelope(testData, encryptedDataKey));
// Note: Memory clearing is difficult to test directly
// In real tests, you would verify no sensitive data remains in memory traces
}
}
```
## Testing Digital Signatures
### Digital Signature Tests
```java
class DigitalSignatureServiceTest {
@Mock
private KmsClient kmsClient;
@InjectMocks
private DigitalSignatureService signatureService;
@Test
void shouldSignAndVerifyData() {
// Arrange
String message = "test message";
byte[] signature = "signature-data".getBytes();
when(kmsClient.sign(any(SignRequest.class)))
.thenReturn(SignResponse.builder()
.signature(SdkBytes.fromByteArray(signature))
.build());
when(kmsClient.verify(any(VerifyRequest.class)))
.thenReturn(VerifyResponse.builder()
.signatureValid(true)
.build());
// Act
byte[] signedSignature = signatureService.signData(message);
boolean isValid = signatureService.verifySignature(message, signedSignature);
// Assert
assertThat(signedSignature).isEqualTo(signature);
assertThat(isValid).isTrue();
}
@Test
void shouldDetectInvalidSignature() {
// Arrange
String message = "test message";
byte[] signature = "invalid-signature".getBytes();
when(kmsClient.verify(any(VerifyRequest.class)))
.thenReturn(VerifyResponse.builder()
.signatureValid(false)
.build());
// Act & Assert
assertThatThrownBy(() ->
signatureService.verifySignature(message, signature))
.isInstanceOf(SecurityException.class)
.hasMessageContaining("Invalid signature");
}
}
```
## Performance Testing
### Performance Test with JMH
```java
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(1)
class KmsPerformanceTest {
@MockBean
private KmsClient kmsClient;
@Autowired
private KmsEncryptionService encryptionService;
@Benchmark
public void testEncryptionPerformance(Blackhole bh) {
String testData = "performance test data with some content";
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenReturn(EncryptResponse.builder()
.ciphertextBlob(SdkBytes.fromString("encrypted", StandardCharsets.UTF_8))
.build());
String result = encryptionService.encrypt(testData);
bh.consume(result);
}
@Benchmark
public void testDecryptionPerformance(Blackhole bh) {
String encryptedData = "encrypted-performance-data";
when(kmsClient.decrypt(any(DecryptRequest.class)))
.thenReturn(DecryptResponse.builder()
.plaintext(SdkBytes.fromString("decrypted", StandardCharsets.UTF_8))
.build());
String result = encryptionService.decrypt(encryptedData);
bh.consume(result);
}
}
```
## Testing Error Scenarios
### Error Handling Tests
```java
class KmsErrorHandlingTest {
@Mock
private KmsClient kmsClient;
@InjectMocks
private KmsEncryptionService encryptionService;
@Test
void shouldHandleThrottlingException() {
// Arrange
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenThrow(KmsException.builder()
.awsErrorDetails(AwsErrorDetails.builder()
.errorCode("ThrottlingException")
.errorMessage("Rate exceeded")
.build())
.build());
// Act & Assert
assertThatThrownBy(() -> encryptionService.encrypt("test"))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("Rate limit exceeded");
}
@Test
void shouldHandleDisabledKey() {
// Arrange
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenThrow(KmsException.builder()
.awsErrorDetails(AwsErrorDetails.builder()
.errorCode("DisabledException")
.errorMessage("Key is disabled")
.build())
.build());
// Act & Assert
assertThatThrownBy(() -> encryptionService.encrypt("test"))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("Key is disabled");
}
@Test
void shouldHandleNotFoundException() {
// Arrange
when(kmsClient.encrypt(any(EncryptRequest.class)))
.thenThrow(KmsException.builder()
.awsErrorDetails(AwsErrorDetails.builder()
.errorCode("NotFoundException")
.errorMessage("Key not found")
.build())
.build());
// Act & Assert
assertThatThrownBy(() -> encryptionService.encrypt("test"))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("Key not found");
}
}
```
## Integration Testing with AWS Local
### Testcontainers KMS Setup
```java
import org.testcontainers.containers.localstack.LocalStackContainer;
import software.amazon.awssdk.services.kms.KmsClient;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KMS;
@SpringBootTest
class KmsAwsLocalIntegrationTest {
@Container
private static final LocalStackContainer localStack =
new LocalStackContainer(DockerImageName.parse("localstack/localstack:latest"))
.withServices(KMS)
.withEnv("DEFAULT_REGION", "us-east-1");
private KmsClient kmsClient;
@BeforeEach
void setup() {
kmsClient = KmsClient.builder()
.region(Region.AWS_GLOBAL)
.endpointOverride(localStack.getEndpointOverride(KMS))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(localStack.getAccessKey(), localStack.getSecretKey())))
.build();
}
@Test
void shouldCreateKeyInLocalKms() {
// This test creates a real key in the local KMS instance
CreateKeyRequest request = CreateKeyRequest.builder()
.description("Test key")
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
.build();
CreateKeyResponse response = kmsClient.createKey(request);
assertThat(response.keyMetadata().keyId()).isNotEmpty();
}
}
```

View File

@@ -0,0 +1,508 @@
---
name: aws-sdk-java-v2-lambda
description: AWS Lambda patterns using AWS SDK for Java 2.x. Use when invoking Lambda functions, creating/updating functions, managing function configurations, working with Lambda layers, or integrating Lambda with Spring Boot applications.
category: aws
tags: [aws, lambda, java, sdk, serverless, functions]
version: 1.1.0
allowed-tools: Read, Write, Bash
---
# AWS SDK for Java 2.x - AWS Lambda
## When to Use
Use this skill when:
- Invoking Lambda functions programmatically
- Creating or updating Lambda functions
- Managing Lambda function configurations
- Working with Lambda environment variables
- Managing Lambda layers and aliases
- Implementing asynchronous Lambda invocations
- Integrating Lambda with Spring Boot
## Overview
AWS Lambda is a compute service that runs code without the need to manage servers. Your code runs automatically, scaling up and down with pay-per-use pricing. Use this skill to implement AWS Lambda operations using AWS SDK for Java 2.x in applications and services.
## Dependencies
```xml
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>lambda</artifactId>
</dependency>
```
## Client Setup
To use AWS Lambda, create a LambdaClient with the required region configuration:
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaClient;
LambdaClient lambdaClient = LambdaClient.builder()
.region(Region.US_EAST_1)
.build();
```
For asynchronous operations, use LambdaAsyncClient:
```java
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
LambdaAsyncClient asyncLambdaClient = LambdaAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
```
## Invoke Lambda Function
### Synchronous Invocation
Invoke Lambda functions synchronously to get immediate results:
```java
import software.amazon.awssdk.services.lambda.model.*;
import software.amazon.awssdk.core.SdkBytes;
public String invokeLambda(LambdaClient lambdaClient,
String functionName,
String payload) {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(payload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
return response.payload().asUtf8String();
}
```
### Asynchronous Invocation
Use asynchronous invocation for fire-and-forget scenarios:
```java
public void invokeLambdaAsync(LambdaClient lambdaClient,
String functionName,
String payload) {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.invocationType(InvocationType.EVENT) // Asynchronous
.payload(SdkBytes.fromUtf8String(payload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
System.out.println("Status: " + response.statusCode());
}
```
### Invoke with JSON Objects
Work with JSON payloads for complex data structures:
```java
import com.fasterxml.jackson.databind.ObjectMapper;
public <T> String invokeLambdaWithObject(LambdaClient lambdaClient,
String functionName,
T requestObject) throws Exception {
ObjectMapper mapper = new ObjectMapper();
String jsonPayload = mapper.writeValueAsString(requestObject);
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
return response.payload().asUtf8String();
}
```
### Parse Typed Responses
Parse JSON responses into typed objects:
```java
public <T> T invokeLambdaAndParse(LambdaClient lambdaClient,
String functionName,
Object request,
Class<T> responseType) throws Exception {
ObjectMapper mapper = new ObjectMapper();
String jsonPayload = mapper.writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeResponse response = lambdaClient.invoke(invokeRequest);
String responseJson = response.payload().asUtf8String();
return mapper.readValue(responseJson, responseType);
}
```
## Function Management
### List Functions
List all Lambda functions for the current account:
```java
public List<FunctionConfiguration> listFunctions(LambdaClient lambdaClient) {
ListFunctionsResponse response = lambdaClient.listFunctions();
return response.functions();
}
```
### Get Function Configuration
Retrieve function configuration and metadata:
```java
public FunctionConfiguration getFunctionConfig(LambdaClient lambdaClient,
String functionName) {
GetFunctionRequest request = GetFunctionRequest.builder()
.functionName(functionName)
.build();
GetFunctionResponse response = lambdaClient.getFunction(request);
return response.configuration();
}
```
### Update Function Code
Update Lambda function code with new deployment package:
```java
import java.nio.file.Files;
import java.nio.file.Paths;
public void updateFunctionCode(LambdaClient lambdaClient,
String functionName,
String zipFilePath) throws IOException {
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
UpdateFunctionCodeRequest request = UpdateFunctionCodeRequest.builder()
.functionName(functionName)
.zipFile(SdkBytes.fromByteArray(zipBytes))
.publish(true)
.build();
UpdateFunctionCodeResponse response = lambdaClient.updateFunctionCode(request);
System.out.println("Updated function version: " + response.version());
}
```
### Update Function Configuration
Modify function settings like timeout, memory, and environment variables:
```java
public void updateFunctionConfiguration(LambdaClient lambdaClient,
String functionName,
Map<String, String> environment) {
Environment env = Environment.builder()
.variables(environment)
.build();
UpdateFunctionConfigurationRequest request = UpdateFunctionConfigurationRequest.builder()
.functionName(functionName)
.environment(env)
.timeout(60)
.memorySize(512)
.build();
lambdaClient.updateFunctionConfiguration(request);
}
```
### Create Function
Create new Lambda functions with code and configuration:
```java
public void createFunction(LambdaClient lambdaClient,
String functionName,
String roleArn,
String handler,
String zipFilePath) throws IOException {
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
FunctionCode code = FunctionCode.builder()
.zipFile(SdkBytes.fromByteArray(zipBytes))
.build();
CreateFunctionRequest request = CreateFunctionRequest.builder()
.functionName(functionName)
.runtime(Runtime.JAVA17)
.role(roleArn)
.handler(handler)
.code(code)
.timeout(60)
.memorySize(512)
.build();
CreateFunctionResponse response = lambdaClient.createFunction(request);
System.out.println("Function ARN: " + response.functionArn());
}
```
### Delete Function
Remove Lambda functions when no longer needed:
```java
public void deleteFunction(LambdaClient lambdaClient, String functionName) {
DeleteFunctionRequest request = DeleteFunctionRequest.builder()
.functionName(functionName)
.build();
lambdaClient.deleteFunction(request);
}
```
## Spring Boot Integration
### Configuration
Configure Lambda clients as Spring beans:
```java
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class LambdaConfiguration {
@Bean
public LambdaClient lambdaClient() {
return LambdaClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
```
### Lambda Invoker Service
Create a service for Lambda function invocation:
```java
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
@Service
public class LambdaInvokerService {
private final LambdaClient lambdaClient;
private final ObjectMapper objectMapper;
@Autowired
public LambdaInvokerService(LambdaClient lambdaClient, ObjectMapper objectMapper) {
this.lambdaClient = lambdaClient;
this.objectMapper = objectMapper;
}
public <T, R> R invoke(String functionName, T request, Class<R> responseType) {
try {
String jsonPayload = objectMapper.writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeResponse response = lambdaClient.invoke(invokeRequest);
if (response.functionError() != null) {
throw new LambdaInvocationException(
"Lambda function error: " + response.functionError());
}
String responseJson = response.payload().asUtf8String();
return objectMapper.readValue(responseJson, responseType);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke Lambda function", e);
}
}
public void invokeAsync(String functionName, Object request) {
try {
String jsonPayload = objectMapper.writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.invocationType(InvocationType.EVENT)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
lambdaClient.invoke(invokeRequest);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke Lambda function async", e);
}
}
}
```
### Typed Lambda Client
Create type-safe interfaces for Lambda services:
```java
public interface OrderProcessor {
OrderResponse processOrder(OrderRequest request);
}
@Service
public class LambdaOrderProcessor implements OrderProcessor {
private final LambdaInvokerService lambdaInvoker;
@Value("${lambda.order-processor.function-name}")
private String functionName;
public LambdaOrderProcessor(LambdaInvokerService lambdaInvoker) {
this.lambdaInvoker = lambdaInvoker;
}
@Override
public OrderResponse processOrder(OrderRequest request) {
return lambdaInvoker.invoke(functionName, request, OrderResponse.class);
}
}
```
## Error Handling
Implement comprehensive error handling for Lambda operations:
```java
public String invokeLambdaSafe(LambdaClient lambdaClient,
String functionName,
String payload) {
try {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(payload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
// Check for function error
if (response.functionError() != null) {
String errorMessage = response.payload().asUtf8String();
throw new RuntimeException("Lambda error: " + errorMessage);
}
// Check status code
if (response.statusCode() != 200) {
throw new RuntimeException("Lambda invocation failed with status: " +
response.statusCode());
}
return response.payload().asUtf8String();
} catch (LambdaException e) {
System.err.println("Lambda error: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
public class LambdaInvocationException extends RuntimeException {
public LambdaInvocationException(String message) {
super(message);
}
public LambdaInvocationException(String message, Throwable cause) {
super(message, cause);
}
}
```
## Examples
For comprehensive code examples, see the references section:
- **Basic examples** - Simple invocation patterns and function management
- **Spring Boot integration** - Complete Spring Boot configuration and service patterns
- **Testing examples** - Unit and integration test patterns
- **Advanced patterns** - Complex scenarios and best practices
## Best Practices
1. **Reuse Lambda clients**: Create once and reuse across invocations
2. **Set appropriate timeouts**: Match client timeout to Lambda function timeout
3. **Use async invocation**: For fire-and-forget scenarios
4. **Handle errors properly**: Check for function errors and status codes
5. **Use environment variables**: For function configuration
6. **Implement retry logic**: For transient failures
7. **Monitor invocations**: Use CloudWatch metrics
8. **Version functions**: Use aliases and versions for production
9. **Use VPC**: For accessing resources in private subnets
10. **Optimize payload size**: Keep payloads small for better performance
## Testing
Test Lambda services using mocks and test assertions:
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ExtendWith(MockitoExtension.class)
class LambdaInvokerServiceTest {
@Mock
private LambdaClient lambdaClient;
@Mock
private ObjectMapper objectMapper;
@InjectMocks
private LambdaInvokerService service;
@Test
void shouldInvokeLambdaSuccessfully() throws Exception {
// Test implementation
}
}
```
## Related Skills
- @aws-sdk-java-v2-core - Core AWS SDK patterns and client configuration
- @spring-boot-dependency-injection - Spring dependency injection best practices
- @unit-test-service-layer - Service testing patterns with Mockito
- @spring-boot-actuator - Production monitoring and health checks
## References
For detailed information and examples, see the following reference files:
- **[Official Documentation](references/official-documentation.md)** - AWS Lambda concepts, API reference, and official guidance
- **[Examples](references/examples.md)** - Complete code examples and integration patterns
## Additional Resources
- [Lambda Examples on GitHub](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/lambda)
- [Lambda API Reference](https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/lambda/package-summary.html)
- [AWS Lambda Developer Guide](https://docs.aws.amazon.com/lambda/latest/dg/welcome.html)

View File

@@ -0,0 +1,544 @@
# AWS Lambda Java SDK Examples
## Client Setup
### Basic Client Configuration
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaClient;
// Create synchronous client
LambdaClient lambdaClient = LambdaClient.builder()
.region(Region.US_EAST_1)
.build();
// Create asynchronous client
LambdaAsyncClient asyncLambdaClient = LambdaAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Client with Configuration
```java
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.http.nio.netty.NettyNioHttpServer;
LambdaClient lambdaClient = LambdaClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(DefaultCredentialsProvider.create())
.httpClientBuilder(NettyNioHttpServer.builder())
.build();
```
## Function Invocation Examples
### Synchronous Invocation with String Payload
```java
import software.amazon.awssdk.services.lambda.model.*;
import software.amazon.awssdk.core.SdkBytes;
public String invokeLambdaSync(LambdaClient lambdaClient,
String functionName,
String payload) {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(payload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
// Check for function errors
if (response.functionError() != null) {
throw new RuntimeException("Lambda function error: " +
response.payload().asUtf8String());
}
return response.payload().asUtf8String();
}
```
### Asynchronous Invocation
```java
import java.util.concurrent.CompletableFuture;
public CompletableFuture<String> invokeLambdaAsync(LambdaClient lambdaClient,
String functionName,
String payload) {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.invocationType(InvocationType.EVENT) // Asynchronous
.payload(SdkBytes.fromUtf8String(payload))
.build();
return lambdaClient.invoke(request)
.thenApply(response -> response.payload().asUtf8String());
}
```
### Invocation with JSON Object
```java
import com.fasterxml.jackson.databind.ObjectMapper;
public <T> String invokeLambdaWithObject(LambdaClient lambdaClient,
String functionName,
T requestObject) throws Exception {
ObjectMapper mapper = new ObjectMapper();
String jsonPayload = mapper.writeValueAsString(requestObject);
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
return response.payload().asUtf8String();
}
```
### Parse Typed Response
```java
import com.fasterxml.jackson.databind.ObjectMapper;
public <T> T invokeLambdaAndParse(LambdaClient lambdaClient,
String functionName,
Object request,
Class<T> responseType) throws Exception {
ObjectMapper mapper = new ObjectMapper();
String jsonPayload = mapper.writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeResponse response = lambdaClient.invoke(invokeRequest);
String responseJson = response.payload().asUtf8String();
return mapper.readValue(responseJson, responseType);
}
```
## Function Management Examples
### List Functions
```java
public List<FunctionConfiguration> listLambdaFunctions(LambdaClient lambdaClient) {
ListFunctionsResponse response = lambdaClient.listFunctions();
return response.functions();
}
// List functions with pagination
public List<FunctionConfiguration> listAllFunctions(LambdaClient lambdaClient) {
ListFunctionsRequest request = ListFunctionsRequest.builder().build();
ListFunctionsResponse response = lambdaClient.listFunctions(request);
return response.functions();
}
```
### Get Function Configuration
```java
public FunctionConfiguration getFunctionConfig(LambdaClient lambdaClient,
String functionName) {
GetFunctionRequest request = GetFunctionRequest.builder()
.functionName(functionName)
.build();
GetFunctionResponse response = lambdaClient.getFunction(request);
return response.configuration();
}
```
### Get Function Code
```java
public byte[] getFunctionCode(LambdaClient lambdaClient,
String functionName) {
GetFunctionRequest request = GetFunctionRequest.builder()
.functionName(functionName)
.build();
GetFunctionResponse response = lambdaClient.getFunction(request);
return response.code().zipFile().asByteArray();
}
```
### Update Function Code
```java
import java.nio.file.Files;
import java.nio.file.Paths;
public void updateLambdaFunction(LambdaClient lambdaClient,
String functionName,
String zipFilePath) throws IOException {
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
UpdateFunctionCodeRequest request = UpdateFunctionCodeRequest.builder()
.functionName(functionName)
.zipFile(SdkBytes.fromByteArray(zipBytes))
.publish(true) // Create new version
.build();
UpdateFunctionCodeResponse response = lambdaClient.updateFunctionCode(request);
System.out.println("Updated function version: " + response.version());
}
```
### Update Function Configuration
```java
public void updateFunctionConfig(LambdaClient lambdaClient,
String functionName,
Map<String, String> environment) {
Environment env = Environment.builder()
.variables(environment)
.build();
UpdateFunctionConfigurationRequest request = UpdateFunctionConfigurationRequest.builder()
.functionName(functionName)
.environment(env)
.timeout(60)
.memorySize(512)
.build();
lambdaClient.updateFunctionConfiguration(request);
}
```
### Create Function
```java
import java.nio.file.Files;
import java.nio.file.Paths;
public void createLambdaFunction(LambdaClient lambdaClient,
String functionName,
String roleArn,
String handler,
String zipFilePath) throws IOException {
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
FunctionCode code = FunctionCode.builder()
.zipFile(SdkBytes.fromByteArray(zipBytes))
.build();
CreateFunctionRequest request = CreateFunctionRequest.builder()
.functionName(functionName)
.runtime(Runtime.JAVA17)
.role(roleArn)
.handler(handler)
.code(code)
.timeout(60)
.memorySize(512)
.environment(Environment.builder()
.variables(Map.of("ENV", "production"))
.build())
.build();
CreateFunctionResponse response = lambdaClient.createFunction(request);
System.out.println("Function ARN: " + response.functionArn());
}
```
### Delete Function
```java
public void deleteLambdaFunction(LambdaClient lambdaClient, String functionName) {
DeleteFunctionRequest request = DeleteFunctionRequest.builder()
.functionName(functionName)
.build();
lambdaClient.deleteFunction(request);
}
```
## Spring Boot Integration Examples
### Configuration Class
```java
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class LambdaConfiguration {
@Bean
public LambdaClient lambdaClient() {
return LambdaClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public LambdaAsyncClient asyncLambdaClient() {
return LambdaAsyncClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
```
### Lambda Invoker Service
```java
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
@Service
public class LambdaInvokerService {
private final LambdaClient lambdaClient;
private final ObjectMapper objectMapper;
@Autowired
public LambdaInvokerService(LambdaClient lambdaClient,
ObjectMapper objectMapper) {
this.lambdaClient = lambdaClient;
this.objectMapper = objectMapper;
}
public <T, R> R invokeFunction(String functionName,
T request,
Class<R> responseType) {
try {
String jsonPayload = objectMapper.writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeResponse response = lambdaClient.invoke(invokeRequest);
if (response.functionError() != null) {
throw new LambdaInvocationException(
"Lambda function error: " + response.functionError());
}
String responseJson = response.payload().asUtf8String();
return objectMapper.readValue(responseJson, responseType);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke Lambda function", e);
}
}
public void invokeFunctionAsync(String functionName, Object request) {
try {
String jsonPayload = objectMapper.writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.invocationType(InvocationType.EVENT)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
lambdaClient.invoke(invokeRequest);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke Lambda function async", e);
}
}
}
```
### Typed Lambda Client Interface
```java
public interface OrderProcessor {
OrderResponse processOrder(OrderRequest request);
CompletableFuture<OrderResponse> processOrderAsync(OrderRequest request);
}
@Service
public class LambdaOrderProcessor implements OrderProcessor {
private final LambdaInvokerService lambdaInvoker;
private final LambdaAsyncClient asyncLambdaClient;
@Value("${lambda.order-processor.function-name}")
private String functionName;
public LambdaOrderProcessor(LambdaInvokerService lambdaInvoker,
LambdaAsyncClient asyncLambdaClient) {
this.lambdaInvoker = lambdaInvoker;
this.asyncLambdaClient = asyncLambdaClient;
}
@Override
public OrderResponse processOrder(OrderRequest request) {
return lambdaInvoker.invoke(functionName, request, OrderResponse.class);
}
@Override
public CompletableFuture<OrderResponse> processOrderAsync(OrderRequest request) {
// Implement async invocation using async client
try {
String jsonPayload = new ObjectMapper().writeValueAsString(request);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(jsonPayload))
.build();
return asyncLambdaClient.invoke(invokeRequest)
.thenApply(response -> {
try {
return new ObjectMapper().readValue(
response.payload().asUtf8String(),
OrderResponse.class);
} catch (Exception e) {
throw new RuntimeException("Failed to parse response", e);
}
});
} catch (Exception e) {
throw new RuntimeException("Failed to invoke Lambda function", e);
}
}
}
```
## Error Handling Examples
### Comprehensive Error Handling
```java
public String invokeLambdaWithFullErrorHandling(LambdaClient lambdaClient,
String functionName,
String payload) {
try {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(SdkBytes.fromUtf8String(payload))
.build();
InvokeResponse response = lambdaClient.invoke(request);
// Check for function error
if (response.functionError() != null) {
String errorMessage = response.payload().asUtf8String();
throw new LambdaInvocationException(
"Lambda function error: " + errorMessage);
}
// Check status code
if (response.statusCode() != 200) {
throw new LambdaInvocationException(
"Lambda invocation failed with status: " + response.statusCode());
}
return response.payload().asUtf8String();
} catch (LambdaException e) {
System.err.println("AWS Lambda error: " + e.awsErrorDetails().errorMessage());
throw new LambdaInvocationException(
"AWS Lambda service error: " + e.awsErrorDetails().errorMessage(), e);
}
}
public class LambdaInvocationException extends RuntimeException {
public LambdaInvocationException(String message) {
super(message);
}
public LambdaInvocationException(String message, Throwable cause) {
super(message, cause);
}
}
```
## Testing Examples
### Unit Test for Lambda Service
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.mockito.Mockito.*;
import static org.assertj.core.api.Assertions.*;
@ExtendWith(MockitoExtension.class)
class LambdaInvokerServiceTest {
@Mock
private LambdaClient lambdaClient;
@Mock
private ObjectMapper objectMapper;
@InjectMocks
private LambdaInvokerService service;
@Test
void shouldInvokeLambdaSuccessfully() throws Exception {
// Given
OrderRequest request = new OrderRequest("ORDER-123");
OrderResponse expectedResponse = new OrderResponse("SUCCESS");
String jsonPayload = "{\"orderId\":\"ORDER-123\"};
String jsonResponse = "{\"status\":\"SUCCESS\"};
when(objectMapper.writeValueAsString(request))
.thenReturn(jsonPayload);
when(lambdaClient.invoke(any(InvokeRequest.class)))
.thenReturn(InvokeResponse.builder()
.statusCode(200)
.payload(SdkBytes.fromUtf8String(jsonResponse))
.build());
when(objectMapper.readValue(jsonResponse, OrderResponse.class))
.thenReturn(expectedResponse);
// When
OrderResponse result = service.invoke(
"order-processor", request, OrderResponse.class);
// Then
assertThat(result).isEqualTo(expectedResponse);
verify(lambdaClient).invoke(any(InvokeRequest.class));
}
@Test
void shouldHandleFunctionError() throws Exception {
// Given
OrderRequest request = new OrderRequest("ORDER-123");
String jsonPayload = "{\"orderId\":\"ORDER-123\"};
String errorResponse = "{\"errorMessage\":\"Invalid input\"};
when(objectMapper.writeValueAsString(request))
.thenReturn(jsonPayload);
when(lambdaClient.invoke(any(InvokeRequest.class)))
.thenReturn(InvokeResponse.builder()
.statusCode(200)
.functionError("Unhandled")
.payload(SdkBytes.fromUtf8String(errorResponse))
.build());
// When & Then
assertThatThrownBy(() ->
service.invoke("order-processor", request, OrderResponse.class))
.isInstanceOf(LambdaInvocationException.class)
.hasMessageContaining("Lambda function error");
}
}
```
## Maven Dependencies
```xml
<!-- AWS SDK for Java v2 Lambda -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>lambda</artifactId>
<version>2.36.3</version> // Use the latest version available
</dependency>
<!-- Jackson for JSON processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<!-- Spring Boot support -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
```

View File

@@ -0,0 +1,112 @@
# AWS Lambda Official Documentation Reference
## Overview
AWS Lambda is a compute service that runs code without the need to manage servers. Your code runs automatically, scaling up and down with pay-per-use pricing.
## Common Use Cases
- Stream processing: Process real-time data streams for analytics
- Web applications: Build scalable web apps that automatically adjust
- Mobile backends: Create secure API backends
- IoT backends: Handle web, mobile, IoT, and third-party API requests
- File processing: Process files automatically when uploaded
- Database operations: Respond to database changes and automate data workflows
- Scheduled tasks: Run automated operations on a regular schedule
## How Lambda Works
1. You write and organize your code in Lambda functions
2. You control security through Lambda permissions using execution roles
3. Event sources and AWS services trigger your Lambda functions
4. Lambda runs your code with language-specific runtimes
## Key Features
### Configuration & Security
- Environment variables modify behavior without deployments
- Versions safely test new features while maintaining stable production
- Lambda layers optimize code reuse across multiple functions
- Code signing ensures only approved code reaches production
### Performance
- Concurrency controls manage responsiveness and resource utilization
- Lambda SnapStart reduces cold start times to sub-second performance
- Response streaming delivers large payloads incrementally
- Container images package functions with complex dependencies
### Integration
- VPC networks secure sensitive resources and internal services
- File system integration shares persistent data across function invocations
- Function URLs create public APIs without additional services
- Lambda extensions augment functions with monitoring and operational tools
## AWS Lambda Java SDK API
### Key Classes
- `LambdaClient` - Synchronous service client
- `LambdaAsyncClient` - Asynchronous service client
- `LambdaClientBuilder` - Builder for synchronous client
- `LambdaAsyncClientBuilder` - Builder for asynchronous client
- `LambdaServiceClientConfiguration` - Client settings configuration
### Related Packages
- `software.amazon.awssdk.services.lambda.model` - API models
- `software.amazon.awssdk.services.lambda.transform` - Request/response transformations
- `software.amazon.awssdk.services.lambda.paginators` - Pagination utilities
- `software.amazon.awssdk.services.lambda.waiters` - Waiter utilities
### Authentication
Lambda supports signature version 4 for API authentication.
### CA Requirements
Clients need to support these CAs:
- Amazon Root CA 1
- Starfield Services Root Certificate Authority - G2
- Starfield Class 2 Certification Authority
## Core API Operations
### Function Management Operations
- `CreateFunction` - Create new Lambda function
- `DeleteFunction` - Delete existing function
- `GetFunction` - Retrieve function configuration
- `UpdateFunctionCode` - Update function code
- `UpdateFunctionConfiguration` - Update function settings
- `ListFunctions` - List functions for account
### Invocation Operations
- `Invoke` - Invoke Lambda function synchronously
- `Invoke` with `InvocationType.EVENT` - Asynchronous invocation
### Environment & Configuration
- Environment variable management
- Function configuration updates
- Version and alias management
- Layer management
## Examples Overview
The AWS documentation includes examples for:
- Basic Lambda function creation and invocation
- Function configuration and updates
- Environment variable management
- Function listing and cleanup
- Integration patterns
## Best Practices from Official Docs
- Reuse Lambda clients across invocations
- Set appropriate timeouts matching function requirements
- Use async invocation for fire-and-forget scenarios
- Implement proper error handling for function errors and status codes
- Use environment variables for configuration management
- Version functions for production stability
- Monitor invocations using CloudWatch metrics
- Implement retry logic for transient failures
- Use VPC integration for private resources
- Optimize payload sizes for performance
## Security Considerations
- Use IAM roles with least privilege
- Implement proper Lambda permissions
- Use environment variables for sensitive data
- Enable CloudTrail logging
- Monitor security events with CloudWatch
- Use code signing for production deployments
- Implement proper authentication and authorization

View File

@@ -0,0 +1,310 @@
---
name: aws-sdk-java-v2-messaging
description: Implement AWS messaging patterns using AWS SDK for Java 2.x for SQS queues and SNS topics. Send/receive messages, manage FIFO queues, implement DLQ, publish messages, manage subscriptions, and build pub/sub patterns.
category: aws
tags: [aws, sqs, sns, java, sdk, messaging, pub-sub, queues, events]
version: 1.1.0
allowed-tools: Read, Write, Bash, Grep
---
# AWS SDK for Java 2.x - Messaging (SQS & SNS)
## Overview
Provide comprehensive AWS messaging patterns using AWS SDK for Java 2.x for both SQS and SNS services. Include client setup, queue management, message operations, subscription management, and Spring Boot integration patterns.
## When to Use
Use this skill when working with:
- Amazon SQS queues for message queuing
- SNS topics for event publishing and notification
- FIFO queues and standard queues
- Dead Letter Queues (DLQ) for message handling
- SNS subscriptions with email, SMS, SQS, Lambda endpoints
- Pub/sub messaging patterns and event-driven architectures
- Spring Boot integration with AWS messaging services
- Testing strategies using LocalStack or Testcontainers
## Quick Start
### Dependencies
```xml
<!-- SQS -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sqs</artifactId>
</dependency>
<!-- SNS -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sns</artifactId>
</dependency>
```
### Basic Client Setup
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sns.SnsClient;
SqsClient sqsClient = SqsClient.builder()
.region(Region.US_EAST_1)
.build();
SnsClient snsClient = SnsClient.builder()
.region(Region.US_EAST_1)
.build();
```
## Examples
### Basic SQS Operations
#### Create and Send Message
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.*;
// Setup SQS client
SqsClient sqsClient = SqsClient.builder()
.region(Region.US_EAST_1)
.build();
// Create queue
String queueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
.queueName("my-queue")
.build()).queueUrl();
// Send message
String messageId = sqsClient.sendMessage(SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody("Hello, SQS!")
.build()).messageId();
```
#### Receive and Delete Message
```java
// Receive messages with long polling
ReceiveMessageResponse response = sqsClient.receiveMessage(ReceiveMessageRequest.builder()
.queueUrl(queueUrl)
.maxNumberOfMessages(10)
.waitTimeSeconds(20)
.build());
// Process and delete messages
response.messages().forEach(message -> {
System.out.println("Received: " + message.body());
sqsClient.deleteMessage(DeleteMessageRequest.builder()
.queueUrl(queueUrl)
.receiptHandle(message.receiptHandle())
.build());
});
```
### Basic SNS Operations
#### Create Topic and Publish
```java
import software.amazon.awssdk.services.sns.SnsClient;
import software.amazon.awssdk.services.sns.model.*;
// Setup SNS client
SnsClient snsClient = SnsClient.builder()
.region(Region.US_EAST_1)
.build();
// Create topic
String topicArn = snsClient.createTopic(CreateTopicRequest.builder()
.name("my-topic")
.build()).topicArn();
// Publish message
String messageId = snsClient.publish(PublishRequest.builder()
.topicArn(topicArn)
.subject("Test Notification")
.message("Hello, SNS!")
.build()).messageId();
```
### Advanced Examples
#### FIFO Queue Pattern
```java
// Create FIFO queue
Map<QueueAttributeName, String> attributes = Map.of(
QueueAttributeName.FIFO_QUEUE, "true",
QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true"
);
String fifoQueueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
.queueName("my-queue.fifo")
.attributes(attributes)
.build()).queueUrl();
// Send FIFO message with group ID
String fifoMessageId = sqsClient.sendMessage(SendMessageRequest.builder()
.queueUrl(fifoQueueUrl)
.messageBody("Order #12345")
.messageGroupId("orders")
.messageDeduplicationId(UUID.randomUUID().toString())
.build()).messageId();
```
#### SNS to SQS Subscription
```java
// Create SQS queue for subscription
String subscriptionQueueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
.queueName("notification-subscriber")
.build()).queueUrl();
// Get queue ARN
String queueArn = sqsClient.getQueueAttributes(GetQueueAttributesRequest.builder()
.queueUrl(subscriptionQueueUrl)
.attributeNames(QueueAttributeName.QUEUE_ARN)
.build()).attributes().get(QueueAttributeName.QUEUE_ARN);
// Subscribe SQS to SNS
String subscriptionArn = snsClient.subscribe(SubscribeRequest.builder()
.protocol("sqs")
.endpoint(queueArn)
.topicArn(topicArn)
.build()).subscriptionArn();
```
### Spring Boot Integration Example
```java
@Service
@RequiredArgsConstructor
public class OrderNotificationService {
private final SnsClient snsClient;
private final ObjectMapper objectMapper;
@Value("${aws.sns.order-topic-arn}")
private String orderTopicArn;
public void sendOrderNotification(Order order) {
try {
String jsonMessage = objectMapper.writeValueAsString(order);
snsClient.publish(PublishRequest.builder()
.topicArn(orderTopicArn)
.subject("New Order Received")
.message(jsonMessage)
.messageAttributes(Map.of(
"orderType", MessageAttributeValue.builder()
.dataType("String")
.stringValue(order.getType())
.build()))
.build());
} catch (Exception e) {
throw new RuntimeException("Failed to send order notification", e);
}
}
}
```
## Best Practices
### SQS Best Practices
- **Use long polling**: Set `waitTimeSeconds` (20-40 seconds) to reduce empty responses
- **Batch operations**: Use `sendMessageBatch` for multiple messages to reduce API calls
- **Visibility timeout**: Set appropriately based on message processing time (default 30 seconds)
- **Delete messages**: Always delete messages after successful processing
- **Handle duplicates**: Implement idempotent processing for retries
- **Implement DLQ**: Route failed messages to dead letter queues for analysis
- **Monitor queue depth**: Use CloudWatch alarms for high queue backlog
- **Use FIFO queues**: When message order and deduplication are critical
### SNS Best Practices
- **Use filter policies**: Reduce noise by filtering messages at the source
- **Message attributes**: Add metadata for subscription routing decisions
- **Retry logic**: Handle transient failures with exponential backoff
- **Monitor failed deliveries**: Set up CloudWatch alarms for failed notifications
- **Security**: Use IAM policies for access control and data encryption
- **FIFO topics**: Use when order and deduplication are critical
- **Avoid large payloads**: Keep messages under 256KB for optimal performance
### General Guidelines
- **Region consistency**: Use the same region for all AWS resources
- **Resource naming**: Use consistent naming conventions for queues and topics
- **Error handling**: Implement proper exception handling and logging
- **Testing**: Use LocalStack for local development and testing
- **Documentation**: Document subscription endpoints and message formats
## Instructions
### Setup AWS Credentials
Configure AWS credentials using environment variables, AWS CLI, or IAM roles:
```bash
export AWS_ACCESS_KEY_ID=your-access-key
export AWS_SECRET_ACCESS_KEY=your-secret-key
export AWS_REGION=us-east-1
```
### Configure Clients
```java
// Basic client configuration
SqsClient sqsClient = SqsClient.builder()
.region(Region.US_EAST_1)
.build();
// Advanced client with custom configuration
SnsClient snsClient = SnsClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(DefaultCredentialsProvider.create())
.httpClient(UrlConnectionHttpClient.create())
.build();
```
### Implement Message Processing
1. **Connect** to SQS/SNS using the AWS SDK clients
2. **Create** queues and topics as needed
3. **Send/receive** messages with appropriate timeout settings
4. **Process** messages in batches for efficiency
5. **Delete** messages after successful processing
6. **Handle** failures with proper error handling and retries
### Integrate with Spring Boot
1. **Configure** beans for `SqsClient` and `SnsClient` in `@Configuration` classes
2. **Use** `@Value` to inject queue URLs and topic ARNs from properties
3. **Create** service classes with business logic for messaging operations
4. **Implement** error handling with `@Retryable` or custom retry logic
5. **Test** integration using Testcontainers or LocalStack
### Monitor and Debug
- Use AWS CloudWatch for monitoring queue depth and message metrics
- Enable AWS SDK logging for debugging client operations
- Implement proper logging for message processing activities
- Use AWS X-Ray for distributed tracing in production environments
## Troubleshooting
### Common Issues
- **Queue does not exist**: Verify queue URL and permissions
- **Message not received**: Check visibility timeout and consumer logic
- **Permission denied**: Verify IAM policies and credentials
- **Connection timeout**: Check network connectivity and region configuration
- **Rate limiting**: Implement retry logic with exponential backoff
### Performance Optimization
- Use long polling to reduce empty responses
- Batch message operations to minimize API calls
- Adjust visibility timeout based on processing time
- Implement connection pooling and reuse clients
- Use appropriate message sizes to avoid fragmentation
## Detailed References
For comprehensive API documentation and advanced patterns, see:
- [@references/detailed-sqs-operations] - Complete SQS operations reference
- [@references/detailed-sns-operations] - Complete SNS operations reference
- [@references/spring-boot-integration] - Spring Boot integration patterns
- [@references/aws-official-documentation] - Official AWS documentation and best practices

View File

@@ -0,0 +1,158 @@
# AWS SQS & SNS Official Documentation Reference
This file contains reference information extracted from official AWS resources for the AWS SDK for Java 2.x messaging patterns.
## Source Documents
- [AWS Java SDK v2 Examples - SQS](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/sqs)
- [AWS Java SDK v2 Examples - SNS](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/sns)
- [AWS SQS Developer Guide](https://docs.aws.amazon.com/sqs/latest/dg/)
- [AWS SNS Developer Guide](https://docs.aws.amazon.com/sns/latest/dg/)
## Amazon SQS Reference
### Core Operations
- **CreateQueue** - Create new SQS queue
- **DeleteMessage** - Delete individual message from queue
- **ListQueues** - List available queues
- **ReceiveMessage** - Receive messages from queue
- **SendMessage** - Send message to queue
- **SendMessageBatch** - Send multiple messages to queue
### Advanced Features
- **Large Message Handling** - Use S3 for messages larger than 256KB
- **Batch Operations** - Process multiple messages efficiently
- **Long Polling** - Reduce empty responses with `waitTimeSeconds`
- **Visibility Timeout** - Control message visibility during processing
- **Dead Letter Queues (DLQ)** - Handle failed messages
- **FIFO Queues** - Ensure message ordering and deduplication
### Java SDK v2 Key Classes
```java
// Core clients and models
software.amazon.awssdk.services.sqs.SqsClient
software.amazon.awssdk.services.sqs.model.*
software.amazon.awssdk.services.sqs.model.QueueAttributeName
```
## Amazon SNS Reference
### Core Operations
- **CreateTopic** - Create new SNS topic
- **Publish** - Send message to topic
- **Subscribe** - Subscribe endpoint to topic
- **ListSubscriptions** - List topic subscriptions
- **Unsubscribe** - Remove subscription
### Advanced Features
- **Platform Endpoints** - Mobile push notifications
- **SMS Publishing** - Send SMS messages
- **FIFO Topics** - Ordered message delivery with deduplication
- **Filter Policies** - Filter messages based on attributes
- **Message Attributes** - Enrich messages with metadata
- **DLQ for Subscriptions** - Handle failed deliveries
### Java SDK v2 Key Classes
```java
// Core clients and models
software.amazon.awssdk.services.sns.SnsClient
software.amazon.awssdk.services.sns.model.*
software.amazon.awssdk.services.sns.model.MessageAttributeValue
```
## Best Practices from AWS
### SQS Best Practices
1. **Use Long Polling**: Set `waitTimeSeconds` (10-40 seconds) to reduce empty responses
2. **Batch Operations**: Use `SendMessageBatch` for efficiency
3. **Visibility Timeout**: Set appropriately based on processing time
4. **Handle Duplicates**: Implement idempotent processing for retries
5. **Monitor Queue Depth**: Use CloudWatch for monitoring
6. **Implement DLQ**: Route failed messages for analysis
### SNS Best Practices
1. **Use Filter Policies**: Reduce noise by filtering messages
2. **Message Attributes**: Add metadata for routing decisions
3. **Retry Logic**: Handle transient failures gracefully
4. **Monitor Failed Deliveries**: Set up CloudWatch alarms
5. **Security**: Use IAM policies for access control
6. **FIFO Topics**: Use when order and deduplication are critical
## Error Handling Patterns
### Common SQS Errors
- **QueueDoesNotExistException**: Verify queue URL
- **MessageNotInflightException**: Check message visibility
- **OverLimitException**: Implement backoff/retry logic
- **InvalidAttributeValueException**: Validate queue attributes
### Common SNS Errors
- **NotFoundException**: Verify topic ARN
- **InvalidParameterException**: Validate subscription parameters
- **InternalFailureException**: Implement retry logic
- **AuthorizationErrorException**: Check IAM permissions
## Integration Patterns
### Spring Boot Integration
- Use `@Service` classes for business logic
- Inject `SqsClient` and `SnsClient` via constructor injection
- Configure clients with `@Configuration` beans
- Use `@Value` for externalizing configuration
### Testing Strategies
- Use LocalStack for local development
- Mock AWS services with Mockito for unit tests
- Integrate with Testcontainers for integration tests
- Test idempotent operations thoroughly
## Configuration Options
### SQS Configuration
```java
SqsClient sqsClient = SqsClient.builder()
.region(Region.US_EAST_1)
.build();
```
### SNS Configuration
```java
SnsClient snsClient = SnsClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Advanced Configuration
- Override endpoint for local development
- Configure custom credentials provider
- Set custom HTTP client
- Configure retry policies
## Monitoring and Observability
### SQS Metrics
- ApproximateNumberOfMessagesVisible
- ApproximateNumberOfMessagesNotVisible
- ApproximateNumberOfMessagesDelayed
- SentMessages
- ReceiveCalls
### SNS Metrics
- NumberOfNotifications
- PublishSuccess
- PublishFailed
- SubscriptionConfirmation
- SubscriptionConfirmationFailed
## Security Considerations
### IAM Permissions
- Grant least privilege access
- Use IAM roles for EC2/ECS
- Implement resource-based policies
- Use condition keys for fine-grained control
### Data Protection
- Encrypt sensitive data in messages
- Use KMS for message encryption
- Implement message signing
- Secure endpoints with HTTPS

View File

@@ -0,0 +1,179 @@
# Detailed SNS Operations Reference
## Topic Management
### Create Standard Topic
```java
public String createTopic(SnsClient snsClient, String topicName) {
CreateTopicRequest request = CreateTopicRequest.builder()
.name(topicName)
.build();
CreateTopicResponse response = snsClient.createTopic(request);
return response.topicArn();
}
```
### Create FIFO Topic
```java
public String createFifoTopic(SnsClient snsClient, String topicName) {
Map<String, String> attributes = new HashMap<>();
attributes.put("FifoTopic", "true");
attributes.put("ContentBasedDeduplication", "true");
CreateTopicRequest request = CreateTopicRequest.builder()
.name(topicName + ".fifo")
.attributes(attributes)
.build();
CreateTopicResponse response = snsClient.createTopic(request);
return response.topicArn();
}
```
### Topic Operations
```java
public List<Topic> listTopics(SnsClient snsClient) {
return snsClient.listTopics().topics();
}
public String getTopicArn(SnsClient snsClient, String topicName) {
return snsClient.createTopic(CreateTopicRequest.builder()
.name(topicName)
.build()).topicArn();
}
```
## Message Publishing
### Publish Basic Message
```java
public String publishMessage(SnsClient snsClient, String topicArn, String message) {
PublishRequest request = PublishRequest.builder()
.topicArn(topicArn)
.message(message)
.build();
PublishResponse response = snsClient.publish(request);
return response.messageId();
}
```
### Publish with Subject
```java
public String publishMessageWithSubject(SnsClient snsClient,
String topicArn,
String subject,
String message) {
PublishRequest request = PublishRequest.builder()
.topicArn(topicArn)
.subject(subject)
.message(message)
.build();
PublishResponse response = snsClient.publish(request);
return response.messageId();
}
```
### Publish with Attributes
```java
public String publishMessageWithAttributes(SnsClient snsClient,
String topicArn,
String message,
Map<String, String> attributes) {
Map<String, MessageAttributeValue> messageAttributes = attributes.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
e -> MessageAttributeValue.builder()
.dataType("String")
.stringValue(e.getValue())
.build()));
PublishRequest request = PublishRequest.builder()
.topicArn(topicArn)
.message(message)
.messageAttributes(messageAttributes)
.build();
PublishResponse response = snsClient.publish(request);
return response.messageId();
}
```
### Publish FIFO Message
```java
public String publishFifoMessage(SnsClient snsClient,
String topicArn,
String message,
String messageGroupId) {
PublishRequest request = PublishRequest.builder()
.topicArn(topicArn)
.message(message)
.messageGroupId(messageGroupId)
.messageDeduplicationId(UUID.randomUUID().toString())
.build();
PublishResponse response = snsClient.publish(request);
return response.messageId();
}
```
## Subscription Management
### Subscribe Email to Topic
```java
public String subscribeEmail(SnsClient snsClient, String topicArn, String email) {
SubscribeRequest request = SubscribeRequest.builder()
.protocol("email")
.endpoint(email)
.topicArn(topicArn)
.build();
SubscribeResponse response = snsClient.subscribe(request);
return response.subscriptionArn();
}
```
### Subscribe SQS to Topic
```java
public String subscribeSqs(SnsClient snsClient, String topicArn, String queueArn) {
SubscribeRequest request = SubscribeRequest.builder()
.protocol("sqs")
.endpoint(queueArn)
.topicArn(topicArn)
.build();
SubscribeResponse response = snsClient.subscribe(request);
return response.subscriptionArn();
}
```
### Subscribe Lambda to Topic
```java
public String subscribeLambda(SnsClient snsClient, String topicArn, String lambdaArn) {
SubscribeRequest request = SubscribeRequest.builder()
.protocol("lambda")
.endpoint(lambdaArn)
.topicArn(topicArn)
.build();
SubscribeResponse response = snsClient.subscribe(request);
return response.subscriptionArn();
}
```
### Subscription Operations
```java
public List<Subscription> listSubscriptions(SnsClient snsClient, String topicArn) {
return snsClient.listSubscriptionsByTopic(ListSubscriptionsByTopicRequest.builder()
.topicArn(topicArn)
.build()).subscriptions();
}
public void unsubscribe(SnsClient snsClient, String subscriptionArn) {
snsClient.unsubscribe(UnsubscribeRequest.builder()
.subscriptionArn(subscriptionArn)
.build());
}
```

View File

@@ -0,0 +1,199 @@
# Detailed SQS Operations Reference
## Queue Management
### Create Standard Queue
```java
public String createQueue(SqsClient sqsClient, String queueName) {
CreateQueueRequest request = CreateQueueRequest.builder()
.queueName(queueName)
.build();
CreateQueueResponse response = sqsClient.createQueue(request);
return response.queueUrl();
}
```
### Create FIFO Queue
```java
public String createFifoQueue(SqsClient sqsClient, String queueName) {
Map<QueueAttributeName, String> attributes = new HashMap<>();
attributes.put(QueueAttributeName.FIFO_QUEUE, "true");
attributes.put(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true");
CreateQueueRequest request = CreateQueueRequest.builder()
.queueName(queueName + ".fifo")
.attributes(attributes)
.build();
CreateQueueResponse response = sqsClient.createQueue(request);
return response.queueUrl();
}
```
### Queue Operations
```java
public String getQueueUrl(SqsClient sqsClient, String queueName) {
return sqsClient.getQueueUrl(GetQueueUrlRequest.builder()
.queueName(queueName)
.build()).queueUrl();
}
public List<String> listQueues(SqsClient sqsClient) {
return sqsClient.listQueues().queueUrls();
}
public void purgeQueue(SqsClient sqsClient, String queueUrl) {
sqsClient.purgeQueue(PurgeQueueRequest.builder()
.queueUrl(queueUrl)
.build());
}
```
## Message Operations
### Send Basic Message
```java
public String sendMessage(SqsClient sqsClient, String queueUrl, String messageBody) {
SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody(messageBody)
.build();
SendMessageResponse response = sqsClient.sendMessage(request);
return response.messageId();
}
```
### Send Message with Attributes
```java
public String sendMessageWithAttributes(SqsClient sqsClient,
String queueUrl,
String messageBody,
Map<String, String> attributes) {
Map<String, MessageAttributeValue> messageAttributes = attributes.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
e -> MessageAttributeValue.builder()
.dataType("String")
.stringValue(e.getValue())
.build()));
SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody(messageBody)
.messageAttributes(messageAttributes)
.build();
SendMessageResponse response = sqsClient.sendMessage(request);
return response.messageId();
}
```
### Send FIFO Message
```java
public String sendFifoMessage(SqsClient sqsClient,
String queueUrl,
String messageBody,
String messageGroupId) {
SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody(messageBody)
.messageGroupId(messageGroupId)
.messageDeduplicationId(UUID.randomUUID().toString())
.build();
SendMessageResponse response = sqsClient.sendMessage(request);
return response.messageId();
}
```
### Send Batch Messages
```java
public void sendBatchMessages(SqsClient sqsClient,
String queueUrl,
List<String> messages) {
List<SendMessageBatchRequestEntry> entries = IntStream.range(0, messages.size())
.mapToObj(i -> SendMessageBatchRequestEntry.builder()
.id(String.valueOf(i))
.messageBody(messages.get(i))
.build())
.collect(Collectors.toList());
SendMessageBatchRequest request = SendMessageBatchRequest.builder()
.queueUrl(queueUrl)
.entries(entries)
.build();
SendMessageBatchResponse response = sqsClient.sendMessageBatch(request);
System.out.println("Successful: " + response.successful().size());
System.out.println("Failed: " + response.failed().size());
}
```
## Message Processing
### Receive Messages with Long Polling
```java
public List<Message> receiveMessages(SqsClient sqsClient, String queueUrl) {
ReceiveMessageRequest request = ReceiveMessageRequest.builder()
.queueUrl(queueUrl)
.maxNumberOfMessages(10)
.waitTimeSeconds(20) // Long polling
.messageAttributeNames("All")
.build();
ReceiveMessageResponse response = sqsClient.receiveMessage(request);
return response.messages();
}
```
### Delete Message
```java
public void deleteMessage(SqsClient sqsClient, String queueUrl, String receiptHandle) {
DeleteMessageRequest request = DeleteMessageRequest.builder()
.queueUrl(queueUrl)
.receiptHandle(receiptHandle)
.build();
sqsClient.deleteMessage(request);
}
```
### Delete Batch Messages
```java
public void deleteBatchMessages(SqsClient sqsClient,
String queueUrl,
List<Message> messages) {
List<DeleteMessageBatchRequestEntry> entries = messages.stream()
.map(msg -> DeleteMessageBatchRequestEntry.builder()
.id(msg.messageId())
.receiptHandle(msg.receiptHandle())
.build())
.collect(Collectors.toList());
DeleteMessageBatchRequest request = DeleteMessageBatchRequest.builder()
.queueUrl(queueUrl)
.entries(entries)
.build();
sqsClient.deleteMessageBatch(request);
}
```
### Change Message Visibility
```java
public void changeMessageVisibility(SqsClient sqsClient,
String queueUrl,
String receiptHandle,
int visibilityTimeout) {
ChangeMessageVisibilityRequest request = ChangeMessageVisibilityRequest.builder()
.queueUrl(queueUrl)
.receiptHandle(receiptHandle)
.visibilityTimeout(visibilityTimeout)
.build();
sqsClient.changeMessageVisibility(request);
}
```

View File

@@ -0,0 +1,292 @@
# Spring Boot Integration Reference
## Configuration
### Basic Bean Configuration
```java
@Configuration
public class MessagingConfiguration {
@Bean
public SqsClient sqsClient() {
return SqsClient.builder()
.region(Region.US_EAST_1)
.build();
}
@Bean
public SnsClient snsClient() {
return SnsClient.builder()
.region(Region.US_EAST_1)
.build();
}
}
```
### Configuration Properties
```yaml
# application.yml
aws:
sqs:
queue-url: https://sqs.us-east-1.amazonaws.com/123456789012/my-queue
sns:
topic-arn: arn:aws:sns:us-east-1:123456789012:my-topic
```
## Service Layer Integration
### SQS Message Service
```java
@Service
@RequiredArgsConstructor
public class SqsMessageService {
private final SqsClient sqsClient;
private final ObjectMapper objectMapper;
@Value("${aws.sqs.queue-url}")
private String queueUrl;
public <T> void sendMessage(T message) {
try {
String jsonMessage = objectMapper.writeValueAsString(message);
SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody(jsonMessage)
.build();
sqsClient.sendMessage(request);
} catch (Exception e) {
throw new RuntimeException("Failed to send SQS message", e);
}
}
public <T> List<T> receiveMessages(Class<T> messageType) {
ReceiveMessageRequest request = ReceiveMessageRequest.builder()
.queueUrl(queueUrl)
.maxNumberOfMessages(10)
.waitTimeSeconds(20)
.build();
ReceiveMessageResponse response = sqsClient.receiveMessage(request);
return response.messages().stream()
.map(msg -> {
try {
return objectMapper.readValue(msg.body(), messageType);
} catch (Exception e) {
throw new RuntimeException("Failed to parse message", e);
}
})
.collect(Collectors.toList());
}
public void deleteMessage(String receiptHandle) {
DeleteMessageRequest request = DeleteMessageRequest.builder()
.queueUrl(queueUrl)
.receiptHandle(receiptHandle)
.build();
sqsClient.deleteMessage(request);
}
}
```
### SNS Notification Service
```java
@Service
@RequiredArgsConstructor
public class SnsNotificationService {
private final SnsClient snsClient;
private final ObjectMapper objectMapper;
@Value("${aws.sns.topic-arn}")
private String topicArn;
public void publishNotification(String subject, Object message) {
try {
String jsonMessage = objectMapper.writeValueAsString(message);
PublishRequest request = PublishRequest.builder()
.topicArn(topicArn)
.subject(subject)
.message(jsonMessage)
.build();
snsClient.publish(request);
} catch (Exception e) {
throw new RuntimeException("Failed to publish SNS notification", e);
}
}
}
```
## Message Listener Pattern
### Scheduled Polling
```java
@Service
@RequiredArgsConstructor
public class SqsMessageListener {
private final SqsClient sqsClient;
private final ObjectMapper objectMapper;
@Value("${aws.sqs.queue-url}")
private String queueUrl;
@Scheduled(fixedDelay = 5000)
public void pollMessages() {
ReceiveMessageRequest request = ReceiveMessageRequest.builder()
.queueUrl(queueUrl)
.maxNumberOfMessages(10)
.waitTimeSeconds(20)
.build();
ReceiveMessageResponse response = sqsClient.receiveMessage(request);
response.messages().forEach(this::processMessage);
}
private void processMessage(Message message) {
try {
// Process message
System.out.println("Processing: " + message.body());
// Delete message after successful processing
deleteMessage(message.receiptHandle());
} catch (Exception e) {
// Handle error - message will become visible again
System.err.println("Failed to process message: " + e.getMessage());
}
}
private void deleteMessage(String receiptHandle) {
DeleteMessageRequest request = DeleteMessageRequest.builder()
.queueUrl(queueUrl)
.receiptHandle(receiptHandle)
.build();
sqsClient.deleteMessage(request);
}
}
```
## Pub/Sub Pattern Integration
### Configuration for Pub/Sub
```java
@Configuration
@RequiredArgsConstructor
public class PubSubConfiguration {
private final SnsClient snsClient;
private final SqsClient sqsClient;
@Bean
@DependsOn("sqsClient")
public String setupPubSub() {
// Create SNS topic
String topicArn = snsClient.createTopic(CreateTopicRequest.builder()
.name("order-events")
.build()).topicArn();
// Create SQS queue
String queueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
.queueName("order-processor")
.build()).queueUrl();
// Get queue ARN
String queueArn = sqsClient.getQueueAttributes(GetQueueAttributesRequest.builder()
.queueUrl(queueUrl)
.attributeNames(QueueAttributeName.QUEUE_ARN)
.build()).attributes().get(QueueAttributeName.QUEUE_ARN);
// Subscribe SQS to SNS
snsClient.subscribe(SubscribeRequest.builder()
.protocol("sqs")
.endpoint(queueArn)
.topicArn(topicArn)
.build());
return topicArn;
}
}
```
## Error Handling Patterns
### Retry Mechanism
```java
@Service
@RequiredArgsConstructor
public class RetryableSqsService {
private final SqsClient sqsClient;
private final RetryTemplate retryTemplate;
public void sendMessageWithRetry(String queueUrl, String messageBody) {
retryTemplate.execute(context -> {
try {
SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody(messageBody)
.build();
sqsClient.sendMessage(request);
return null;
} catch (Exception e) {
throw new RetryableException("Failed to send message", e);
}
});
}
}
```
## Testing Integration
### LocalStack Configuration
```java
@TestConfiguration
public class LocalStackMessagingConfig {
@Container
static LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.0"))
.withServices(
LocalStackContainer.Service.SQS,
LocalStackContainer.Service.SNS
);
@Bean
public SqsClient sqsClient() {
return SqsClient.builder()
.region(Region.US_EAST_1)
.endpointOverride(
localstack.getEndpointOverride(LocalStackContainer.Service.SQS))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
localstack.getAccessKey(),
localstack.getSecretKey())))
.build();
}
@Bean
public SnsClient snsClient() {
return SnsClient.builder()
.region(Region.US_EAST_1)
.endpointOverride(
localstack.getEndpointOverride(LocalStackContainer.Service.SNS))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
localstack.getAccessKey(),
localstack.getSecretKey())))
.build();
}
}
```

View File

@@ -0,0 +1,400 @@
---
name: aws-sdk-java-v2-rds
description: AWS RDS (Relational Database Service) management using AWS SDK for Java 2.x. Use when creating, modifying, monitoring, or managing Amazon RDS database instances, snapshots, parameter groups, and configurations.
category: aws
tags: [aws, rds, database, java, sdk, postgresql, mysql, aurora, spring-boot]
version: 1.1.0
allowed-tools: Read, Write, Bash
---
# AWS SDK for Java v2 - RDS Management
This skill provides comprehensive guidance for working with Amazon RDS (Relational Database Service) using the AWS SDK for Java 2.x, covering database instance management, snapshots, parameter groups, and RDS operations.
## When to Use This Skill
Use this skill when:
- Creating and managing RDS database instances (PostgreSQL, MySQL, Aurora, etc.)
- Taking and restoring database snapshots
- Managing DB parameter groups and configurations
- Querying RDS instance metadata and status
- Setting up Multi-AZ deployments
- Configuring automated backups
- Managing security groups for RDS
- Connecting Lambda functions to RDS databases
- Implementing RDS IAM authentication
- Monitoring RDS instances and metrics
## Getting Started
### RDS Client Setup
The `RdsClient` is the main entry point for interacting with Amazon RDS.
**Basic Client Creation:**
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rds.RdsClient;
RdsClient rdsClient = RdsClient.builder()
.region(Region.US_EAST_1)
.build();
// Use client
describeInstances(rdsClient);
// Always close the client
rdsClient.close();
```
**Client with Custom Configuration:**
```java
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
RdsClient rdsClient = RdsClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(ProfileCredentialsProvider.create("myprofile"))
.httpClient(ApacheHttpClient.builder()
.connectionTimeout(Duration.ofSeconds(30))
.socketTimeout(Duration.ofSeconds(60))
.build())
.build();
```
### Describing DB Instances
Retrieve information about existing RDS instances.
**List All DB Instances:**
```java
public static void describeInstances(RdsClient rdsClient) {
try {
DescribeDbInstancesResponse response = rdsClient.describeDBInstances();
List<DBInstance> instanceList = response.dbInstances();
for (DBInstance instance : instanceList) {
System.out.println("Instance ARN: " + instance.dbInstanceArn());
System.out.println("Engine: " + instance.engine());
System.out.println("Status: " + instance.dbInstanceStatus());
System.out.println("Endpoint: " + instance.endpoint().address());
System.out.println("Port: " + instance.endpoint().port());
System.out.println("---");
}
} catch (RdsException e) {
System.err.println(e.getMessage());
System.exit(1);
}
}
```
## Key Operations
### Creating DB Instances
Create new RDS database instances with various configurations.
**Create Basic DB Instance:**
```java
public static String createDBInstance(RdsClient rdsClient,
String dbInstanceIdentifier,
String dbName,
String masterUsername,
String masterPassword) {
try {
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.dbInstanceIdentifier(dbInstanceIdentifier)
.dbName(dbName)
.engine("postgres")
.engineVersion("14.7")
.dbInstanceClass("db.t3.micro")
.allocatedStorage(20)
.masterUsername(masterUsername)
.masterUserPassword(masterPassword)
.publiclyAccessible(false)
.build();
CreateDbInstanceResponse response = rdsClient.createDBInstance(request);
System.out.println("Creating DB instance: " + response.dbInstance().dbInstanceArn());
return response.dbInstance().dbInstanceArn();
} catch (RdsException e) {
System.err.println("Error creating instance: " + e.getMessage());
throw e;
}
}
```
### Managing DB Parameter Groups
Create and manage custom parameter groups for database configuration.
**Create DB Parameter Group:**
```java
public static void createDBParameterGroup(RdsClient rdsClient,
String groupName,
String description) {
try {
CreateDbParameterGroupRequest request = CreateDbParameterGroupRequest.builder()
.dbParameterGroupName(groupName)
.dbParameterGroupFamily("postgres15")
.description(description)
.build();
CreateDbParameterGroupResponse response = rdsClient.createDBParameterGroup(request);
System.out.println("Created parameter group: " + response.dbParameterGroup().dbParameterGroupName());
} catch (RdsException e) {
System.err.println("Error creating parameter group: " + e.getMessage());
throw e;
}
}
```
### Managing DB Snapshots
Create, restore, and manage database snapshots.
**Create DB Snapshot:**
```java
public static String createDBSnapshot(RdsClient rdsClient,
String dbInstanceIdentifier,
String snapshotIdentifier) {
try {
CreateDbSnapshotRequest request = CreateDbSnapshotRequest.builder()
.dbInstanceIdentifier(dbInstanceIdentifier)
.dbSnapshotIdentifier(snapshotIdentifier)
.build();
CreateDbSnapshotResponse response = rdsClient.createDBSnapshot(request);
System.out.println("Creating snapshot: " + response.dbSnapshot().dbSnapshotIdentifier());
return response.dbSnapshot().dbSnapshotArn();
} catch (RdsException e) {
System.err.println("Error creating snapshot: " + e.getMessage());
throw e;
}
}
```
## Integration Patterns
### Spring Boot Integration
Refer to [references/spring-boot-integration.md](references/spring-boot-integration.md) for complete Spring Boot integration examples including:
- Spring Boot configuration with application properties
- RDS client bean configuration
- Service layer implementation
- REST controller design
- Exception handling
- Testing strategies
### Lambda Integration
Refer to [references/lambda-integration.md](references/lambda-integration.md) for Lambda integration examples including:
- Traditional Lambda + RDS connections
- Lambda with connection pooling
- Using AWS Secrets Manager for credentials
- Lambda with AWS SDK for RDS management
- Security configuration and best practices
## Advanced Operations
### Modifying DB Instances
Update existing RDS instances.
```java
public static void modifyDBInstance(RdsClient rdsClient,
String dbInstanceIdentifier,
String newInstanceClass) {
try {
ModifyDbInstanceRequest request = ModifyDbInstanceRequest.builder()
.dbInstanceIdentifier(dbInstanceIdentifier)
.dbInstanceClass(newInstanceClass)
.applyImmediately(false) // Apply during maintenance window
.build();
ModifyDbInstanceResponse response = rdsClient.modifyDBInstance(request);
System.out.println("Modified instance: " + response.dbInstance().dbInstanceIdentifier());
System.out.println("New class: " + response.dbInstance().dbInstanceClass());
} catch (RdsException e) {
System.err.println("Error modifying instance: " + e.getMessage());
throw e;
}
}
```
### Deleting DB Instances
Delete RDS instances with optional final snapshot.
```java
public static void deleteDBInstanceWithSnapshot(RdsClient rdsClient,
String dbInstanceIdentifier,
String finalSnapshotIdentifier) {
try {
DeleteDbInstanceRequest request = DeleteDbInstanceRequest.builder()
.dbInstanceIdentifier(dbInstanceIdentifier)
.skipFinalSnapshot(false)
.finalDBSnapshotIdentifier(finalSnapshotIdentifier)
.build();
DeleteDbInstanceResponse response = rdsClient.deleteDBInstance(request);
System.out.println("Deleting instance: " + response.dbInstance().dbInstanceIdentifier());
} catch (RdsException e) {
System.err.println("Error deleting instance: " + e.getMessage());
throw e;
}
}
```
## Best Practices
### Security
**Always use encryption:**
```java
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.storageEncrypted(true)
.kmsKeyId("arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012")
.build();
```
**Use VPC security groups:**
```java
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.vpcSecurityGroupIds("sg-12345678")
.publiclyAccessible(false)
.build();
```
### High Availability
**Enable Multi-AZ for production:**
```java
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.multiAZ(true)
.build();
```
### Backups
**Configure automated backups:**
```java
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.backupRetentionPeriod(7)
.preferredBackupWindow("03:00-04:00")
.build();
```
### Monitoring
**Enable CloudWatch logs:**
```java
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.enableCloudwatchLogsExports("postgresql", "upgrade")
.build();
```
### Cost Optimization
**Use appropriate instance class:**
```java
// Development
.dbInstanceClass("db.t3.micro")
// Production
.dbInstanceClass("db.r5.large")
```
### Deletion Protection
**Enable for production databases:**
```java
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
.deletionProtection(true)
.build();
```
### Resource Management
**Always close clients:**
```java
try (RdsClient rdsClient = RdsClient.builder()
.region(Region.US_EAST_1)
.build()) {
// Use client
} // Automatically closed
```
## Dependencies
### Maven Dependencies
```xml
<dependencies>
<!-- AWS SDK for RDS -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>rds</artifactId>
<version>2.20.0</version> // Use the latest version available
</dependency>
<!-- PostgreSQL Driver -->
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.6.0</version> // Use the correct version available
</dependency>
<!-- MySQL Driver -->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version>
</dependency>
</dependencies>
```
### Gradle Dependencies
```gradle
dependencies {
// AWS SDK for RDS
implementation 'software.amazon.awssdk:rds:2.20.0'
// PostgreSQL Driver
implementation 'org.postgresql:postgresql:42.6.0'
// MySQL Driver
implementation 'mysql:mysql-connector-java:8.0.33'
}
```
## Reference Documentation
For detailed API reference, see:
- [API Reference](references/api-reference.md) - Complete API documentation and data models
- [Spring Boot Integration](references/spring-boot-integration.md) - Spring Boot patterns and examples
- [Lambda Integration](references/lambda-integration.md) - Lambda function patterns and best practices
## Error Handling
See [API Reference](references/api-reference.md#error-handling) for comprehensive error handling patterns including common exceptions, error response structure, and pagination support.
## Performance Considerations
- Use connection pooling for multiple database operations
- Implement retry logic for transient failures
- Monitor CloudWatch metrics for performance optimization
- Use appropriate instance types for workload requirements
- Enable Performance Insights for database optimization
## Support
For support with AWS RDS operations using AWS SDK for Java 2.x:
- AWS Documentation: [Amazon RDS User Guide](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Welcome.html)
- AWS SDK Documentation: [AWS SDK for Java 2.x](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html)
- AWS Support: [AWS Support Center](https://aws.amazon.com/premiumsupport/)

View File

@@ -0,0 +1,122 @@
# AWS RDS API Reference
## Core API Operations
### Describe Operations
- `describeDBInstances` - List database instances
- `describeDBParameterGroups` - List parameter groups
- `describeDBSnapshots` - List database snapshots
- `describeDBSubnetGroups` - List subnet groups
### Instance Management
- `createDBInstance` - Create new database instance
- `modifyDBInstance` - Modify existing instance
- `deleteDBInstance` - Delete database instance
### Parameter Groups
- `createDBParameterGroup` - Create parameter group
- `modifyDBParameterGroup` - Modify parameters
- `deleteDBParameterGroup` - Delete parameter group
### Snapshots
- `createDBSnapshot` - Create database snapshot
- `restoreDBInstanceFromDBSnapshot` - Restore from snapshot
- `deleteDBSnapshot` - Delete snapshot
## Key Data Models
### DBInstance
```java
String dbInstanceIdentifier() // Instance name
String dbInstanceArn() // ARN identifier
String engine() // Database engine
String engineVersion() // Engine version
String dbInstanceClass() // Instance type
int allocatedStorage() // Storage size in GB
Endpoint endpoint() // Connection endpoint
String dbInstanceStatus() // Instance status
boolean multiAZ() // Multi-AZ enabled
boolean storageEncrypted() // Storage encrypted
```
### DBParameter
```java
String parameterName() // Parameter name
String parameterValue() // Parameter value
String description() // Description
int applyMethod() // Apply method (immediate/reboot)
```
### CreateDbInstanceRequest Builder
```java
CreateDbInstanceRequest.builder()
.dbInstanceIdentifier(identifier)
.engine("postgres") // Database engine
.engineVersion("15.2") // Engine version
.dbInstanceClass("db.t3.micro") // Instance type
.allocatedStorage(20) // Storage size
.masterUsername(username) // Admin username
.masterUserPassword(password) // Admin password
.publiclyAccessible(false) // Public access
.storageEncrypted(true) // Storage encryption
.multiAZ(true) // High availability
.backupRetentionPeriod(7) // Backup retention
.deletionProtection(true) // Protection from deletion
.build()
```
## Error Handling
### Common Exceptions
- `DBInstanceNotFoundFault` - Instance doesn't exist
- `DBSnapshotAlreadyExistsFault` - Snapshot name conflicts
- `InsufficientDBInstanceCapacity` - Instance type unavailable
- `InvalidParameterValueException` - Invalid configuration value
- `StorageQuotaExceeded` - Storage limit reached
### Error Response Structure
```java
try {
rdsClient.createDBInstance(request);
} catch (RdsException e) {
// AWS specific error handling
String errorCode = e.awsErrorDetails().errorCode();
String errorMessage = e.awsErrorDetails().errorMessage();
switch (errorCode) {
case "DBInstanceNotFoundFault":
// Handle missing instance
break;
case "InvalidParameterValueException":
// Handle invalid parameters
break;
default:
// Generic error handling
}
}
```
## Pagination Support
### List Instances with Pagination
```java
DescribeDbInstancesRequest request = DescribeDbInstancesRequest.builder()
.maxResults(100) // Limit results per page
.build();
String marker = null;
do {
if (marker != null) {
request = request.toBuilder()
.marker(marker)
.build();
}
DescribeDbInstancesResponse response = rdsClient.describeDBInstances(request);
List<DBInstance> instances = response.dbInstances();
// Process instances
marker = response.marker();
} while (marker != null);
```

View File

@@ -0,0 +1,382 @@
# AWS Lambda Integration with RDS
## Lambda RDS Connection Patterns
### 1. Traditional Lambda + RDS Connection
```java
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
public class RdsLambdaHandler implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
@Override
public APIGatewayProxyResponseEvent handleRequest(APIGatewayProxyRequestEvent event, Context context) {
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent();
try {
// Get environment variables
String host = System.getenv("ProxyHostName");
String port = System.getenv("Port");
String dbName = System.getenv("DBName");
String username = System.getenv("DBUserName");
String password = System.getenv("DBPassword");
// Create connection string
String connectionString = String.format(
"jdbc:mysql://%s:%s/%s?useSSL=true&requireSSL=true",
host, port, dbName
);
// Execute query
String sql = "SELECT COUNT(*) FROM users";
try (Connection connection = DriverManager.getConnection(connectionString, username, password);
PreparedStatement statement = connection.prepareStatement(sql);
ResultSet resultSet = statement.executeQuery()) {
if (resultSet.next()) {
int count = resultSet.getInt(1);
response.setStatusCode(200);
response.setBody("{\"count\": " + count + "}");
}
}
} catch (Exception e) {
response.setStatusCode(500);
response.setBody("{\"error\": \"" + e.getMessage() + "\"}");
}
return response;
}
}
```
### 2. Lambda with Connection Pooling
```java
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import javax.sql.DataSource;
public class RdsLambdaConfig {
private static DataSource dataSource;
public static synchronized DataSource getDataSource() {
if (dataSource == null) {
HikariConfig config = new HikariConfig();
String host = System.getenv("ProxyHostName");
String port = System.getenv("Port");
String dbName = System.getenv("DBName");
String username = System.getenv("DBUserName");
String password = System.getenv("DBPassword");
config.setJdbcUrl(String.format("jdbc:mysql://%s:%s/%s", host, port, dbName));
config.setUsername(username);
config.setPassword(password);
// Connection pool settings
config.setMaximumPoolSize(5);
config.setMinimumIdle(2);
config.setIdleTimeout(30000);
config.setConnectionTimeout(20000);
config.setMaxLifetime(1800000);
// MySQL-specific settings
config.addDataSourceProperty("useSSL", true);
config.addDataSourceProperty("requireSSL", true);
config.addDataSourceProperty("serverSslCertificate", "rds-ca-2019");
config.addDataSourceProperty("connectTimeout", "30");
dataSource = new HikariDataSource(config);
}
return dataSource;
}
}
```
### 3. Using AWS Secrets Manager for Credentials
```java
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder;
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest;
import com.amazonaws.services.secretsmanager.model.GetSecretValueResult;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.Map;
public class RdsSecretsHelper {
private static final String SECRET_NAME = "prod/rds/db_credentials";
private static final String REGION = "us-east-1";
public static Map<String, String> getRdsCredentials() {
AWSSecretsManager client = AWSSecretsManagerClientBuilder.standard()
.withRegion(REGION)
.build();
GetSecretValueRequest request = GetSecretValueRequest.builder()
.secretId(SECRET_NAME)
.build();
GetSecretValueResult result = client.getSecretValue(request);
// Parse secret JSON
ObjectMapper objectMapper = new ObjectMapper();
Map<String, Object> secretMap = objectMapper.readValue(result.getSecretString(), HashMap.class);
Map<String, String> credentials = new HashMap<>();
secretMap.forEach((key, value) -> {
credentials.put(key, value.toString());
});
return credentials;
}
}
```
### 4. Lambda with AWS SDK for RDS
```java
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rds.RdsClient;
import software.amazon.awssdk.services.rds.model.*;
public class RdsManagementLambda implements RequestHandler<ApiRequest, ApiResponse> {
@Override
public ApiResponse handleRequest(ApiRequest request, Context context) {
RdsClient rdsClient = RdsClient.builder()
.region(Region.US_EAST_1)
.build();
try {
switch (request.getAction()) {
case "list-instances":
return listInstances(rdsClient);
case "create-snapshot":
return createSnapshot(rdsClient, request.getInstanceId(), request.getSnapshotId());
case "describe-instance":
return describeInstance(rdsClient, request.getInstanceId());
default:
return new ApiResponse(400, "Unknown action: " + request.getAction());
}
} catch (Exception e) {
context.getLogger().log("Error: " + e.getMessage());
return new ApiResponse(500, "Error: " + e.getMessage());
} finally {
rdsClient.close();
}
}
private ApiResponse listInstances(RdsClient rdsClient) {
DescribeDbInstancesResponse response = rdsClient.describeDBInstances();
return new ApiResponse(200, response.toString());
}
private ApiResponse createSnapshot(RdsClient rdsClient, String instanceId, String snapshotId) {
CreateDbSnapshotRequest request = CreateDbSnapshotRequest.builder()
.dbInstanceIdentifier(instanceId)
.dbSnapshotIdentifier(snapshotId)
.build();
CreateDbSnapshotResponse response = rdsClient.createDBSnapshot(request);
return new ApiResponse(200, "Snapshot created: " + response.dbSnapshot().dbSnapshotIdentifier());
}
private ApiResponse describeInstance(RdsClient rdsClient, String instanceId) {
DescribeDbInstancesRequest request = DescribeDbInstancesRequest.builder()
.dbInstanceIdentifier(instanceId)
.build();
DescribeDbInstancesResponse response = rdsClient.describeDBInstances(request);
return new ApiResponse(200, response.toString());
}
}
class ApiRequest {
private String action;
private String instanceId;
private String snapshotId;
// getters and setters
}
class ApiResponse {
private int statusCode;
private String body;
// constructor, getters
}
```
## Best Practices for Lambda + RDS
### 1. Security Configuration
**IAM Role:**
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"rds:*"
],
"Resource": "*"
}
]
}
```
**Security Group:**
- Use security groups to restrict access
- Only allow Lambda function IP ranges
- Use VPC endpoints for private connections
### 2. Environment Variables
```bash
# Environment variables for Lambda
DB_HOST=mydb.abc123.us-east-1.rds.amazonaws.com
DB_PORT=5432
DB_NAME=mydatabase
DB_USERNAME=admin
DB_PASSWORD=${DB_PASSWORD}
DB_CONNECTION_STRING=jdbc:postgresql://${DB_HOST}:${DB_PORT}/${DB_NAME}
```
### 3. Error Handling
```java
import com.amazonaws.services.lambda.runtime.LambdaLogger;
public class LambdaErrorHandler {
public static void handleRdsError(Exception e, LambdaLogger logger) {
if (e instanceof RdsException) {
RdsException rdsException = (RdsException) e;
logger.log("RDS Error: " + rdsException.awsErrorDetails().errorCode());
switch (rdsException.awsErrorDetails().errorCode()) {
case "DBInstanceNotFoundFault":
logger.log("Database instance not found");
break;
case "InvalidParameterValueException":
logger.log("Invalid parameter provided");
break;
case "InstanceAlreadyExistsFault":
logger.log("Instance already exists");
break;
default:
logger.log("Unknown RDS error: " + rdsException.getMessage());
}
} else {
logger.log("Non-RDS error: " + e.getMessage());
}
}
}
```
### 4. Performance Optimization
**Cold Start Mitigation:**
```java
import javax.sql.DataSource;
import java.sql.Connection;
public class RdsConnectionHelper {
private static DataSource dataSource;
private static long lastConnectionTime = 0;
private static final long CONNECTION_TIMEOUT = 300000; // 5 minutes
public static Connection getConnection() throws SQLException {
long currentTime = System.currentTimeMillis();
if (dataSource == null || (currentTime - lastConnectionTime) > CONNECTION_TIMEOUT) {
dataSource = createDataSource();
lastConnectionTime = currentTime;
}
return dataSource.getConnection();
}
private static DataSource createDataSource() {
// Connection pool creation
}
}
```
**Batch Processing:**
```java
public class RdsBatchProcessor {
public void processBatch(List<String> userIds) {
String sql = "SELECT * FROM users WHERE user_id IN (?)";
try (Connection connection = getConnection();
PreparedStatement statement = connection.prepareStatement(sql)) {
// Convert list to SQL IN clause
String placeholders = userIds.stream()
.map(id -> "?")
.collect(Collectors.joining(","));
String finalSql = sql.replace("?", placeholders);
// Set parameters
for (int i = 0; i < userIds.size(); i++) {
statement.setString(i + 1, userIds.get(i));
}
ResultSet resultSet = statement.executeQuery();
// Process results
} catch (SQLException e) {
LambdaErrorHandler.handleRdsError(e, logger);
}
}
}
```
### 5. Monitoring and Logging
```java
import com.amazonaws.services.cloudwatch.AmazonCloudWatch;
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder;
import com.amazonaws.services.cloudwatch.model.MetricDatum;
import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest;
public class RdsMetricsPublisher {
private static final String NAMESPACE = "RDS/Lambda";
private AmazonCloudWatch cloudWatch;
public RdsMetricsPublisher() {
this.cloudWatch = AmazonCloudWatchClientBuilder.defaultClient();
}
public void publishMetric(String metricName, double value) {
MetricDatum datum = new MetricDatum()
.withMetricName(metricName)
.withUnit("Count")
.withValue(value)
.withTimestamp(new Date());
PutMetricDataRequest request = new PutMetricDataRequest()
.withNamespace(NAMESPACE)
.withMetricData(Collections.singletonList(datum));
cloudWatch.putMetricData(request);
}
}
```

View File

@@ -0,0 +1,325 @@
# Spring Boot Integration with AWS RDS
## Configuration
### application.properties
```properties
# AWS Configuration
aws.region=us-east-1
aws.rds.instance-identifier=mydb-instance
# RDS Connection (from RDS endpoint)
spring.datasource.url=jdbc:postgresql://mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydatabase
spring.datasource.username=admin
spring.datasource.password=${DB_PASSWORD}
spring.datasource.driver-class-name=org.postgresql.Driver
# JPA Configuration
spring.jpa.hibernate.ddl-auto=validate
spring.jpa.show-sql=false
spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.PostgreSQLDialect
# Connection Pool Configuration
spring.datasource.hikari.maximum-pool-size=10
spring.datasource.hikari.minimum-idle=5
spring.datasource.hikari.idle-timeout=30000
spring.datasource.hikari.connection-timeout=20000
```
### AWS Configuration
```java
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rds.RdsClient;
@Configuration
public class AwsRdsConfig {
@Value("${aws.region}")
private String awsRegion;
@Bean
public RdsClient rdsClient() {
return RdsClient.builder()
.region(Region.of(awsRegion))
.build();
}
}
```
### Service Layer
```java
import org.springframework.stereotype.Service;
import software.amazon.awssdk.services.rds.RdsClient;
import software.amazon.awssdk.services.rds.model.*;
import java.util.List;
@Service
public class RdsService {
private final RdsClient rdsClient;
public RdsService(RdsClient rdsClient) {
this.rdsClient = rdsClient;
}
public List<DBInstance> listInstances() {
DescribeDbInstancesResponse response = rdsClient.describeDBInstances();
return response.dbInstances();
}
public DBInstance getInstanceDetails(String instanceId) {
DescribeDbInstancesRequest request = DescribeDbInstancesRequest.builder()
.dbInstanceIdentifier(instanceId)
.build();
DescribeDbInstancesResponse response = rdsClient.describeDBInstances(request);
return response.dbInstances().get(0);
}
public String createSnapshot(String instanceId, String snapshotId) {
CreateDbSnapshotRequest request = CreateDbSnapshotRequest.builder()
.dbInstanceIdentifier(instanceId)
.dbSnapshotIdentifier(snapshotId)
.build();
CreateDbSnapshotResponse response = rdsClient.createDBSnapshot(request);
return response.dbSnapshot().dbSnapshotArn();
}
public void modifyInstance(String instanceId, String newInstanceClass) {
ModifyDbInstanceRequest request = ModifyDbInstanceRequest.builder()
.dbInstanceIdentifier(instanceId)
.dbInstanceClass(newInstanceClass)
.applyImmediately(true)
.build();
rdsClient.modifyDBInstance(request);
}
}
```
### REST Controller
```java
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import software.amazon.awssdk.services.rds.model.DBInstance;
import java.util.List;
@RestController
@RequestMapping("/api/rds")
public class RdsController {
private final RdsService rdsService;
public RdsController(RdsService rdsService) {
this.rdsService = rdsService;
}
@GetMapping("/instances")
public ResponseEntity<List<DBInstance>> listInstances() {
return ResponseEntity.ok(rdsService.listInstances());
}
@GetMapping("/instances/{id}")
public ResponseEntity<DBInstance> getInstanceDetails(@PathVariable String id) {
return ResponseEntity.ok(rdsService.getInstanceDetails(id));
}
@PostMapping("/snapshots")
public ResponseEntity<String> createSnapshot(
@RequestParam String instanceId,
@RequestParam String snapshotId) {
String arn = rdsService.createSnapshot(instanceId, snapshotId);
return ResponseEntity.ok(arn);
}
@PutMapping("/instances/{id}")
public ResponseEntity<String> modifyInstance(
@PathVariable String id,
@RequestParam String instanceClass) {
rdsService.modifyInstance(id, instanceClass);
return ResponseEntity.ok("Instance modified successfully");
}
}
```
### Exception Handling
```java
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.annotation.RestControllerAdvice;
@RestControllerAdvice
public class RdsExceptionHandler {
@ExceptionHandler(RdsException.class)
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
public ErrorResponse handleRdsException(RdsException e) {
return new ErrorResponse(
"RDS_ERROR",
e.getMessage(),
e.awsErrorDetails().errorCode()
);
}
@ExceptionHandler(Exception.class)
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
public ErrorResponse handleGenericException(Exception e) {
return new ErrorResponse(
"INTERNAL_ERROR",
e.getMessage()
);
}
}
class ErrorResponse {
private String code;
private String message;
private String details;
// Constructor, getters, setters
}
```
## Testing
### Unit Tests
```java
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.mockito.Mockito.*;
import static org.junit.jupiter.api.Assertions.*;
@ExtendWith(MockitoExtension.class)
class RdsServiceTest {
@Mock
private RdsClient rdsClient;
@Test
void listInstances_shouldReturnInstances() {
// Arrange
DescribeDbInstancesResponse response = DescribeDbInstancesResponse.builder()
.dbInstances(List.of(createTestInstance()))
.build();
when(rdsClient.describeDBInstances()).thenReturn(response);
RdsService service = new RdsService(rdsClient);
// Act
List<DBInstance> result = service.listInstances();
// Assert
assertEquals(1, result.size());
verify(rdsClient).describeDBInstances();
}
private DBInstance createTestInstance() {
return DBInstance.builder()
.dbInstanceIdentifier("test-instance")
.engine("postgres")
.dbInstanceStatus("available")
.build();
}
}
```
### Integration Tests
```java
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
@ActiveProfiles = "test"
class RdsServiceIntegrationTest {
@Autowired
private RdsService rdsService;
@Test
void listInstances_integrationTest() {
// This test requires actual AWS credentials and RDS instances
// Should only run with proper test configuration
assumeTrue(false, "Integration test disabled");
List<DBInstance> instances = rdsService.listInstances();
assertNotNull(instances);
}
}
```
## Best Practices
### 1. Configuration Management
- Use Spring profiles for different environments
- Externalize sensitive configuration (passwords, keys)
- Use Spring Cloud Config for multi-environment management
### 2. Connection Pooling
```properties
# HikariCP Configuration
spring.datasource.hikari.maximum-pool-size=20
spring.datasource.hikari.minimum-idle=10
spring.datasource.hikari.idle-timeout=600000
spring.datasource.hikari.connection-timeout=30000
spring.datasource.hikari.connection-test-query=SELECT 1
```
### 3. Retry Logic
```java
import org.springframework.retry.annotation.Retryable;
import org.springframework.retry.annotation.Backoff;
@Service
public class RdsServiceWithRetry {
private final RdsClient rdsClient;
@Retryable(value = { RdsException.class },
maxAttempts = 3,
backoff = @Backoff(delay = 1000))
public List<DBInstance> listInstancesWithRetry() {
return rdsClient.describeDBInstances().dbInstances();
}
}
```
### 4. Monitoring
```java
import org.springframework.boot.actuator.health.Health;
import org.springframework.boot.actuator.health.HealthIndicator;
import org.springframework.stereotype.Component;
@Component
public class RdsHealthIndicator implements HealthIndicator {
private final RdsClient rdsClient;
public RdsHealthIndicator(RdsClient rdsClient) {
this.rdsClient = rdsClient;
}
@Override
public Health health() {
try {
rdsClient.describeDBInstances();
return Health.up()
.withDetail("service", "RDS")
.build();
} catch (Exception e) {
return Health.down()
.withDetail("error", e.getMessage())
.build();
}
}
}
```

View File

@@ -0,0 +1,691 @@
---
name: aws-sdk-java-v2-s3
description: Amazon S3 patterns and examples using AWS SDK for Java 2.x. Use when working with S3 buckets, uploading/downloading objects, multipart uploads, presigned URLs, S3 Transfer Manager, object operations, or S3-specific configurations.
category: aws
tags: [aws, s3, java, sdk, storage, objects, transfer-manager, presigned-urls]
version: 1.1.0
allowed-tools: Read, Write, Bash
---
# AWS SDK for Java 2.x - Amazon S3
## When to Use
Use this skill when:
- Creating, listing, or deleting S3 buckets with proper configuration
- Uploading or downloading objects from S3 with metadata and encryption
- Working with multipart uploads for large files (>100MB) with error handling
- Generating presigned URLs for temporary access to S3 objects
- Copying or moving objects between S3 buckets with metadata preservation
- Setting object metadata, storage classes, and access controls
- Implementing S3 Transfer Manager for optimized file transfers
- Integrating S3 with Spring Boot applications for cloud storage
- Setting up S3 event notifications for object lifecycle management
- Managing bucket policies, CORS configuration, and access controls
- Implementing retry mechanisms and error handling for S3 operations
- Testing S3 integrations with LocalStack for development environments
## Dependencies
```xml
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
<version>2.20.0</version> // Use the latest stable version
</dependency>
<!-- For S3 Transfer Manager -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3-transfer-manager</artifactId>
<version>2.20.0</version> // Use the latest stable version
</dependency>
<!-- For async operations -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>netty-nio-client</artifactId>
<version>2.20.0</version> // Use the latest stable version
</dependency>
```
## Client Setup
### Basic Synchronous Client
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.build();
```
### Basic Asynchronous Client
```java
import software.amazon.awssdk.services.s3.S3AsyncClient;
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Configured Client with Retry Logic
```java
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.backoff.ExponentialRetryBackoff;
import java.time.Duration;
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.httpClientBuilder(ApacheHttpClient.builder()
.maxConnections(200)
.connectionTimeout(Duration.ofSeconds(5)))
.overrideConfiguration(b -> b
.apiCallTimeout(Duration.ofSeconds(60))
.apiCallAttemptTimeout(Duration.ofSeconds(30))
.retryPolicy(RetryPolicy.builder()
.numRetries(3)
.retryBackoffStrategy(ExponentialRetryBackoff.builder()
.baseDelay(Duration.ofSeconds(1))
.maxBackoffTime(Duration.ofSeconds(30))
.build())
.build()))
.build();
```
## Basic Bucket Operations
### Create Bucket
```java
import software.amazon.awssdk.services.s3.model.*;
import java.util.concurrent.CompletableFuture;
public void createBucket(S3Client s3Client, String bucketName) {
try {
CreateBucketRequest request = CreateBucketRequest.builder()
.bucket(bucketName)
.build();
s3Client.createBucket(request);
// Wait until bucket is ready
HeadBucketRequest waitRequest = HeadBucketRequest.builder()
.bucket(bucketName)
.build();
s3Client.waiter().waitUntilBucketExists(waitRequest);
System.out.println("Bucket created successfully: " + bucketName);
} catch (S3Exception e) {
System.err.println("Error creating bucket: " + e.awsErrorDetails().errorMessage());
throw e;
}
}
```
### List All Buckets
```java
public List<String> listAllBuckets(S3Client s3Client) {
ListBucketsResponse response = s3Client.listBuckets();
return response.buckets().stream()
.map(Bucket::name)
.collect(Collectors.toList());
}
```
### Check if Bucket Exists
```java
public boolean bucketExists(S3Client s3Client, String bucketName) {
try {
HeadBucketRequest request = HeadBucketRequest.builder()
.bucket(bucketName)
.build();
s3Client.headBucket(request);
return true;
} catch (NoSuchBucketException e) {
return false;
}
}
```
## Basic Object Operations
### Upload File to S3
```java
import software.amazon.awssdk.core.sync.RequestBody;
import java.nio.file.Paths;
public void uploadFile(S3Client s3Client, String bucketName, String key, String filePath) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
System.out.println("File uploaded: " + key);
}
```
### Download File from S3
```java
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import java.nio.file.Paths;
public void downloadFile(S3Client s3Client, String bucketName, String key, String destPath) {
GetObjectRequest request = GetObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
s3Client.getObject(request, Paths.get(destPath));
System.out.println("File downloaded: " + destPath);
}
```
### Get Object Metadata
```java
public Map<String, String> getObjectMetadata(S3Client s3Client, String bucketName, String key) {
HeadObjectRequest request = HeadObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
HeadObjectResponse response = s3Client.headObject(request);
return response.metadata();
}
```
## Advanced Object Operations
### Upload with Metadata and Encryption
```java
public void uploadWithMetadata(S3Client s3Client, String bucketName, String key,
String filePath, Map<String, String> metadata) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.metadata(metadata)
.contentType("application/pdf")
.serverSideEncryption(ServerSideEncryption.AES256)
.storageClass(StorageClass.STANDARD_IA)
.build();
PutObjectResponse response = s3Client.putObject(request,
RequestBody.fromFile(Paths.get(filePath)));
System.out.println("Upload completed. ETag: " + response.eTag());
}
```
### Copy Object Between Buckets
```java
public void copyObject(S3Client s3Client, String sourceBucket, String sourceKey,
String destBucket, String destKey) {
CopyObjectRequest request = CopyObjectRequest.builder()
.sourceBucket(sourceBucket)
.sourceKey(sourceKey)
.destinationBucket(destBucket)
.destinationKey(destKey)
.build();
s3Client.copyObject(request);
System.out.println("Object copied: " + sourceKey + " -> " + destKey);
}
```
### Delete Multiple Objects
```java
public void deleteMultipleObjects(S3Client s3Client, String bucketName, List<String> keys) {
List<ObjectIdentifier> objectIds = keys.stream()
.map(key -> ObjectIdentifier.builder().key(key).build())
.collect(Collectors.toList());
Delete delete = Delete.builder()
.objects(objectIds)
.build();
DeleteObjectsRequest request = DeleteObjectsRequest.builder()
.bucket(bucketName)
.delete(delete)
.build();
DeleteObjectsResponse response = s3Client.deleteObjects(request);
response.deleted().forEach(deleted ->
System.out.println("Deleted: " + deleted.key()));
response.errors().forEach(error ->
System.err.println("Failed to delete " + error.key() + ": " + error.message()));
}
```
## Presigned URLs
### Generate Download URL
```java
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
import software.amazon.awssdk.services.s3.presigner.model.*;
import java.time.Duration;
public String generateDownloadUrl(String bucketName, String key) {
try (S3Presigner presigner = S3Presigner.builder()
.region(Region.US_EAST_1)
.build()) {
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
.signatureDuration(Duration.ofMinutes(10))
.getObjectRequest(getObjectRequest)
.build();
PresignedGetObjectRequest presignedRequest = presigner.presignGetObject(presignRequest);
return presignedRequest.url().toString();
}
}
```
### Generate Upload URL
```java
public String generateUploadUrl(String bucketName, String key) {
try (S3Presigner presigner = S3Presigner.create()) {
PutObjectRequest putObjectRequest = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
PutObjectPresignRequest presignRequest = PutObjectPresignRequest.builder()
.signatureDuration(Duration.ofMinutes(5))
.putObjectRequest(putObjectRequest)
.build();
PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest);
return presignedRequest.url().toString();
}
}
```
## S3 Transfer Manager
### Upload with Transfer Manager
```java
import software.amazon.awssdk.transfer.s3.*;
import software.amazon.awssdk.transfer.s3.model.*;
public void uploadWithTransferManager(String bucketName, String key, String filePath) {
try (S3TransferManager transferManager = S3TransferManager.create()) {
UploadFileRequest uploadRequest = UploadFileRequest.builder()
.putObjectRequest(req -> req
.bucket(bucketName)
.key(key))
.source(Paths.get(filePath))
.build();
FileUpload upload = transferManager.uploadFile(uploadRequest);
// Monitor progress
upload.progressFuture().thenAccept(progress -> {
System.out.println("Upload progress: " + progress.progressPercent() + "%");
});
CompletedFileUpload result = upload.completionFuture().join();
System.out.println("Upload complete. ETag: " + result.response().eTag());
}
}
```
### Download with Transfer Manager
```java
public void downloadWithTransferManager(String bucketName, String key, String destPath) {
try (S3TransferManager transferManager = S3TransferManager.create()) {
DownloadFileRequest downloadRequest = DownloadFileRequest.builder()
.getObjectRequest(req -> req
.bucket(bucketName)
.key(key))
.destination(Paths.get(destPath))
.build();
FileDownload download = transferManager.downloadFile(downloadRequest);
CompletedFileDownload result = download.completionFuture().join();
System.out.println("Download complete. Size: " + result.response().contentLength());
}
}
```
## Spring Boot Integration
### Configuration Properties
```java
import org.springframework.boot.context.properties.ConfigurationProperties;
@ConfigurationProperties(prefix = "aws.s3")
public class S3Properties {
private String accessKey;
private String secretKey;
private String region = "us-east-1";
private String endpoint;
private String defaultBucket;
private boolean asyncEnabled = false;
private boolean transferManagerEnabled = true;
// Getters and setters
public String getAccessKey() { return accessKey; }
public void setAccessKey(String accessKey) { this.accessKey = accessKey; }
// ... other getters and setters
}
```
### S3 Configuration Class
```java
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.regions.Region;
import java.net.URI;
@Configuration
public class S3Configuration {
private final S3Properties properties;
public S3Configuration(S3Properties properties) {
this.properties = properties;
}
@Bean
public S3Client s3Client() {
S3Client.Builder builder = S3Client.builder()
.region(Region.of(properties.getRegion()));
if (properties.getAccessKey() != null && properties.getSecretKey() != null) {
builder.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
properties.getAccessKey(),
properties.getSecretKey())));
}
if (properties.getEndpoint() != null) {
builder.endpointOverride(URI.create(properties.getEndpoint()));
}
return builder.build();
}
@Bean
public S3AsyncClient s3AsyncClient() {
S3AsyncClient.Builder builder = S3AsyncClient.builder()
.region(Region.of(properties.getRegion()));
if (properties.getAccessKey() != null && properties.getSecretKey() != null) {
builder.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
properties.getAccessKey(),
properties.getSecretKey())));
}
if (properties.getEndpoint() != null) {
builder.endpointOverride(URI.create(properties.getEndpoint()));
}
return builder.build();
}
@Bean
public S3TransferManager s3TransferManager() {
return S3TransferManager.builder()
.s3Client(s3Client())
.build();
}
}
```
### S3 Service
```java
import org.springframework.stereotype.Service;
import software.amazon.awssdk.transfer.s3.S3TransferManager;
import software.amazon.awssdk.services.s3.model.*;
import java.nio.file.*;
import java.util.*;
import java.util.concurrent.CompletableFuture;
@Service
@RequiredArgsConstructor
public class S3Service {
private final S3Client s3Client;
private final S3AsyncClient s3AsyncClient;
private final S3TransferManager transferManager;
private final S3Properties properties;
public CompletableFuture<Void> uploadFileAsync(String key, Path file) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(properties.getDefaultBucket())
.key(key)
.build();
return CompletableFuture.runAsync(() -> {
s3Client.putObject(request, RequestBody.fromFile(file));
});
}
public CompletableFuture<byte[]> downloadFileAsync(String key) {
GetObjectRequest request = GetObjectRequest.builder()
.bucket(properties.getDefaultBucket())
.key(key)
.build();
return CompletableFuture.supplyAsync(() -> {
try (ResponseInputStream<GetObjectResponse> response = s3Client.getObject(request)) {
return response.readAllBytes();
} catch (IOException e) {
throw new RuntimeException("Failed to read S3 object", e);
}
});
}
public CompletableFuture<String> generatePresignedUrl(String key, Duration duration) {
return CompletableFuture.supplyAsync(() -> {
try (S3Presigner presigner = S3Presigner.builder()
.region(Region.of(properties.getRegion()))
.build()) {
GetObjectRequest getRequest = GetObjectRequest.builder()
.bucket(properties.getDefaultBucket())
.key(key)
.build();
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
.signatureDuration(duration)
.getObjectRequest(getRequest)
.build();
return presigner.presignGetObject(presignRequest).url().toString();
}
});
}
public Flux<S3Object> listObjects(String prefix) {
ListObjectsV2Request request = ListObjectsV2Request.builder()
.bucket(properties.getDefaultBucket())
.prefix(prefix)
.build();
return Flux.create(sink -> {
s3Client.listObjectsV2Paginator(request)
.contents()
.forEach(sink::next);
sink.complete();
});
}
}
```
## Examples
### Basic File Upload Example
```java
public class S3UploadExample {
public static void main(String[] args) {
// Initialize client
S3Client s3Client = S3Client.builder()
.region(Region.US_EAST_1)
.build();
String bucketName = "my-example-bucket";
String filePath = "document.pdf";
String key = "uploads/document.pdf";
// Create bucket if it doesn't exist
if (!bucketExists(s3Client, bucketName)) {
createBucket(s3Client, bucketName);
}
// Upload file
Map<String, String> metadata = Map.of(
"author", "John Doe",
"content-type", "application/pdf",
"upload-date", java.time.LocalDate.now().toString()
);
uploadWithMetadata(s3Client, bucketName, key, filePath, metadata);
// Generate presigned URL
String downloadUrl = generateDownloadUrl(bucketName, key);
System.out.println("Download URL: " + downloadUrl);
// Close client
s3Client.close();
}
}
```
### Batch File Processing Example
```java
import java.nio.file.*;
import java.util.stream.*;
public class S3BatchProcessing {
public void processDirectoryUpload(S3Client s3Client, String bucketName, String directoryPath) {
try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
List<CompletableFuture<Void>> futures = paths
.filter(Files::isRegularFile)
.map(path -> {
String key = bucketName + "/" + path.getFileName().toString();
return CompletableFuture.runAsync(() -> {
uploadFile(s3Client, bucketName, key, path.toString());
});
})
.collect(Collectors.toList());
// Wait for all uploads to complete
CompletableFuture.allOf(
futures.toArray(new CompletableFuture[0])
).join();
System.out.println("All files uploaded successfully");
} catch (IOException e) {
throw new RuntimeException("Failed to process directory", e);
}
}
}
```
## Best Practices
### Performance Optimization
1. **Use S3 Transfer Manager**: Automatically handles multipart uploads, parallel transfers, and progress tracking for files >100MB
2. **Reuse S3 Client**: Clients are thread-safe and should be reused throughout the application lifecycle
3. **Enable async operations**: Use S3AsyncClient for I/O-bound operations to improve throughput
4. **Configure proper timeouts**: Set appropriate timeouts for large file operations
5. **Use connection pooling**: Configure HTTP client for optimal connection management
### Security Considerations
1. **Use temporary credentials**: Always use IAM roles or AWS STS for short-lived access tokens
2. **Enable server-side encryption**: Use AES-256 or AWS KMS for sensitive data
3. **Implement access controls**: Use bucket policies and IAM roles instead of access keys in production
4. **Validate object metadata**: Sanitize user-provided metadata to prevent header injection
5. **Use presigned URLs**: Avoid exposing credentials by using temporary access URLs
### Error Handling
1. **Implement retry logic**: Network operations should have exponential backoff retry strategies
2. **Handle throttling**: Implement proper handling of 429 Too Many Requests responses
3. **Validate object existence**: Check if objects exist before operations that require them
4. **Clean up failed operations**: Abort multipart uploads that fail
5. **Log appropriately**: Log successful operations and errors for monitoring
### Cost Optimization
1. **Use appropriate storage classes**: Choose STANDARD, STANDARD_IA, INTELLIGENT_TIERING based on access patterns
2. **Implement lifecycle policies**: Automatically transition or expire objects
3. **Enable object versioning**: For important data that needs retention
4. **Monitor usage**: Track data transfer and storage costs
5. **Minimize API calls**: Use batch operations when possible
## Constraints and Limitations
- **File size limits**: Single PUT operations limited to 5GB; use multipart uploads for larger files
- **Batch operations**: Maximum 1000 objects per DeleteObjects operation
- **Metadata size**: User-defined metadata limited to 2KB
- **Concurrent transfers**: Transfer Manager handles up to 100 concurrent transfers by default
- **Region consistency**: Cross-region operations may incur additional costs and latency
- **S3 eventual consistency**: New objects might not be immediately visible after upload
## References
For more detailed information, see:
- [AWS S3 Object Operations Reference](./references/s3-object-operations.md)
- [S3 Transfer Manager Patterns](./references/s3-transfer-patterns.md)
- [Spring Boot Integration Guide](./references/s3-spring-boot-integration.md)
- [AWS S3 Developer Guide](https://docs.aws.amazon.com/AmazonS3/latest/userguide/)
- [AWS SDK for Java 2.x S3 API](https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/s3/package-summary.html)
## Related Skills
- `aws-sdk-java-v2-core` - Core AWS SDK patterns and configuration
- `spring-boot-dependency`-injection - Spring dependency injection patterns
- `unit-test-service-layer` - Testing service layer patterns
- `unit-test-wiremock-rest-api` - Testing external API integrations

View File

@@ -0,0 +1,371 @@
# S3 Object Operations Reference
## Detailed Object Operations
### Advanced Upload Patterns
#### Streaming Upload with Progress Monitoring
```java
public void uploadWithProgress(S3Client s3Client, String bucketName, String key,
String filePath) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
try (RequestBody file = RequestBody.fromFile(Paths.get(filePath))) {
s3Client.putObject(request, file);
}
}
```
#### Conditional Upload
```java
public void conditionalUpload(S3Client s3Client, String bucketName, String key,
String filePath, String expectedETag) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.ifMatch(expectedETag)
.build();
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
}
```
### Advanced Download Patterns
#### Range Requests for Large Files
```java
public void downloadInChunks(S3Client s3Client, String bucketName, String key,
String destPath, int chunkSizeMB) {
long fileSize = getFileSize(s3Client, bucketName, key);
int chunkSize = chunkSizeMB * 1024 * 1024;
try (OutputStream os = new FileOutputStream(destPath)) {
for (long start = 0; start < fileSize; start += chunkSize) {
long end = Math.min(start + chunkSize - 1, fileSize - 1);
GetObjectRequest request = GetObjectRequest.builder()
.bucket(bucketName)
.key(key)
.range("bytes=" + start + "-" + end)
.build();
try (ResponseInputStream<GetObjectResponse> response =
s3Client.getObject(request)) {
response.transferTo(os);
}
}
}
}
```
### Metadata Management
#### Setting and Retrieving Object Metadata
```java
public void setObjectMetadata(S3Client s3Client, String bucketName, String key,
Map<String, String> metadata) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.metadata(metadata)
.build();
s3Client.putObject(request, RequestBody.empty());
}
public Map<String, String> getObjectMetadata(S3Client s3Client,
String bucketName, String key) {
HeadObjectRequest request = HeadObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
HeadObjectResponse response = s3Client.headObject(request);
return response.metadata();
}
```
### Storage Classes and Lifecycle
#### Managing Different Storage Classes
```java
public void uploadWithStorageClass(S3Client s3Client, String bucketName, String key,
String filePath, StorageClass storageClass) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.storageClass(storageClass)
.build();
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
}
// Storage class options:
// STANDARD - Default storage class
// STANDARD_IA - Infrequent Access
// ONEZONE_IA - Single-zone infrequent access
// INTELLIGENT_TIERING - Automatically optimizes storage
// GLACIER - Archive storage
// DEEP_ARCHIVE - Long-term archive storage
```
### Object Tagging
#### Adding and Managing Tags
```java
public void addTags(S3Client s3Client, String bucketName, String key,
Map<String, String> tags) {
Tagging tagging = Tagging.builder()
.tagSet(tags.entrySet().stream()
.map(entry -> Tag.builder()
.key(entry.getKey())
.value(entry.getValue())
.build())
.collect(Collectors.toList()))
.build();
PutObjectTaggingRequest request = PutObjectTaggingRequest.builder()
.bucket(bucketName)
.key(key)
.tagging(tagging)
.build();
s3Client.putObjectTagging(request);
}
public Map<String, String> getTags(S3Client s3Client, String bucketName, String key) {
GetObjectTaggingRequest request = GetObjectTaggingRequest.builder()
.bucket(bucketName)
.key(key)
.build();
GetObjectTaggingResponse response = s3Client.getObjectTagging(request);
return response.tagSet().stream()
.collect(Collectors.toMap(Tag::key, Tag::value));
}
```
### Advanced Copy Operations
#### Server-Side Copy with Metadata
```java
public void copyWithMetadata(S3Client s3Client, String sourceBucket, String sourceKey,
String destBucket, String destKey,
Map<String, String> metadata) {
CopyObjectRequest request = CopyObjectRequest.builder()
.sourceBucket(sourceBucket)
.sourceKey(sourceKey)
.destinationBucket(destBucket)
.destinationKey(destKey)
.metadata(metadata)
.metadataDirective(MetadataDirective.REPLACE)
.build();
s3Client.copyObject(request);
}
```
## Error Handling Patterns
### Retry Mechanisms
```java
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.backoff.FixedRetryBackoff;
import software.amazon.awssdk.core.retry.conditions.RetryCondition;
public S3Client createS3ClientWithRetry() {
return S3Client.builder()
.overrideConfiguration(b -> b
.retryPolicy(RetryPolicy.builder()
.numRetries(3)
.retryBackoffStrategy(FixedRetryBackoff.create(
Duration.ofSeconds(1), 3))
.retryCondition(RetryCondition.defaultRetryCondition())
.build()))
.build();
}
```
### Throttling Handling
```java
public void handleThrottling(S3Client s3Client, String bucketName, String key) {
try {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
s3Client.putObject(request, RequestBody.fromString("test"));
} catch (S3Exception e) {
if (e.statusCode() == 429) {
// Too Many Requests - implement backoff
try {
Thread.sleep(1000);
// Retry logic here
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
}
throw e;
}
}
```
## Performance Optimization
### Batch Operations
#### Batch Delete Objects
```java
public void batchDeleteObjects(S3Client s3Client, String bucketName,
List<String> keys) {
int batchSize = 1000; // S3 limit for batch operations
int totalBatches = (int) Math.ceil((double) keys.size() / batchSize);
for (int i = 0; i < totalBatches; i++) {
List<String> batchKeys = keys.subList(
i * batchSize,
Math.min((i + 1) * batchSize, keys.size()));
List<ObjectIdentifier> objectIdentifiers = batchKeys.stream()
.map(key -> ObjectIdentifier.builder().key(key).build())
.collect(Collectors.toList());
Delete delete = Delete.builder()
.objects(objectIdentifiers)
.build();
DeleteObjectsRequest request = DeleteObjectsRequest.builder()
.bucket(bucketName)
.delete(delete)
.build();
s3Client.deleteObjects(request);
}
}
```
### Parallel Uploads
```java
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public void parallelUploads(S3Client s3Client, String bucketName,
List<String> keys, ExecutorService executor) {
List<CompletableFuture<Void>> futures = new ArrayList<>();
for (String key : keys) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
s3Client.putObject(request, RequestBody.fromString("data"));
}, executor);
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
}
```
## Security Considerations
### Access Control
#### Setting Object ACLs
```java
public void setObjectAcl(S3Client s3Client, String bucketName, String key,
CannedAccessControlList acl) {
PutObjectAclRequest request = PutObjectAclRequest.builder()
.bucket(bucketName)
.key(key)
.acl(acl)
.build();
s3Client.putObjectAcl(request);
}
// ACL options:
// private, public-read, public-read-write, authenticated-read,
// aws-exec-read, bucket-owner-read, bucket-owner-full-control
```
#### Encryption
```java
public void encryptedUpload(S3Client s3Client, String bucketName, String key,
String filePath, String kmsKeyId) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.serverSideEncryption(ServerSideEncryption.AWS_KMS)
.ssekmsKeyId(kmsKeyId)
.build();
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
}
```
## Monitoring and Logging
#### Upload Completion Events
```java
public void uploadWithMonitoring(S3Client s3Client, String bucketName, String key,
String filePath) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
Response<PutObjectResponse> response = s3Client.putObject(request,
RequestBody.fromFile(Paths.get(filePath)));
System.out.println("Upload completed with ETag: " +
response.response().eTag());
}
```
## Integration Patterns
### Event Notifications
```java
public void setupEventNotifications(S3Client s3Client, String bucketName) {
NotificationConfiguration configuration = NotificationConfiguration.builder()
.topicConfigurations(TopicConfiguration.builder()
.topicArn("arn:aws:sns:us-east-1:123456789012:my-topic")
.events(Event.OBJECT_CREATED_PUT, Event.OBJECT_CREATED_POST)
.build())
.build();
PutBucketNotificationConfigurationRequest request =
PutBucketNotificationConfigurationRequest.builder()
.bucket(bucketName)
.notificationConfiguration(configuration)
.build();
s3Client.putBucketNotificationConfiguration(request);
}
```

View File

@@ -0,0 +1,668 @@
# S3 Spring Boot Integration Reference
## Advanced Spring Boot Configuration
### Multi-Environment Configuration
```java
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
@Configuration
@EnableConfigurationProperties(S3Properties.class)
public class S3Configuration {
private final S3Properties properties;
public S3Configuration(S3Properties properties) {
this.properties = properties;
}
@Bean
@ConditionalOnProperty(name = "s3.client.async.enabled", havingValue = "true")
public S3AsyncClient s3AsyncClient() {
return S3AsyncClient.builder()
.region(Region.of(properties.getRegion()))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
properties.getAccessKey(),
properties.getSecretKey())))
.endpointOverride(URI.create(properties.getEndpoint()))
.build();
}
@Bean
@ConditionalOnProperty(name = "s3.client.sync.enabled", havingValue = "true", matchIfMissing = true)
public S3Client s3Client() {
return S3Client.builder()
.region(Region.of(properties.getRegion()))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
properties.getAccessKey(),
properties.getSecretKey())))
.endpointOverride(URI.create(properties.getEndpoint()))
.build();
}
@Bean
@ConditionalOnProperty(name = "s3.transfer-manager.enabled", havingValue = "true")
public S3TransferManager s3TransferManager() {
return S3TransferManager.builder()
.s3Client(s3Client())
.build();
}
@Bean
@ConditionalOnProperty(name = "s3.presigner.enabled", havingValue = "true")
public S3Presigner s3Presigner() {
return S3Presigner.builder()
.region(Region.of(properties.getRegion()))
.build();
}
}
@ConfigurationProperties(prefix = "s3")
@Data
public class S3Properties {
private String accessKey;
private String secretKey;
private String region = "us-east-1";
private String endpoint = null;
private boolean syncEnabled = true;
private boolean asyncEnabled = false;
private boolean transferManagerEnabled = false;
private boolean presignerEnabled = false;
private int maxConnections = 100;
private int connectionTimeout = 5000;
private int socketTimeout = 30000;
private String defaultBucket;
}
```
### Profile-Specific Configuration
```properties
# application-dev.properties
s3.access-key=${AWS_ACCESS_KEY}
s3.secret-key=${AWS_SECRET_KEY}
s3.region=us-east-1
s3.endpoint=http://localhost:4566
s3.async-enabled=true
s3.transfer-manager-enabled=true
# application-prod.properties
s3.access-key=${AWS_ACCESS_KEY}
s3.secret-key=${AWS_SECRET_KEY}
s3.region=us-east-1
s3.async-enabled=true
s3.presigner-enabled=true
```
## Advanced Service Patterns
### Generic S3 Service Template
```java
import software.amazon.awssdk.services.s3.model.*;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.nio.file.*;
import java.util.*;
import java.util.stream.Collectors;
@Service
@RequiredArgsConstructor
public class S3Service {
private final S3Client s3Client;
private final S3AsyncClient s3AsyncClient;
private final S3TransferManager transferManager;
private final S3Properties s3Properties;
// Basic Operations
public Mono<Void> uploadObjectAsync(String key, byte[] data) {
return Mono.fromFuture(() -> {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(s3Properties.getDefaultBucket())
.key(key)
.build();
return s3AsyncClient.putObject(request,
RequestBody.fromBytes(data)).future();
});
}
public Mono<byte[]> downloadObjectAsync(String key) {
return Mono.fromFuture(() -> {
GetObjectRequest request = GetObjectRequest.builder()
.bucket(s3Properties.getDefaultBucket())
.key(key)
.build();
return s3AsyncClient.getObject(request)
.thenApply(response -> {
try {
return response.readAllBytes();
} catch (IOException e) {
throw new RuntimeException("Failed to read S3 object", e);
}
});
});
}
// Advanced Operations
public Mono<UploadResult> uploadWithMetadata(String key,
Path file,
Map<String, String> metadata) {
return Mono.fromFuture(() -> {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(s3Properties.getDefaultBucket())
.key(key)
.metadata(metadata)
.contentType(getContentType(file))
.build();
return s3AsyncClient.putObject(request, RequestBody.fromFile(file))
.thenApply(response -> new UploadResult(key, response.eTag()));
});
}
public Flux<S3Object> listObjectsWithPrefix(String prefix) {
ListObjectsV2Request request = ListObjectsV2Request.builder()
.bucket(s3Properties.getDefaultBucket())
.prefix(prefix)
.build();
return Flux.create(sink -> {
s3Client.listObjectsV2Paginator(request)
.contents()
.forEach(sink::next);
sink.complete();
});
}
public Mono<Void> batchDelete(List<String> keys) {
return Mono.fromFuture(() -> {
List<ObjectIdentifier> objectIdentifiers = keys.stream()
.map(key -> ObjectIdentifier.builder().key(key).build())
.collect(Collectors.toList());
Delete delete = Delete.builder()
.objects(objectIdentifiers)
.build();
DeleteObjectsRequest request = DeleteObjectsRequest.builder()
.bucket(s3Properties.getDefaultBucket())
.delete(delete)
.build();
return s3AsyncClient.deleteObjects(request).future();
});
}
// Transfer Manager Operations
public Mono<UploadResult> uploadWithTransferManager(String key, Path file) {
return Mono.fromFuture(() -> {
UploadFileRequest request = UploadFileRequest.builder()
.putObjectRequest(req -> req
.bucket(s3Properties.getDefaultBucket())
.key(key))
.source(file)
.build();
return transferManager.uploadFile(request)
.completionFuture()
.thenApply(result -> new UploadResult(key, result.response().eTag()));
});
}
public Mono<DownloadResult> downloadWithTransferManager(String key, Path destination) {
return Mono.fromFuture(() -> {
DownloadFileRequest request = DownloadFileRequest.builder()
.getObjectRequest(req -> req
.bucket(s3Properties.getDefaultBucket())
.key(key))
.destination(destination)
.build();
return transferManager.downloadFile(request)
.completionFuture()
.thenApply(result -> new DownloadResult(destination, result.response().contentLength()));
});
}
// Utility Methods
private String getContentType(Path file) {
try {
return Files.probeContentType(file);
} catch (IOException e) {
return "application/octet-stream";
}
}
// Records for Results
public record UploadResult(String key, String eTag) {}
public record DownloadResult(Path path, long size) {}
}
```
### Event-Driven S3 Operations
```java
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Mono;
@Service
@RequiredArgsConstructor
public class S3EventService {
private final S3Service s3Service;
private final ApplicationEventPublisher eventPublisher;
@Transactional
public Mono<UploadResult> uploadAndPublishEvent(String key, Path file) {
return s3Service.uploadWithTransferManager(key, file)
.doOnSuccess(result -> {
eventPublisher.publishEvent(new S3UploadEvent(key, result.eTag()));
})
.doOnError(error -> {
eventPublisher.publishEvent(new S3UploadFailedEvent(key, error.getMessage()));
});
}
public Mono<String> generatePresignedUrl(String key) {
return s3Service.downloadObjectAsync(key)
.flatMap(data -> {
return Mono.fromCallable(() -> {
S3Presigner presigner = S3Presigner.create();
try {
GetObjectRequest request = GetObjectRequest.builder()
.bucket(s3Service.getDefaultBucket())
.key(key)
.build();
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
.signatureDuration(Duration.ofMinutes(10))
.getObjectRequest(request)
.build();
return presigner.presignGetObject(presignRequest)
.url()
.toString();
} finally {
presigner.close();
}
});
});
}
}
// Event Classes
public class S3UploadEvent extends ApplicationEvent {
private final String key;
private final String eTag;
public S3UploadEvent(String key, String eTag) {
super(key);
this.key = key;
this.eTag = eTag;
}
public String getKey() { return key; }
public String getETag() { return eTag; }
}
public class S3UploadFailedEvent extends ApplicationEvent {
private final String key;
private final String errorMessage;
public S3UploadFailedEvent(String key, String errorMessage) {
super(key);
this.key = key;
this.errorMessage = errorMessage;
}
public String getKey() { return key; }
public String getErrorMessage() { return errorMessage; }
}
```
### Retry and Error Handling
```java
import org.springframework.retry.annotation.*;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.model.*;
@Service
@RequiredArgsConstructor
public class ResilientS3Service {
private final S3Client s3Client;
private final RetryTemplate retryTemplate;
@Retryable(value = {S3Exception.class, SdkClientException.class},
maxAttempts = 3,
backoff = @Backoff(delay = 1000, multiplier = 2))
public Mono<PutObjectResponse> uploadWithRetry(String key, Path file) {
return Mono.fromCallable(() -> {
PutObjectRequest request = PutObjectRequest.builder()
.bucket("my-bucket")
.key(key)
.build();
return s3Client.putObject(request, RequestBody.fromFile(file));
});
}
@Recover
public Mono<PutObjectResponse> uploadRecover(S3Exception e, String key, Path file) {
// Log the failure and potentially send notification
System.err.println("Upload failed after retries: " + e.getMessage());
return Mono.error(new S3UploadException("Upload failed after retries", e));
}
@Retryable(value = {S3Exception.class},
maxAttempts = 5,
backoff = @Backoff(delay = 2000, multiplier = 2))
public Mono<Void> copyObjectWithRetry(String sourceKey, String destinationKey) {
return Mono.fromFuture(() -> {
CopyObjectRequest request = CopyObjectRequest.builder()
.sourceBucket("source-bucket")
.sourceKey(sourceKey)
.destinationBucket("destination-bucket")
.destinationKey(destinationKey)
.build();
return s3AsyncClient.copyObject(request).future();
});
}
}
public class S3UploadException extends RuntimeException {
public S3UploadException(String message, Throwable cause) {
super(message, cause);
}
}
```
## Testing Integration
### Test Configuration with LocalStack
```java
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.test.context.ActiveProfiles;
import org.testcontainers.utility.DockerImageName;
@Testcontainers
@ActiveProfiles("test")
@TestConfiguration
public class S3TestConfig {
@Container
static LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.0"))
.withServices(LocalStackContainer.Service.S3)
.withEnv("DEFAULT_REGION", "us-east-1");
@Bean
public S3Client testS3Client() {
return S3Client.builder()
.region(Region.US_EAST_1)
.endpointOverride(localstack.getEndpointOverride(LocalStackContainer.Service.S3))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
localstack.getAccessKey(),
localstack.getSecretKey())))
.build();
}
@Bean
public S3AsyncClient testS3AsyncClient() {
return S3AsyncClient.builder()
.region(Region.US_EAST_1)
.endpointOverride(localstack.getEndpointOverride(LocalStackContainer.Service.S3))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
localstack.getAccessKey(),
localstack.getSecretKey())))
.build();
}
}
```
### Unit Testing with Mocks
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.*;
import reactor.core.publisher.Mono;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class S3ServiceTest {
@Mock
private S3Client s3Client;
@InjectMocks
private S3Service s3Service;
@Test
void uploadObjectAsync_ShouldReturnUploadResult() {
// Arrange
String key = "test-key";
byte[] data = "test-content".getBytes();
String eTag = "12345";
PutObjectResponse response = PutObjectResponse.builder()
.eTag(eTag)
.build();
when(s3Client.putObject(any(PutObjectRequest.class), any()))
.thenReturn(response);
// Act
Mono<UploadResult> result = s3Service.uploadObjectAsync(key, data);
// Assert
result.subscribe(uploadResult -> {
assertEquals(key, uploadResult.key());
assertEquals(eTag, uploadResult.eTag());
});
}
@Test
void listObjectsWithPrefix_ShouldReturnObjectList() {
// Arrange
String prefix = "documents/";
S3Object object1 = S3Object.builder().key("documents/file1.txt").build();
S3Object object2 = S3Object.builder().key("documents/file2.txt").build();
ListObjectsV2Response response = ListObjectsV2Response.builder()
.contents(object1, object2)
.build();
when(s3Client.listObjectsV2(any(ListObjectsV2Request.class)))
.thenReturn(response);
// Act
Flux<S3Object> result = s3Service.listObjectsWithPrefix(prefix);
// Assert
result.collectList()
.subscribe(objects -> {
assertEquals(2, objects.size());
assertTrue(objects.stream().allMatch(obj -> obj.key().startsWith(prefix)));
});
}
}
```
### Integration Testing
```java
import org.junit.jupiter.api.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
import software.amazon.awssdk.services.s3.model.*;
import java.nio.file.*;
import java.util.Map;
@SpringBootTest
@ActiveProfiles("test")
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class S3IntegrationTest {
@Autowired
private S3Service s3Service;
private static final String TEST_BUCKET = "test-bucket";
private static final String TEST_FILE = "test-document.txt";
@BeforeAll
static void setup() throws Exception {
// Create test file
Files.write(Paths.get(TEST_FILE), "Test content".getBytes());
}
@Test
@Order(1)
void uploadFile_ShouldSucceed() {
// Act & Assert
s3Service.uploadWithMetadata(TEST_FILE, Paths.get(TEST_FILE),
Map.of("author", "test", "type", "document"))
.as(StepVerifier::create)
.expectNextMatches(result ->
result.key().equals(TEST_FILE) && result.eTag() != null)
.verifyComplete();
}
@Test
@Order(2)
void downloadFile_ShouldReturnContent() {
// Act & Assert
s3Service.downloadObjectAsync(TEST_FILE)
.as(StepVerifier::create)
.expectNext("Test content".getBytes())
.verifyComplete();
}
@Test
@Order(3)
void listObjects_ShouldReturnFiles() {
// Act & Assert
s3Service.listObjectsWithPrefix("")
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
}
@AfterAll
static void cleanup() {
try {
Files.deleteIfExists(Paths.get(TEST_FILE));
} catch (IOException e) {
// Ignore
}
}
}
```
## Advanced Configuration Patterns
### Environment-Specific Configuration
```java
import org.springframework.boot.autoconfigure.condition.*;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import software.amazon.awssdk.auth.credentials.*;
@Configuration
public class EnvironmentAwareS3Config {
@Bean
@ConditionalOnMissingBean
public AwsCredentialsProvider awsCredentialsProvider(S3Properties properties) {
if (properties.getAccessKey() != null && properties.getSecretKey() != null) {
return StaticCredentialsProvider.create(
AwsBasicCredentials.create(
properties.getAccessKey(),
properties.getSecretKey()));
}
return DefaultCredentialsProvider.create();
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(name = "s3.region")
public Region region(S3Properties properties) {
return Region.of(properties.getRegion());
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(name = "s3.endpoint")
public String endpoint(S3Properties properties) {
return properties.getEndpoint();
}
}
```
### Multi-Bucket Support
```java
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.Map;
@Service
@RequiredArgsConstructor
public class MultiBucketS3Service {
private final Map<String, S3Client> bucketClients = new HashMap<>();
private final S3Client defaultS3Client;
@Autowired
public MultiBucketS3Service(S3Client defaultS3Client) {
this.defaultS3Client = defaultS3Client;
}
public S3Client getClientForBucket(String bucketName) {
return bucketClients.computeIfAbsent(bucketName, name ->
S3Client.builder()
.region(defaultS3Client.config().region())
.credentialsProvider(defaultS3Client.config().credentialsProvider())
.build());
}
public Mono<UploadResult> uploadToBucket(String bucketName, String key, Path file) {
S3Client client = getClientForBucket(bucketName);
// Upload implementation using the specific client
return Mono.empty(); // Implementation
}
}
```

View File

@@ -0,0 +1,473 @@
# S3 Transfer Patterns Reference
## S3 Transfer Manager Advanced Patterns
### Configuration and Optimization
#### Custom Transfer Manager Configuration
```java
import software.amazon.awssdk.transfer.s3.S3TransferManager;
import software.amazon.awssdk.transfer.s3.model.UploadFileRequest;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import java.time.Duration;
public S3TransferManager createOptimizedTransferManager(S3Client s3Client) {
return S3TransferManager.builder()
.s3Client(s3Client)
.storageProvider(ApacheHttpClient.builder()
.maxConnections(200)
.connectionTimeout(Duration.ofSeconds(5))
.socketTimeout(Duration.ofSeconds(60))
.build())
.build();
}
```
#### Parallel Upload Configuration
```java
public void configureParallelUploads() {
S3TransferManager transferManager = S3TransferManager.create();
FileUpload upload = transferManager.uploadFile(
UploadFileRequest.builder()
.putObjectRequest(req -> req
.bucket("my-bucket")
.key("large-file.bin"))
.source(Paths.get("large-file.bin"))
.build());
// Track upload progress
upload.progressFuture().thenAccept(progress -> {
System.out.println("Upload progress: " + progress.progressPercent());
});
// Handle completion
upload.completionFuture().thenAccept(result -> {
System.out.println("Upload completed with ETag: " +
result.response().eTag());
});
}
```
### Advanced Upload Patterns
#### Multipart Upload with Progress Monitoring
```java
public void multipartUploadWithProgress(S3Client s3Client, String bucketName,
String key, String filePath) {
int partSize = 5 * 1024 * 1024; // 5 MB parts
File file = new File(filePath);
CreateMultipartUploadRequest createRequest = CreateMultipartUploadRequest.builder()
.bucket(bucketName)
.key(key)
.build();
CreateMultipartUploadResponse createResponse = s3Client.createMultipartUpload(createRequest);
String uploadId = createResponse.uploadId();
List<CompletedPart> completedParts = new ArrayList<>();
long uploadedBytes = 0;
long totalBytes = file.length();
try (FileInputStream fis = new FileInputStream(file)) {
byte[] buffer = new byte[partSize];
int partNumber = 1;
while (true) {
int bytesRead = fis.read(buffer);
if (bytesRead == -1) break;
byte[] partData = Arrays.copyOf(buffer, bytesRead);
UploadPartRequest uploadRequest = UploadPartRequest.builder()
.bucket(bucketName)
.key(key)
.uploadId(uploadId)
.partNumber(partNumber)
.build();
UploadPartResponse uploadResponse = s3Client.uploadPart(
uploadRequest, RequestBody.fromBytes(partData));
completedParts.add(CompletedPart.builder()
.partNumber(partNumber)
.eTag(uploadResponse.eTag())
.build());
uploadedBytes += bytesRead;
partNumber++;
// Log progress
double progress = (double) uploadedBytes / totalBytes * 100;
System.out.printf("Upload progress: %.2f%%%n", progress);
}
CompleteMultipartUploadRequest completeRequest =
CompleteMultipartUploadRequest.builder()
.bucket(bucketName)
.key(key)
.uploadId(uploadId)
.multipartUpload(CompletedMultipartUpload.builder()
.parts(completedParts)
.build())
.build();
s3Client.completeMultipartUpload(completeRequest);
} catch (Exception e) {
// Abort on failure
AbortMultipartUploadRequest abortRequest =
AbortMultipartUploadRequest.builder()
.bucket(bucketName)
.key(key)
.uploadId(uploadId)
.build();
s3Client.abortMultipartUpload(abortRequest);
throw new RuntimeException("Multipart upload failed", e);
}
}
```
#### Resume Interrupted Uploads
```java
public void resumeUpload(S3Client s3Client, String bucketName, String key,
String filePath, String existingUploadId) {
ListMultipartUploadsRequest listRequest = ListMultipartUploadsRequest.builder()
.bucket(bucketName)
.prefix(key)
.build();
ListMultipartUploadsResponse listResponse = s3Client.listMultipartUploads(listRequest);
// Check if upload already exists
boolean uploadExists = listResponse.uploads().stream()
.anyMatch(upload -> upload.key().equals(key) &&
upload.uploadId().equals(existingUploadId));
if (uploadExists) {
// Resume existing upload
continueExistingUpload(s3Client, bucketName, key, existingUploadId, filePath);
} else {
// Start new upload
multipartUploadWithProgress(s3Client, bucketName, key, filePath);
}
}
private void continueExistingUpload(S3Client s3Client, String bucketName,
String key, String uploadId, String filePath) {
// List already uploaded parts
ListPartsRequest listPartsRequest = ListPartsRequest.builder()
.bucket(bucketName)
.key(key)
.uploadId(uploadId)
.build();
ListPartsResponse listPartsResponse = s3Client.listParts(listPartsRequest);
List<CompletedPart> completedParts = listPartsResponse.parts().stream()
.map(part -> CompletedPart.builder()
.partNumber(part.partNumber())
.eTag(part.eTag())
.build())
.collect(Collectors.toList());
// Upload remaining parts
// ... implementation of remaining parts upload
}
```
### Advanced Download Patterns
#### Partial File Download
```java
public void downloadPartialFile(S3Client s3Client, String bucketName, String key,
String destPath, long startByte, long endByte) {
GetObjectRequest request = GetObjectRequest.builder()
.bucket(bucketName)
.key(key)
.range("bytes=" + startByte + "-" + endByte)
.build();
try (ResponseInputStream<GetObjectResponse> response = s3Client.getObject(request);
OutputStream outputStream = new FileOutputStream(destPath)) {
response.transferTo(outputStream);
System.out.println("Partial download completed: " +
(endByte - startByte + 1) + " bytes");
}
}
```
#### Parallel Downloads
```java
import java.util.concurrent.*;
import java.util.stream.*;
public void parallelDownloads(S3Client s3Client, String bucketName,
String key, String destPath, int chunkCount) {
long fileSize = getFileSize(s3Client, bucketName, key);
long chunkSize = fileSize / chunkCount;
ExecutorService executor = Executors.newFixedThreadPool(chunkCount);
List<Future<Void>> futures = new ArrayList<>();
for (int i = 0; i < chunkCount; i++) {
long start = i * chunkSize;
long end = (i == chunkCount - 1) ? fileSize - 1 : start + chunkSize - 1;
Future<Void> future = executor.submit(() -> {
downloadPartialFile(s3Client, bucketName, key,
destPath + ".part" + i, start, end);
return null;
});
futures.add(future);
}
// Wait for all downloads to complete
for (Future<Void> future : futures) {
try {
future.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Download failed", e);
}
}
// Combine chunks
combineChunks(destPath, chunkCount);
executor.shutdown();
}
private void combineChunks(String baseName, int chunkCount) throws IOException {
try (OutputStream outputStream = new FileOutputStream(baseName)) {
for (int i = 0; i < chunkCount; i++) {
String chunkFile = baseName + ".part" + i;
try (InputStream inputStream = new FileInputStream(chunkFile)) {
inputStream.transferTo(outputStream);
}
new File(chunkFile).delete();
}
}
}
```
### Error Handling and Retry
#### Upload with Exponential Backoff
```java
import software.amazon.awssdk.core.retry.conditions.*;
import software.amazon.awssdk.core.retry.*;
import software.amazon.awssdk.core.retry.backoff.*;
public void resilientUpload(S3Client s3Client, String bucketName, String key,
String filePath) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
// Configure retry policy
S3Client retryS3Client = S3Client.builder()
.overrideConfiguration(b -> b
.retryPolicy(RetryPolicy.builder()
.numRetries(5)
.retryBackoffStrategy(
ExponentialRetryBackoff.builder()
.baseDelay(Duration.ofSeconds(1))
.maxBackoffTime(Duration.ofSeconds(30))
.build())
.retryCondition(
RetryCondition.or(
RetryCondition.defaultRetryCondition(),
RetryCondition.create(response ->
response.httpResponse().is5xxServerError()))
)
.build()))
.build();
retryS3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
}
```
#### Upload with Checkpoint
```java
import java.nio.file.*;
public void uploadWithCheckpoint(S3Client s3Client, String bucketName,
String key, String filePath) {
String checkpointFile = filePath + ".checkpoint";
Path checkpointPath = Paths.get(checkpointFile);
long startPos = 0;
if (Files.exists(checkpointPath)) {
// Read checkpoint
try {
startPos = Long.parseLong(Files.readString(checkpointPath));
} catch (IOException e) {
throw new RuntimeException("Failed to read checkpoint", e);
}
}
if (startPos > 0) {
// Resume upload
continueUploadFromCheckpoint(s3Client, bucketName, key, filePath, startPos);
} else {
// Start new upload
startNewUpload(s3Client, bucketName, key, filePath);
}
// Update checkpoint
long endPos = new File(filePath).length();
try {
Files.writeString(checkpointPath, String.valueOf(endPos));
} catch (IOException e) {
throw new RuntimeException("Failed to write checkpoint", e);
}
}
private void continueUploadFromCheckpoint(S3Client s3Client, String bucketName,
String key, String filePath, long startPos) {
// Implement resume logic
}
private void startNewUpload(S3Client s3Client, String bucketName,
String key, String filePath) {
// Implement initial upload logic
}
```
### Performance Tuning
#### Buffer Configuration
```java
public S3Client configureLargeBuffer() {
return S3Client.builder()
.overrideConfiguration(b -> b
.apiCallAttemptTimeout(Duration.ofMinutes(5))
.apiCallTimeout(Duration.ofMinutes(10)))
.build();
}
public S3TransferManager configureHighThroughput() {
return S3TransferManager.builder()
.multipartUploadThreshold(8 * 1024 * 1024) // 8 MB
.multipartUploadPartSize(10 * 1024 * 1024) // 10 MB
.build();
}
```
#### Network Optimization
```java
public S3Client createOptimizedS3Client() {
return S3Client.builder()
.httpClientBuilder(ApacheHttpClient.builder()
.maxConnections(200)
.connectionPoolStrategy(ConnectionPoolStrategy.defaultStrategy())
.socketTimeout(Duration.ofSeconds(30))
.connectionTimeout(Duration.ofSeconds(5))
.connectionAcquisitionTimeout(Duration.ofSeconds(30))
.build())
.region(Region.US_EAST_1)
.build();
}
```
### Monitoring and Metrics
#### Upload Progress Tracking
```java
public void uploadWithProgressTracking(S3Client s3Client, String bucketName,
String key, String filePath) {
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
// Create progress listener
software.amazon.awssdk.core.ProgressListener progressListener =
progressEvent -> {
System.out.println("Transferred: " +
progressEvent.transferredBytes() + " bytes");
System.out.println("Progress: " +
progressEvent.progressPercent() + "%");
};
Response<PutObjectResponse> response = s3Client.putObject(
request,
RequestBody.fromFile(Paths.get(filePath)),
software.amazon.awssdk.core.sync.RequestBody.fromFile(Paths.get(filePath))
.contentLength(new File(filePath).length()),
progressListener);
System.out.println("Upload complete. ETag: " +
response.response().eTag());
}
```
#### Throughput Measurement
```java
public void measureUploadThroughput(S3Client s3Client, String bucketName,
String key, String filePath) {
long startTime = System.currentTimeMillis();
long fileSize = new File(filePath).length();
PutObjectRequest request = PutObjectRequest.builder()
.bucket(bucketName)
.key(key)
.build();
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
long endTime = System.currentTimeMillis();
long duration = endTime - startTime;
double throughput = (fileSize * 1000.0) / duration / (1024 * 1024); // MB/s
System.out.printf("Upload throughput: %.2f MB/s%n", throughput);
}
```
## Testing and Validation
#### Upload Validation
```java
public void validateUpload(S3Client s3Client, String bucketName, String key,
String localFilePath) {
// Download file from S3
byte[] s3Content = downloadObject(s3Client, bucketName, key);
// Read local file
byte[] localContent = Files.readAllBytes(Paths.get(localFilePath));
// Validate content matches
if (!Arrays.equals(s3Content, localContent)) {
throw new RuntimeException("Upload validation failed: content mismatch");
}
// Verify file size
long s3Size = s3Content.length;
long localSize = localContent.length;
if (s3Size != localSize) {
throw new RuntimeException("Upload validation failed: size mismatch");
}
System.out.println("Upload validation successful");
}
```

View File

@@ -0,0 +1,342 @@
---
name: aws-sdk-java-v2-secrets-manager
description: AWS Secrets Manager patterns using AWS SDK for Java 2.x. Use when storing/retrieving secrets (passwords, API keys, tokens), rotating secrets automatically, managing database credentials, or integrating secret management into Spring Boot applications.
category: aws
tags: [aws, secrets-manager, java, sdk, security, credentials, spring-boot]
version: 1.1.0
allowed-tools: Read, Write, Glob, Bash
---
# AWS SDK for Java 2.x - AWS Secrets Manager
## When to Use
Use this skill when:
- Storing and retrieving application secrets programmatically
- Managing database credentials securely without hardcoding
- Implementing automatic secret rotation with Lambda functions
- Integrating AWS Secrets Manager with Spring Boot applications
- Setting up secret caching for improved performance
- Creating secure configuration management systems
- Working with multi-region secret deployments
- Implementing audit logging for secret access
## Dependencies
### Maven
```xml
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>secretsmanager</artifactId>
</dependency>
<!-- For secret caching (recommended for production) -->
<dependency>
<groupId>com.amazonaws.secretsmanager</groupId>
<artifactId>aws-secretsmanager-caching-java</artifactId>
<version>2.0.0</version> // Use the sdk v2 compatible version
</dependency>
```
### Gradle
```gradle
implementation 'software.amazon.awssdk:secretsmanager'
implementation 'com.amazonaws.secretsmanager:aws-secretsmanager-caching-java:2.0.0
```
## Quick Start
### Basic Client Setup
```java
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
SecretsManagerClient secretsClient = SecretsManagerClient.builder()
.region(Region.US_EAST_1)
.build();
```
### Store a Secret
```java
import software.amazon.awssdk.services.secretsmanager.model.*;
public String createSecret(String secretName, String secretValue) {
CreateSecretRequest request = CreateSecretRequest.builder()
.name(secretName)
.secretString(secretValue)
.build();
CreateSecretResponse response = secretsClient.createSecret(request);
return response.arn();
}
```
### Retrieve a Secret
```java
public String getSecretValue(String secretName) {
GetSecretValueRequest request = GetSecretValueRequest.builder()
.secretId(secretName)
.build();
GetSecretValueResponse response = secretsClient.getSecretValue(request);
return response.secretString();
}
```
## Core Operations
### Secret Management
- Create secrets with `createSecret()`
- Retrieve secrets with `getSecretValue()`
- Update secrets with `updateSecret()`
- Delete secrets with `deleteSecret()`
- List secrets with `listSecrets()`
- Restore deleted secrets with `restoreSecret()`
### Secret Versioning
- Access specific versions by `versionId`
- Access versions by stage (e.g., "AWSCURRENT", "AWSPENDING")
- Automatically manage version history
### Secret Rotation
- Configure automatic rotation schedules
- Lambda-based rotation functions
- Immediate rotation with `rotateSecret()`
## Caching for Performance
### Setup Cache
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
public class CachedSecrets {
private final SecretCache cache;
public CachedSecrets(SecretsManagerClient secretsClient) {
this.cache = new SecretCache(secretsClient);
}
public String getCachedSecret(String secretName) {
return cache.getSecretString(secretName);
}
}
```
### Cache Configuration
```java
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
SecretCacheConfiguration config = SecretCacheConfiguration.builder()
.maxCacheSize(1000)
.cacheItemTTL(3600000) // 1 hour
.build();
```
## Spring Boot Integration
### Configuration
```java
@Configuration
public class SecretsManagerConfiguration {
@Bean
public SecretsManagerClient secretsManagerClient() {
return SecretsManagerClient.builder()
.region(Region.of(region))
.build();
}
@Bean
public SecretCache secretCache(SecretsManagerClient secretsClient) {
return new SecretCache(secretsClient);
}
}
```
### Service Layer
```java
@Service
public class SecretsService {
private final SecretCache cache;
public SecretsService(SecretCache cache) {
this.cache = cache;
}
public <T> T getSecretAsObject(String secretName, Class<T> type) {
String secretJson = cache.getSecretString(secretName);
return objectMapper.readValue(secretJson, type);
}
}
```
### Database Configuration
```java
@Configuration
public class DatabaseConfiguration {
@Bean
public DataSource dataSource(SecretsService secretsService) {
Map<String, String> credentials = secretsService.getSecretAsMap(
"prod/database/credentials");
HikariConfig config = new HikariConfig();
config.setJdbcUrl(credentials.get("url"));
config.setUsername(credentials.get("username"));
config.setPassword(credentials.get("password"));
return new HikariDataSource(config);
}
}
```
## Examples
### Database Credentials Structure
```json
{
"engine": "postgres",
"host": "mydb.us-east-1.rds.amazonaws.com",
"port": 5432,
"username": "admin",
"password": "MySecurePassword123!",
"dbname": "mydatabase",
"url": "jdbc:postgresql://mydb.us-east-1.rds.amazonaws.com:5432/mydatabase"
}
```
### API Keys Structure
```json
{
"api_key": "abcd1234-5678-90ef-ghij-klmnopqrstuv",
"api_secret": "MySecretKey123!",
"api_token": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
```
## Common Patterns
### Error Handling
```java
try {
String secret = secretsClient.getSecretValue(request).secretString();
} catch (SecretsManagerException e) {
if (e.awsErrorDetails().errorCode().equals("ResourceNotFoundException")) {
// Handle missing secret
}
throw e;
}
```
### Batch Operations
```java
List<String> secretNames = List.of("secret1", "secret2", "secret3");
Map<String, String> secrets = secretNames.stream()
.collect(Collectors.toMap(
Function.identity(),
name -> cache.getSecretString(name)
));
```
## Best Practices
1. **Secret Management**:
- Use descriptive secret names with hierarchical structure
- Implement versioning and rotation
- Add tags for organization and billing
2. **Caching**:
- Always use caching in production environments
- Configure appropriate TTL values based on secret sensitivity
- Monitor cache hit rates
3. **Security**:
- Never log secret values
- Use KMS encryption for sensitive secrets
- Implement least privilege IAM policies
- Enable CloudTrail logging
4. **Performance**:
- Reuse SecretsManagerClient instances
- Use async operations when appropriate
- Monitor API throttling limits
5. **Spring Boot Integration**:
- Use `@Value` annotations for secret names
- Implement proper exception handling
- Use configuration properties for secret names
## Testing Strategies
### Unit Testing
```java
@ExtendWith(MockitoExtension.class)
class SecretsServiceTest {
@Mock
private SecretCache cache;
@InjectMocks
private SecretsService secretsService;
@Test
void shouldGetSecret() {
when(cache.getSecretString("test-secret")).thenReturn("secret-value");
String result = secretsService.getSecret("test-secret");
assertEquals("secret-value", result);
}
}
```
### Integration Testing
```java
@SpringBootTest(classes = TestSecretsConfiguration.class)
class SecretsManagerIntegrationTest {
@Autowired
private SecretsService secretsService;
@Test
void shouldRetrieveSecret() {
String secret = secretsService.getSecret("test-secret");
assertNotNull(secret);
}
}
```
## Troubleshooting
### Common Issues
- **Access Denied**: Check IAM permissions
- **Resource Not Found**: Verify secret name and region
- **Decryption Failure**: Ensure KMS key permissions
- **Throttling**: Implement retry logic and backoff
### Debug Commands
```bash
# Check secret exists
aws secretsmanager describe-secret --secret-id my-secret
# List all secrets
aws secretsmanager list-secrets
# Get secret value (CLI)
aws secretsmanager get-secret-value --secret-id my-secret
```
## References
For detailed information and advanced patterns, see:
- [API Reference](./references/api-reference.md) - Complete API documentation
- [Caching Guide](./references/caching-guide.md) - Performance optimization strategies
- [Spring Boot Integration](./references/spring-boot-integration.md) - Complete Spring integration patterns
## Related Skills
- `aws-sdk-java-v2-core` - Core AWS SDK patterns and best practices
- `aws-sdk-java-v2-kms` - KMS encryption and key management
- `spring-boot-dependency-injection` - Spring dependency injection patterns

View File

@@ -0,0 +1,38 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class {{ConfigClass}} {
@Value("${aws.secrets.region}")
private String region;
@Bean
public SecretsManagerClient secretsManagerClient() {
return SecretsManagerClient.builder()
.region(Region.of(region))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(
"${aws.accessKeyId}",
"${aws.secretKey}"
)
))
.build();
}
@Bean
public SecretCache secretCache(SecretsManagerClient secretsClient) {
SecretCacheConfiguration config = SecretCacheConfiguration.builder()
.maxCacheSize(100)
.cacheItemTTL(3600000) // 1 hour
.build();
return new SecretCache(secretsClient, config);
}
}

View File

@@ -0,0 +1,126 @@
# AWS Secrets Manager API Reference
## Overview
AWS Secrets Manager provides a service to enable you to store, manage, and retrieve secrets with API version 2017-10-17.
## Core Classes
### SecretsManagerClient
- **Purpose**: Synchronous client for AWS Secrets Manager
- **Location**: `software.amazon.awssdk.services.secretsmanager.SecretsManagerClient`
- **Builder**: `SecretsManagerClient.builder()`
### SecretsManagerAsyncClient
- **Purpose**: Asynchronous client for AWS Secrets Manager
- **Location**: `software.amazon.awssdk.services.secretsmanager.SecretsManagerAsyncClient`
- **Builder**: `SecretsManagerAsyncClient.builder()`
## Configuration Classes
### SecretsManagerClientBuilder
- Methods:
- `region(Region region)` - Set AWS region
- `credentialsProvider(AwsCredentialsProvider credentialsProvider)` - Set credentials
- `build()` - Create client instance
### SecretsManagerServiceClientConfiguration
- Service client settings and configuration
## Request Types
### CreateSecretRequest
- **Fields**:
- `name(String name)` - Secret name (required)
- `secretString(String secretString)` - Secret value
- `secretBinary(SdkBytes secretBinary)` - Binary secret value
- `description(String description)` - Secret description
- `kmsKeyId(String kmsKeyId)` - KMS key for encryption
- `tags(List<Tag> tags)` - Tags for organization
### GetSecretValueRequest
- **Fields**:
- `secretId(String secretId)` - Secret name or ARN
- `versionId(String versionId)` - Specific version ID
- `versionStage(String versionStage)` - Version stage (e.g., "AWSCURRENT")
### UpdateSecretRequest
- **Fields**:
- `secretId(String secretId)` - Secret name or ARN
- `secretString(String secretString)` - New secret value
- `secretBinary(SdkBytes secretBinary)` - New binary secret value
- `kmsKeyId(String kmsKeyId)` - KMS key for encryption
### DeleteSecretRequest
- **Fields**:
- `secretId(String secretId)` - Secret name or ARN
- `recoveryWindowInDays(Long recoveryWindowInDays)` - Recovery period
- `forceDeleteWithoutRecovery(Boolean forceDeleteWithoutRecovery)` - Immediate deletion
### RotateSecretRequest
- **Fields**:
- `secretId(String secretId)` - Secret name or ARN
- `rotationLambdaArn(String rotationLambdaArn)` - Lambda ARN for rotation
- `rotationRules(RotationRulesType rotationRules)` - Rotation configuration
- `rotationSchedule(RotationScheduleType rotationSchedule)` - Schedule configuration
## Response Types
### CreateSecretResponse
- **Fields**:
- `arn()` - Secret ARN
- `name()` - Secret name
- `versionId()` - Version ID
### GetSecretValueResponse
- **Fields**:
- `arn()` - Secret ARN
- `name()` - Secret name
- `versionId()` - Version ID
- `secretString()` - Secret value as string
- `secretBinary()` - Secret value as binary
- `versionStages()` - Version stages
### UpdateSecretResponse
- **Fields**:
- `arn()` - Secret ARN
- `name()` - Secret name
- `versionId()` - New version ID
### DeleteSecretResponse
- **Fields**:
- `arn()` - Secret ARN
- `name()` - Secret name
- `deletionDate()` - Deletion date/time
### RotateSecretResponse
- **Fields**:
- `arn()` - Secret ARN
- `name()` - Secret name
- `versionId()` - New version ID
## Paginated Operations
### ListSecretsRequest
- **Fields**:
- `maxResults(Integer maxResults)` - Maximum results per page
- `nextToken(String nextToken)` - Token for next page
- `filter(String filter)` - Filter criteria
### ListSecretsResponse
- **Fields**:
- `secretList()` - List of secrets
- `nextToken()` - Token for next page
## Error Handling
### SecretsManagerException
- Common error codes:
- `ResourceNotFoundException` - Secret not found
- `InvalidParameterException` - Invalid parameters
- `MalformedPolicyDocumentException` - Invalid policy document
- `InternalServiceErrorException` - Internal service error
- `InvalidRequestException` - Invalid request
- `DecryptionFailure` - Decryption failed
- `ResourceExistsException` - Resource already exists
- `ResourceConflictException` - Resource conflict
- `ValidationException` - Validation failed

View File

@@ -0,0 +1,304 @@
# AWS Secrets Manager Caching Guide
## Overview
The AWS Secrets Manager Java caching client enables in-process caching of secrets for Java applications, reducing API calls and improving performance.
## Prerequisites
- Java 8+ development environment
- AWS account with Secrets Manager access
- Appropriate IAM permissions
## Installation
### Maven Dependency
```xml
<dependency>
<groupId>com.amazonaws.secretsmanager</groupId>
<artifactId>aws-secretsmanager-caching-java</artifactId>
<version>2.0.0</version> // Use the latest version compatible with sdk v2
</dependency>
```
### Gradle Dependency
```gradle
implementation 'com.amazonaws.secretsmanager:aws-secretsmanager-caching-java:2.0.0'
```
## Basic Usage
### Simple Cache Setup
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
public class SimpleCacheExample {
private final SecretCache cache = new SecretCache();
public String getSecret(String secretId) {
return cache.getSecretString(secretId);
}
}
```
### Cache with Custom SecretsManagerClient
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
public class ClientAwareCacheExample {
private final SecretCache cache;
public ClientAwareCacheExample(SecretsManagerClient secretsClient) {
this.cache = new SecretCache(secretsClient);
}
public String getSecret(String secretId) {
return cache.getSecretString(secretId);
}
}
```
## Cache Configuration
### SecretCacheConfiguration
```java
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
public class ConfiguredCacheExample {
private final SecretCache cache;
public ConfiguredCacheExample(SecretsManagerClient secretsClient) {
SecretCacheConfiguration config = new SecretCacheConfiguration()
.withMaxCacheSize(1000) // Maximum number of cached secrets
.withCacheItemTTL(3600000); // 1 hour TTL in milliseconds
this.cache = new SecretCache(secretsClient, config);
}
}
```
### Configuration Options
| Property | Type | Default | Description |
|----------|------|---------|-------------|
| `maxCacheSize` | Integer | 1000 | Maximum number of cached secrets |
| `cacheItemTTL` | Long | 300000 (5 min) | Cache item TTL in milliseconds |
| `cacheSizeEvictionPercentage` | Integer | 10 | Percentage of items to evict when cache is full |
## Advanced Caching Patterns
### Multi-Layer Cache
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import java.util.concurrent.ConcurrentHashMap;
public class MultiLayerCache {
private final SecretCache secretsManagerCache;
private final ConcurrentHashMap<String, String> localCache;
private final long localCacheTtl = 30000; // 30 seconds
public MultiLayerCache(SecretsManagerClient secretsClient) {
this.secretsManagerCache = new SecretCache(secretsClient);
this.localCache = new ConcurrentHashMap<>();
}
public String getSecret(String secretId) {
// Check local cache first
String cached = localCache.get(secretId);
if (cached != null) {
return cached;
}
// Get from Secrets Manager cache
String secret = secretsManagerCache.getSecretString(secretId);
if (secret != null) {
localCache.put(secretId, secret);
}
return secret;
}
}
```
### Cache Statistics
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
public class CacheStatsExample {
private final SecretCache cache;
public void demonstrateCacheStats() {
// Get cache statistics
long hitCount = cache.getHitCount();
long missCount = cache.getMissCount();
double hitRatio = cache.getHitRatio();
System.out.println("Cache Hit Ratio: " + hitRatio);
System.out.println("Hits: " + hitCount + ", Misses: " + missCount);
// Clear cache statistics
cache.clearCacheStats();
}
}
```
## Error Handling and Cache Management
### Cache Refresh Strategy
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
public class CacheRefreshManager {
private final SecretCache cache;
private final ScheduledExecutorService scheduler;
public CacheRefreshManager(SecretsManagerClient secretsClient) {
this.cache = new SecretCache(secretsClient);
this.scheduler = Executors.newScheduledThreadPool(1);
}
public void startRefreshSchedule() {
// Refresh cache every hour
scheduler.scheduleAtFixedRate(this::refreshCache, 1, 1, TimeUnit.HOURS);
}
private void refreshCache() {
System.out.println("Refreshing cache...");
cache.refresh();
}
public void shutdown() {
scheduler.shutdown();
}
}
```
### Fallback Mechanism
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
public class FallbackCacheExample {
private final SecretCache cache;
private final SecretsManagerClient fallbackClient;
public FallbackCacheExample(SecretsManagerClient primaryClient, SecretsManagerClient fallbackClient) {
this.cache = new SecretCache(primaryClient);
this.fallbackClient = fallbackClient;
}
public String getSecretWithFallback(String secretId) {
try {
// Try cached value first
return cache.getSecretString(secretId);
} catch (Exception e) {
// Fallback to direct API call
return getSecretDirect(secretId);
}
}
private String getSecretDirect(String secretId) {
GetSecretValueRequest request = GetSecretValueRequest.builder()
.secretId(secretId)
.build();
return fallbackClient.getSecretValue(request).secretString();
}
}
```
## Performance Optimization
### Batch Secret Retrieval
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import java.util.List;
import java.util.ArrayList;
public class BatchSecretRetrieval {
private final SecretCache cache;
public List<String> getMultipleSecrets(List<String> secretIds) {
List<String> results = new ArrayList<>();
for (String secretId : secretIds) {
String secret = cache.getSecretString(secretId);
results.add(secret != null ? secret : "NOT_FOUND");
}
return results;
}
public Map<String, String> getSecretsAsMap(List<String> secretIds) {
Map<String, String> secretMap = new HashMap<>();
for (String secretId : secretIds) {
String secret = cache.getSecretString(secretId);
if (secret != null) {
secretMap.put(secretId, secret);
}
}
return secretMap;
}
}
```
## Monitoring and Debugging
### Cache Monitoring
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
public class CacheMonitor {
private final SecretCache cache;
public void monitorCachePerformance() {
// Monitor cache hit rate
double hitRatio = cache.getHitRatio();
System.out.println("Cache Hit Ratio: " + hitRatio);
// Monitor cache size
long currentSize = cache.size();
System.out.println("Current Cache Size: " + currentSize);
// Monitor cache hits and misses
long hits = cache.getHitCount();
long misses = cache.getMissCount();
System.out.println("Cache Hits: " + hits + ", Misses: " + misses);
}
public void printCacheContents() {
// Note: SecretCache doesn't provide direct access to all cached items
// This is a security feature to prevent accidental exposure of secrets
System.out.println("Cache contents are protected and cannot be directly inspected");
}
}
```
## Best Practices
1. **Cache Size Configuration**:
- Adjust `maxCacheSize` based on available memory
- Monitor memory usage and adjust accordingly
- Consider using heap analysis tools
2. **TTL Configuration**:
- Balance between performance and freshness
- Shorter TTL for frequently changing secrets
- Longer TTL for stable secrets
3. **Error Handling**:
- Implement fallback mechanisms
- Handle cache misses gracefully
- Log errors without exposing sensitive information
4. **Security Considerations**:
- Never log secret values
- Use appropriate IAM permissions
- Consider encryption at rest for cached data
5. **Memory Management**:
- Monitor memory usage
- Consider cache eviction strategies
- Implement proper cleanup in shutdown hooks

View File

@@ -0,0 +1,535 @@
# AWS Secrets Manager Spring Boot Integration
## Overview
Integrate AWS Secrets Manager with Spring Boot applications using the caching library for optimal performance and security.
## Dependencies
### Required Dependencies
```xml
<!-- AWS Secrets Manager -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>secretsmanager</artifactId>
</dependency>
<!-- AWS Secrets Manager Caching -->
<dependency>
<groupId>com.amazonaws.secretsmanager</groupId>
<artifactId>aws-secretsmanager-caching-java</artifactId>
<version>2.0.0</version> // Use the latest version compatible with sdk v2
</dependency>
<!-- Spring Boot Starter -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- Jackson for JSON processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<!-- Connection Pooling -->
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
</dependency>
```
## Configuration Properties
### application.yml
```yaml
spring:
application:
name: aws-secrets-manager-app
datasource:
url: jdbc:postgresql://localhost:5432/mydb
username: ${db.username}
password: ${db.password}
hikari:
maximum-pool-size: 10
minimum-idle: 5
aws:
secrets:
region: us-east-1
# Secret names for different environments
database-credentials: prod/database/credentials
api-keys: prod/external-api/keys
redis-config: prod/redis/config
app:
external-api:
secret-name: prod/external/credentials
base-url: https://api.example.com
```
## Core Components
### SecretsManager Configuration
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
@Configuration
public class SecretsManagerConfiguration {
@Value("${aws.secrets.region}")
private String region;
@Bean
public SecretsManagerClient secretsManagerClient() {
return SecretsManagerClient.builder()
.region(Region.of(region))
.build();
}
@Bean
public SecretCache secretCache(SecretsManagerClient secretsClient) {
SecretCacheConfiguration config = SecretCacheConfiguration.builder()
.maxCacheSize(100)
.cacheItemTTL(3600000) // 1 hour
.build();
return new SecretCache(secretsClient, config);
}
}
```
### Secrets Service
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.stereotype.Service;
import java.util.Map;
@Service
public class SecretsService {
private final SecretCache secretCache;
private final ObjectMapper objectMapper;
public SecretsService(SecretCache secretCache, ObjectMapper objectMapper) {
this.secretCache = secretCache;
this.objectMapper = objectMapper;
}
/**
* Get secret as string
*/
public String getSecret(String secretName) {
try {
return secretCache.getSecretString(secretName);
} catch (Exception e) {
throw new RuntimeException("Failed to retrieve secret: " + secretName, e);
}
}
/**
* Get secret as object of specified type
*/
public <T> T getSecretAsObject(String secretName, Class<T> type) {
try {
String secretJson = secretCache.getSecretString(secretName);
return objectMapper.readValue(secretJson, type);
} catch (Exception e) {
throw new RuntimeException("Failed to parse secret: " + secretName, e);
}
}
/**
* Get secret as Map
*/
public Map<String, String> getSecretAsMap(String secretName) {
try {
String secretJson = secretCache.getSecretString(secretName);
return objectMapper.readValue(secretJson,
new TypeReference<Map<String, String>>() {});
} catch (Exception e) {
throw new RuntimeException("Failed to parse secret map: " + secretName, e);
}
}
/**
* Get secret with fallback
*/
public String getSecretWithFallback(String secretName, String defaultValue) {
try {
String secret = secretCache.getSecretString(secretName);
return secret != null ? secret : defaultValue;
} catch (Exception e) {
return defaultValue;
}
}
}
```
## Database Configuration Integration
### Dynamic DataSource Configuration
```java
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.jdbc.DataSourceBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.sql.DataSource;
@Configuration
public class DatabaseConfiguration {
private final SecretsService secretsService;
@Value("${aws.secrets.database-credentials}")
private String dbSecretName;
public DatabaseConfiguration(SecretsService secretsService) {
this.secretsService = secretsService;
}
@Bean
public DataSource dataSource() {
Map<String, String> credentials = secretsService.getSecretAsMap(dbSecretName);
HikariConfig config = new HikariConfig();
config.setJdbcUrl(credentials.get("url"));
config.setUsername(credentials.get("username"));
config.setPassword(credentials.get("password"));
config.setMaximumPoolSize(10);
config.setMinimumIdle(5);
config.setConnectionTimeout(30000);
config.setIdleTimeout(600000);
config.setMaxLifetime(1800000);
config.setLeakDetectionThreshold(15000);
return new HikariDataSource(config);
}
}
```
### Configuration Properties with Secrets
```java
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
@Component
@ConfigurationProperties(prefix = "app")
public class AppProperties {
private final SecretsService secretsService;
@Value("${app.external-api.secret-name}")
private String apiSecretName;
public AppProperties(SecretsService secretsService) {
this.secretsService = secretsService;
}
private String apiKey;
public String getApiKey() {
if (apiKey == null) {
apiKey = secretsService.getSecret(apiSecretName);
}
return apiKey;
}
// Additional application properties
private String externalApiBaseUrl;
public String getExternalApiBaseUrl() {
return externalApiBaseUrl;
}
public void setExternalApiBaseUrl(String externalApiBaseUrl) {
this.externalApiBaseUrl = externalApiBaseUrl;
}
}
```
## Property Source Integration
### Custom Property Source
```java
import org.springframework.core.env.Environment;
import org.springframework.core.env.PropertySource;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import java.util.HashMap;
import java.util.Map;
@Component
public class SecretsManagerPropertySource extends PropertySource<SecretsService> {
public static final String SECRETS_MANAGER_PROPERTY_SOURCE_NAME = "secretsManagerPropertySource";
private final SecretsService secretsService;
private final Environment environment;
public SecretsManagerPropertySource(SecretsService secretsService, Environment environment) {
super(SECRETS_MANAGER_PROPERTY_SOURCE_NAME, secretsService);
this.secretsService = secretsService;
this.environment = environment;
}
@PostConstruct
public void loadSecrets() {
// Load secrets specified in application.yml
String secretPrefix = "aws.secrets.";
environment.getPropertyNames().forEach(propertyName -> {
if (propertyName.startsWith(secretPrefix)) {
String secretName = environment.getProperty(propertyName);
String secretValue = secretsService.getSecret(secretName);
if (secretValue != null) {
// Add to property source (note: this is simplified)
// In practice, you'd need to work with PropertySources
}
}
});
}
@Override
public Object getProperty(String name) {
if (name.startsWith("aws.secret.")) {
String secretName = name.substring("aws.secret.".length());
return secretsService.getSecret(secretName);
}
return null;
}
}
```
## API Integration
### REST Client with Secrets
```java
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
@Service
public class ExternalApiClient {
private final SecretsService secretsService;
private final RestTemplate restTemplate;
private final AppProperties appProperties;
public ExternalApiClient(SecretsService secretsService,
RestTemplate restTemplate,
AppProperties appProperties) {
this.secretsService = secretsService;
this.restTemplate = restTemplate;
this.appProperties = appProperties;
}
public String callExternalApi(String endpoint) {
Map<String, String> apiCredentials = secretsService.getSecretAsMap(
appProperties.getExternalApiSecretName());
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", "Bearer " + apiCredentials.get("api_token"));
headers.set("X-API-Key", apiCredentials.get("api_key"));
headers.set("Content-Type", "application/json");
HttpEntity<String> entity = new HttpEntity<>(headers);
ResponseEntity<String> response = restTemplate.exchange(
endpoint,
HttpMethod.GET,
entity,
String.class);
return response.getBody();
}
}
```
### Configuration for REST Template
```java
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
@Configuration
public class RestTemplateConfiguration {
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}
```
## Security Configuration
### Security Setup
```java
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.web.SecurityFilterChain;
@Configuration
@EnableWebSecurity
public class SecurityConfiguration {
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http
.authorizeHttpRequests(auth -> auth
.requestMatchers("/api/secrets/**").hasRole("ADMIN")
.anyRequest().permitAll()
)
.httpBasic()
.and()
.csrf().disable();
return http.build();
}
}
```
## Testing Configuration
### Test Configuration
```java
import com.amazonaws.secretsmanager.caching.SecretCache;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
import org.springframework.mock.env.MockEnvironment;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
import static org.mockito.Mockito.*;
@TestConfiguration
public class TestSecretsConfiguration {
@Bean
@Primary
public SecretsManagerClient secretsManagerClient() {
SecretsManagerClient mockClient = mock(SecretsManagerClient.class);
// Mock successful secret retrieval
when(mockClient.getSecretValue(any()))
.thenReturn(GetSecretValueResponse.builder()
.secretString("{\"username\":\"test\",\"password\":\"testpass\"}")
.build());
return mockClient;
}
@Bean
@Primary
public SecretCache secretCache(SecretsManagerClient mockClient) {
SecretCache mockCache = mock(SecretCache.class);
when(mockCache.getSecretString(anyString()))
.thenReturn("{\"username\":\"test\",\"password\":\"testpass\"}");
return mockCache;
}
@Bean
public MockEnvironment mockEnvironment() {
MockEnvironment env = new MockEnvironment();
env.setProperty("aws.secrets.region", "us-east-1");
env.setProperty("aws.secrets.database-credentials", "test-db-credentials");
return env;
}
}
```
### Unit Tests
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.mockito.Mockito.*;
import static org.junit.jupiter.api.Assertions.*;
@ExtendWith(MockitoExtension.class)
class SecretsServiceTest {
@Mock
private SecretCache secretCache;
@InjectMocks
private SecretsService secretsService;
@Test
void shouldGetSecret() {
String secretName = "test-secret";
String expectedValue = "secret-value";
when(secretCache.getSecretString(secretName))
.thenReturn(expectedValue);
String result = secretsService.getSecret(secretName);
assertEquals(expectedValue, result);
verify(secretCache).getSecretString(secretName);
}
@Test
void shouldGetSecretAsMap() throws Exception {
String secretName = "test-secret";
String secretJson = "{\"key\":\"value\"}";
Map<String, String> expectedMap = Map.of("key", "value");
when(secretCache.getSecretString(secretName))
.thenReturn(secretJson);
Map<String, String> result = secretsService.getSecretAsMap(secretName);
assertEquals(expectedMap, result);
}
}
```
## Best Practices
1. **Environment-Specific Configuration**:
- Use different secret names for development, staging, and production
- Implement proper environment variable management
- Use Spring profiles for environment-specific configurations
2. **Security Considerations**:
- Never log secret values
- Use appropriate IAM roles and policies
- Enable encryption in transit and at rest
- Implement proper access controls
3. **Performance Optimization**:
- Use caching for frequently accessed secrets
- Configure appropriate TTL values
- Monitor cache hit rates and adjust accordingly
- Use connection pooling for database connections
4. **Error Handling**:
- Implement fallback mechanisms for critical secrets
- Handle partial secret retrieval gracefully
- Provide meaningful error messages without exposing sensitive information
- Implement circuit breakers for external API calls
5. **Monitoring and Logging**:
- Monitor secret retrieval performance
- Track cache hit/miss ratios
- Log secret access patterns (without values)
- Set up alerts for abnormal secret access patterns

View File

@@ -0,0 +1,336 @@
---
name: unit-test-application-events
description: Testing Spring application events (ApplicationEvent) with @EventListener and ApplicationEventPublisher. Test event publishing, listening, and async event handling in Spring Boot applications. Use when validating event-driven workflows in your Spring Boot services.
category: testing
tags: [junit-5, application-events, event-driven, listeners, publishers]
version: 1.0.1
---
# Unit Testing Application Events
Test Spring ApplicationEvent publishers and event listeners using JUnit 5. Verify event publishing, listener execution, and event propagation without full context startup.
## When to Use This Skill
Use this skill when:
- Testing ApplicationEventPublisher event publishing
- Testing @EventListener method invocation
- Verifying event listener logic and side effects
- Testing event propagation through listeners
- Want fast event-driven architecture tests
- Testing both synchronous and asynchronous event handling
## Setup: Event Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.springframework.boot:spring-boot-starter")
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.mockito:mockito-core")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Event Publishing and Listening
### Custom Event and Publisher
```java
// Custom application event
public class UserCreatedEvent extends ApplicationEvent {
private final User user;
public UserCreatedEvent(Object source, User user) {
super(source);
this.user = user;
}
public User getUser() {
return user;
}
}
// Service that publishes events
@Service
public class UserService {
private final ApplicationEventPublisher eventPublisher;
private final UserRepository userRepository;
public UserService(ApplicationEventPublisher eventPublisher, UserRepository userRepository) {
this.eventPublisher = eventPublisher;
this.userRepository = userRepository;
}
public User createUser(String name, String email) {
User user = new User(name, email);
User savedUser = userRepository.save(user);
eventPublisher.publishEvent(new UserCreatedEvent(this, savedUser));
return savedUser;
}
}
// Unit test
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class UserServiceEventTest {
@Mock
private ApplicationEventPublisher eventPublisher;
@Mock
private UserRepository userRepository;
@InjectMocks
private UserService userService;
@Test
void shouldPublishUserCreatedEvent() {
User newUser = new User(1L, "Alice", "alice@example.com");
when(userRepository.save(any(User.class))).thenReturn(newUser);
ArgumentCaptor<UserCreatedEvent> eventCaptor = ArgumentCaptor.forClass(UserCreatedEvent.class);
userService.createUser("Alice", "alice@example.com");
verify(eventPublisher).publishEvent(eventCaptor.capture());
UserCreatedEvent capturedEvent = eventCaptor.getValue();
assertThat(capturedEvent.getUser()).isEqualTo(newUser);
}
}
```
## Testing Event Listeners
### @EventListener Annotation
```java
// Event listener
@Component
public class UserEventListener {
private final EmailService emailService;
public UserEventListener(EmailService emailService) {
this.emailService = emailService;
}
@EventListener
public void onUserCreated(UserCreatedEvent event) {
User user = event.getUser();
emailService.sendWelcomeEmail(user.getEmail());
}
}
// Unit test for listener
class UserEventListenerTest {
@Test
void shouldSendWelcomeEmailWhenUserCreated() {
EmailService emailService = mock(EmailService.class);
UserEventListener listener = new UserEventListener(emailService);
User newUser = new User(1L, "Alice", "alice@example.com");
UserCreatedEvent event = new UserCreatedEvent(this, newUser);
listener.onUserCreated(event);
verify(emailService).sendWelcomeEmail("alice@example.com");
}
@Test
void shouldNotThrowExceptionWhenEmailServiceFails() {
EmailService emailService = mock(EmailService.class);
doThrow(new RuntimeException("Email service down"))
.when(emailService).sendWelcomeEmail(any());
UserEventListener listener = new UserEventListener(emailService);
User newUser = new User(1L, "Alice", "alice@example.com");
UserCreatedEvent event = new UserCreatedEvent(this, newUser);
// Should handle exception gracefully
assertThatCode(() -> listener.onUserCreated(event))
.doesNotThrowAnyException();
}
}
```
## Testing Multiple Listeners
### Event Propagation
```java
class UserCreatedEvent extends ApplicationEvent {
private final User user;
private final List<String> notifications = new ArrayList<>();
public UserCreatedEvent(Object source, User user) {
super(source);
this.user = user;
}
public void addNotification(String notification) {
notifications.add(notification);
}
public List<String> getNotifications() {
return notifications;
}
}
class MultiListenerTest {
@Test
void shouldNotifyMultipleListenersSequentially() {
EmailService emailService = mock(EmailService.class);
NotificationService notificationService = mock(NotificationService.class);
AnalyticsService analyticsService = mock(AnalyticsService.class);
UserEventListener emailListener = new UserEventListener(emailService);
UserEventListener notificationListener = new UserEventListener(notificationService);
UserEventListener analyticsListener = new UserEventListener(analyticsService);
User user = new User(1L, "Alice", "alice@example.com");
UserCreatedEvent event = new UserCreatedEvent(this, user);
emailListener.onUserCreated(event);
notificationListener.onUserCreated(event);
analyticsListener.onUserCreated(event);
verify(emailService).send(any());
verify(notificationService).notify(any());
verify(analyticsService).track(any());
}
}
```
## Testing Conditional Event Listeners
### @EventListener with Condition
```java
@Component
public class ConditionalEventListener {
@EventListener(condition = "#event.user.age > 18")
public void onAdultUserCreated(UserCreatedEvent event) {
// Handle adult user
}
}
class ConditionalListenerTest {
@Test
void shouldProcessEventWhenConditionMatches() {
// Test logic for matching condition
}
@Test
void shouldSkipEventWhenConditionDoesNotMatch() {
// Test logic for non-matching condition
}
}
```
## Testing Async Event Listeners
### @Async with @EventListener
```java
@Component
public class AsyncEventListener {
private final SlowService slowService;
@EventListener
@Async
public void onUserCreatedAsync(UserCreatedEvent event) {
slowService.processUser(event.getUser());
}
}
class AsyncEventListenerTest {
@Test
void shouldProcessEventAsynchronously() throws Exception {
SlowService slowService = mock(SlowService.class);
AsyncEventListener listener = new AsyncEventListener(slowService);
User user = new User(1L, "Alice", "alice@example.com");
UserCreatedEvent event = new UserCreatedEvent(this, user);
listener.onUserCreatedAsync(event);
// Event processed asynchronously
Thread.sleep(100); // Wait for async completion
verify(slowService).processUser(user);
}
}
```
## Best Practices
- **Mock ApplicationEventPublisher** in unit tests
- **Capture published events** using ArgumentCaptor
- **Test listener side effects** explicitly
- **Test error handling** in listeners
- **Keep event listeners focused** on single responsibility
- **Verify event data integrity** when capturing
- **Test both sync and async** event processing
## Common Pitfalls
- Testing actual event publishing without mocking publisher
- Not verifying listener invocation
- Not capturing event details
- Testing listener registration instead of logic
- Not handling listener exceptions
## Troubleshooting
**Event not being captured**: Verify ArgumentCaptor type matches event class.
**Listener not invoked**: Ensure event is actually published and listener is registered.
**Async listener timing issues**: Use Thread.sleep() or Awaitility to wait for completion.
## References
- [Spring ApplicationEvent](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/context/ApplicationEvent.html)
- [Spring ApplicationEventPublisher](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/context/ApplicationEventPublisher.html)
- [@EventListener Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/context/event/EventListener.html)

View File

@@ -0,0 +1,476 @@
---
name: unit-test-bean-validation
description: Unit testing Jakarta Bean Validation (@Valid, @NotNull, @Min, @Max, etc.) with custom validators and constraint violations. Test validation logic without Spring context. Use when ensuring data integrity and validation rules are correct.
category: testing
tags: [junit-5, validation, bean-validation, jakarta-validation, constraints]
version: 1.0.1
---
# Unit Testing Bean Validation and Custom Validators
Test validation annotations and custom validator implementations using JUnit 5. Verify constraint violations, error messages, and validation logic in isolation.
## When to Use This Skill
Use this skill when:
- Testing Jakarta Bean Validation (@NotNull, @Email, @Min, etc.)
- Testing custom @Constraint validators
- Verifying constraint violation error messages
- Testing cross-field validation logic
- Want fast validation tests without Spring context
- Testing complex validation scenarios and edge cases
## Setup: Bean Validation
### Maven
```xml
<dependency>
<groupId>jakarta.validation</groupId>
<artifactId>jakarta.validation-api</artifactId>
</dependency>
<dependency>
<groupId>org.hibernate.validator</groupId>
<artifactId>hibernate-validator</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("jakarta.validation:jakarta.validation-api")
testImplementation("org.hibernate.validator:hibernate-validator")
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Testing Validation Constraints
### Setup Validator
```java
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;
import jakarta.validation.Validation;
import jakarta.validation.ConstraintViolation;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.*;
class UserValidationTest {
private Validator validator;
@BeforeEach
void setUp() {
ValidatorFactory factory = Validation.buildDefaultValidatorFactory();
validator = factory.getValidator();
}
@Test
void shouldPassValidationWithValidUser() {
User user = new User("Alice", "alice@example.com", 25);
Set<ConstraintViolation<User>> violations = validator.validate(user);
assertThat(violations).isEmpty();
}
@Test
void shouldFailValidationWhenNameIsNull() {
User user = new User(null, "alice@example.com", 25);
Set<ConstraintViolation<User>> violations = validator.validate(user);
assertThat(violations)
.hasSize(1)
.extracting(ConstraintViolation::getMessage)
.contains("must not be blank");
}
}
```
## Testing Individual Constraint Annotations
### Test @NotNull, @NotBlank, @Email
```java
class UserDtoTest {
private Validator validator;
@BeforeEach
void setUp() {
validator = Validation.buildDefaultValidatorFactory().getValidator();
}
@Test
void shouldFailWhenEmailIsInvalid() {
UserDto dto = new UserDto("Alice", "invalid-email");
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
assertThat(violations)
.extracting(ConstraintViolation::getPropertyPath)
.extracting(Path::toString)
.contains("email");
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("must be a valid email address");
}
@Test
void shouldFailWhenNameIsBlank() {
UserDto dto = new UserDto(" ", "alice@example.com");
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
assertThat(violations)
.extracting(ConstraintViolation::getPropertyPath)
.extracting(Path::toString)
.contains("name");
}
@Test
void shouldFailWhenAgeIsNegative() {
UserDto dto = new UserDto("Alice", "alice@example.com", -5);
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("must be greater than or equal to 0");
}
@Test
void shouldPassWhenAllConstraintsSatisfied() {
UserDto dto = new UserDto("Alice", "alice@example.com", 25);
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
assertThat(violations).isEmpty();
}
}
```
## Testing @Min, @Max, @Size Constraints
```java
class ProductDtoTest {
private Validator validator;
@BeforeEach
void setUp() {
validator = Validation.buildDefaultValidatorFactory().getValidator();
}
@Test
void shouldFailWhenPriceIsBelowMinimum() {
ProductDto product = new ProductDto("Laptop", -100.0);
Set<ConstraintViolation<ProductDto>> violations = validator.validate(product);
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("must be greater than 0");
}
@Test
void shouldFailWhenQuantityExceedsMaximum() {
ProductDto product = new ProductDto("Laptop", 1000.0, 999999);
Set<ConstraintViolation<ProductDto>> violations = validator.validate(product);
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("must be less than or equal to 10000");
}
@Test
void shouldFailWhenDescriptionTooLong() {
String longDescription = "x".repeat(1001);
ProductDto product = new ProductDto("Laptop", 1000.0, longDescription);
Set<ConstraintViolation<ProductDto>> violations = validator.validate(product);
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("size must be between 0 and 1000");
}
}
```
## Testing Custom Validators
### Create and Test Custom Constraint
```java
// Custom constraint annotation
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Constraint(validatedBy = PhoneNumberValidator.class)
public @interface ValidPhoneNumber {
String message() default "invalid phone number format";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}
// Custom validator implementation
public class PhoneNumberValidator implements ConstraintValidator<ValidPhoneNumber, String> {
private static final String PHONE_PATTERN = "^\\d{3}-\\d{3}-\\d{4}$";
@Override
public boolean isValid(String value, ConstraintValidatorContext context) {
if (value == null) return true; // null values handled by @NotNull
return value.matches(PHONE_PATTERN);
}
}
// Unit test for custom validator
class PhoneNumberValidatorTest {
private Validator validator;
@BeforeEach
void setUp() {
validator = Validation.buildDefaultValidatorFactory().getValidator();
}
@Test
void shouldAcceptValidPhoneNumber() {
Contact contact = new Contact("Alice", "555-123-4567");
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
assertThat(violations).isEmpty();
}
@Test
void shouldRejectInvalidPhoneNumberFormat() {
Contact contact = new Contact("Alice", "5551234567"); // No dashes
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("invalid phone number format");
}
@Test
void shouldRejectPhoneNumberWithLetters() {
Contact contact = new Contact("Alice", "ABC-DEF-GHIJ");
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
assertThat(violations).isNotEmpty();
}
@Test
void shouldAllowNullPhoneNumber() {
Contact contact = new Contact("Alice", null);
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
assertThat(violations).isEmpty();
}
}
```
## Testing Cross-Field Validation
### Custom Multi-Field Constraint
```java
// Custom constraint for cross-field validation
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Constraint(validatedBy = PasswordMatchValidator.class)
public @interface PasswordsMatch {
String message() default "passwords do not match";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}
// Validator implementation
public class PasswordMatchValidator implements ConstraintValidator<PasswordsMatch, ChangePasswordRequest> {
@Override
public boolean isValid(ChangePasswordRequest value, ConstraintValidatorContext context) {
if (value == null) return true;
return value.getNewPassword().equals(value.getConfirmPassword());
}
}
// Unit test
class PasswordValidationTest {
private Validator validator;
@BeforeEach
void setUp() {
validator = Validation.buildDefaultValidatorFactory().getValidator();
}
@Test
void shouldPassWhenPasswordsMatch() {
ChangePasswordRequest request = new ChangePasswordRequest("oldPass", "newPass123", "newPass123");
Set<ConstraintViolation<ChangePasswordRequest>> violations = validator.validate(request);
assertThat(violations).isEmpty();
}
@Test
void shouldFailWhenPasswordsDoNotMatch() {
ChangePasswordRequest request = new ChangePasswordRequest("oldPass", "newPass123", "differentPass");
Set<ConstraintViolation<ChangePasswordRequest>> violations = validator.validate(request);
assertThat(violations)
.extracting(ConstraintViolation::getMessage)
.contains("passwords do not match");
}
}
```
## Testing Validation Groups
### Conditional Validation
```java
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public interface CreateValidation {}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public interface UpdateValidation {}
class UserDto {
@NotNull(groups = {CreateValidation.class})
private String name;
@Min(value = 1, groups = {CreateValidation.class, UpdateValidation.class})
private int age;
}
class ValidationGroupsTest {
private Validator validator;
@BeforeEach
void setUp() {
validator = Validation.buildDefaultValidatorFactory().getValidator();
}
@Test
void shouldRequireNameOnlyDuringCreation() {
UserDto user = new UserDto(null, 25);
Set<ConstraintViolation<UserDto>> violations = validator.validate(user, CreateValidation.class);
assertThat(violations)
.extracting(ConstraintViolation::getPropertyPath)
.extracting(Path::toString)
.contains("name");
}
@Test
void shouldAllowNullNameDuringUpdate() {
UserDto user = new UserDto(null, 25);
Set<ConstraintViolation<UserDto>> violations = validator.validate(user, UpdateValidation.class);
assertThat(violations).isEmpty();
}
}
```
## Testing Parameterized Validation Scenarios
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class EmailValidationTest {
private Validator validator;
@BeforeEach
void setUp() {
validator = Validation.buildDefaultValidatorFactory().getValidator();
}
@ParameterizedTest
@ValueSource(strings = {
"user@example.com",
"john.doe+tag@example.co.uk",
"admin123@subdomain.example.com"
})
void shouldAcceptValidEmails(String email) {
UserDto user = new UserDto("Alice", email);
Set<ConstraintViolation<UserDto>> violations = validator.validate(user);
assertThat(violations).isEmpty();
}
@ParameterizedTest
@ValueSource(strings = {
"invalid-email",
"user@",
"@example.com",
"user name@example.com"
})
void shouldRejectInvalidEmails(String email) {
UserDto user = new UserDto("Alice", email);
Set<ConstraintViolation<UserDto>> violations = validator.validate(user);
assertThat(violations).isNotEmpty();
}
}
```
## Best Practices
- **Validate at unit test level** before testing service/controller layers
- **Test both valid and invalid cases** for every constraint
- **Use custom validators** for business-specific validation rules
- **Test error messages** to ensure they're user-friendly
- **Test edge cases**: null, empty string, whitespace-only strings
- **Use validation groups** for conditional validation rules
- **Keep validator logic simple** - complex validation belongs in service tests
## Common Pitfalls
- Forgetting to test null values
- Not extracting violation details (message, property, constraint type)
- Testing validation at service/controller level instead of unit tests
- Creating overly complex custom validators
- Not documenting constraint purposes in error messages
## Troubleshooting
**ValidatorFactory not found**: Ensure `jakarta.validation-api` and `hibernate-validator` are on classpath.
**Custom validator not invoked**: Verify `@Constraint(validatedBy = YourValidator.class)` is correctly specified.
**Null handling confusion**: By default, `@NotNull` checks null, other constraints ignore null (use `@NotNull` with others for mandatory fields).
## References
- [Jakarta Bean Validation Spec](https://jakarta.ee/specifications/bean-validation/)
- [Hibernate Validator Documentation](https://hibernate.org/validator/)
- [Custom Constraints](https://docs.jboss.org/hibernate/stable/validator/reference/en-US/html_single/#validator-customconstraints)

View File

@@ -0,0 +1,453 @@
---
name: unit-test-boundary-conditions
description: Edge case and boundary testing patterns for unit tests. Testing minimum/maximum values, null cases, empty collections, and numeric precision. Pure JUnit 5 unit tests. Use when ensuring code handles limits and special cases correctly.
category: testing
tags: [junit-5, boundary-testing, edge-cases, parameterized-test]
version: 1.0.1
---
# Unit Testing Boundary Conditions and Edge Cases
Test boundary conditions, edge cases, and limit values systematically. Verify code behavior at limits, with null/empty inputs, and overflow scenarios.
## When to Use This Skill
Use this skill when:
- Testing minimum and maximum values
- Testing null and empty inputs
- Testing whitespace-only strings
- Testing overflow/underflow scenarios
- Testing collections with zero/one/many items
- Verifying behavior at API boundaries
- Want comprehensive edge case coverage
## Setup: Boundary Testing
### Maven
```xml
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.junit.jupiter:junit-jupiter-params")
testImplementation("org.assertj:assertj-core")
}
```
## Numeric Boundary Testing
### Integer Limits
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import static org.assertj.core.api.Assertions.*;
class IntegerBoundaryTest {
@ParameterizedTest
@ValueSource(ints = {Integer.MIN_VALUE, Integer.MIN_VALUE + 1, 0, Integer.MAX_VALUE - 1, Integer.MAX_VALUE})
void shouldHandleIntegerBoundaries(int value) {
assertThat(value).isNotNull();
}
@Test
void shouldHandleIntegerOverflow() {
int maxInt = Integer.MAX_VALUE;
int result = Math.addExact(maxInt, 1); // Will throw ArithmeticException
assertThatThrownBy(() -> Math.addExact(Integer.MAX_VALUE, 1))
.isInstanceOf(ArithmeticException.class);
}
@Test
void shouldHandleIntegerUnderflow() {
assertThatThrownBy(() -> Math.subtractExact(Integer.MIN_VALUE, 1))
.isInstanceOf(ArithmeticException.class);
}
@Test
void shouldHandleZero() {
int result = MathUtils.divide(0, 5);
assertThat(result).isZero();
assertThatThrownBy(() -> MathUtils.divide(5, 0))
.isInstanceOf(ArithmeticException.class);
}
}
```
## String Boundary Testing
### Null, Empty, and Whitespace
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class StringBoundaryTest {
@ParameterizedTest
@ValueSource(strings = {"", " ", " ", "\t", "\n"})
void shouldConsiderEmptyAndWhitespaceAsInvalid(String input) {
boolean result = StringUtils.isNotBlank(input);
assertThat(result).isFalse();
}
@Test
void shouldHandleNullString() {
String result = StringUtils.trim(null);
assertThat(result).isNull();
}
@Test
void shouldHandleSingleCharacter() {
String result = StringUtils.capitalize("a");
assertThat(result).isEqualTo("A");
String result2 = StringUtils.trim("x");
assertThat(result2).isEqualTo("x");
}
@Test
void shouldHandleVeryLongString() {
String longString = "x".repeat(1000000);
assertThat(longString.length()).isEqualTo(1000000);
assertThat(StringUtils.isNotBlank(longString)).isTrue();
}
@Test
void shouldHandleSpecialCharacters() {
String special = "!@#$%^&*()_+-={}[]|\\:;<>?,./";
assertThat(StringUtils.length(special)).isEqualTo(31);
}
}
```
## Collection Boundary Testing
### Empty, Single, and Large Collections
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class CollectionBoundaryTest {
@Test
void shouldHandleEmptyList() {
List<String> empty = List.of();
assertThat(empty).isEmpty();
assertThat(CollectionUtils.first(empty)).isNull();
assertThat(CollectionUtils.count(empty)).isZero();
}
@Test
void shouldHandleSingleElementList() {
List<String> single = List.of("only");
assertThat(single).hasSize(1);
assertThat(CollectionUtils.first(single)).isEqualTo("only");
assertThat(CollectionUtils.last(single)).isEqualTo("only");
}
@Test
void shouldHandleLargeList() {
List<Integer> large = new ArrayList<>();
for (int i = 0; i < 100000; i++) {
large.add(i);
}
assertThat(large).hasSize(100000);
assertThat(CollectionUtils.first(large)).isZero();
assertThat(CollectionUtils.last(large)).isEqualTo(99999);
}
@Test
void shouldHandleNullInCollection() {
List<String> withNull = new ArrayList<>(List.of("a", null, "c"));
assertThat(withNull).contains(null);
assertThat(CollectionUtils.filterNonNull(withNull)).hasSize(2);
}
@Test
void shouldHandleDuplicatesInCollection() {
List<Integer> duplicates = List.of(1, 1, 2, 2, 3, 3);
assertThat(duplicates).hasSize(6);
Set<Integer> unique = new HashSet<>(duplicates);
assertThat(unique).hasSize(3);
}
}
```
## Floating Point Boundary Testing
### Precision and Special Values
```java
class FloatingPointBoundaryTest {
@Test
void shouldHandleFloatingPointPrecision() {
double result = 0.1 + 0.2;
// Floating point comparison needs tolerance
assertThat(result).isCloseTo(0.3, within(0.0001));
}
@Test
void shouldHandleSpecialFloatingPointValues() {
assertThat(Double.POSITIVE_INFINITY).isGreaterThan(Double.MAX_VALUE);
assertThat(Double.NEGATIVE_INFINITY).isLessThan(Double.MIN_VALUE);
assertThat(Double.NaN).isNotEqualTo(Double.NaN); // NaN != NaN
}
@Test
void shouldHandleVerySmallAndLargeNumbers() {
double tiny = Double.MIN_VALUE;
double huge = Double.MAX_VALUE;
assertThat(tiny).isGreaterThan(0);
assertThat(huge).isPositive();
}
@Test
void shouldHandleZeroInDivision() {
double result = 1.0 / 0.0;
assertThat(result).isEqualTo(Double.POSITIVE_INFINITY);
double result2 = -1.0 / 0.0;
assertThat(result2).isEqualTo(Double.NEGATIVE_INFINITY);
double result3 = 0.0 / 0.0;
assertThat(result3).isNaN();
}
}
```
## Date/Time Boundary Testing
### Min/Max Dates and Edge Cases
```java
class DateTimeBoundaryTest {
@Test
void shouldHandleMinAndMaxDates() {
LocalDate min = LocalDate.MIN;
LocalDate max = LocalDate.MAX;
assertThat(min).isBefore(max);
assertThat(DateUtils.isValid(min)).isTrue();
assertThat(DateUtils.isValid(max)).isTrue();
}
@Test
void shouldHandleLeapYearBoundary() {
LocalDate leapYearEnd = LocalDate.of(2024, 2, 29);
assertThat(leapYearEnd).isNotNull();
assertThat(LocalDate.of(2024, 2, 29)).isEqualTo(leapYearEnd);
}
@Test
void shouldHandleInvalidDateInNonLeapYear() {
assertThatThrownBy(() -> LocalDate.of(2023, 2, 29))
.isInstanceOf(DateTimeException.class);
}
@Test
void shouldHandleYearBoundaries() {
LocalDate newYear = LocalDate.of(2024, 1, 1);
LocalDate lastDay = LocalDate.of(2024, 12, 31);
assertThat(newYear).isBefore(lastDay);
}
@Test
void shouldHandleMidnightBoundary() {
LocalTime midnight = LocalTime.MIDNIGHT;
LocalTime almostMidnight = LocalTime.of(23, 59, 59);
assertThat(almostMidnight).isBefore(midnight);
}
}
```
## Array Index Boundary Testing
### First, Last, and Out of Bounds
```java
class ArrayBoundaryTest {
@Test
void shouldHandleFirstElementAccess() {
int[] array = {1, 2, 3, 4, 5};
assertThat(array[0]).isEqualTo(1);
}
@Test
void shouldHandleLastElementAccess() {
int[] array = {1, 2, 3, 4, 5};
assertThat(array[array.length - 1]).isEqualTo(5);
}
@Test
void shouldThrowOnNegativeIndex() {
int[] array = {1, 2, 3};
assertThatThrownBy(() -> {
int value = array[-1];
}).isInstanceOf(ArrayIndexOutOfBoundsException.class);
}
@Test
void shouldThrowOnOutOfBoundsIndex() {
int[] array = {1, 2, 3};
assertThatThrownBy(() -> {
int value = array[10];
}).isInstanceOf(ArrayIndexOutOfBoundsException.class);
}
@Test
void shouldHandleEmptyArray() {
int[] empty = {};
assertThat(empty.length).isZero();
assertThatThrownBy(() -> {
int value = empty[0];
}).isInstanceOf(ArrayIndexOutOfBoundsException.class);
}
}
```
## Concurrent and Thread Boundary Testing
### Null and Race Conditions
```java
import java.util.concurrent.*;
class ConcurrentBoundaryTest {
@Test
void shouldHandleNullInConcurrentMap() {
ConcurrentHashMap<String, String> map = new ConcurrentHashMap<>();
map.put("key", "value");
assertThat(map.get("nonexistent")).isNull();
}
@Test
void shouldHandleConcurrentModification() {
List<Integer> list = new CopyOnWriteArrayList<>(List.of(1, 2, 3, 4, 5));
// Should not throw ConcurrentModificationException
for (int num : list) {
if (num == 3) {
list.add(6);
}
}
assertThat(list).hasSize(6);
}
@Test
void shouldHandleEmptyBlockingQueue() throws InterruptedException {
BlockingQueue<String> queue = new LinkedBlockingQueue<>();
assertThat(queue.poll()).isNull();
}
}
```
## Parameterized Boundary Testing
### Multiple Boundary Cases
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
class ParameterizedBoundaryTest {
@ParameterizedTest
@CsvSource({
"null, false", // null
"'', false", // empty
"' ', false", // whitespace
"a, true", // single char
"abc, true" // normal
})
void shouldValidateStringBoundaries(String input, boolean expected) {
boolean result = StringValidator.isValid(input);
assertThat(result).isEqualTo(expected);
}
@ParameterizedTest
@ValueSource(ints = {Integer.MIN_VALUE, 0, 1, -1, Integer.MAX_VALUE})
void shouldHandleNumericBoundaries(int value) {
assertThat(value).isNotNull();
}
}
```
## Best Practices
- **Test explicitly at boundaries** - don't rely on random testing
- **Test null and empty separately** from valid inputs
- **Use parameterized tests** for multiple boundary cases
- **Test both sides of boundaries** (just below, at, just above)
- **Verify error messages** are helpful for invalid boundaries
- **Document why** specific boundaries matter
- **Test overflow/underflow** for numeric operations
## Common Pitfalls
- Testing only "happy path" without boundary cases
- Forgetting null/empty cases
- Not testing floating point precision
- Not testing collection boundaries (empty, single, many)
- Not testing string boundaries (null, empty, whitespace)
## Troubleshooting
**Floating point comparison fails**: Use `isCloseTo(expected, within(tolerance))`.
**Collection boundaries unclear**: List cases explicitly: empty (0), single (1), many (>1).
**Date boundary confusing**: Use `LocalDate.MIN`, `LocalDate.MAX` for clear boundaries.
## References
- [Integer.MIN_VALUE/MAX_VALUE](https://docs.oracle.com/javase/8/docs/api/java/lang/Integer.html)
- [Double.MIN_VALUE/MAX_VALUE](https://docs.oracle.com/javase/8/docs/api/java/lang/Double.html)
- [AssertJ Floating Point Assertions](https://assertj.github.io/assertj-core-features-highlight.html#assertions-on-numbers)
- [Boundary Value Analysis](https://en.wikipedia.org/wiki/Boundary-value_analysis)

View File

@@ -0,0 +1,401 @@
---
name: unit-test-caching
description: Unit tests for caching behavior using Spring Cache annotations (@Cacheable, @CachePut, @CacheEvict). Use when validating cache configuration and cache hit/miss scenarios.
category: testing
tags: [junit-5, caching, cacheable, cache-evict, cache-put]
version: 1.0.1
---
# Unit Testing Spring Caching
Test Spring caching annotations (@Cacheable, @CacheEvict, @CachePut) without full Spring context. Verify cache behavior, hits/misses, and invalidation strategies.
## When to Use This Skill
Use this skill when:
- Testing @Cacheable method caching
- Testing @CacheEvict cache invalidation
- Testing @CachePut cache updates
- Verifying cache key generation
- Testing conditional caching
- Want fast caching tests without Redis or cache infrastructure
## Setup: Caching Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-cache</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.springframework.boot:spring-boot-starter-cache")
testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("org.mockito:mockito-core")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Testing @Cacheable
### Cache Hit and Miss Behavior
```java
// Service with caching
@Service
public class UserService {
private final UserRepository userRepository;
public UserService(UserRepository userRepository) {
this.userRepository = userRepository;
}
@Cacheable("users")
public User getUserById(Long id) {
return userRepository.findById(id).orElse(null);
}
}
// Test caching behavior
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.springframework.cache.CacheManager;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.cache.concurrent.ConcurrentMapCacheManager;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static org.mockito.Mockito.*;
import static org.assertj.core.api.Assertions.*;
@Configuration
@EnableCaching
class CacheTestConfig {
@Bean
public CacheManager cacheManager() {
return new ConcurrentMapCacheManager("users");
}
}
class UserServiceCachingTest {
private UserRepository userRepository;
private UserService userService;
private CacheManager cacheManager;
@BeforeEach
void setUp() {
userRepository = mock(UserRepository.class);
cacheManager = new ConcurrentMapCacheManager("users");
userService = new UserService(userRepository);
}
@Test
void shouldCacheUserAfterFirstCall() {
User user = new User(1L, "Alice");
when(userRepository.findById(1L)).thenReturn(Optional.of(user));
User firstCall = userService.getUserById(1L);
User secondCall = userService.getUserById(1L);
assertThat(firstCall).isEqualTo(secondCall);
verify(userRepository, times(1)).findById(1L); // Called only once due to cache
}
@Test
void shouldReturnCachedValueOnSecondCall() {
User user = new User(1L, "Alice");
when(userRepository.findById(1L)).thenReturn(Optional.of(user));
userService.getUserById(1L); // First call - hits database
User cachedResult = userService.getUserById(1L); // Second call - hits cache
assertThat(cachedResult).isEqualTo(user);
verify(userRepository, times(1)).findById(1L);
}
}
```
## Testing @CacheEvict
### Cache Invalidation
```java
@Service
public class ProductService {
private final ProductRepository productRepository;
public ProductService(ProductRepository productRepository) {
this.productRepository = productRepository;
}
@Cacheable("products")
public Product getProductById(Long id) {
return productRepository.findById(id).orElse(null);
}
@CacheEvict("products")
public void deleteProduct(Long id) {
productRepository.deleteById(id);
}
@CacheEvict(value = "products", allEntries = true)
public void clearAllProducts() {
// Clear entire cache
}
}
class ProductCacheEvictTest {
private ProductRepository productRepository;
private ProductService productService;
private CacheManager cacheManager;
@BeforeEach
void setUp() {
productRepository = mock(ProductRepository.class);
cacheManager = new ConcurrentMapCacheManager("products");
productService = new ProductService(productRepository);
}
@Test
void shouldEvictProductFromCacheWhenDeleted() {
Product product = new Product(1L, "Laptop", 999.99);
when(productRepository.findById(1L)).thenReturn(Optional.of(product));
productService.getProductById(1L); // Cache the product
productService.deleteProduct(1L); // Evict from cache
User cachedAfterEvict = userService.getUserById(1L);
// After eviction, repository should be called again
verify(productRepository, times(2)).findById(1L);
}
@Test
void shouldClearAllEntriesFromCache() {
Product product1 = new Product(1L, "Laptop", 999.99);
Product product2 = new Product(2L, "Mouse", 29.99);
when(productRepository.findById(1L)).thenReturn(Optional.of(product1));
when(productRepository.findById(2L)).thenReturn(Optional.of(product2));
productService.getProductById(1L);
productService.getProductById(2L);
productService.clearAllProducts(); // Clear all cache entries
productService.getProductById(1L);
productService.getProductById(2L);
// Repository called twice for each product
verify(productRepository, times(2)).findById(1L);
verify(productRepository, times(2)).findById(2L);
}
}
```
## Testing @CachePut
### Cache Update
```java
@Service
public class OrderService {
private final OrderRepository orderRepository;
public OrderService(OrderRepository orderRepository) {
this.orderRepository = orderRepository;
}
@Cacheable("orders")
public Order getOrder(Long id) {
return orderRepository.findById(id).orElse(null);
}
@CachePut(value = "orders", key = "#order.id")
public Order updateOrder(Order order) {
return orderRepository.save(order);
}
}
class OrderCachePutTest {
private OrderRepository orderRepository;
private OrderService orderService;
@BeforeEach
void setUp() {
orderRepository = mock(OrderRepository.class);
orderService = new OrderService(orderRepository);
}
@Test
void shouldUpdateCacheWhenOrderIsUpdated() {
Order originalOrder = new Order(1L, "Pending", 100.0);
Order updatedOrder = new Order(1L, "Shipped", 100.0);
when(orderRepository.findById(1L)).thenReturn(Optional.of(originalOrder));
when(orderRepository.save(updatedOrder)).thenReturn(updatedOrder);
orderService.getOrder(1L);
Order result = orderService.updateOrder(updatedOrder);
assertThat(result.getStatus()).isEqualTo("Shipped");
// Next call should return updated version from cache
Order cachedOrder = orderService.getOrder(1L);
assertThat(cachedOrder.getStatus()).isEqualTo("Shipped");
}
}
```
## Testing Conditional Caching
### Cache with Conditions
```java
@Service
public class DataService {
private final DataRepository dataRepository;
public DataService(DataRepository dataRepository) {
this.dataRepository = dataRepository;
}
@Cacheable(value = "data", unless = "#result == null")
public Data getData(Long id) {
return dataRepository.findById(id).orElse(null);
}
@Cacheable(value = "users", condition = "#id > 0")
public User getUser(Long id) {
return userRepository.findById(id).orElse(null);
}
}
class ConditionalCachingTest {
@Test
void shouldNotCacheNullResults() {
DataRepository dataRepository = mock(DataRepository.class);
when(dataRepository.findById(999L)).thenReturn(Optional.empty());
DataService service = new DataService(dataRepository);
service.getData(999L);
service.getData(999L);
// Should call repository twice because null results are not cached
verify(dataRepository, times(2)).findById(999L);
}
@Test
void shouldNotCacheWhenConditionIsFalse() {
UserRepository userRepository = mock(UserRepository.class);
User user = new User(1L, "Alice");
when(userRepository.findById(-1L)).thenReturn(Optional.of(user));
DataService service = new DataService(null);
service.getUser(-1L);
service.getUser(-1L);
// Should call repository twice because id <= 0 doesn't match condition
verify(userRepository, times(2)).findById(-1L);
}
}
```
## Testing Cache Keys
### Verify Cache Key Generation
```java
@Service
public class InventoryService {
private final InventoryRepository inventoryRepository;
public InventoryService(InventoryRepository inventoryRepository) {
this.inventoryRepository = inventoryRepository;
}
@Cacheable(value = "inventory", key = "#productId + '-' + #warehouseId")
public InventoryItem getInventory(Long productId, Long warehouseId) {
return inventoryRepository.findByProductAndWarehouse(productId, warehouseId);
}
}
class CacheKeyTest {
@Test
void shouldGenerateCorrectCacheKey() {
InventoryRepository repository = mock(InventoryRepository.class);
InventoryItem item = new InventoryItem(1L, 1L, 100);
when(repository.findByProductAndWarehouse(1L, 1L)).thenReturn(item);
InventoryService service = new InventoryService(repository);
service.getInventory(1L, 1L); // Cache: "1-1"
service.getInventory(1L, 1L); // Hit cache: "1-1"
service.getInventory(2L, 1L); // Miss cache: "2-1"
verify(repository, times(2)).findByProductAndWarehouse(any(), any());
}
}
```
## Best Practices
- **Use in-memory CacheManager** for unit tests
- **Verify repository calls** to confirm cache hits/misses
- **Test both positive and negative** cache scenarios
- **Test cache invalidation** thoroughly
- **Test conditional caching** with various conditions
- **Keep cache configuration simple** in tests
- **Mock dependencies** that services use
## Common Pitfalls
- Testing actual cache infrastructure instead of caching logic
- Not verifying repository call counts
- Forgetting to test cache eviction
- Not testing conditional caching
- Not resetting cache between tests
## Troubleshooting
**Cache not working in tests**: Ensure `@EnableCaching` is in test configuration.
**Wrong cache key generated**: Use `SpEL` syntax correctly in `@Cacheable(key = "...")`.
**Cache not evicting**: Verify `@CacheEvict` key matches stored key exactly.
## References
- [Spring Caching Documentation](https://docs.spring.io/spring-framework/docs/current/reference/html/integration.html#cache)
- [Spring Cache Abstractions](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/cache/annotation/Cacheable.html)
- [SpEL in Caching](https://docs.spring.io/spring-framework/docs/current/reference/html/core.html#expressions)

View File

@@ -0,0 +1,458 @@
---
name: unit-test-config-properties
description: Unit tests for @ConfigurationProperties classes with @ConfigurationPropertiesTest. Use when validating application configuration binding and validation.
category: testing
tags: [junit-5, configuration-properties, spring-profiles, property-binding]
version: 1.0.1
---
# Unit Testing Configuration Properties and Profiles
Test @ConfigurationProperties bindings, environment-specific configurations, and property validation using JUnit 5. Verify configuration loading without full Spring context startup.
## When to Use This Skill
Use this skill when:
- Testing @ConfigurationProperties property binding
- Testing property name mapping and type conversions
- Verifying configuration validation
- Testing environment-specific configurations
- Testing nested property structures
- Want fast configuration tests without Spring context
## Setup: Configuration Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
annotationProcessor("org.springframework.boot:spring-boot-configuration-processor")
testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Testing ConfigurationProperties
### Simple Property Binding
```java
// Configuration properties class
@ConfigurationProperties(prefix = "app.security")
@Data
public class SecurityProperties {
private String jwtSecret;
private long jwtExpirationMs;
private int maxLoginAttempts;
private boolean enableTwoFactor;
}
// Unit test
import org.junit.jupiter.api.Test;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import static org.assertj.core.api.Assertions.*;
class SecurityPropertiesTest {
@Test
void shouldBindPropertiesFromEnvironment() {
new ApplicationContextRunner()
.withPropertyValues(
"app.security.jwtSecret=my-secret-key",
"app.security.jwtExpirationMs=3600000",
"app.security.maxLoginAttempts=5",
"app.security.enableTwoFactor=true"
)
.withBean(SecurityProperties.class)
.run(context -> {
SecurityProperties props = context.getBean(SecurityProperties.class);
assertThat(props.getJwtSecret()).isEqualTo("my-secret-key");
assertThat(props.getJwtExpirationMs()).isEqualTo(3600000L);
assertThat(props.getMaxLoginAttempts()).isEqualTo(5);
assertThat(props.isEnableTwoFactor()).isTrue();
});
}
@Test
void shouldUseDefaultValuesWhenPropertiesNotProvided() {
new ApplicationContextRunner()
.withPropertyValues("app.security.jwtSecret=key")
.withBean(SecurityProperties.class)
.run(context -> {
SecurityProperties props = context.getBean(SecurityProperties.class);
assertThat(props.getJwtSecret()).isEqualTo("key");
assertThat(props.getMaxLoginAttempts()).isZero();
});
}
}
```
## Testing Nested Configuration Properties
### Complex Property Structure
```java
@ConfigurationProperties(prefix = "app.database")
@Data
public class DatabaseProperties {
private String url;
private String username;
private Pool pool = new Pool();
private List<Replica> replicas = new ArrayList<>();
@Data
public static class Pool {
private int maxSize = 10;
private int minIdle = 5;
private long connectionTimeout = 30000;
}
@Data
public static class Replica {
private String name;
private String url;
private int priority;
}
}
class NestedPropertiesTest {
@Test
void shouldBindNestedProperties() {
new ApplicationContextRunner()
.withPropertyValues(
"app.database.url=jdbc:mysql://localhost/db",
"app.database.username=admin",
"app.database.pool.maxSize=20",
"app.database.pool.minIdle=10",
"app.database.pool.connectionTimeout=60000"
)
.withBean(DatabaseProperties.class)
.run(context -> {
DatabaseProperties props = context.getBean(DatabaseProperties.class);
assertThat(props.getUrl()).isEqualTo("jdbc:mysql://localhost/db");
assertThat(props.getPool().getMaxSize()).isEqualTo(20);
assertThat(props.getPool().getConnectionTimeout()).isEqualTo(60000L);
});
}
@Test
void shouldBindListOfReplicas() {
new ApplicationContextRunner()
.withPropertyValues(
"app.database.replicas[0].name=replica-1",
"app.database.replicas[0].url=jdbc:mysql://replica1/db",
"app.database.replicas[0].priority=1",
"app.database.replicas[1].name=replica-2",
"app.database.replicas[1].url=jdbc:mysql://replica2/db",
"app.database.replicas[1].priority=2"
)
.withBean(DatabaseProperties.class)
.run(context -> {
DatabaseProperties props = context.getBean(DatabaseProperties.class);
assertThat(props.getReplicas()).hasSize(2);
assertThat(props.getReplicas().get(0).getName()).isEqualTo("replica-1");
assertThat(props.getReplicas().get(1).getPriority()).isEqualTo(2);
});
}
}
```
## Testing Property Validation
### Validate Configuration with Constraints
```java
@ConfigurationProperties(prefix = "app.server")
@Data
@Validated
public class ServerProperties {
@NotBlank
private String host;
@Min(1)
@Max(65535)
private int port = 8080;
@Positive
private int threadPoolSize;
@Email
private String adminEmail;
}
class ConfigurationValidationTest {
@Test
void shouldFailValidationWhenHostIsBlank() {
new ApplicationContextRunner()
.withPropertyValues(
"app.server.host=",
"app.server.port=8080",
"app.server.threadPoolSize=10"
)
.withBean(ServerProperties.class)
.run(context -> {
assertThat(context).hasFailed()
.getFailure()
.hasMessageContaining("host");
});
}
@Test
void shouldFailValidationWhenPortOutOfRange() {
new ApplicationContextRunner()
.withPropertyValues(
"app.server.host=localhost",
"app.server.port=99999",
"app.server.threadPoolSize=10"
)
.withBean(ServerProperties.class)
.run(context -> {
assertThat(context).hasFailed();
});
}
@Test
void shouldPassValidationWithValidConfiguration() {
new ApplicationContextRunner()
.withPropertyValues(
"app.server.host=localhost",
"app.server.port=8080",
"app.server.threadPoolSize=10",
"app.server.adminEmail=admin@example.com"
)
.withBean(ServerProperties.class)
.run(context -> {
assertThat(context).hasNotFailed();
ServerProperties props = context.getBean(ServerProperties.class);
assertThat(props.getHost()).isEqualTo("localhost");
});
}
}
```
## Testing Profile-Specific Configurations
### Environment-Specific Properties
```java
@Configuration
@Profile("prod")
class ProductionConfiguration {
@Bean
public SecurityProperties securityProperties() {
SecurityProperties props = new SecurityProperties();
props.setEnableTwoFactor(true);
props.setMaxLoginAttempts(3);
return props;
}
}
@Configuration
@Profile("dev")
class DevelopmentConfiguration {
@Bean
public SecurityProperties securityProperties() {
SecurityProperties props = new SecurityProperties();
props.setEnableTwoFactor(false);
props.setMaxLoginAttempts(999);
return props;
}
}
class ProfileBasedConfigurationTest {
@Test
void shouldLoadProductionConfiguration() {
new ApplicationContextRunner()
.withPropertyValues("spring.profiles.active=prod")
.withUserConfiguration(ProductionConfiguration.class)
.run(context -> {
SecurityProperties props = context.getBean(SecurityProperties.class);
assertThat(props.isEnableTwoFactor()).isTrue();
assertThat(props.getMaxLoginAttempts()).isEqualTo(3);
});
}
@Test
void shouldLoadDevelopmentConfiguration() {
new ApplicationContextRunner()
.withPropertyValues("spring.profiles.active=dev")
.withUserConfiguration(DevelopmentConfiguration.class)
.run(context -> {
SecurityProperties props = context.getBean(SecurityProperties.class);
assertThat(props.isEnableTwoFactor()).isFalse();
assertThat(props.getMaxLoginAttempts()).isEqualTo(999);
});
}
}
```
## Testing Type Conversion
### Property Type Binding
```java
@ConfigurationProperties(prefix = "app.features")
@Data
public class FeatureProperties {
private Duration cacheExpiry = Duration.ofMinutes(10);
private DataSize maxUploadSize = DataSize.ofMegabytes(100);
private List<String> enabledFeatures;
private Map<String, String> featureFlags;
private Charset fileEncoding = StandardCharsets.UTF_8;
}
class TypeConversionTest {
@Test
void shouldConvertStringToDuration() {
new ApplicationContextRunner()
.withPropertyValues("app.features.cacheExpiry=30s")
.withBean(FeatureProperties.class)
.run(context -> {
FeatureProperties props = context.getBean(FeatureProperties.class);
assertThat(props.getCacheExpiry()).isEqualTo(Duration.ofSeconds(30));
});
}
@Test
void shouldConvertStringToDataSize() {
new ApplicationContextRunner()
.withPropertyValues("app.features.maxUploadSize=50MB")
.withBean(FeatureProperties.class)
.run(context -> {
FeatureProperties props = context.getBean(FeatureProperties.class);
assertThat(props.getMaxUploadSize()).isEqualTo(DataSize.ofMegabytes(50));
});
}
@Test
void shouldConvertCommaDelimitedListToList() {
new ApplicationContextRunner()
.withPropertyValues("app.features.enabledFeatures=feature1,feature2,feature3")
.withBean(FeatureProperties.class)
.run(context -> {
FeatureProperties props = context.getBean(FeatureProperties.class);
assertThat(props.getEnabledFeatures())
.containsExactly("feature1", "feature2", "feature3");
});
}
}
```
## Testing Property Binding with Default Values
### Verify Default Configuration
```java
@ConfigurationProperties(prefix = "app.cache")
@Data
public class CacheProperties {
private long ttlSeconds = 300;
private int maxSize = 1000;
private boolean enabled = true;
private String cacheType = "IN_MEMORY";
}
class DefaultValuesTest {
@Test
void shouldUseDefaultValuesWhenNotSpecified() {
new ApplicationContextRunner()
.withBean(CacheProperties.class)
.run(context -> {
CacheProperties props = context.getBean(CacheProperties.class);
assertThat(props.getTtlSeconds()).isEqualTo(300L);
assertThat(props.getMaxSize()).isEqualTo(1000);
assertThat(props.isEnabled()).isTrue();
assertThat(props.getCacheType()).isEqualTo("IN_MEMORY");
});
}
@Test
void shouldOverrideDefaultValuesWithProvidedProperties() {
new ApplicationContextRunner()
.withPropertyValues(
"app.cache.ttlSeconds=600",
"app.cache.cacheType=REDIS"
)
.withBean(CacheProperties.class)
.run(context -> {
CacheProperties props = context.getBean(CacheProperties.class);
assertThat(props.getTtlSeconds()).isEqualTo(600L);
assertThat(props.getCacheType()).isEqualTo("REDIS");
assertThat(props.getMaxSize()).isEqualTo(1000); // Default unchanged
});
}
}
```
## Best Practices
- **Test all property bindings** including nested structures
- **Test validation constraints** thoroughly
- **Test both default and custom values**
- **Use ApplicationContextRunner** for context-free testing
- **Test profile-specific configurations** separately
- **Verify type conversions** work correctly
- **Test edge cases** (empty strings, null values, type mismatches)
## Common Pitfalls
- Not testing validation constraints
- Forgetting to test default values
- Not testing nested property structures
- Testing with wrong property prefix
- Not handling type conversion properly
## Troubleshooting
**Properties not binding**: Verify prefix and property names match exactly (including kebab-case to camelCase conversion).
**Validation not triggered**: Ensure `@Validated` is present and validation dependencies are on classpath.
**ApplicationContextRunner not found**: Verify `spring-boot-starter-test` is in test dependencies.
## References
- [Spring Boot ConfigurationProperties](https://docs.spring.io/spring-boot/docs/current/reference/html/configuration-metadata.html)
- [ApplicationContextRunner Testing](https://docs.spring.io/spring-boot/docs/current/api/org/springframework/boot/test/context/runner/ApplicationContextRunner.html)
- [Spring Profiles](https://docs.spring.io/spring-boot/docs/current/reference/html/features.html#features.profiles)

View File

@@ -0,0 +1,351 @@
---
name: unit-test-controller-layer
description: Unit tests for REST controllers using MockMvc and @WebMvcTest. Test request/response mapping, validation, and exception handling. Use when testing web layer endpoints in isolation.
category: testing
tags: [junit-5, mockito, unit-testing, controller, rest, mockmvc]
version: 1.0.1
---
# Unit Testing REST Controllers with MockMvc
Test @RestController and @Controller classes by mocking service dependencies and verifying HTTP responses, status codes, and serialization. Use MockMvc for lightweight controller testing without loading the full Spring context.
## When to Use This Skill
Use this skill when:
- Testing REST controller request/response handling
- Verifying HTTP status codes and response formats
- Testing request parameter binding and validation
- Mocking service layer for isolated controller tests
- Testing content negotiation and response headers
- Want fast controller tests without integration test overhead
## Setup: MockMvc + Mockito
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("org.mockito:mockito-core")
}
```
## Basic Pattern: Testing GET Endpoint
### Simple GET Endpoint Test
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import static org.mockito.Mockito.*;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
@ExtendWith(MockitoExtension.class)
class UserControllerTest {
@Mock
private UserService userService;
@InjectMocks
private UserController userController;
private MockMvc mockMvc;
void setUp() {
mockMvc = MockMvcBuilders.standaloneSetup(userController).build();
}
@Test
void shouldReturnAllUsers() throws Exception {
List<UserDto> users = List.of(
new UserDto(1L, "Alice"),
new UserDto(2L, "Bob")
);
when(userService.getAllUsers()).thenReturn(users);
mockMvc.perform(get("/api/users"))
.andExpect(status().isOk())
.andExpect(jsonPath("$").isArray())
.andExpect(jsonPath("$[0].id").value(1))
.andExpect(jsonPath("$[0].name").value("Alice"))
.andExpect(jsonPath("$[1].id").value(2));
verify(userService, times(1)).getAllUsers();
}
@Test
void shouldReturnUserById() throws Exception {
UserDto user = new UserDto(1L, "Alice");
when(userService.getUserById(1L)).thenReturn(user);
mockMvc.perform(get("/api/users/1"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.id").value(1))
.andExpect(jsonPath("$.name").value("Alice"));
verify(userService).getUserById(1L);
}
}
```
## Testing POST Endpoint
### Create Resource with Request Body
```java
@Test
void shouldCreateUserAndReturn201() throws Exception {
UserCreateRequest request = new UserCreateRequest("Alice", "alice@example.com");
UserDto createdUser = new UserDto(1L, "Alice", "alice@example.com");
when(userService.createUser(any(UserCreateRequest.class)))
.thenReturn(createdUser);
mockMvc.perform(post("/api/users")
.contentType("application/json")
.content("{\"name\":\"Alice\",\"email\":\"alice@example.com\"}"))
.andExpect(status().isCreated())
.andExpect(jsonPath("$.id").value(1))
.andExpect(jsonPath("$.name").value("Alice"))
.andExpect(jsonPath("$.email").value("alice@example.com"));
verify(userService).createUser(any(UserCreateRequest.class));
}
```
## Testing Error Scenarios
### Handle 404 Not Found
```java
@Test
void shouldReturn404WhenUserNotFound() throws Exception {
when(userService.getUserById(999L))
.thenThrow(new UserNotFoundException("User not found"));
mockMvc.perform(get("/api/users/999"))
.andExpect(status().isNotFound())
.andExpect(jsonPath("$.error").value("User not found"));
verify(userService).getUserById(999L);
}
```
### Handle 400 Bad Request
```java
@Test
void shouldReturn400WhenRequestBodyInvalid() throws Exception {
mockMvc.perform(post("/api/users")
.contentType("application/json")
.content("{\"name\":\"\"}")) // Empty name
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.errors").isArray());
}
```
## Testing PUT/PATCH Endpoints
### Update Resource
```java
@Test
void shouldUpdateUserAndReturn200() throws Exception {
UserUpdateRequest request = new UserUpdateRequest("Alice Updated");
UserDto updatedUser = new UserDto(1L, "Alice Updated");
when(userService.updateUser(eq(1L), any(UserUpdateRequest.class)))
.thenReturn(updatedUser);
mockMvc.perform(put("/api/users/1")
.contentType("application/json")
.content("{\"name\":\"Alice Updated\"}"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.id").value(1))
.andExpect(jsonPath("$.name").value("Alice Updated"));
verify(userService).updateUser(eq(1L), any(UserUpdateRequest.class));
}
```
## Testing DELETE Endpoint
### Delete Resource
```java
@Test
void shouldDeleteUserAndReturn204() throws Exception {
doNothing().when(userService).deleteUser(1L);
mockMvc.perform(delete("/api/users/1"))
.andExpect(status().isNoContent());
verify(userService).deleteUser(1L);
}
@Test
void shouldReturn404WhenDeletingNonExistentUser() throws Exception {
doThrow(new UserNotFoundException("User not found"))
.when(userService).deleteUser(999L);
mockMvc.perform(delete("/api/users/999"))
.andExpect(status().isNotFound());
}
```
## Testing Request Parameters
### Query Parameters
```java
@Test
void shouldFilterUsersByName() throws Exception {
List<UserDto> users = List.of(new UserDto(1L, "Alice"));
when(userService.searchUsers("Alice")).thenReturn(users);
mockMvc.perform(get("/api/users/search?name=Alice"))
.andExpect(status().isOk())
.andExpect(jsonPath("$").isArray())
.andExpect(jsonPath("$[0].name").value("Alice"));
verify(userService).searchUsers("Alice");
}
```
### Path Variables
```java
@Test
void shouldGetUserByIdFromPath() throws Exception {
UserDto user = new UserDto(123L, "Alice");
when(userService.getUserById(123L)).thenReturn(user);
mockMvc.perform(get("/api/users/{id}", 123L))
.andExpect(status().isOk())
.andExpect(jsonPath("$.id").value(123));
}
```
## Testing Response Headers
### Verify Response Headers
```java
@Test
void shouldReturnCustomHeaders() throws Exception {
when(userService.getAllUsers()).thenReturn(List.of());
mockMvc.perform(get("/api/users"))
.andExpect(status().isOk())
.andExpect(header().exists("X-Total-Count"))
.andExpect(header().string("X-Total-Count", "0"))
.andExpect(header().string("Content-Type", containsString("application/json")));
}
```
## Testing Request Headers
### Send Request Headers
```java
@Test
void shouldRequireAuthorizationHeader() throws Exception {
mockMvc.perform(get("/api/users"))
.andExpect(status().isUnauthorized());
mockMvc.perform(get("/api/users")
.header("Authorization", "Bearer token123"))
.andExpect(status().isOk());
}
```
## Content Negotiation
### Test Different Accept Headers
```java
@Test
void shouldReturnJsonWhenAcceptHeaderIsJson() throws Exception {
UserDto user = new UserDto(1L, "Alice");
when(userService.getUserById(1L)).thenReturn(user);
mockMvc.perform(get("/api/users/1")
.accept("application/json"))
.andExpect(status().isOk())
.andExpect(content().contentType("application/json"));
}
```
## Advanced: Testing Multiple Status Codes
```java
@Test
void shouldReturnDifferentStatusCodesForDifferentScenarios() throws Exception {
// Successful response
when(userService.getUserById(1L)).thenReturn(new UserDto(1L, "Alice"));
mockMvc.perform(get("/api/users/1"))
.andExpect(status().isOk());
// Not found
when(userService.getUserById(999L))
.thenThrow(new UserNotFoundException("Not found"));
mockMvc.perform(get("/api/users/999"))
.andExpect(status().isNotFound());
// Unauthorized
mockMvc.perform(get("/api/admin/users"))
.andExpect(status().isUnauthorized());
}
```
## Best Practices
- **Use standalone setup** when testing single controller: `MockMvcBuilders.standaloneSetup()`
- **Mock service layer** - controllers should focus on HTTP handling
- **Test happy path and error paths** thoroughly
- **Verify service method calls** to ensure controller delegates correctly
- **Use content() matchers** for response body validation
- **Keep tests focused** on one endpoint behavior per test
- **Use JsonPath** for fluent JSON response assertions
## Common Pitfalls
- **Testing business logic in controller**: Move to service tests
- **Not mocking service layer**: Always mock service dependencies
- **Testing framework behavior**: Focus on your code, not Spring code
- **Hardcoding URLs**: Use MockMvcRequestBuilders helpers
- **Not verifying mock interactions**: Always verify service was called correctly
## Troubleshooting
**Content type mismatch**: Ensure `contentType()` matches controller's `@PostMapping(consumes=...)` or use default.
**JsonPath not matching**: Use `mockMvc.perform(...).andDo(print())` to see actual response content.
**Status code assertions fail**: Check controller `@RequestMapping`, `@PostMapping` status codes and error handling.
## References
- [Spring MockMvc Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/test/web/servlet/MockMvc.html)
- [JsonPath for REST Assertions](https://goessner.net/articles/JsonPath/)
- [Spring Testing Best Practices](https://docs.spring.io/spring-boot/docs/current/reference/html/features.html#features.testing)

View File

@@ -0,0 +1,466 @@
---
name: unit-test-exception-handler
description: Unit tests for @ExceptionHandler and @ControllerAdvice for global exception handling. Use when validating error response formatting and HTTP status codes.
category: testing
tags: [junit-5, exception-handler, controller-advice, error-handling, mockmvc]
version: 1.0.1
---
# Unit Testing ExceptionHandler and ControllerAdvice
Test exception handlers and global exception handling logic using MockMvc. Verify error response formatting, HTTP status codes, and exception-to-response mapping.
## When to Use This Skill
Use this skill when:
- Testing @ExceptionHandler methods in @ControllerAdvice
- Testing exception-to-error-response transformations
- Verifying HTTP status codes for different exception types
- Testing error message formatting and localization
- Want fast exception handler tests without full integration tests
## Setup: Exception Handler Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.springframework.boot:spring-boot-starter-web")
testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Global Exception Handler
### Create Exception Handler
```java
// Global exception handler
@ControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(ResourceNotFoundException.class)
@ResponseStatus(HttpStatus.NOT_FOUND)
public ErrorResponse handleResourceNotFound(ResourceNotFoundException ex) {
return new ErrorResponse(
HttpStatus.NOT_FOUND.value(),
"Resource not found",
ex.getMessage()
);
}
@ExceptionHandler(ValidationException.class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
public ErrorResponse handleValidationException(ValidationException ex) {
return new ErrorResponse(
HttpStatus.BAD_REQUEST.value(),
"Validation failed",
ex.getMessage()
);
}
}
// Error response DTO
public record ErrorResponse(
int status,
String error,
String message
) {}
```
### Unit Test Exception Handler
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
@ExtendWith(MockitoExtension.class)
class GlobalExceptionHandlerTest {
@InjectMocks
private GlobalExceptionHandler exceptionHandler;
private MockMvc mockMvc;
@BeforeEach
void setUp() {
mockMvc = MockMvcBuilders
.standaloneSetup(new TestController())
.setControllerAdvice(exceptionHandler)
.build();
}
@Test
void shouldReturnNotFoundWhenResourceNotFoundException() throws Exception {
mockMvc.perform(get("/api/users/999"))
.andExpect(status().isNotFound())
.andExpect(jsonPath("$.status").value(404))
.andExpect(jsonPath("$.error").value("Resource not found"))
.andExpect(jsonPath("$.message").value("User not found"));
}
@Test
void shouldReturnBadRequestWhenValidationException() throws Exception {
mockMvc.perform(post("/api/users")
.contentType("application/json")
.content("{\"name\":\"\"}"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.status").value(400))
.andExpect(jsonPath("$.error").value("Validation failed"));
}
}
// Test controller that throws exceptions
@RestController
@RequestMapping("/api")
class TestController {
@GetMapping("/users/{id}")
public User getUser(@PathVariable Long id) {
throw new ResourceNotFoundException("User not found");
}
}
```
## Testing Multiple Exception Types
### Handle Various Exception Types
```java
@ControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(ResourceNotFoundException.class)
@ResponseStatus(HttpStatus.NOT_FOUND)
public ErrorResponse handleResourceNotFound(ResourceNotFoundException ex) {
return new ErrorResponse(404, "Not found", ex.getMessage());
}
@ExceptionHandler(DuplicateResourceException.class)
@ResponseStatus(HttpStatus.CONFLICT)
public ErrorResponse handleDuplicateResource(DuplicateResourceException ex) {
return new ErrorResponse(409, "Conflict", ex.getMessage());
}
@ExceptionHandler(UnauthorizedException.class)
@ResponseStatus(HttpStatus.UNAUTHORIZED)
public ErrorResponse handleUnauthorized(UnauthorizedException ex) {
return new ErrorResponse(401, "Unauthorized", ex.getMessage());
}
@ExceptionHandler(AccessDeniedException.class)
@ResponseStatus(HttpStatus.FORBIDDEN)
public ErrorResponse handleAccessDenied(AccessDeniedException ex) {
return new ErrorResponse(403, "Forbidden", ex.getMessage());
}
@ExceptionHandler(Exception.class)
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
public ErrorResponse handleGenericException(Exception ex) {
return new ErrorResponse(500, "Internal server error", "An unexpected error occurred");
}
}
class MultiExceptionHandlerTest {
private MockMvc mockMvc;
private GlobalExceptionHandler handler;
@BeforeEach
void setUp() {
handler = new GlobalExceptionHandler();
mockMvc = MockMvcBuilders
.standaloneSetup(new TestController())
.setControllerAdvice(handler)
.build();
}
@Test
void shouldReturn404ForNotFound() throws Exception {
mockMvc.perform(get("/api/users/999"))
.andExpect(status().isNotFound())
.andExpect(jsonPath("$.status").value(404));
}
@Test
void shouldReturn409ForDuplicate() throws Exception {
mockMvc.perform(post("/api/users")
.contentType("application/json")
.content("{\"email\":\"existing@example.com\"}"))
.andExpect(status().isConflict())
.andExpect(jsonPath("$.status").value(409));
}
@Test
void shouldReturn401ForUnauthorized() throws Exception {
mockMvc.perform(get("/api/admin/dashboard"))
.andExpect(status().isUnauthorized())
.andExpect(jsonPath("$.status").value(401));
}
@Test
void shouldReturn403ForAccessDenied() throws Exception {
mockMvc.perform(get("/api/admin/users"))
.andExpect(status().isForbidden())
.andExpect(jsonPath("$.status").value(403));
}
@Test
void shouldReturn500ForGenericException() throws Exception {
mockMvc.perform(get("/api/error"))
.andExpect(status().isInternalServerError())
.andExpect(jsonPath("$.status").value(500));
}
}
```
## Testing Error Response Structure
### Verify Error Response Format
```java
@ControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(BadRequestException.class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
public ResponseEntity<ErrorDetails> handleBadRequest(BadRequestException ex) {
ErrorDetails details = new ErrorDetails(
System.currentTimeMillis(),
HttpStatus.BAD_REQUEST.value(),
"Bad Request",
ex.getMessage(),
new Date()
);
return new ResponseEntity<>(details, HttpStatus.BAD_REQUEST);
}
}
class ErrorResponseStructureTest {
private MockMvc mockMvc;
@BeforeEach
void setUp() {
mockMvc = MockMvcBuilders
.standaloneSetup(new TestController())
.setControllerAdvice(new GlobalExceptionHandler())
.build();
}
@Test
void shouldIncludeTimestampInErrorResponse() throws Exception {
mockMvc.perform(post("/api/data")
.contentType("application/json")
.content("{}"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.timestamp").exists())
.andExpect(jsonPath("$.status").value(400))
.andExpect(jsonPath("$.error").value("Bad Request"))
.andExpect(jsonPath("$.message").exists())
.andExpect(jsonPath("$.date").exists());
}
@Test
void shouldIncludeAllRequiredErrorFields() throws Exception {
MvcResult result = mockMvc.perform(get("/api/invalid"))
.andExpect(status().isBadRequest())
.andReturn();
String response = result.getResponse().getContentAsString();
assertThat(response).contains("timestamp");
assertThat(response).contains("status");
assertThat(response).contains("error");
assertThat(response).contains("message");
}
}
```
## Testing Validation Error Handling
### Handle MethodArgumentNotValidException
```java
@ControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(MethodArgumentNotValidException.class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
public ValidationErrorResponse handleValidationException(
MethodArgumentNotValidException ex) {
Map<String, String> errors = new HashMap<>();
ex.getBindingResult().getFieldErrors().forEach(error ->
errors.put(error.getField(), error.getDefaultMessage())
);
return new ValidationErrorResponse(
HttpStatus.BAD_REQUEST.value(),
"Validation failed",
errors
);
}
}
class ValidationExceptionHandlerTest {
private MockMvc mockMvc;
@BeforeEach
void setUp() {
mockMvc = MockMvcBuilders
.standaloneSetup(new UserController())
.setControllerAdvice(new GlobalExceptionHandler())
.build();
}
@Test
void shouldReturnValidationErrorsForInvalidInput() throws Exception {
mockMvc.perform(post("/api/users")
.contentType("application/json")
.content("{\"name\":\"\",\"age\":-5}"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.status").value(400))
.andExpect(jsonPath("$.errors.name").exists())
.andExpect(jsonPath("$.errors.age").exists());
}
@Test
void shouldIncludeErrorMessageForEachField() throws Exception {
mockMvc.perform(post("/api/users")
.contentType("application/json")
.content("{\"name\":\"\",\"email\":\"invalid\"}"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.errors.name").value("must not be blank"))
.andExpect(jsonPath("$.errors.email").value("must be valid email"));
}
}
```
## Testing Exception Handler with Custom Logic
### Exception Handler with Context
```java
@ControllerAdvice
public class GlobalExceptionHandler {
private final MessageService messageService;
private final LoggingService loggingService;
public GlobalExceptionHandler(MessageService messageService, LoggingService loggingService) {
this.messageService = messageService;
this.loggingService = loggingService;
}
@ExceptionHandler(BusinessException.class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
public ErrorResponse handleBusinessException(BusinessException ex, HttpServletRequest request) {
loggingService.logException(ex, request.getRequestURI());
String localizedMessage = messageService.getMessage(ex.getErrorCode());
return new ErrorResponse(
HttpStatus.BAD_REQUEST.value(),
"Business error",
localizedMessage
);
}
}
class ExceptionHandlerWithContextTest {
private MockMvc mockMvc;
private GlobalExceptionHandler handler;
private MessageService messageService;
private LoggingService loggingService;
@BeforeEach
void setUp() {
messageService = mock(MessageService.class);
loggingService = mock(LoggingService.class);
handler = new GlobalExceptionHandler(messageService, loggingService);
mockMvc = MockMvcBuilders
.standaloneSetup(new TestController())
.setControllerAdvice(handler)
.build();
}
@Test
void shouldLocalizeErrorMessage() throws Exception {
when(messageService.getMessage("USER_NOT_FOUND"))
.thenReturn("L'utilisateur n'a pas été trouvé");
mockMvc.perform(get("/api/users/999"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.message").value("L'utilisateur n'a pas été trouvé"));
verify(messageService).getMessage("USER_NOT_FOUND");
}
@Test
void shouldLogExceptionOccurrence() throws Exception {
mockMvc.perform(get("/api/users/999"))
.andExpect(status().isBadRequest());
verify(loggingService).logException(any(BusinessException.class), anyString());
}
}
```
## Best Practices
- **Test all exception handlers** with real exception throws
- **Verify HTTP status codes** for each exception type
- **Test error response structure** to ensure consistency
- **Verify logging** is triggered appropriately
- **Use mock controllers** to throw exceptions in tests
- **Test both happy and error paths**
- **Keep error messages user-friendly** and consistent
## Common Pitfalls
- Not testing the full request path (use MockMvc with controller)
- Forgetting to include `@ControllerAdvice` in MockMvc setup
- Not verifying all required fields in error response
- Testing handler logic instead of exception handling behavior
- Not testing edge cases (null exceptions, unusual messages)
## Troubleshooting
**Exception handler not invoked**: Ensure controller is registered with MockMvc and actually throws the exception.
**JsonPath matchers not matching**: Use `.andDo(print())` to see actual response structure.
**Status code mismatch**: Verify `@ResponseStatus` annotation on handler method.
## References
- [Spring ControllerAdvice Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/bind/annotation/ControllerAdvice.html)
- [Spring ExceptionHandler](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/bind/annotation/ExceptionHandler.html)
- [MockMvc Testing](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/test/web/servlet/MockMvc.html)

View File

@@ -0,0 +1,398 @@
---
name: unit-test-json-serialization
description: Unit tests for JSON serialization/deserialization with Jackson and @JsonTest. Use when validating JSON mapping, custom serializers, and date format handling.
category: testing
tags: [junit-5, json-test, jackson, serialization, deserialization]
version: 1.0.1
---
# Unit Testing JSON Serialization with @JsonTest
Test JSON serialization and deserialization of POJOs using Spring's @JsonTest. Verify Jackson configuration, custom serializers, and JSON mapping accuracy.
## When to Use This Skill
Use this skill when:
- Testing JSON serialization of DTOs
- Testing JSON deserialization to objects
- Testing custom Jackson serializers/deserializers
- Verifying JSON field names and formats
- Testing null handling in JSON
- Want fast JSON mapping tests without full Spring context
## Setup: JSON Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-json</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.springframework.boot:spring-boot-starter-json")
implementation("com.fasterxml.jackson.core:jackson-databind")
testImplementation("org.springframework.boot:spring-boot-starter-test")
}
```
## Basic Pattern: @JsonTest
### Test JSON Serialization
```java
import org.springframework.boot.test.autoconfigure.json.JsonTest;
import org.springframework.boot.test.json.JacksonTester;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.*;
@JsonTest
class UserDtoJsonTest {
@Autowired
private JacksonTester<UserDto> json;
@Test
void shouldSerializeUserToJson() throws Exception {
UserDto user = new UserDto(1L, "Alice", "alice@example.com", 25);
org.assertj.core.data.Offset result = json.write(user);
result
.extractingJsonPathNumberValue("$.id").isEqualTo(1)
.extractingJsonPathStringValue("$.name").isEqualTo("Alice")
.extractingJsonPathStringValue("$.email").isEqualTo("alice@example.com")
.extractingJsonPathNumberValue("$.age").isEqualTo(25);
}
@Test
void shouldDeserializeJsonToUser() throws Exception {
String json_content = "{\"id\":1,\"name\":\"Alice\",\"email\":\"alice@example.com\",\"age\":25}";
UserDto user = json.parse(json_content).getObject();
assertThat(user)
.isNotNull()
.hasFieldOrPropertyWithValue("id", 1L)
.hasFieldOrPropertyWithValue("name", "Alice")
.hasFieldOrPropertyWithValue("email", "alice@example.com")
.hasFieldOrPropertyWithValue("age", 25);
}
@Test
void shouldHandleNullFields() throws Exception {
String json_content = "{\"id\":1,\"name\":null,\"email\":\"alice@example.com\",\"age\":null}";
UserDto user = json.parse(json_content).getObject();
assertThat(user.getName()).isNull();
assertThat(user.getAge()).isNull();
}
}
```
## Testing Custom JSON Properties
### @JsonProperty and @JsonIgnore
```java
public class Order {
@JsonProperty("order_id")
private Long id;
@JsonProperty("total_amount")
private BigDecimal amount;
@JsonIgnore
private String internalNote;
private LocalDateTime createdAt;
}
@JsonTest
class OrderJsonTest {
@Autowired
private JacksonTester<Order> json;
@Test
void shouldMapJsonPropertyNames() throws Exception {
String json_content = "{\"order_id\":123,\"total_amount\":99.99,\"createdAt\":\"2024-01-15T10:30:00\"}";
Order order = json.parse(json_content).getObject();
assertThat(order.getId()).isEqualTo(123L);
assertThat(order.getAmount()).isEqualByComparingTo(new BigDecimal("99.99"));
}
@Test
void shouldIgnoreJsonIgnoreAnnotatedFields() throws Exception {
Order order = new Order(123L, new BigDecimal("99.99"));
order.setInternalNote("Secret note");
JsonContent<Order> result = json.write(order);
assertThat(result.json).doesNotContain("internalNote");
}
}
```
## Testing List Deserialization
### JSON Arrays
```java
@JsonTest
class UserListJsonTest {
@Autowired
private JacksonTester<List<UserDto>> json;
@Test
void shouldDeserializeUserList() throws Exception {
String jsonArray = "[{\"id\":1,\"name\":\"Alice\"},{\"id\":2,\"name\":\"Bob\"}]";
List<UserDto> users = json.parseObject(jsonArray);
assertThat(users)
.hasSize(2)
.extracting(UserDto::getName)
.containsExactly("Alice", "Bob");
}
@Test
void shouldSerializeUserListToJson() throws Exception {
List<UserDto> users = List.of(
new UserDto(1L, "Alice"),
new UserDto(2L, "Bob")
);
JsonContent<List<UserDto>> result = json.write(users);
result.json.contains("Alice").contains("Bob");
}
}
```
## Testing Nested Objects
### Complex JSON Structures
```java
public class Product {
private Long id;
private String name;
private Category category;
private List<Review> reviews;
}
public class Category {
private Long id;
private String name;
}
public class Review {
private String reviewer;
private int rating;
private String comment;
}
@JsonTest
class ProductJsonTest {
@Autowired
private JacksonTester<Product> json;
@Test
void shouldSerializeNestedObjects() throws Exception {
Category category = new Category(1L, "Electronics");
Product product = new Product(1L, "Laptop", category);
JsonContent<Product> result = json.write(product);
result
.extractingJsonPathNumberValue("$.id").isEqualTo(1)
.extractingJsonPathStringValue("$.name").isEqualTo("Laptop")
.extractingJsonPathNumberValue("$.category.id").isEqualTo(1)
.extractingJsonPathStringValue("$.category.name").isEqualTo("Electronics");
}
@Test
void shouldDeserializeNestedObjects() throws Exception {
String json_content = "{\"id\":1,\"name\":\"Laptop\",\"category\":{\"id\":1,\"name\":\"Electronics\"}}";
Product product = json.parse(json_content).getObject();
assertThat(product.getCategory())
.isNotNull()
.hasFieldOrPropertyWithValue("name", "Electronics");
}
@Test
void shouldHandleListOfNestedObjects() throws Exception {
String json_content = "{\"id\":1,\"name\":\"Laptop\",\"reviews\":[{\"reviewer\":\"John\",\"rating\":5},{\"reviewer\":\"Jane\",\"rating\":4}]}";
Product product = json.parse(json_content).getObject();
assertThat(product.getReviews())
.hasSize(2)
.extracting(Review::getRating)
.containsExactly(5, 4);
}
}
```
## Testing Date/Time Formatting
### LocalDateTime and Other Temporal Types
```java
@JsonTest
class DateTimeJsonTest {
@Autowired
private JacksonTester<Event> json;
@Test
void shouldFormatDateTimeCorrectly() throws Exception {
LocalDateTime dateTime = LocalDateTime.of(2024, 1, 15, 10, 30, 0);
Event event = new Event("Conference", dateTime);
JsonContent<Event> result = json.write(event);
result.extractingJsonPathStringValue("$.scheduledAt").isEqualTo("2024-01-15T10:30:00");
}
@Test
void shouldDeserializeDateTimeFromJson() throws Exception {
String json_content = "{\"name\":\"Conference\",\"scheduledAt\":\"2024-01-15T10:30:00\"}";
Event event = json.parse(json_content).getObject();
assertThat(event.getScheduledAt())
.isEqualTo(LocalDateTime.of(2024, 1, 15, 10, 30, 0));
}
}
```
## Testing Custom Serializers
### Custom JsonSerializer Implementation
```java
public class CustomMoneySerializer extends JsonSerializer<BigDecimal> {
@Override
public void serialize(BigDecimal value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
if (value == null) {
gen.writeNull();
} else {
gen.writeString(String.format("$%.2f", value));
}
}
}
public class Price {
@JsonSerialize(using = CustomMoneySerializer.class)
private BigDecimal amount;
}
@JsonTest
class CustomSerializerTest {
@Autowired
private JacksonTester<Price> json;
@Test
void shouldUseCustomSerializer() throws Exception {
Price price = new Price(new BigDecimal("99.99"));
JsonContent<Price> result = json.write(price);
result.extractingJsonPathStringValue("$.amount").isEqualTo("$99.99");
}
}
```
## Testing Polymorphic Deserialization
### Type Information in JSON
```java
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = CreditCard.class, name = "credit_card"),
@JsonSubTypes.Type(value = PayPal.class, name = "paypal")
})
public abstract class PaymentMethod {
private String id;
}
@JsonTest
class PolymorphicJsonTest {
@Autowired
private JacksonTester<PaymentMethod> json;
@Test
void shouldDeserializeCreditCard() throws Exception {
String json_content = "{\"type\":\"credit_card\",\"id\":\"card123\",\"cardNumber\":\"****1234\"}";
PaymentMethod method = json.parse(json_content).getObject();
assertThat(method).isInstanceOf(CreditCard.class);
}
@Test
void shouldDeserializePayPal() throws Exception {
String json_content = "{\"type\":\"paypal\",\"id\":\"pp123\",\"email\":\"user@paypal.com\"}";
PaymentMethod method = json.parse(json_content).getObject();
assertThat(method).isInstanceOf(PayPal.class);
}
}
```
## Best Practices
- **Use @JsonTest** for focused JSON testing
- **Test both serialization and deserialization**
- **Test null handling** and missing fields
- **Test nested and complex structures**
- **Verify field name mapping** with @JsonProperty
- **Test date/time formatting** thoroughly
- **Test edge cases** (empty strings, empty collections)
## Common Pitfalls
- Not testing null values
- Not testing nested objects
- Forgetting to test field name mappings
- Not verifying JSON property presence/absence
- Not testing deserialization of invalid JSON
## Troubleshooting
**JacksonTester not available**: Ensure class is annotated with `@JsonTest`.
**Field name doesn't match**: Check @JsonProperty annotation and Jackson configuration.
**DateTime parsing fails**: Verify date format matches Jackson's expected format.
## References
- [Spring @JsonTest Documentation](https://docs.spring.io/spring-boot/docs/current/api/org/springframework/boot/test/autoconfigure/json/JsonTest.html)
- [Jackson ObjectMapper](https://fasterxml.github.io/jackson-databind/javadoc/2.15/com/fasterxml/jackson/databind/ObjectMapper.html)
- [JSON Annotations](https://fasterxml.github.io/jackson-annotations/javadoc/2.15/)

View File

@@ -0,0 +1,434 @@
---
name: unit-test-mapper-converter
description: Unit tests for mappers and converters (MapStruct, custom mappers). Test object transformation logic in isolation. Use when ensuring correct data transformation between DTOs and domain objects.
category: testing
tags: [junit-5, mapstruct, mapper, dto, entity, converter]
version: 1.0.1
---
# Unit Testing Mappers and Converters
Test MapStruct mappers and custom converter classes. Verify field mapping accuracy, null handling, type conversions, and nested object transformations.
## When to Use This Skill
Use this skill when:
- Testing MapStruct mapper implementations
- Testing custom entity-to-DTO converters
- Testing nested object mapping
- Verifying null handling in mappers
- Testing type conversions and transformations
- Want comprehensive mapping test coverage before integration tests
## Setup: Testing Mappers
### Maven
```xml
<dependency>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct</artifactId>
<version>1.5.5.Final</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.mapstruct:mapstruct:1.5.5.Final")
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Testing MapStruct Mapper
### Simple Entity to DTO Mapping
```java
// Mapper interface
@Mapper(componentModel = "spring")
public interface UserMapper {
UserDto toDto(User user);
User toEntity(UserDto dto);
List<UserDto> toDtos(List<User> users);
}
// Unit test
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.*;
class UserMapperTest {
private final UserMapper userMapper = Mappers.getMapper(UserMapper.class);
@Test
void shouldMapUserToDto() {
User user = new User(1L, "Alice", "alice@example.com", 25);
UserDto dto = userMapper.toDto(user);
assertThat(dto)
.isNotNull()
.extracting("id", "name", "email", "age")
.containsExactly(1L, "Alice", "alice@example.com", 25);
}
@Test
void shouldMapDtoToEntity() {
UserDto dto = new UserDto(1L, "Alice", "alice@example.com", 25);
User user = userMapper.toEntity(dto);
assertThat(user)
.isNotNull()
.hasFieldOrPropertyWithValue("id", 1L)
.hasFieldOrPropertyWithValue("name", "Alice");
}
@Test
void shouldMapListOfUsers() {
List<User> users = List.of(
new User(1L, "Alice", "alice@example.com", 25),
new User(2L, "Bob", "bob@example.com", 30)
);
List<UserDto> dtos = userMapper.toDtos(users);
assertThat(dtos)
.hasSize(2)
.extracting(UserDto::getName)
.containsExactly("Alice", "Bob");
}
@Test
void shouldHandleNullEntity() {
UserDto dto = userMapper.toDto(null);
assertThat(dto).isNull();
}
}
```
## Testing Nested Object Mapping
### Map Complex Hierarchies
```java
// Entities with nesting
class User {
private Long id;
private String name;
private Address address;
private List<Phone> phones;
}
// Mapper with nested mapping
@Mapper(componentModel = "spring")
public interface UserMapper {
UserDto toDto(User user);
User toEntity(UserDto dto);
}
// Unit test for nested objects
class NestedObjectMapperTest {
private final UserMapper userMapper = Mappers.getMapper(UserMapper.class);
@Test
void shouldMapNestedAddress() {
Address address = new Address("123 Main St", "New York", "NY", "10001");
User user = new User(1L, "Alice", address);
UserDto dto = userMapper.toDto(user);
assertThat(dto.getAddress())
.isNotNull()
.hasFieldOrPropertyWithValue("street", "123 Main St")
.hasFieldOrPropertyWithValue("city", "New York");
}
@Test
void shouldMapListOfNestedPhones() {
List<Phone> phones = List.of(
new Phone("123-456-7890", "MOBILE"),
new Phone("987-654-3210", "HOME")
);
User user = new User(1L, "Alice", null, phones);
UserDto dto = userMapper.toDto(user);
assertThat(dto.getPhones())
.hasSize(2)
.extracting(PhoneDto::getNumber)
.containsExactly("123-456-7890", "987-654-3210");
}
@Test
void shouldHandleNullNestedObjects() {
User user = new User(1L, "Alice", null);
UserDto dto = userMapper.toDto(user);
assertThat(dto.getAddress()).isNull();
}
}
```
## Testing Custom Mapping Methods
### Mapper with @Mapping Annotations
```java
@Mapper(componentModel = "spring")
public interface ProductMapper {
@Mapping(source = "name", target = "productName")
@Mapping(source = "price", target = "salePrice")
@Mapping(target = "discount", expression = "java(product.getPrice() * 0.1)")
ProductDto toDto(Product product);
@Mapping(source = "productName", target = "name")
@Mapping(source = "salePrice", target = "price")
Product toEntity(ProductDto dto);
}
class CustomMappingTest {
private final ProductMapper mapper = Mappers.getMapper(ProductMapper.class);
@Test
void shouldMapFieldsWithCustomNames() {
Product product = new Product(1L, "Laptop", 999.99);
ProductDto dto = mapper.toDto(product);
assertThat(dto)
.hasFieldOrPropertyWithValue("productName", "Laptop")
.hasFieldOrPropertyWithValue("salePrice", 999.99);
}
@Test
void shouldCalculateDiscountFromExpression() {
Product product = new Product(1L, "Laptop", 100.0);
ProductDto dto = mapper.toDto(product);
assertThat(dto.getDiscount()).isEqualTo(10.0);
}
@Test
void shouldReverseMapCustomFields() {
ProductDto dto = new ProductDto(1L, "Laptop", 999.99);
Product product = mapper.toEntity(dto);
assertThat(product)
.hasFieldOrPropertyWithValue("name", "Laptop")
.hasFieldOrPropertyWithValue("price", 999.99);
}
}
```
## Testing Enum Mapping
### Map Enums Between Entity and DTO
```java
// Enum with different representation
enum UserStatus { ACTIVE, INACTIVE, SUSPENDED }
enum UserStatusDto { ENABLED, DISABLED, LOCKED }
@Mapper(componentModel = "spring")
public interface UserMapper {
@ValueMapping(source = "ACTIVE", target = "ENABLED")
@ValueMapping(source = "INACTIVE", target = "DISABLED")
@ValueMapping(source = "SUSPENDED", target = "LOCKED")
UserStatusDto toStatusDto(UserStatus status);
}
class EnumMapperTest {
private final UserMapper mapper = Mappers.getMapper(UserMapper.class);
@Test
void shouldMapActiveToEnabled() {
UserStatusDto dto = mapper.toStatusDto(UserStatus.ACTIVE);
assertThat(dto).isEqualTo(UserStatusDto.ENABLED);
}
@Test
void shouldMapSuspendedToLocked() {
UserStatusDto dto = mapper.toStatusDto(UserStatus.SUSPENDED);
assertThat(dto).isEqualTo(UserStatusDto.LOCKED);
}
}
```
## Testing Custom Type Conversions
### Non-MapStruct Custom Converter
```java
// Custom converter class
public class DateFormatter {
private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
public static String format(LocalDate date) {
return date != null ? date.format(formatter) : null;
}
public static LocalDate parse(String dateString) {
return dateString != null ? LocalDate.parse(dateString, formatter) : null;
}
}
// Unit test
class DateFormatterTest {
@Test
void shouldFormatLocalDateToString() {
LocalDate date = LocalDate.of(2024, 1, 15);
String result = DateFormatter.format(date);
assertThat(result).isEqualTo("2024-01-15");
}
@Test
void shouldParseStringToLocalDate() {
String dateString = "2024-01-15";
LocalDate result = DateFormatter.parse(dateString);
assertThat(result).isEqualTo(LocalDate.of(2024, 1, 15));
}
@Test
void shouldHandleNullInFormat() {
String result = DateFormatter.format(null);
assertThat(result).isNull();
}
@Test
void shouldHandleInvalidDateFormat() {
assertThatThrownBy(() -> DateFormatter.parse("invalid-date"))
.isInstanceOf(DateTimeParseException.class);
}
}
```
## Testing Bidirectional Mapping
### Entity ↔ DTO Round Trip
```java
class BidirectionalMapperTest {
private final UserMapper mapper = Mappers.getMapper(UserMapper.class);
@Test
void shouldMaintainDataInRoundTrip() {
User original = new User(1L, "Alice", "alice@example.com", 25);
UserDto dto = mapper.toDto(original);
User restored = mapper.toEntity(dto);
assertThat(restored)
.hasFieldOrPropertyWithValue("id", original.getId())
.hasFieldOrPropertyWithValue("name", original.getName())
.hasFieldOrPropertyWithValue("email", original.getEmail())
.hasFieldOrPropertyWithValue("age", original.getAge());
}
@Test
void shouldPreserveAllFieldsInBothDirections() {
Address address = new Address("123 Main", "NYC", "NY", "10001");
User user = new User(1L, "Alice", "alice@example.com", 25, address);
UserDto dto = mapper.toDto(user);
User restored = mapper.toEntity(dto);
assertThat(restored).usingRecursiveComparison().isEqualTo(user);
}
}
```
## Testing Partial Mapping
### Update Existing Entity from DTO
```java
@Mapper(componentModel = "spring")
public interface UserMapper {
void updateEntity(@MappingTarget User entity, UserDto dto);
}
class PartialMapperTest {
private final UserMapper mapper = Mappers.getMapper(UserMapper.class);
@Test
void shouldUpdateExistingEntity() {
User existing = new User(1L, "Alice", "alice@old.com", 25);
UserDto dto = new UserDto(1L, "Alice", "alice@new.com", 26);
mapper.updateEntity(existing, dto);
assertThat(existing)
.hasFieldOrPropertyWithValue("email", "alice@new.com")
.hasFieldOrPropertyWithValue("age", 26);
}
@Test
void shouldNotUpdateFieldsNotInDto() {
User existing = new User(1L, "Alice", "alice@example.com", 25);
UserDto dto = new UserDto(1L, "Bob", null, 0);
mapper.updateEntity(existing, dto);
// Assuming null-aware mapping is configured
assertThat(existing.getEmail()).isEqualTo("alice@example.com");
}
}
```
## Best Practices
- **Test all mapper methods** comprehensively
- **Verify null handling** for every nullable field
- **Test nested objects** independently and together
- **Use recursive comparison** for complex nested structures
- **Test bidirectional mapping** to catch asymmetries
- **Keep mapper tests simple and focused** on transformation correctness
- **Use Mappers.getMapper()** for non-Spring standalone tests
## Common Pitfalls
- Not testing null input cases
- Not verifying nested object mappings
- Assuming bidirectional mapping is symmetric
- Not testing edge cases (empty collections, etc.)
- Tight coupling of mapper tests to MapStruct internals
## Troubleshooting
**Null pointer exceptions during mapping**: Check `nullValuePropertyMappingStrategy` and `nullValueCheckStrategy` in `@Mapper`.
**Enum mapping not working**: Verify `@ValueMapping` annotations correctly map source to target values.
**Nested mapping produces null**: Ensure nested mapper interfaces are also mapped in parent mapper.
## References
- [MapStruct Official Documentation](https://mapstruct.org/)
- [MapStruct Mapping Strategies](https://mapstruct.org/documentation/stable/reference/html/)
- [JUnit 5 Best Practices](https://junit.org/junit5/docs/current/user-guide/)

View File

@@ -0,0 +1,374 @@
---
name: unit-test-parameterized
description: Parameterized testing patterns with @ParameterizedTest, @ValueSource, @CsvSource. Run single test method with multiple input combinations. Use when testing multiple scenarios with similar logic.
category: testing
tags: [junit-5, parameterized-test, value-source, csv-source, method-source]
version: 1.0.1
---
# Parameterized Unit Tests with JUnit 5
Write efficient parameterized unit tests that run the same test logic with multiple input values. Reduce test duplication and improve test coverage using @ParameterizedTest.
## When to Use This Skill
Use this skill when:
- Testing methods with multiple valid inputs
- Testing boundary values systematically
- Testing multiple invalid inputs for error cases
- Want to reduce test duplication
- Testing multiple scenarios with similar assertions
- Need data-driven testing approach
## Setup: Parameterized Testing
### Maven
```xml
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: @ValueSource
### Simple Value Testing
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import static org.assertj.core.api.Assertions.*;
class StringUtilsTest {
@ParameterizedTest
@ValueSource(strings = {"hello", "world", "test"})
void shouldCapitalizeAllStrings(String input) {
String result = StringUtils.capitalize(input);
assertThat(result).startsWith(input.substring(0, 1).toUpperCase());
}
@ParameterizedTest
@ValueSource(ints = {1, 2, 3, 4, 5})
void shouldBePositive(int number) {
assertThat(number).isPositive();
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void shouldHandleBothBooleanValues(boolean value) {
assertThat(value).isNotNull();
}
}
```
## @MethodSource for Complex Data
### Factory Method Data Source
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.stream.Stream;
class CalculatorTest {
static Stream<org.junit.jupiter.params.provider.Arguments> additionTestCases() {
return Stream.of(
Arguments.of(1, 2, 3),
Arguments.of(0, 0, 0),
Arguments.of(-1, 1, 0),
Arguments.of(100, 200, 300),
Arguments.of(-5, -10, -15)
);
}
@ParameterizedTest
@MethodSource("additionTestCases")
void shouldAddNumbersCorrectly(int a, int b, int expected) {
int result = Calculator.add(a, b);
assertThat(result).isEqualTo(expected);
}
}
```
## @CsvSource for Tabular Data
### CSV-Based Test Data
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
class UserValidationTest {
@ParameterizedTest
@CsvSource({
"alice@example.com, true",
"bob@gmail.com, true",
"invalid-email, false",
"user@, false",
"@example.com, false",
"user name@example.com, false"
})
void shouldValidateEmailAddresses(String email, boolean expected) {
boolean result = UserValidator.isValidEmail(email);
assertThat(result).isEqualTo(expected);
}
@ParameterizedTest
@CsvSource({
"123-456-7890, true",
"555-123-4567, true",
"1234567890, false",
"123-45-6789, false",
"abc-def-ghij, false"
})
void shouldValidatePhoneNumbers(String phone, boolean expected) {
boolean result = PhoneValidator.isValid(phone);
assertThat(result).isEqualTo(expected);
}
}
```
## @CsvFileSource for External Data
### CSV File-Based Testing
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvFileSource;
class PriceCalculationTest {
@ParameterizedTest
@CsvFileSource(resources = "/test-data/prices.csv", numLinesToSkip = 1)
void shouldCalculateTotalPrice(String product, double price, int quantity, double expected) {
double total = PriceCalculator.calculateTotal(price, quantity);
assertThat(total).isEqualTo(expected);
}
}
// test-data/prices.csv:
// product,price,quantity,expected
// Laptop,999.99,1,999.99
// Mouse,29.99,3,89.97
// Keyboard,79.99,2,159.98
```
## @EnumSource for Enum Testing
### Enum-Based Test Data
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
enum Status { ACTIVE, INACTIVE, PENDING, DELETED }
class StatusHandlerTest {
@ParameterizedTest
@EnumSource(Status.class)
void shouldHandleAllStatuses(Status status) {
assertThat(status).isNotNull();
}
@ParameterizedTest
@EnumSource(value = Status.class, names = {"ACTIVE", "INACTIVE"})
void shouldHandleSpecificStatuses(Status status) {
assertThat(status).isIn(Status.ACTIVE, Status.INACTIVE);
}
@ParameterizedTest
@EnumSource(value = Status.class, mode = EnumSource.Mode.EXCLUDE, names = {"DELETED"})
void shouldHandleStatusesExcludingDeleted(Status status) {
assertThat(status).isNotEqualTo(Status.DELETED);
}
}
```
## Custom Display Names
### Readable Test Output
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class DiscountCalculationTest {
@ParameterizedTest(name = "Discount of {0}% should be calculated correctly")
@ValueSource(ints = {5, 10, 15, 20})
void shouldApplyDiscount(int discountPercent) {
double originalPrice = 100.0;
double discounted = DiscountCalculator.apply(originalPrice, discountPercent);
double expected = originalPrice * (1 - discountPercent / 100.0);
assertThat(discounted).isEqualTo(expected);
}
@ParameterizedTest(name = "User role {0} should have {1} permissions")
@CsvSource({
"ADMIN, 100",
"MANAGER, 50",
"USER, 10"
})
void shouldHaveCorrectPermissions(String role, int expectedPermissions) {
User user = new User(role);
assertThat(user.getPermissionCount()).isEqualTo(expectedPermissions);
}
}
```
## Combining Multiple Sources
### ArgumentsProvider for Complex Scenarios
```java
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import java.util.stream.Stream;
class RangeValidatorArgumentProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(
Arguments.of(0, 0, 100, true), // Min boundary
Arguments.of(100, 0, 100, true), // Max boundary
Arguments.of(50, 0, 100, true), // Middle value
Arguments.of(-1, 0, 100, false), // Below range
Arguments.of(101, 0, 100, false) // Above range
);
}
}
class RangeValidatorTest {
@ParameterizedTest
@ArgumentsSource(RangeValidatorArgumentProvider.class)
void shouldValidateRangeCorrectly(int value, int min, int max, boolean expected) {
boolean result = RangeValidator.isInRange(value, min, max);
assertThat(result).isEqualTo(expected);
}
}
```
## Testing Edge Cases with Parameters
### Boundary Value Analysis
```java
class BoundaryValueTest {
@ParameterizedTest
@ValueSource(ints = {
Integer.MIN_VALUE, // Absolute minimum
Integer.MIN_VALUE + 1, // Just above minimum
-1, // Negative boundary
0, // Zero boundary
1, // Just above zero
Integer.MAX_VALUE - 1, // Just below maximum
Integer.MAX_VALUE // Absolute maximum
})
void shouldHandleAllBoundaryValues(int value) {
int incremented = MathUtils.increment(value);
assertThat(incremented).isNotLessThan(value);
}
@ParameterizedTest
@CsvSource({
", false", // null
"'', false", // empty
"' ', false", // whitespace only
"a, true", // single character
"abc, true" // normal
})
void shouldValidateStrings(String input, boolean expected) {
boolean result = StringValidator.isValid(input);
assertThat(result).isEqualTo(expected);
}
}
```
## Repeat Tests
### Run Same Test Multiple Times
```java
import org.junit.jupiter.api.RepeatedTest;
class ConcurrencyTest {
@RepeatedTest(100)
void shouldHandleConcurrentAccess() {
// Test that might reveal race conditions if run multiple times
AtomicInteger counter = new AtomicInteger(0);
counter.incrementAndGet();
assertThat(counter.get()).isEqualTo(1);
}
}
```
## Best Practices
- **Use @ParameterizedTest** to reduce test duplication
- **Use descriptive display names** with `(name = "...")`
- **Test boundary values** systematically
- **Keep test logic simple** - focus on single assertion
- **Organize test data logically** - group similar scenarios
- **Use @MethodSource** for complex test data
- **Use @CsvSource** for tabular test data
- **Document expected behavior** in test names
## Common Patterns
**Testing error conditions**:
```java
@ParameterizedTest
@ValueSource(strings = {"", " ", null})
void shouldThrowExceptionForInvalidInput(String input) {
assertThatThrownBy(() -> Parser.parse(input))
.isInstanceOf(IllegalArgumentException.class);
}
```
**Testing multiple valid inputs**:
```java
@ParameterizedTest
@ValueSource(ints = {1, 2, 3, 5, 8, 13})
void shouldBeInFibonacciSequence(int number) {
assertThat(FibonacciChecker.isFibonacci(number)).isTrue();
}
```
## Troubleshooting
**Parameter not matching**: Verify number and type of parameters match test method signature.
**Display name not showing**: Check parameter syntax in `name = "..."`.
**CSV parsing error**: Ensure CSV format is correct and quote strings containing commas.
## References
- [JUnit 5 Parameterized Tests](https://junit.org/junit5/docs/current/user-guide/#writing-tests-parameterized-tests)
- [@ParameterizedTest Documentation](https://junit.org/junit5/docs/current/api/org.junit.jupiter.params/org/junit/jupiter/params/ParameterizedTest.html)
- [Boundary Value Analysis](https://en.wikipedia.org/wiki/Boundary-value_analysis)

View File

@@ -0,0 +1,434 @@
---
name: unit-test-scheduled-async
description: Unit tests for scheduled and async tasks using @Scheduled and @Async. Mock task execution and timing. Use when validating asynchronous operations and scheduling behavior.
category: testing
tags: [junit-5, scheduled, async, concurrency, completablefuture]
version: 1.0.1
---
# Unit Testing @Scheduled and @Async Methods
Test scheduled tasks and async methods using JUnit 5 without running the actual scheduler. Verify execution logic, timing, and asynchronous behavior.
## When to Use This Skill
Use this skill when:
- Testing @Scheduled method logic
- Testing @Async method behavior
- Verifying CompletableFuture results
- Testing async error handling
- Want fast tests without actual scheduling
- Testing background task logic in isolation
## Setup: Async/Scheduled Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.springframework.boot:spring-boot-starter")
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.awaitility:awaitility")
testImplementation("org.assertj:assertj-core")
}
```
## Testing @Async Methods
### Basic Async Testing with CompletableFuture
```java
// Service with async methods
@Service
public class EmailService {
@Async
public CompletableFuture<Boolean> sendEmailAsync(String to, String subject) {
return CompletableFuture.supplyAsync(() -> {
// Simulate email sending
System.out.println("Sending email to " + to);
return true;
});
}
@Async
public void notifyUser(String userId) {
System.out.println("Notifying user: " + userId);
}
}
// Unit test
import java.util.concurrent.CompletableFuture;
import static org.assertj.core.api.Assertions.*;
class EmailServiceAsyncTest {
@Test
void shouldReturnCompletedFutureWhenSendingEmail() throws Exception {
EmailService service = new EmailService();
CompletableFuture<Boolean> result = service.sendEmailAsync("test@example.com", "Hello");
Boolean success = result.get(); // Wait for completion
assertThat(success).isTrue();
}
@Test
void shouldCompleteWithinTimeout() {
EmailService service = new EmailService();
CompletableFuture<Boolean> result = service.sendEmailAsync("test@example.com", "Hello");
assertThat(result)
.isCompletedWithValue(true);
}
}
```
## Testing Async with Mocked Dependencies
### Async Service with Dependencies
```java
@Service
public class UserNotificationService {
private final EmailService emailService;
private final SmsService smsService;
public UserNotificationService(EmailService emailService, SmsService smsService) {
this.emailService = emailService;
this.smsService = smsService;
}
@Async
public CompletableFuture<String> notifyUserAsync(String userId) {
return CompletableFuture.supplyAsync(() -> {
emailService.send(userId);
smsService.send(userId);
return "Notification sent";
});
}
}
// Unit test
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ExtendWith(MockitoExtension.class)
class UserNotificationServiceAsyncTest {
@Mock
private EmailService emailService;
@Mock
private SmsService smsService;
@InjectMocks
private UserNotificationService notificationService;
@Test
void shouldNotifyUserAsynchronously() throws Exception {
CompletableFuture<String> result = notificationService.notifyUserAsync("user123");
String message = result.get();
assertThat(message).isEqualTo("Notification sent");
verify(emailService).send("user123");
verify(smsService).send("user123");
}
@Test
void shouldHandleAsyncExceptionGracefully() {
doThrow(new RuntimeException("Email service failed"))
.when(emailService).send(any());
CompletableFuture<String> result = notificationService.notifyUserAsync("user123");
assertThatThrownBy(result::get)
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(RuntimeException.class);
}
}
```
## Testing @Scheduled Methods
### Mock Task Execution
```java
// Scheduled task
@Component
public class DataRefreshTask {
private final DataRepository dataRepository;
public DataRefreshTask(DataRepository dataRepository) {
this.dataRepository = dataRepository;
}
@Scheduled(fixedDelay = 60000)
public void refreshCache() {
List<Data> data = dataRepository.findAll();
// Update cache
}
@Scheduled(cron = "0 0 * * * *") // Every hour
public void cleanupOldData() {
dataRepository.deleteOldData(LocalDateTime.now().minusDays(30));
}
}
// Unit test - test logic without actual scheduling
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ExtendWith(MockitoExtension.class)
class DataRefreshTaskTest {
@Mock
private DataRepository dataRepository;
@InjectMocks
private DataRefreshTask dataRefreshTask;
@Test
void shouldRefreshCacheFromRepository() {
List<Data> expectedData = List.of(new Data(1L, "item1"));
when(dataRepository.findAll()).thenReturn(expectedData);
dataRefreshTask.refreshCache(); // Call method directly
verify(dataRepository).findAll();
}
@Test
void shouldCleanupOldData() {
LocalDateTime cutoffDate = LocalDateTime.now().minusDays(30);
dataRefreshTask.cleanupOldData();
verify(dataRepository).deleteOldData(any(LocalDateTime.class));
}
}
```
## Testing Async with Awaility
### Wait for Async Completion
```java
import org.awaitility.Awaitility;
import java.util.concurrent.atomic.AtomicInteger;
@Service
public class BackgroundWorker {
private final AtomicInteger processedCount = new AtomicInteger(0);
@Async
public void processItems(List<String> items) {
items.forEach(item -> {
// Process item
processedCount.incrementAndGet();
});
}
public int getProcessedCount() {
return processedCount.get();
}
}
class AwaitilityAsyncTest {
@Test
void shouldProcessAllItemsAsynchronously() {
BackgroundWorker worker = new BackgroundWorker();
List<String> items = List.of("item1", "item2", "item3");
worker.processItems(items);
// Wait for async operation to complete (up to 5 seconds)
Awaitility.await()
.atMost(Duration.ofSeconds(5))
.pollInterval(Duration.ofMillis(100))
.untilAsserted(() -> {
assertThat(worker.getProcessedCount()).isEqualTo(3);
});
}
@Test
void shouldTimeoutWhenProcessingTakesTooLong() {
BackgroundWorker worker = new BackgroundWorker();
List<String> items = List.of("item1", "item2", "item3");
worker.processItems(items);
assertThatThrownBy(() ->
Awaitility.await()
.atMost(Duration.ofMillis(100))
.until(() -> worker.getProcessedCount() == 10)
).isInstanceOf(ConditionTimeoutException.class);
}
}
```
## Testing Async Error Handling
### Handle Exceptions in Async Methods
```java
@Service
public class DataProcessingService {
@Async
public CompletableFuture<Boolean> processDataAsync(String data) {
return CompletableFuture.supplyAsync(() -> {
if (data == null || data.isEmpty()) {
throw new IllegalArgumentException("Data cannot be empty");
}
// Process data
return true;
});
}
@Async
public CompletableFuture<String> safeFetchData(String id) {
return CompletableFuture.supplyAsync(() -> {
try {
return fetchData(id);
} catch (Exception e) {
return "Error: " + e.getMessage();
}
});
}
}
class AsyncErrorHandlingTest {
@Test
void shouldPropagateExceptionFromAsyncMethod() {
DataProcessingService service = new DataProcessingService();
CompletableFuture<Boolean> result = service.processDataAsync(null);
assertThatThrownBy(result::get)
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Data cannot be empty");
}
@Test
void shouldHandleExceptionGracefullyWithFallback() throws Exception {
DataProcessingService service = new DataProcessingService();
CompletableFuture<String> result = service.safeFetchData("invalid");
String message = result.get();
assertThat(message).startsWith("Error:");
}
}
```
## Testing Scheduled Task Timing
### Test Schedule Configuration
```java
@Component
public class HealthCheckTask {
private final HealthCheckService healthCheckService;
private int executionCount = 0;
public HealthCheckTask(HealthCheckService healthCheckService) {
this.healthCheckService = healthCheckService;
}
@Scheduled(fixedRate = 5000) // Every 5 seconds
public void checkHealth() {
executionCount++;
healthCheckService.check();
}
public int getExecutionCount() {
return executionCount;
}
}
class ScheduledTaskTimingTest {
@Test
void shouldExecuteTaskMultipleTimes() {
HealthCheckService mockService = mock(HealthCheckService.class);
HealthCheckTask task = new HealthCheckTask(mockService);
// Execute manually multiple times
task.checkHealth();
task.checkHealth();
task.checkHealth();
assertThat(task.getExecutionCount()).isEqualTo(3);
verify(mockService, times(3)).check();
}
}
```
## Best Practices
- **Test async method logic directly** without Spring async executor
- **Use CompletableFuture.get()** to wait for results in tests
- **Mock dependencies** that async methods use
- **Test error paths** for async operations
- **Use Awaitility** when testing actual async behavior is needed
- **Mock scheduled tasks** by calling methods directly in tests
- **Verify task execution count** for testing scheduling logic
## Common Pitfalls
- Testing with actual @Async executor (use direct method calls instead)
- Not waiting for CompletableFuture completion in tests
- Forgetting to test exception handling in async methods
- Not mocking dependencies that async methods call
- Trying to test actual scheduling timing (test logic instead)
## Troubleshooting
**CompletableFuture hangs in test**: Ensure methods complete or set timeout with `.get(timeout, unit)`.
**Async method not executing**: Call method directly instead of relying on @Async in tests.
**Awaitility timeout**: Increase timeout duration or reduce polling interval.
## References
- [Spring @Async Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/scheduling/annotation/Async.html)
- [Spring @Scheduled Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/scheduling/annotation/Scheduled.html)
- [Awaitility Testing Library](https://github.com/awaitility/awaitility)
- [CompletableFuture API](https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/CompletableFuture.html)

View File

@@ -0,0 +1,476 @@
---
name: unit-test-security-authorization
description: Unit tests for Spring Security with @PreAuthorize, @Secured, @RolesAllowed. Test role-based access control and authorization policies. Use when validating security configurations and access control logic.
category: testing
tags: [junit-5, spring-security, authorization, roles, preauthorize, mockmvc]
version: 1.0.1
---
# Unit Testing Security and Authorization
Test Spring Security authorization logic using @PreAuthorize, @Secured, and custom permission evaluators. Verify access control decisions without full security infrastructure.
## When to Use This Skill
Use this skill when:
- Testing @PreAuthorize and @Secured method-level security
- Testing role-based access control (RBAC)
- Testing custom permission evaluators
- Verifying access denied scenarios
- Testing authorization with authenticated principals
- Want fast authorization tests without full Spring Security context
## Setup: Security Testing
### Maven
```xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-security</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-test</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
implementation("org.springframework.boot:spring-boot-starter-security")
testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("org.springframework.security:spring-security-test")
}
```
## Basic Pattern: Testing @PreAuthorize
### Simple Role-Based Access Control
```java
// Service with security annotations
@Service
public class UserService {
@PreAuthorize("hasRole('ADMIN')")
public void deleteUser(Long userId) {
// delete logic
}
@PreAuthorize("hasRole('USER')")
public User getCurrentUser() {
// get user logic
}
@PreAuthorize("hasAnyRole('ADMIN', 'MANAGER')")
public List<User> listAllUsers() {
// list logic
}
}
// Unit test
import org.junit.jupiter.api.Test;
import org.springframework.security.test.context.support.WithMockUser;
import static org.assertj.core.api.Assertions.*;
class UserServiceSecurityTest {
@Test
@WithMockUser(roles = "ADMIN")
void shouldAllowAdminToDeleteUser() {
UserService service = new UserService();
assertThatCode(() -> service.deleteUser(1L))
.doesNotThrowAnyException();
}
@Test
@WithMockUser(roles = "USER")
void shouldDenyUserFromDeletingUser() {
UserService service = new UserService();
assertThatThrownBy(() -> service.deleteUser(1L))
.isInstanceOf(AccessDeniedException.class);
}
@Test
@WithMockUser(roles = "ADMIN")
void shouldAllowAdminAndManagerToListUsers() {
UserService service = new UserService();
assertThatCode(() -> service.listAllUsers())
.doesNotThrowAnyException();
}
@Test
void shouldDenyAnonymousUserAccess() {
UserService service = new UserService();
assertThatThrownBy(() -> service.deleteUser(1L))
.isInstanceOf(AccessDeniedException.class);
}
}
```
## Testing @Secured Annotation
### Legacy Security Configuration
```java
@Service
public class OrderService {
@Secured("ROLE_ADMIN")
public Order approveOrder(Long orderId) {
// approval logic
}
@Secured({"ROLE_ADMIN", "ROLE_MANAGER"})
public List<Order> getOrders() {
// get orders
}
}
class OrderSecurityTest {
@Test
@WithMockUser(roles = "ADMIN")
void shouldAllowAdminToApproveOrder() {
OrderService service = new OrderService();
assertThatCode(() -> service.approveOrder(1L))
.doesNotThrowAnyException();
}
@Test
@WithMockUser(roles = "USER")
void shouldDenyUserFromApprovingOrder() {
OrderService service = new OrderService();
assertThatThrownBy(() -> service.approveOrder(1L))
.isInstanceOf(AccessDeniedException.class);
}
}
```
## Testing Controller Security with MockMvc
### Secure REST Endpoints
```java
@RestController
@RequestMapping("/api/admin")
public class AdminController {
@GetMapping("/users")
@PreAuthorize("hasRole('ADMIN')")
public List<UserDto> listAllUsers() {
// logic
}
@DeleteMapping("/users/{id}")
@PreAuthorize("hasRole('ADMIN')")
public void deleteUser(@PathVariable Long id) {
// delete logic
}
}
// Testing with MockMvc
import org.springframework.security.test.context.support.WithMockUser;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
class AdminControllerSecurityTest {
private MockMvc mockMvc;
@BeforeEach
void setUp() {
mockMvc = MockMvcBuilders
.standaloneSetup(new AdminController())
.apply(springSecurity())
.build();
}
@Test
@WithMockUser(roles = "ADMIN")
void shouldAllowAdminToListUsers() throws Exception {
mockMvc.perform(get("/api/admin/users"))
.andExpect(status().isOk());
}
@Test
@WithMockUser(roles = "USER")
void shouldDenyUserFromListingUsers() throws Exception {
mockMvc.perform(get("/api/admin/users"))
.andExpect(status().isForbidden());
}
@Test
void shouldDenyAnonymousAccessToAdminEndpoint() throws Exception {
mockMvc.perform(get("/api/admin/users"))
.andExpect(status().isUnauthorized());
}
@Test
@WithMockUser(roles = "ADMIN")
void shouldAllowAdminToDeleteUser() throws Exception {
mockMvc.perform(delete("/api/admin/users/1"))
.andExpect(status().isOk());
}
}
```
## Testing Expression-Based Authorization
### Complex Permission Expressions
```java
@Service
public class DocumentService {
@PreAuthorize("hasRole('ADMIN') or authentication.principal.username == #owner")
public Document getDocument(String owner, Long docId) {
// get document
}
@PreAuthorize("hasPermission(#docId, 'Document', 'WRITE')")
public void updateDocument(Long docId, String content) {
// update logic
}
@PreAuthorize("#userId == authentication.principal.id")
public UserProfile getUserProfile(Long userId) {
// get profile
}
}
class ExpressionBasedSecurityTest {
@Test
@WithMockUser(username = "alice", roles = "ADMIN")
void shouldAllowAdminToAccessAnyDocument() {
DocumentService service = new DocumentService();
assertThatCode(() -> service.getDocument("bob", 1L))
.doesNotThrowAnyException();
}
@Test
@WithMockUser(username = "alice")
void shouldAllowOwnerToAccessOwnDocument() {
DocumentService service = new DocumentService();
assertThatCode(() -> service.getDocument("alice", 1L))
.doesNotThrowAnyException();
}
@Test
@WithMockUser(username = "alice")
void shouldDenyUserAccessToOtherUserDocument() {
DocumentService service = new DocumentService();
assertThatThrownBy(() -> service.getDocument("bob", 1L))
.isInstanceOf(AccessDeniedException.class);
}
@Test
@WithMockUser(username = "alice", id = "1")
void shouldAllowUserToAccessOwnProfile() {
DocumentService service = new DocumentService();
assertThatCode(() -> service.getUserProfile(1L))
.doesNotThrowAnyException();
}
@Test
@WithMockUser(username = "alice", id = "1")
void shouldDenyUserAccessToOtherProfile() {
DocumentService service = new DocumentService();
assertThatThrownBy(() -> service.getUserProfile(999L))
.isInstanceOf(AccessDeniedException.class);
}
}
```
## Testing Custom Permission Evaluator
### Create and Test Custom Permission Logic
```java
// Custom permission evaluator
@Component
public class DocumentPermissionEvaluator implements PermissionEvaluator {
private final DocumentRepository documentRepository;
public DocumentPermissionEvaluator(DocumentRepository documentRepository) {
this.documentRepository = documentRepository;
}
@Override
public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) {
if (authentication == null) return false;
Document document = (Document) targetDomainObject;
String userUsername = authentication.getName();
return document.getOwner().getUsername().equals(userUsername) ||
userHasRole(authentication, "ADMIN");
}
@Override
public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, Object permission) {
if (authentication == null) return false;
if (!"Document".equals(targetType)) return false;
Document document = documentRepository.findById((Long) targetId).orElse(null);
if (document == null) return false;
return hasPermission(authentication, document, permission);
}
private boolean userHasRole(Authentication authentication, String role) {
return authentication.getAuthorities().stream()
.anyMatch(auth -> auth.getAuthority().equals("ROLE_" + role));
}
}
// Unit test for custom evaluator
class DocumentPermissionEvaluatorTest {
private DocumentPermissionEvaluator evaluator;
private DocumentRepository documentRepository;
private Authentication adminAuth;
private Authentication userAuth;
private Document document;
@BeforeEach
void setUp() {
documentRepository = mock(DocumentRepository.class);
evaluator = new DocumentPermissionEvaluator(documentRepository);
document = new Document(1L, "Test Doc", new User("alice"));
adminAuth = new UsernamePasswordAuthenticationToken(
"admin",
null,
List.of(new SimpleGrantedAuthority("ROLE_ADMIN"))
);
userAuth = new UsernamePasswordAuthenticationToken(
"alice",
null,
List.of(new SimpleGrantedAuthority("ROLE_USER"))
);
}
@Test
void shouldGrantPermissionToDocumentOwner() {
boolean hasPermission = evaluator.hasPermission(userAuth, document, "WRITE");
assertThat(hasPermission).isTrue();
}
@Test
void shouldDenyPermissionToNonOwner() {
Authentication otherUserAuth = new UsernamePasswordAuthenticationToken(
"bob",
null,
List.of(new SimpleGrantedAuthority("ROLE_USER"))
);
boolean hasPermission = evaluator.hasPermission(otherUserAuth, document, "WRITE");
assertThat(hasPermission).isFalse();
}
@Test
void shouldGrantPermissionToAdmin() {
boolean hasPermission = evaluator.hasPermission(adminAuth, document, "WRITE");
assertThat(hasPermission).isTrue();
}
@Test
void shouldDenyNullAuthentication() {
boolean hasPermission = evaluator.hasPermission(null, document, "WRITE");
assertThat(hasPermission).isFalse();
}
}
```
## Testing Multiple Roles
### Parameterized Role Testing
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class RoleBasedAccessTest {
private AdminService service;
@BeforeEach
void setUp() {
service = new AdminService();
}
@ParameterizedTest
@ValueSource(strings = {"ADMIN", "SUPER_ADMIN", "SYSTEM"})
@WithMockUser(roles = "ADMIN")
void shouldAllowPrivilegedRolesToDeleteUser(String role) {
assertThatCode(() -> service.deleteUser(1L))
.doesNotThrowAnyException();
}
@ParameterizedTest
@ValueSource(strings = {"USER", "GUEST", "READONLY"})
void shouldDenyUnprivilegedRolesToDeleteUser(String role) {
assertThatThrownBy(() -> service.deleteUser(1L))
.isInstanceOf(AccessDeniedException.class);
}
}
```
## Best Practices
- **Use @WithMockUser** for setting authenticated user context
- **Test both allow and deny cases** for each security rule
- **Test with different roles** to verify role-based decisions
- **Test expression-based security** comprehensively
- **Mock external dependencies** (permission evaluators, etc.)
- **Test anonymous access separately** from authenticated access
- **Use @EnableGlobalMethodSecurity** in configuration for method-level security
## Common Pitfalls
- Forgetting to enable method security in test configuration
- Not testing both allow and deny scenarios
- Testing framework code instead of authorization logic
- Not handling null authentication in tests
- Mixing authentication and authorization tests unnecessarily
## Troubleshooting
**AccessDeniedException not thrown**: Ensure `@EnableGlobalMethodSecurity(prePostEnabled = true)` is configured.
**@WithMockUser not working**: Verify Spring Security test dependencies are on classpath.
**Custom PermissionEvaluator not invoked**: Check `@EnableGlobalMethodSecurity(securedEnabled = true, prePostEnabled = true)`.
## References
- [Spring Security Method Security](https://docs.spring.io/spring-security/site/docs/current/reference/html5/#jc-method)
- [Spring Security Testing](https://docs.spring.io/spring-security/site/docs/current/reference/html5/#test)
- [@WithMockUser Documentation](https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/test/context/support/WithMockUser.html)

View File

@@ -0,0 +1,329 @@
---
name: unit-test-service-layer
description: Unit tests for service layer with Mockito. Test business logic in isolation by mocking dependencies. Use when validating service behaviors and business logic without database or external services.
category: testing
tags: [junit-5, mockito, unit-testing, service-layer, business-logic]
version: 1.0.1
---
# Unit Testing Service Layer with Mockito
Test @Service annotated classes by mocking all injected dependencies. Focus on business logic validation without starting the Spring container.
## When to Use This Skill
Use this skill when:
- Testing business logic in @Service classes
- Mocking repository and external client dependencies
- Verifying service interactions with mocked collaborators
- Testing complex workflows and orchestration logic
- Want fast, isolated unit tests (no database, no API calls)
- Testing error handling and edge cases in services
## Setup with Mockito and JUnit 5
### Maven
```xml
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.mockito:mockito-core")
testImplementation("org.mockito:mockito-junit-jupiter")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Service with Mocked Dependencies
### Single Dependency
```java
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.mockito.Mockito.*;
import static org.assertj.core.api.Assertions.*;
@ExtendWith(MockitoExtension.class)
class UserServiceTest {
@Mock
private UserRepository userRepository;
@InjectMocks
private UserService userService;
@Test
void shouldReturnAllUsers() {
// Arrange
List<User> expectedUsers = List.of(
new User(1L, "Alice"),
new User(2L, "Bob")
);
when(userRepository.findAll()).thenReturn(expectedUsers);
// Act
List<User> result = userService.getAllUsers();
// Assert
assertThat(result).hasSize(2);
assertThat(result).containsExactly(
new User(1L, "Alice"),
new User(2L, "Bob")
);
verify(userRepository, times(1)).findAll();
}
}
```
### Multiple Dependencies
```java
@ExtendWith(MockitoExtension.class)
class UserEnrichmentServiceTest {
@Mock
private UserRepository userRepository;
@Mock
private EmailService emailService;
@Mock
private AnalyticsClient analyticsClient;
@InjectMocks
private UserEnrichmentService enrichmentService;
@Test
void shouldCreateUserAndSendWelcomeEmail() {
User newUser = new User(1L, "Alice", "alice@example.com");
when(userRepository.save(any(User.class))).thenReturn(newUser);
doNothing().when(emailService).sendWelcomeEmail(newUser.getEmail());
User result = enrichmentService.registerNewUser("Alice", "alice@example.com");
assertThat(result.getId()).isEqualTo(1L);
assertThat(result.getName()).isEqualTo("Alice");
verify(userRepository).save(any(User.class));
verify(emailService).sendWelcomeEmail("alice@example.com");
verify(analyticsClient, never()).trackUserRegistration(any());
}
}
```
## Testing Exception Handling
### Service Throws Expected Exception
```java
@Test
void shouldThrowExceptionWhenUserNotFound() {
when(userRepository.findById(999L))
.thenThrow(new UserNotFoundException("User not found"));
assertThatThrownBy(() -> userService.getUserDetails(999L))
.isInstanceOf(UserNotFoundException.class)
.hasMessageContaining("User not found");
verify(userRepository).findById(999L);
}
@Test
void shouldRethrowRepositoryException() {
when(userRepository.findAll())
.thenThrow(new DataAccessException("Database connection failed"));
assertThatThrownBy(() -> userService.getAllUsers())
.isInstanceOf(DataAccessException.class)
.hasMessageContaining("Database connection failed");
}
```
## Testing Complex Workflows
### Multiple Service Method Calls
```java
@Test
void shouldTransferMoneyBetweenAccounts() {
Account fromAccount = new Account(1L, 1000.0);
Account toAccount = new Account(2L, 500.0);
when(accountRepository.findById(1L)).thenReturn(Optional.of(fromAccount));
when(accountRepository.findById(2L)).thenReturn(Optional.of(toAccount));
when(accountRepository.save(any(Account.class)))
.thenAnswer(invocation -> invocation.getArgument(0));
moneyTransferService.transfer(1L, 2L, 200.0);
// Verify both accounts were updated
verify(accountRepository, times(2)).save(any(Account.class));
assertThat(fromAccount.getBalance()).isEqualTo(800.0);
assertThat(toAccount.getBalance()).isEqualTo(700.0);
}
```
## Argument Capturing and Verification
### Capture Arguments Passed to Mock
```java
import org.mockito.ArgumentCaptor;
@Test
void shouldCaptureUserDataWhenSaving() {
ArgumentCaptor<User> userCaptor = ArgumentCaptor.forClass(User.class);
when(userRepository.save(any(User.class)))
.thenAnswer(invocation -> invocation.getArgument(0));
userService.createUser("Alice", "alice@example.com");
verify(userRepository).save(userCaptor.capture());
User capturedUser = userCaptor.getValue();
assertThat(capturedUser.getName()).isEqualTo("Alice");
assertThat(capturedUser.getEmail()).isEqualTo("alice@example.com");
}
@Test
void shouldCaptureMultipleArgumentsAcrossMultipleCalls() {
ArgumentCaptor<User> userCaptor = ArgumentCaptor.forClass(User.class);
userService.createUser("Alice", "alice@example.com");
userService.createUser("Bob", "bob@example.com");
verify(userRepository, times(2)).save(userCaptor.capture());
List<User> capturedUsers = userCaptor.getAllValues();
assertThat(capturedUsers).hasSize(2);
assertThat(capturedUsers.get(0).getName()).isEqualTo("Alice");
assertThat(capturedUsers.get(1).getName()).isEqualTo("Bob");
}
```
## Verification Patterns
### Verify Call Order and Frequency
```java
import org.mockito.InOrder;
@Test
void shouldCallMethodsInCorrectOrder() {
InOrder inOrder = inOrder(userRepository, emailService);
userService.registerNewUser("Alice", "alice@example.com");
inOrder.verify(userRepository).save(any(User.class));
inOrder.verify(emailService).sendWelcomeEmail(any());
}
@Test
void shouldCallMethodExactlyOnce() {
userService.getUserDetails(1L);
verify(userRepository, times(1)).findById(1L);
verify(userRepository, never()).findAll();
}
```
## Testing Async/Reactive Services
### Service with CompletableFuture
```java
@Test
void shouldReturnCompletableFutureWhenFetchingAsyncData() {
List<User> users = List.of(new User(1L, "Alice"));
when(userRepository.findAllAsync())
.thenReturn(CompletableFuture.completedFuture(users));
CompletableFuture<List<User>> result = userService.getAllUsersAsync();
assertThat(result).isCompletedWithValue(users);
}
```
## Best Practices
- **Use @ExtendWith(MockitoExtension.class)** for JUnit 5 integration
- **Construct service manually** instead of using reflection when possible
- **Mock only direct dependencies** of the service under test
- **Verify interactions** to ensure correct collaboration
- **Use descriptive variable names**: `expectedUser`, `actualUser`, `captor`
- **Test one behavior per test method** - keep tests focused
- **Avoid testing framework code** - focus on business logic
## Common Patterns
**Partial Mock with Spy**:
```java
@Spy
@InjectMocks
private UserService userService; // Real instance, but can stub some methods
@Test
void shouldUseRealMethodButMockDependency() {
when(userRepository.findById(any())).thenReturn(Optional.of(new User()));
// Calls real userService methods but userRepository is mocked
}
```
**Constructor Injection for Testing**:
```java
// In your service (production code)
public class UserService {
private final UserRepository userRepository;
public UserService(UserRepository userRepository) {
this.repository = userRepository;
}
}
// In your test - can inject mocks directly
@Test
void test() {
UserRepository mockRepo = mock(UserRepository.class);
UserService service = new UserService(mockRepo);
}
```
## Troubleshooting
**UnfinishedStubbingException**: Ensure all `when()` calls are completed with `thenReturn()`, `thenThrow()`, or `thenAnswer()`.
**UnnecessaryStubbingException**: Remove unused stub definitions. Use `@ExtendWith(MockitoExtension.class)` with `MockitoExtension.LENIENT` if you intentionally have unused stubs.
**NullPointerException in test**: Verify `@InjectMocks` correctly injects all mocked dependencies into the service constructor.
## References
- [Mockito Documentation](https://javadoc.io/doc/org.mockito/mockito-core/latest/org/mockito/Mockito.html)
- [JUnit 5 User Guide](https://junit.org/junit5/docs/current/user-guide/)
- [AssertJ Assertions](https://assertj.github.io/assertj-core-features-highlight.html)

View File

@@ -0,0 +1,389 @@
---
name: unit-test-utility-methods
description: Unit tests for utility/helper classes and static methods. Test pure functions and helper logic. Use when validating utility code correctness.
category: testing
tags: [junit-5, unit-testing, utility, static-methods, pure-functions]
version: 1.0.1
---
# Unit Testing Utility Classes and Static Methods
Test static utility methods using JUnit 5. Focus on pure functions without side effects, edge cases, and boundary conditions.
## When to Use This Skill
Use this skill when:
- Testing utility classes with static helper methods
- Testing pure functions with no state or side effects
- Testing string manipulation and formatting utilities
- Testing calculation and conversion utilities
- Testing collections and array utilities
- Want simple, fast tests without mocking complexity
- Testing data transformation and validation helpers
## Basic Pattern: Static Utility Testing
### Simple String Utility
```java
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.*;
class StringUtilsTest {
@Test
void shouldCapitalizeFirstLetter() {
String result = StringUtils.capitalize("hello");
assertThat(result).isEqualTo("Hello");
}
@Test
void shouldHandleEmptyString() {
String result = StringUtils.capitalize("");
assertThat(result).isEmpty();
}
@Test
void shouldHandleNullInput() {
String result = StringUtils.capitalize(null);
assertThat(result).isNull();
}
@Test
void shouldHandleSingleCharacter() {
String result = StringUtils.capitalize("a");
assertThat(result).isEqualTo("A");
}
@Test
void shouldNotChangePascalCase() {
String result = StringUtils.capitalize("Hello");
assertThat(result).isEqualTo("Hello");
}
}
```
## Testing Null Handling
### Null-Safe Utility Methods
```java
class NullSafeUtilsTest {
@Test
void shouldReturnDefaultValueWhenNull() {
Object result = NullSafeUtils.getOrDefault(null, "default");
assertThat(result).isEqualTo("default");
}
@Test
void shouldReturnValueWhenNotNull() {
Object result = NullSafeUtils.getOrDefault("value", "default");
assertThat(result).isEqualTo("value");
}
@Test
void shouldReturnFalseWhenStringIsNull() {
boolean result = NullSafeUtils.isNotBlank(null);
assertThat(result).isFalse();
}
@Test
void shouldReturnTrueWhenStringHasContent() {
boolean result = NullSafeUtils.isNotBlank(" text ");
assertThat(result).isTrue();
}
}
```
## Testing Calculations and Conversions
### Math Utilities
```java
class MathUtilsTest {
@Test
void shouldCalculatePercentage() {
double result = MathUtils.percentage(25, 100);
assertThat(result).isEqualTo(25.0);
}
@Test
void shouldHandleZeroDivisor() {
double result = MathUtils.percentage(50, 0);
assertThat(result).isZero();
}
@Test
void shouldRoundToTwoDecimalPlaces() {
double result = MathUtils.round(3.14159, 2);
assertThat(result).isEqualTo(3.14);
}
@Test
void shouldHandleNegativeNumbers() {
int result = MathUtils.absoluteValue(-42);
assertThat(result).isEqualTo(42);
}
}
```
## Testing Collection Utilities
### List/Set/Map Operations
```java
class CollectionUtilsTest {
@Test
void shouldFilterList() {
List<Integer> numbers = List.of(1, 2, 3, 4, 5);
List<Integer> evenNumbers = CollectionUtils.filter(numbers, n -> n % 2 == 0);
assertThat(evenNumbers).containsExactly(2, 4);
}
@Test
void shouldReturnEmptyListWhenNoMatches() {
List<Integer> numbers = List.of(1, 3, 5);
List<Integer> evenNumbers = CollectionUtils.filter(numbers, n -> n % 2 == 0);
assertThat(evenNumbers).isEmpty();
}
@Test
void shouldHandleNullList() {
List<Integer> result = CollectionUtils.filter(null, n -> true);
assertThat(result).isEmpty();
}
@Test
void shouldJoinStringsWithSeparator() {
String result = CollectionUtils.join(List.of("a", "b", "c"), "-");
assertThat(result).isEqualTo("a-b-c");
}
@Test
void shouldHandleEmptyList() {
String result = CollectionUtils.join(List.of(), "-");
assertThat(result).isEmpty();
}
@Test
void shouldDeduplicateList() {
List<String> input = List.of("apple", "banana", "apple", "cherry", "banana");
Set<String> unique = CollectionUtils.deduplicate(input);
assertThat(unique).containsExactlyInAnyOrder("apple", "banana", "cherry");
}
}
```
## Testing String Transformations
### Format and Parse Utilities
```java
class FormatUtilsTest {
@Test
void shouldFormatCurrencyWithSymbol() {
String result = FormatUtils.formatCurrency(1234.56);
assertThat(result).isEqualTo("$1,234.56");
}
@Test
void shouldHandleNegativeCurrency() {
String result = FormatUtils.formatCurrency(-100.00);
assertThat(result).isEqualTo("-$100.00");
}
@Test
void shouldParsePhoneNumber() {
String result = FormatUtils.parsePhoneNumber("5551234567");
assertThat(result).isEqualTo("(555) 123-4567");
}
@Test
void shouldFormatDate() {
LocalDate date = LocalDate.of(2024, 1, 15);
String result = FormatUtils.formatDate(date, "yyyy-MM-dd");
assertThat(result).isEqualTo("2024-01-15");
}
@Test
void shouldSluggifyString() {
String result = FormatUtils.sluggify("Hello World! 123");
assertThat(result).isEqualTo("hello-world-123");
}
}
```
## Testing Data Validation
### Validator Utilities
```java
class ValidatorUtilsTest {
@Test
void shouldValidateEmailFormat() {
boolean valid = ValidatorUtils.isValidEmail("user@example.com");
assertThat(valid).isTrue();
boolean invalid = ValidatorUtils.isValidEmail("invalid-email");
assertThat(invalid).isFalse();
}
@Test
void shouldValidatePhoneNumber() {
boolean valid = ValidatorUtils.isValidPhone("555-123-4567");
assertThat(valid).isTrue();
boolean invalid = ValidatorUtils.isValidPhone("12345");
assertThat(invalid).isFalse();
}
@Test
void shouldValidateUrlFormat() {
boolean valid = ValidatorUtils.isValidUrl("https://example.com");
assertThat(valid).isTrue();
boolean invalid = ValidatorUtils.isValidUrl("not a url");
assertThat(invalid).isFalse();
}
@Test
void shouldValidateCreditCardNumber() {
boolean valid = ValidatorUtils.isValidCreditCard("4532015112830366");
assertThat(valid).isTrue();
boolean invalid = ValidatorUtils.isValidCreditCard("1234567890123456");
assertThat(invalid).isFalse();
}
}
```
## Testing Parameterized Scenarios
### Multiple Test Cases with @ParameterizedTest
```java
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.junit.jupiter.params.provider.CsvSource;
class StringUtilsParametrizedTest {
@ParameterizedTest
@ValueSource(strings = {"", " ", "null", "undefined"})
void shouldConsiderFalsyValuesAsEmpty(String input) {
boolean result = StringUtils.isEmpty(input);
assertThat(result).isTrue();
}
@ParameterizedTest
@CsvSource({
"hello,HELLO",
"world,WORLD",
"javaScript,JAVASCRIPT",
"123ABC,123ABC"
})
void shouldConvertToUpperCase(String input, String expected) {
String result = StringUtils.toUpperCase(input);
assertThat(result).isEqualTo(expected);
}
}
```
## Testing with Mockito for External Dependencies
### Utility with Dependency (Rare Case)
```java
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class DateUtilsTest {
@Mock
private Clock clock;
@Test
void shouldGetCurrentDateFromClock() {
Instant fixedTime = Instant.parse("2024-01-15T10:30:00Z");
when(clock.instant()).thenReturn(fixedTime);
LocalDate result = DateUtils.today(clock);
assertThat(result).isEqualTo(LocalDate.of(2024, 1, 15));
}
}
```
## Edge Cases and Boundary Testing
```java
class MathUtilsEdgeCaseTest {
@Test
void shouldHandleMaxIntegerValue() {
int result = MathUtils.increment(Integer.MAX_VALUE);
assertThat(result).isEqualTo(Integer.MAX_VALUE);
}
@Test
void shouldHandleMinIntegerValue() {
int result = MathUtils.decrement(Integer.MIN_VALUE);
assertThat(result).isEqualTo(Integer.MIN_VALUE);
}
@Test
void shouldHandleVeryLargeNumbers() {
BigDecimal result = MathUtils.add(
new BigDecimal("999999999999.99"),
new BigDecimal("0.01")
);
assertThat(result).isEqualTo(new BigDecimal("1000000000000.00"));
}
@Test
void shouldHandleFloatingPointPrecision() {
double result = MathUtils.multiply(0.1, 0.2);
assertThat(result).isCloseTo(0.02, within(0.0001));
}
}
```
## Best Practices
- **Test pure functions exclusively** - no side effects or state
- **Cover happy path and edge cases** - null, empty, extreme values
- **Use descriptive test names** - clearly state what's being tested
- **Keep tests simple and short** - utility tests should be quick to understand
- **Use @ParameterizedTest** for testing multiple similar scenarios
- **Avoid mocking when not needed** - only mock external dependencies
- **Test boundary conditions** - min/max values, empty collections, null inputs
## Common Pitfalls
- Testing framework behavior instead of utility logic
- Over-mocking when pure functions need no mocks
- Not testing null/empty edge cases
- Not testing negative numbers and extreme values
- Test methods too large - split complex scenarios
## Troubleshooting
**Floating point precision issues**: Use `isCloseTo()` with delta instead of exact equality.
**Null handling inconsistency**: Decide whether utility returns null or throws exception, then test consistently.
**Complex utility logic belongs elsewhere**: Consider refactoring into testable units.
## References
- [JUnit 5 Parameterized Tests](https://junit.org/junit5/docs/current/user-guide/#writing-tests-parameterized-tests)
- [AssertJ Assertions](https://assertj.github.io/assertj-core-features-highlight.html)
- [Testing Edge Cases and Boundaries](https://www.baeldung.com/testing-properties-methods-using-mockito)

View File

@@ -0,0 +1,170 @@
---
name: unit-test-wiremock-rest-api
description: Unit tests for external REST APIs using WireMock to mock HTTP endpoints. Use when testing service integrations with external APIs.
category: testing
tags: [junit-5, wiremock, unit-testing, rest-api, mocking, http-stubbing]
version: 1.0.1
---
# Unit Testing REST APIs with WireMock
Test interactions with third-party REST APIs without making real network calls using WireMock. This skill focuses on pure unit tests (no Spring context) that stub HTTP responses and verify requests.
## When to Use This Skill
Use this skill when:
- Testing services that call external REST APIs
- Need to stub HTTP responses for predictable test behavior
- Want to test error scenarios (timeouts, 500 errors, malformed responses)
- Need to verify request details (headers, query params, request body)
- Integrating with third-party services (payment gateways, weather APIs, etc.)
- Testing without network dependencies or rate limits
- Building unit tests that run fast in CI/CD pipelines
## Core Dependencies
### Maven
```xml
<dependency>
<groupId>org.wiremock</groupId>
<artifactId>wiremock</artifactId>
<version>3.4.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
```
### Gradle
```kotlin
dependencies {
testImplementation("org.wiremock:wiremock:3.4.1")
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core")
}
```
## Basic Pattern: Stubbing and Verifying
### Simple Stub with WireMock Extension
```java
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import static com.github.tomakehurst.wiremock.client.WireMock.*;
import static org.assertj.core.api.Assertions.assertThat;
class ExternalWeatherServiceTest {
@RegisterExtension
static WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort())
.build();
@Test
void shouldFetchWeatherDataFromExternalApi() {
wireMock.stubFor(get(urlEqualTo("/weather?city=London"))
.withHeader("Accept", containing("application/json"))
.willReturn(aResponse()
.withStatus(200)
.withHeader("Content-Type", "application/json")
.withBody("{\"city\":\"London\",\"temperature\":15,\"condition\":\"Cloudy\"}")));
String baseUrl = wireMock.getRuntimeInfo().getHttpBaseUrl();
WeatherApiClient client = new WeatherApiClient(baseUrl);
WeatherData weather = client.getWeather("London");
assertThat(weather.getCity()).isEqualTo("London");
assertThat(weather.getTemperature()).isEqualTo(15);
wireMock.verify(getRequestedFor(urlEqualTo("/weather?city=London"))
.withHeader("Accept", containing("application/json")));
}
}
```
## Testing Error Scenarios
### Test 4xx and 5xx Responses
```java
@Test
void shouldHandleNotFoundError() {
wireMock.stubFor(get(urlEqualTo("/api/users/999"))
.willReturn(aResponse()
.withStatus(404)
.withBody("{\"error\":\"User not found\"}")));
WeatherApiClient client = new WeatherApiClient(wireMock.getRuntimeInfo().getHttpBaseUrl());
assertThatThrownBy(() -> client.getUser(999))
.isInstanceOf(UserNotFoundException.class)
.hasMessageContaining("User not found");
}
@Test
void shouldRetryOnServerError() {
wireMock.stubFor(get(urlEqualTo("/api/data"))
.willReturn(aResponse()
.withStatus(500)
.withBody("{\"error\":\"Internal server error\"}")));
ApiClient client = new ApiClient(wireMock.getRuntimeInfo().getHttpBaseUrl());
assertThatThrownBy(() -> client.fetchData())
.isInstanceOf(ServerErrorException.class);
}
```
## Request Verification
### Verify Request Details and Payload
```java
@Test
void shouldVerifyRequestBody() {
wireMock.stubFor(post(urlEqualTo("/api/users"))
.willReturn(aResponse()
.withStatus(201)
.withBody("{\"id\":123,\"name\":\"Alice\"}")));
ApiClient client = new ApiClient(wireMock.getRuntimeInfo().getHttpBaseUrl());
UserResponse response = client.createUser("Alice");
assertThat(response.getId()).isEqualTo(123);
wireMock.verify(postRequestedFor(urlEqualTo("/api/users"))
.withRequestBody(matchingJsonPath("$.name", equalTo("Alice")))
.withHeader("Content-Type", containing("application/json")));
}
```
## Best Practices
- **Use dynamic port** to avoid port conflicts in parallel test execution
- **Verify requests** to ensure correct API usage
- **Test error scenarios** thoroughly
- **Keep stubs focused** - one concern per test
- **Reset WireMock** between tests automatically via `@RegisterExtension`
- **Never call real APIs** - always stub third-party endpoints
## Troubleshooting
**WireMock not intercepting requests**: Ensure your HTTP client uses the stubbed URL from `wireMock.getRuntimeInfo().getHttpBaseUrl()`.
**Port conflicts**: Always use `wireMockConfig().dynamicPort()` to let WireMock choose available port.
## References
- [WireMock Official Documentation](https://wiremock.org/)
- [WireMock Stubs and Mocking](https://wiremock.org/docs/stubbing/)
- [JUnit 5 Extensions](https://junit.org/junit5/docs/current/user-guide/#extensions)

View File

@@ -0,0 +1,145 @@
---
name: langchain4j-ai-services-patterns
description: Build declarative AI Services with LangChain4j using interface-based patterns, annotations, memory management, tools integration, and advanced application patterns. Use when implementing type-safe AI-powered features with minimal boilerplate code in Java applications.
category: ai-development
tags: [langchain4j, ai-services, annotations, declarative, tools, memory, function-calling, llm, java]
version: 1.1.0
allowed-tools: Read, Write, Bash
---
# LangChain4j AI Services Patterns
This skill provides guidance for building declarative AI Services with LangChain4j using interface-based patterns, annotations for system and user messages, memory management, tools integration, and advanced AI application patterns that abstract away low-level LLM interactions.
## When to Use
Use this skill when:
- Building declarative AI-powered interfaces with minimal boilerplate code
- Creating type-safe AI services with Java interfaces and annotations
- Implementing conversational AI systems with memory management
- Designing AI services that can call external tools and functions
- Building multi-agent systems with specialized AI components
- Creating AI services with different personas and behaviors
- Implementing RAG (Retrieval-Augmented Generation) patterns declaratively
- Building production AI applications with proper error handling and validation
- Creating AI services that return structured data types (enums, POJOs, lists)
- Implementing streaming AI responses with reactive patterns
## Overview
LangChain4j AI Services allow you to define AI-powered functionality using plain Java interfaces with annotations, eliminating the need for manual prompt construction and response parsing. This pattern provides type-safe, declarative AI capabilities with minimal boilerplate code.
## Quick Start
### Basic AI Service Definition
```java
interface Assistant {
String chat(String userMessage);
}
// Create instance - LangChain4j generates implementation
Assistant assistant = AiServices.create(Assistant.class, chatModel);
// Use the service
String response = assistant.chat("Hello, how are you?");
```
### System Message and Templates
```java
interface CustomerSupportBot {
@SystemMessage("You are a helpful customer support agent for TechCorp")
String handleInquiry(String customerMessage);
@UserMessage("Analyze sentiment: {{it}}")
String analyzeSentiment(String feedback);
}
CustomerSupportBot bot = AiServices.create(CustomerSupportBot.class, chatModel);
```
### Memory Management
```java
interface MultiUserAssistant {
String chat(@MemoryId String userId, String userMessage);
}
Assistant assistant = AiServices.builder(MultiUserAssistant.class)
.chatModel(model)
.chatMemoryProvider(userId -> MessageWindowChatMemory.withMaxMessages(10))
.build();
```
### Tool Integration
```java
class Calculator {
@Tool("Add two numbers") double add(double a, double b) { return a + b; }
}
interface MathGenius {
String ask(String question);
}
MathGenius mathGenius = AiServices.builder(MathGenius.class)
.chatModel(model)
.tools(new Calculator())
.build();
```
## Examples
See [examples.md](references/examples.md) for comprehensive practical examples including:
- Basic chat interfaces
- Stateful assistants with memory
- Multi-user scenarios
- Structured output extraction
- Tool calling and function execution
- Streaming responses
- Error handling
- RAG integration
- Production patterns
## API Reference
Complete API documentation, annotations, interfaces, and configuration patterns are available in [references.md](references/references.md).
## Best Practices
1. **Use type-safe interfaces** instead of string-based prompts
2. **Implement proper memory management** with appropriate limits
3. **Design clear tool descriptions** with parameter documentation
4. **Handle errors gracefully** with custom error handlers
5. **Use structured output** for predictable responses
6. **Implement validation** for user inputs
7. **Monitor performance** for production deployments
## Dependencies
```xml
<!-- Maven -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>1.8.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>1.8.0</version>
</dependency>
```
```gradle
// Gradle
implementation 'dev.langchain4j:langchain4j:1.8.0'
implementation 'dev.langchain4j:langchain4j-open-ai:1.8.0'
```
## References
- [LangChain4j Documentation](https://langchain4j.com/docs/)
- [LangChain4j AI Services - API References](references/references.md)
- [LangChain4j AI Services - Practical Examples](references/examples.md)

View File

@@ -0,0 +1,534 @@
# LangChain4j AI Services - Practical Examples
This document provides practical, production-ready examples for LangChain4j AI Services patterns.
## 1. Basic Chat Interface
**Scenario**: Simple conversational interface without memory.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.model.openai.OpenAiChatModel;
interface SimpleChat {
String chat(String userMessage);
}
public class BasicChatExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.temperature(0.7)
.build();
var chat = AiServices.builder(SimpleChat.class)
.chatModel(chatModel)
.build();
String response = chat.chat("What is Spring Boot?");
System.out.println(response);
}
}
```
## 2. Stateful Assistant with Memory
**Scenario**: Multi-turn conversation with 10-message history.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.openai.OpenAiChatModel;
interface ConversationalAssistant {
String chat(String userMessage);
}
public class StatefulAssistantExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.build();
var assistant = AiServices.builder(ConversationalAssistant.class)
.chatModel(chatModel)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.build();
// Multi-turn conversation
System.out.println(assistant.chat("My name is Alice"));
System.out.println(assistant.chat("What is my name?")); // Remembers: "Your name is Alice"
System.out.println(assistant.chat("What year was Spring Boot released?")); // Answers: "2014"
System.out.println(assistant.chat("Tell me more about it")); // Context aware
}
}
```
## 3. Multi-User Memory with @MemoryId
**Scenario**: Separate conversation history per user.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.openai.OpenAiChatModel;
interface MultiUserAssistant {
String chat(@MemoryId int userId, String userMessage);
}
public class MultiUserMemoryExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.build();
var assistant = AiServices.builder(MultiUserAssistant.class)
.chatModel(chatModel)
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(20))
.build();
// User 1 conversation
System.out.println(assistant.chat(1, "I like Java"));
System.out.println(assistant.chat(1, "What language do I prefer?")); // Java
// User 2 conversation - separate memory
System.out.println(assistant.chat(2, "I prefer Python"));
System.out.println(assistant.chat(2, "What language do I prefer?")); // Python
// User 1 - still remembers Java
System.out.println(assistant.chat(1, "What about me?")); // Java
}
}
```
## 4. System Message & Template Variables
**Scenario**: Configurable system prompt with dynamic template variables.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
import dev.langchain4j.model.openai.OpenAiChatModel;
interface TemplatedAssistant {
@SystemMessage("You are a {{role}} expert. Be concise and professional.")
String chat(@V("role") String role, String userMessage);
@SystemMessage("You are a helpful assistant. Translate to {{language}}")
@UserMessage("Translate this: {{text}}")
String translate(@V("text") String text, @V("language") String language);
}
public class TemplatedAssistantExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.temperature(0.3)
.build();
var assistant = AiServices.create(TemplatedAssistant.class, chatModel);
// Dynamic role
System.out.println(assistant.chat("Java", "Explain dependency injection"));
System.out.println(assistant.chat("DevOps", "Explain Docker containers"));
// Translation with template
System.out.println(assistant.translate("Hello, how are you?", "Spanish"));
System.out.println(assistant.translate("Good morning", "French"));
}
}
```
## 5. Structured Output Extraction
**Scenario**: Extract structured data (POJO, enum, list) from LLM responses.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.util.List;
enum Sentiment {
POSITIVE, NEGATIVE, NEUTRAL
}
class ContactInfo {
@Description("Person's full name")
String fullName;
@Description("Email address")
String email;
@Description("Phone number with country code")
String phone;
}
interface DataExtractor {
@UserMessage("Analyze sentiment: {{text}}")
Sentiment extractSentiment(String text);
@UserMessage("Extract contact from: {{text}}")
ContactInfo extractContact(String text);
@UserMessage("List all technologies in: {{text}}")
List<String> extractTechnologies(String text);
@UserMessage("Count items in: {{text}}")
int countItems(String text);
}
public class StructuredOutputExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.responseFormat("json_object")
.build();
var extractor = AiServices.create(DataExtractor.class, chatModel);
// Enum extraction
Sentiment sentiment = extractor.extractSentiment("This product is amazing!");
System.out.println("Sentiment: " + sentiment); // POSITIVE
// POJO extraction
ContactInfo contact = extractor.extractContact(
"John Smith, john@example.com, +1-555-1234");
System.out.println("Name: " + contact.fullName);
System.out.println("Email: " + contact.email);
// List extraction
List<String> techs = extractor.extractTechnologies(
"We use Java, Spring Boot, PostgreSQL, and Docker");
System.out.println("Technologies: " + techs); // [Java, Spring Boot, PostgreSQL, Docker]
// Primitive type
int count = extractor.countItems("I have 3 apples, 5 oranges, and 2 bananas");
System.out.println("Total items: " + count); // 10
}
}
```
## 6. Tool Calling / Function Calling
**Scenario**: LLM calls Java methods to solve problems.
```java
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.time.LocalDate;
class Calculator {
@Tool("Add two numbers")
int add(@P("first number") int a, @P("second number") int b) {
return a + b;
}
@Tool("Multiply two numbers")
int multiply(@P("first") int a, @P("second") int b) {
return a * b;
}
}
class WeatherService {
@Tool("Get weather for a city")
String getWeather(@P("city name") String city) {
// Simulate API call
return "Weather in " + city + ": 22°C, Sunny";
}
}
class DateService {
@Tool("Get current date")
String getCurrentDate() {
return LocalDate.now().toString();
}
}
interface ToolUsingAssistant {
String chat(String userMessage);
}
public class ToolCallingExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.temperature(0.0)
.build();
var assistant = AiServices.builder(ToolUsingAssistant.class)
.chatModel(chatModel)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.tools(new Calculator(), new WeatherService(), new DateService())
.build();
// LLM calls tools automatically
System.out.println(assistant.chat("What is 25 + 37?"));
// Uses Calculator.add() → "25 + 37 equals 62"
System.out.println(assistant.chat("What's the weather in Paris?"));
// Uses WeatherService.getWeather() → "Weather in Paris: 22°C, Sunny"
System.out.println(assistant.chat("Calculate (5 + 3) * 4"));
// Uses add() and multiply() → "Result is 32"
System.out.println(assistant.chat("What's today's date?"));
// Uses getCurrentDate() → Shows current date
}
}
```
## 7. Streaming Responses
**Scenario**: Real-time token-by-token streaming for UI responsiveness.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
interface StreamingAssistant {
TokenStream streamChat(String userMessage);
}
public class StreamingExample {
public static void main(String[] args) {
var streamingModel = OpenAiStreamingChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.temperature(0.7)
.build();
var assistant = AiServices.builder(StreamingAssistant.class)
.streamingChatModel(streamingModel)
.build();
// Stream response token by token
assistant.streamChat("Tell me a short story about a robot")
.onNext(token -> System.out.print(token)) // Print each token
.onCompleteResponse(response -> {
System.out.println("\n--- Complete ---");
System.out.println("Tokens used: " + response.tokenUsage().totalTokenCount());
})
.onError(error -> System.err.println("Error: " + error.getMessage()))
.start();
// Wait for completion
try {
Thread.sleep(5000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
```
## 8. System Persona with Context
**Scenario**: Different assistants with distinct personalities and knowledge domains.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.model.openai.OpenAiChatModel;
interface JavaExpert {
@SystemMessage("""
You are a Java expert with 15+ years experience.
Focus on best practices, performance, and clean code.
Provide code examples when relevant.
""")
String answer(String question);
}
interface SecurityExpert {
@SystemMessage("""
You are a cybersecurity expert specializing in application security.
Always consider OWASP principles and threat modeling.
Provide practical security recommendations.
""")
String answer(String question);
}
interface DevOpsExpert {
@SystemMessage("""
You are a DevOps engineer with expertise in cloud deployment,
CI/CD pipelines, containerization, and infrastructure as code.
""")
String answer(String question);
}
public class PersonaExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.temperature(0.5)
.build();
var javaExpert = AiServices.create(JavaExpert.class, chatModel);
var securityExpert = AiServices.create(SecurityExpert.class, chatModel);
var devopsExpert = AiServices.create(DevOpsExpert.class, chatModel);
var question = "How should I handle database connections?";
System.out.println("=== Java Expert ===");
System.out.println(javaExpert.answer(question));
System.out.println("\n=== Security Expert ===");
System.out.println(securityExpert.answer(question));
System.out.println("\n=== DevOps Expert ===");
System.out.println(devopsExpert.answer(question));
}
}
```
## 9. Error Handling & Tool Execution Errors
**Scenario**: Graceful handling of tool failures and LLM errors.
```java
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.model.openai.OpenAiChatModel;
class DataAccessService {
@Tool("Query database for user")
String queryUser(String userId) {
// Simulate potential error
if (!userId.matches("\\d+")) {
throw new IllegalArgumentException("Invalid user ID format");
}
return "User " + userId + ": John Doe";
}
@Tool("Update user email")
String updateEmail(String userId, String email) {
if (!email.contains("@")) {
throw new IllegalArgumentException("Invalid email format");
}
return "Updated email for user " + userId;
}
}
interface ResilientAssistant {
String execute(String command);
}
public class ErrorHandlingExample {
public static void main(String[] args) {
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.build();
var assistant = AiServices.builder(ResilientAssistant.class)
.chatModel(chatModel)
.tools(new DataAccessService())
.toolExecutionErrorHandler((request, exception) -> {
System.err.println("Tool error: " + exception.getMessage());
return "Error: " + exception.getMessage();
})
.build();
// Will handle tool errors gracefully
System.out.println(assistant.execute("Get details for user abc"));
System.out.println(assistant.execute("Update user 123 with invalid-email"));
}
}
```
## 10. RAG Integration with AI Services
**Scenario**: AI Service with content retrieval for knowledge-based Q&A.
```java
import dev.langchain4j.service.AiServices;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
interface KnowledgeBaseAssistant {
String askAbout(String question);
}
public class RAGIntegrationExample {
public static void main(String[] args) {
// Setup embedding store
var embeddingStore = new InMemoryEmbeddingStore<TextSegment>();
// Setup models
var embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small")
.build();
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.build();
// Ingest documents
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
ingestor.ingest(Document.from("Spring Boot is a framework for building Java applications."));
ingestor.ingest(Document.from("Spring Data JPA simplifies database access."));
// Create retriever
var contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(3)
.minScore(0.7)
.build();
// Create AI Service with RAG
var assistant = AiServices.builder(KnowledgeBaseAssistant.class)
.chatModel(chatModel)
.contentRetriever(contentRetriever)
.build();
String answer = assistant.askAbout("What is Spring Boot?");
System.out.println(answer);
}
}
```
## Best Practices Summary
1. **Always use @SystemMessage** for consistent behavior across different messages
2. **Enable temperature=0** for deterministic tasks (extraction, calculations)
3. **Use MessageWindowChatMemory** for conversation history management
4. **Implement error handling** for tool failures
5. **Use structured output** when you need typed responses
6. **Stream long responses** for better UX
7. **Use @MemoryId** for multi-user scenarios
8. **Template variables** for dynamic system prompts
9. **Tool descriptions** should be clear and actionable
10. **Always validate** tool parameters before execution

View File

@@ -0,0 +1,433 @@
# LangChain4j AI Services - API References
Complete API reference for LangChain4j AI Services patterns.
## Core Interfaces and Classes
### AiServices Builder
**Purpose**: Creates implementations of custom Java interfaces backed by LLM capabilities.
```java
public class AiServices {
static <T> AiServicesBuilder<T> builder(Class<T> aiService)
// Create builder for an AI service interface
static <T> T create(Class<T> aiService, ChatModel chatModel)
// Quick creation with just chat model
static <T> T builder(Class<T> aiService)
.chatModel(ChatModel chatModel) // Required for sync
.streamingChatModel(StreamingChatModel) // Required for streaming
.chatMemory(ChatMemory) // Single shared memory
.chatMemoryProvider(ChatMemoryProvider) // Per-user memory
.tools(Object... tools) // Register tool objects
.toolProvider(ToolProvider) // Dynamic tool selection
.contentRetriever(ContentRetriever) // For RAG
.retrievalAugmentor(RetrievalAugmentor) // Advanced RAG
.moderationModel(ModerationModel) // Content moderation
.build() // Build the implementation
}
```
### Core Annotations
**@SystemMessage**: Define system prompt for the AI service.
```java
@SystemMessage("You are a helpful Java developer")
String chat(String userMessage);
// Template variables
@SystemMessage("You are a {{expertise}} expert")
String explain(@V("expertise") String domain, String question);
```
**@UserMessage**: Define user message template.
```java
@UserMessage("Translate to {{language}}: {{text}}")
String translate(@V("language") String lang, @V("text") String text);
// With method parameters matching template
@UserMessage("Summarize: {{it}}")
String summarize(String text); // {{it}} refers to parameter
```
**@MemoryId**: Create separate memory context per identifier.
```java
interface MultiUserChat {
String chat(@MemoryId String userId, String message);
String chat(@MemoryId int sessionId, String message);
}
```
**@V**: Map method parameter to template variable.
```java
@UserMessage("Write {{type}} code for {{language}}")
String writeCode(@V("type") String codeType, @V("language") String lang);
```
### ChatMemory Implementations
**MessageWindowChatMemory**: Keeps last N messages.
```java
ChatMemory memory = MessageWindowChatMemory.withMaxMessages(10);
// Or with explicit builder
ChatMemory memory = MessageWindowChatMemory.builder()
.maxMessages(10)
.build();
```
**ChatMemoryProvider**: Factory for creating per-user memory.
```java
ChatMemoryProvider provider = memoryId ->
MessageWindowChatMemory.withMaxMessages(20);
```
### Tool Integration
**@Tool**: Mark methods that LLM can call.
```java
@Tool("Calculate sum of two numbers")
int add(@P("first number") int a, @P("second number") int b) {
return a + b;
}
```
**@P**: Parameter description for LLM.
```java
@Tool("Search documents")
List<Document> search(
@P("search query") String query,
@P("max results") int limit
) { ... }
```
**ToolProvider**: Dynamic tool selection based on context.
```java
interface DynamicToolAssistant {
String execute(String command);
}
ToolProvider provider = context ->
context.contains("calculate") ? new Calculator() : new DataService();
```
### Structured Output
**@Description**: Annotate output fields for extraction.
```java
class Person {
@Description("Person's full name")
String name;
@Description("Age in years")
int age;
}
interface Extractor {
@UserMessage("Extract person from: {{it}}")
Person extract(String text);
}
```
### Error Handling
**ToolExecutionErrorHandler**: Handle tool execution failures.
```java
.toolExecutionErrorHandler((request, exception) -> {
logger.error("Tool failed: " + request.name(), exception);
return "Tool execution failed: " + exception.getMessage();
})
```
**ToolArgumentsErrorHandler**: Handle malformed tool arguments.
```java
.toolArgumentsErrorHandler((request, exception) -> {
logger.warn("Invalid arguments for " + request.name());
return "Please provide valid arguments";
})
```
## Streaming APIs
### TokenStream
**Purpose**: Handle streaming LLM responses token-by-token.
```java
interface StreamingAssistant {
TokenStream streamChat(String message);
}
TokenStream stream = assistant.streamChat("Tell me a story");
stream
.onNext(token -> {
// Process each token
System.out.print(token);
})
.onCompleteResponse(response -> {
// Full response available
System.out.println("\nTokens used: " + response.tokenUsage());
})
.onError(error -> {
System.err.println("Error: " + error);
})
.onToolExecuted(toolExecution -> {
System.out.println("Tool: " + toolExecution.request().name());
})
.onRetrieved(contents -> {
// RAG content retrieved
contents.forEach(c -> System.out.println(c.textSegment()));
})
.start();
```
### StreamingChatResponseHandler
**Purpose**: Callback-based streaming without TokenStream.
```java
streamingModel.chat(request, new StreamingChatResponseHandler() {
@Override
public void onPartialResponse(String partialResponse) {
System.out.print(partialResponse);
}
@Override
public void onCompleteResponse(ChatResponse response) {
System.out.println("\nComplete!");
}
@Override
public void onError(Throwable error) {
error.printStackTrace();
}
});
```
## Content Retrieval
### ContentRetriever Interface
**Purpose**: Fetch relevant content for RAG.
```java
interface ContentRetriever {
Content retrieve(Query query);
List<Content> retrieveAll(List<Query> queries);
}
```
### EmbeddingStoreContentRetriever
```java
ContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5) // Default max results
.minScore(0.7) // Similarity threshold
.dynamicMaxResults(query -> 10) // Query-dependent
.dynamicMinScore(query -> 0.8) // Query-dependent
.filter(new IsEqualTo("userId", "123")) // Metadata filter
.dynamicFilter(query -> {...}) // Dynamic filter
.build();
```
### RetrievalAugmentor
**Purpose**: Advanced RAG pipeline with query transformation and re-ranking.
```java
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(new CompressingQueryTransformer(chatModel))
.contentRetriever(contentRetriever)
.contentAggregator(ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.minScore(0.8)
.build())
.build();
// Use with AI Service
var assistant = AiServices.builder(Assistant.class)
.chatModel(chatModel)
.retrievalAugmentor(augmentor)
.build();
```
## Request/Response Models
### ChatRequest
**Purpose**: Build complex chat requests with multiple messages.
```java
ChatRequest request = ChatRequest.builder()
.messages(
SystemMessage.from("You are helpful"),
UserMessage.from("What is AI?"),
AiMessage.from("AI is...")
)
.temperature(0.7)
.maxTokens(500)
.topP(0.95)
.build();
ChatResponse response = chatModel.chat(request);
```
### ChatResponse
**Purpose**: Access chat model responses and metadata.
```java
String content = response.aiMessage().text();
TokenUsage usage = response.tokenUsage();
System.out.println("Tokens: " + usage.totalTokenCount());
System.out.println("Prompt tokens: " + usage.inputTokenCount());
System.out.println("Completion tokens: " + usage.outputTokenCount());
System.out.println("Finish reason: " + response.finishReason());
```
## Query and Content
### Query
**Purpose**: Represent a user query in retrieval context.
```java
// Query object contains:
String text // The query text
Metadata metadata() // Query metadata (e.g., userId)
Object metadata(String key) // Get metadata value
Object metadata(String key, Object defaultValue)
```
### Content
**Purpose**: Retrieved content with metadata.
```java
String textSegment() // Retrieved text
double score() // Relevance score
Metadata metadata() // Content metadata (e.g., source)
Map<String, Object> source() // Original source data
```
## Message Types
### SystemMessage
```java
SystemMessage message = SystemMessage.from("You are a code reviewer");
```
### UserMessage
```java
UserMessage message = UserMessage.from("Review this code");
// With images
UserMessage message = UserMessage.from(
TextContent.from("Analyze this"),
ImageContent.from("http://...", "image/png")
);
```
### AiMessage
```java
AiMessage message = AiMessage.from("Here's my analysis");
// With tool calls
AiMessage message = AiMessage.from(
"Let me calculate",
ToolExecutionResultMessage.from(toolName, result)
);
```
## Configuration Patterns
### Chat Model Configuration
```java
ChatModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini") // Model selection
.temperature(0.7) // Creativity (0-2)
.topP(0.95) // Diversity (0-1)
.topK(40) // Top K tokens
.maxTokens(2000) // Max generation
.frequencyPenalty(0.0) // Reduce repetition
.presencePenalty(0.0) // Reduce topic switching
.seed(42) // Reproducibility
.logRequests(true) // Debug logging
.logResponses(true) // Debug logging
.build();
```
### Embedding Model Configuration
```java
EmbeddingModel embedder = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small")
.dimensions(512) // Custom dimensions
.build();
```
## Best Practices for API Usage
1. **Type Safety**: Always define typed interfaces for type safety at compile time
2. **Separation of Concerns**: Use different interfaces for different domains
3. **Error Handling**: Always implement error handlers for tools
4. **Memory Management**: Choose appropriate memory implementation for use case
5. **Token Optimization**: Use temperature=0 for deterministic tasks
6. **Testing**: Mock ChatModel for unit tests
7. **Logging**: Enable request/response logging in development
8. **Rate Limiting**: Implement backoff strategies for API calls
9. **Caching**: Cache responses for frequently asked questions
10. **Monitoring**: Track token usage for cost management
## Common Patterns
### Factory Pattern for Multiple Assistants
```java
public class AssistantFactory {
static JavaExpert createJavaExpert() {
return AiServices.create(JavaExpert.class, chatModel);
}
static PythonExpert createPythonExpert() {
return AiServices.create(PythonExpert.class, chatModel);
}
}
```
### Decorator Pattern for Enhanced Functionality
```java
public class LoggingAssistant implements Assistant {
private final Assistant delegate;
public String chat(String message) {
logger.info("User: " + message);
String response = delegate.chat(message);
logger.info("Assistant: " + response);
return response;
}
}
```
### Builder Pattern for Complex Configurations
```java
var assistant = AiServices.builder(ComplexAssistant.class)
.chatModel(getChatModel())
.chatMemory(getMemory())
.tools(getTool1(), getTool2())
.contentRetriever(getRetriever())
.build();
```
## Resources
- [LangChain4j Documentation](https://docs.langchain4j.dev)
- [OpenAI API Reference](https://platform.openai.com/docs)
- [LangChain4j GitHub](https://github.com/langchain4j/langchain4j)
- [LangChain4j Examples](https://github.com/langchain4j/langchain4j-examples)

View File

@@ -0,0 +1,393 @@
---
name: langchain4j-mcp-server-patterns
description: Model Context Protocol (MCP) server implementation patterns with LangChain4j. Use when building MCP servers to extend AI capabilities with custom tools, resources, and prompt templates.
category: ai-integration
tags: [langchain4j, mcp, model-context-protocol, tools, resources, prompts, ai-services, java, spring-boot, enterprise]
version: 1.1.0
allowed-tools: Read, Write, Bash, WebFetch
---
# LangChain4j MCP Server Implementation Patterns
Implement Model Context Protocol (MCP) servers with LangChain4j to extend AI capabilities with standardized tools, resources, and prompt templates.
## When to Use
Use this skill when building:
- AI applications requiring external tool integration
- Enterprise MCP servers with multi-domain support (GitHub, databases, APIs)
- Dynamic tool providers with context-aware filtering
- Resource-based data access systems for AI models
- Prompt template servers for standardized AI interactions
- Scalable AI agents with resilient tool execution
- Multi-modal AI applications with diverse data sources
- Spring Boot applications with MCP integration
- Production-ready MCP servers with security and monitoring
## Quick Start
### Basic MCP Server
Create a simple MCP server with one tool:
```java
MCPServer server = MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(new SimpleWeatherToolProvider())
.build();
server.start();
```
### Spring Boot Integration
Configure MCP server in Spring Boot:
```java
@Bean
public MCPSpringConfig mcpServer(List<ToolProvider> tools) {
return MCPSpringConfig.builder()
.tools(tools)
.server(new StdioServer.Builder())
.build();
}
```
## Core Concepts
### MCP Architecture
MCP standardizes AI application connections:
- **Tools**: Executable functions (database queries, API calls)
- **Resources**: Data sources (files, schemas, documentation)
- **Prompts**: Pre-configured templates for tasks
- **Transport**: Communication layer (stdio, HTTP, WebSocket)
```
AI Application ←→ MCP Client ←→ Transport ←→ MCP Server ←→ External Service
```
### Key Components
- **MCPServer**: Main server instance with configuration
- **ToolProvider**: Tool specification and execution interface
- **ResourceListProvider/ResourceReadHandler**: Resource access
- **PromptListProvider/PromptGetHandler**: Template management
- **Transport**: Communication mechanisms (stdio, HTTP)
## Implementation Patterns
### Tool Provider Pattern
Create tools with proper schema validation:
```java
class WeatherToolProvider implements ToolProvider {
@Override
public List<ToolSpecification> listTools() {
return List.of(ToolSpecification.builder()
.name("get_weather")
.description("Get weather for a city")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"city", Map.of("type", "string", "description", "City name")
),
"required", List.of("city")
))
.build());
}
@Override
public String executeTool(String name, String arguments) {
// Parse arguments and execute tool logic
return "Weather data result";
}
}
```
### Resource Provider Pattern
Provide static and dynamic resources:
```java
class CompanyResourceProvider
implements ResourceListProvider, ResourceReadHandler {
@Override
public List<McpResource> listResources() {
return List.of(
McpResource.builder()
.uri("policies")
.name("Company Policies")
.mimeType("text/plain")
.build()
);
}
@Override
public String readResource(String uri) {
return loadResourceContent(uri);
}
}
```
### Prompt Template Pattern
Create reusable prompt templates:
```java
class PromptTemplateProvider
implements PromptListProvider, PromptGetHandler {
@Override
public List<Prompt> listPrompts() {
return List.of(
Prompt.builder()
.name("code-review")
.description("Review code for quality")
.build()
);
}
@Override
public String getPrompt(String name, Map<String, String> args) {
return applyTemplate(name, args);
}
}
```
## Transport Configuration
### Stdio Transport
Local process communication:
```java
McpTransport transport = new StdioMcpTransport.Builder()
.command(List.of("npm", "exec", "@modelcontextprotocol/server-everything"))
.logEvents(true)
.build();
```
### HTTP Transport
Remote server communication:
```java
McpTransport transport = new HttpMcpTransport.Builder()
.sseUrl("http://localhost:3001/sse")
.logRequests(true)
.logResponses(true)
.build();
```
## Client Integration
### MCP Client Setup
Connect to MCP servers:
```java
McpClient client = new DefaultMcpClient.Builder()
.key("my-client")
.transport(transport)
.cacheToolList(true)
.build();
// List available tools
List<ToolSpecification> tools = client.listTools();
```
### Tool Provider Integration
Bridge MCP servers to LangChain4j AI services:
```java
McpToolProvider provider = McpToolProvider.builder()
.mcpClients(mcpClient)
.failIfOneServerFails(false)
.filter((client, tool) -> filterByPermissions(tool))
.build();
// Integrate with AI service
AIAssistant assistant = AiServices.builder(AIAssistant.class)
.chatModel(chatModel)
.toolProvider(provider)
.build();
```
## Security & Best Practices
### Tool Security
Implement secure tool filtering:
```java
McpToolProvider secureProvider = McpToolProvider.builder()
.mcpClients(mcpClient)
.filter((client, tool) -> {
if (tool.name().startsWith("admin_") && !isAdmin()) {
return false;
}
return true;
})
.build();
```
### Resource Security
Apply access controls to resources:
```java
public boolean canAccessResource(String uri, User user) {
return resourceService.hasAccess(uri, user);
}
```
### Error Handling
Implement robust error handling:
```java
try {
String result = mcpClient.executeTool(request);
} catch (McpException e) {
log.error("MCP execution failed: {}", e.getMessage());
return fallbackResult();
}
```
## Advanced Patterns
### Multi-Server Configuration
Configure multiple MCP servers:
```java
@Bean
public List<McpClient> mcpClients(List<ServerConfig> configs) {
return configs.stream()
.map(this::createMcpClient)
.collect(Collectors.toList());
}
@Bean
public McpToolProvider multiServerProvider(List<McpClient> clients) {
return McpToolProvider.builder()
.mcpClients(clients)
.failIfOneServerFails(false)
.build();
}
```
### Dynamic Tool Discovery
Runtime tool filtering based on context:
```java
McpToolProvider contextualProvider = McpToolProvider.builder()
.mcpClients(clients)
.filter((client, tool) -> isToolAllowed(user, tool, context))
.build();
```
### Health Monitoring
Monitor MCP server health:
```java
@Component
public class McpHealthChecker {
@Scheduled(fixedRate = 30000) // 30 seconds
public void checkServers() {
mcpClients.forEach(client -> {
try {
client.listTools();
markHealthy(client.key());
} catch (Exception e) {
markUnhealthy(client.key(), e.getMessage());
}
});
}
}
```
## Configuration
### Application Properties
Configure MCP servers in application.yml:
```yaml
mcp:
servers:
github:
type: docker
command: ["/usr/local/bin/docker", "run", "-e", "GITHUB_TOKEN", "-i", "mcp/github"]
log-events: true
database:
type: stdio
command: ["/usr/bin/npm", "exec", "@modelcontextprotocol/server-sqlite"]
log-events: false
```
### Spring Boot Configuration
Configure MCP with Spring Boot:
```java
@Configuration
@EnableConfigurationProperties(McpProperties.class)
public class McpConfiguration {
@Bean
public MCPServer mcpServer(List<ToolProvider> providers) {
return MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(providers)
.enableLogging(true)
.build();
}
}
```
## Examples
Refer to [examples.md](./references/examples.md) for comprehensive implementation examples including:
- Basic MCP server setup
- Multi-tool enterprise servers
- Resource and prompt providers
- Spring Boot integration
- Error handling patterns
- Security implementations
## API Reference
Complete API documentation is available in [api-reference.md](./references/api-reference.md) covering:
- Core MCP classes and interfaces
- Transport configuration
- Client and server patterns
- Error handling strategies
- Configuration management
- Testing and validation
## Best Practices
1. **Resource Management**: Always close MCP clients properly using try-with-resources
2. **Error Handling**: Implement graceful degradation when servers fail
3. **Security**: Use tool filtering and resource access controls
4. **Performance**: Enable caching and optimize tool execution
5. **Monitoring**: Implement health checks and observability
6. **Testing**: Create comprehensive test suites with mocks
7. **Documentation**: Document tools, resources, and prompts clearly
8. **Configuration**: Use structured configuration for maintainability
## References
- [LangChain4j Documentation](https://langchain4j.com/docs/)
- [Model Context Protocol Specification](https://modelcontextprotocol.org/)
- [API Reference](./references/api-reference.md)
- [Examples](./references/examples.md)

View File

@@ -0,0 +1,315 @@
package com.example.mcp;
import dev.langchain4j.mcp.*;
import dev.langchain4j.mcp.transport.*;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import java.util.List;
// Helper imports
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Map;
/**
* Template for creating MCP servers with LangChain4j.
*
* This template provides a starting point for building MCP servers with:
* - Tool providers
* - Resource providers
* - Prompt providers
* - Spring Boot integration
* - Configuration management
*/
@SpringBootApplication
public class MCPServerTemplate {
public static void main(String[] args) {
SpringApplication.run(MCPServerTemplate.class, args);
}
/**
* Configure and build the main MCP server instance.
*/
@Bean
public MCPServer mcpServer(
List<ToolProvider> toolProviders,
List<ResourceListProvider> resourceProviders,
List<PromptListProvider> promptProviders) {
return MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(toolProviders)
.addResourceProvider(resourceProviders)
.addPromptProvider(promptProviders)
.enableLogging(true)
.build();
}
/**
* Configure MCP clients for connecting to external MCP servers.
*/
@Bean
public McpClient mcpClient() {
StdioMcpTransport transport = new StdioMcpTransport.Builder()
.command(List.of("npm", "exec", "@modelcontextprotocol/server-everything@0.6.2"))
.logEvents(true)
.build();
return new DefaultMcpClient.Builder()
.key("template-client")
.transport(transport)
.cacheToolList(true)
.build();
}
/**
* Configure MCP tool provider for AI services integration.
*/
@Bean
public McpToolProvider mcpToolProvider(McpClient mcpClient) {
return McpToolProvider.builder()
.mcpClients(mcpClient)
.failIfOneServerFails(false)
.build();
}
}
/**
* Example tool provider implementing a simple calculator.
*/
class CalculatorToolProvider implements ToolProvider {
@Override
public List<ToolSpecification> listTools() {
return List.of(
ToolSpecification.builder()
.name("add")
.description("Add two numbers")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"a", Map.of("type", "number", "description", "First number"),
"b", Map.of("type", "number", "description", "Second number")
),
"required", List.of("a", "b")
))
.build(),
ToolSpecification.builder()
.name("multiply")
.description("Multiply two numbers")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"a", Map.of("type", "number", "description", "First number"),
"b", Map.of("type", "number", "description", "Second number")
),
"required", List.of("a", "b")
))
.build()
);
}
@Override
public String executeTool(String name, String arguments) {
try {
// Parse JSON arguments
ObjectMapper mapper = new ObjectMapper();
JsonNode argsNode = mapper.readTree(arguments);
double a = argsNode.get("a").asDouble();
double b = argsNode.get("b").asDouble();
switch (name) {
case "add":
return String.valueOf(a + b);
case "multiply":
return String.valueOf(a * b);
default:
throw new UnsupportedOperationException("Unknown tool: " + name);
}
} catch (Exception e) {
return "Error executing tool: " + e.getMessage();
}
}
}
/**
* Example resource provider for static company information.
*/
class CompanyResourceProvider implements ResourceListProvider, ResourceReadHandler {
@Override
public List<McpResource> listResources() {
return List.of(
McpResource.builder()
.uri("company-info")
.name("Company Information")
.description("Basic company details and contact information")
.mimeType("text/plain")
.build(),
McpResource.builder()
.uri("policies")
.name("Company Policies")
.description("Company policies and procedures")
.mimeType("text/markdown")
.build()
);
}
@Override
public String readResource(String uri) {
switch (uri) {
case "company-info":
return loadCompanyInfo();
case "policies":
return loadPolicies();
default:
throw new ResourceNotFoundException("Resource not found: " + uri);
}
}
private String loadCompanyInfo() {
return """
Company Information:
===================
Name: Example Corporation
Founded: 2020
Industry: Technology
Employees: 100+
Contact:
- Email: info@example.com
- Phone: +1-555-0123
- Website: https://example.com
Mission: To deliver innovative AI solutions
""";
}
private String loadPolicies() {
return """
Company Policies:
=================
1. Code of Conduct
- Treat all team members with respect
- Maintain professional communication
- Report any concerns to management
2. Security Policy
- Use strong passwords
- Enable 2FA when available
- Report security incidents immediately
3. Work Environment
- Flexible working hours
- Remote work options
- Support for continuous learning
""";
}
}
/**
* Example prompt template provider for common AI tasks.
*/
class PromptTemplateProvider implements PromptListProvider, PromptGetHandler {
@Override
public List<Prompt> listPrompts() {
return List.of(
Prompt.builder()
.name("code-review")
.description("Review code for quality, security, and best practices")
.build(),
Prompt.builder()
.name("documentation-generation")
.description("Generate technical documentation from code")
.build(),
Prompt.builder()
.name("bug-analysis")
.description("Analyze and explain potential bugs in code")
.build()
);
}
@Override
public String getPrompt(String name, Map<String, String> arguments) {
switch (name) {
case "code-review":
return createCodeReviewPrompt(arguments);
case "documentation-generation":
return createDocumentationPrompt(arguments);
case "bug-analysis":
return createBugAnalysisPrompt(arguments);
default:
throw new PromptNotFoundException("Prompt not found: " + name);
}
}
private String createCodeReviewPrompt(Map<String, String> args) {
String code = args.getOrDefault("code", "");
String language = args.getOrDefault("language", "unknown");
return String.format("""
Review the following %s code for quality, security, and best practices:
```%s
%s
```
Please analyze:
1. Code quality and readability
2. Security vulnerabilities
3. Performance optimizations
4. Best practices compliance
5. Error handling
Provide specific recommendations for improvements.
""", language, language, code);
}
private String createDocumentationPrompt(Map<String, String> args) {
String code = args.getOrDefault("code", "");
String component = args.getOrDefault("component", "function");
return String.format("""
Generate comprehensive documentation for the following %s:
```%s
%s
```
Include:
1. Function/method signatures
2. Parameters and return values
3. Purpose and usage examples
4. Dependencies and requirements
5. Error conditions and handling
""", component, "java", code);
}
private String createBugAnalysisPrompt(Map<String, String> args) {
String code = args.getOrDefault("code", "");
return String.format("""
Analyze the following code for potential bugs and issues:
```java
%s
```
Look for:
1. Null pointer exceptions
2. Logic errors
3. Resource leaks
4. Race conditions
5. Edge cases
6. Type mismatches
Explain each issue found and suggest fixes.
""", code);
}
}

View File

@@ -0,0 +1,435 @@
# LangChain4j MCP Server API Reference
This document provides comprehensive API documentation for implementing MCP servers with LangChain4j.
## Core MCP Classes
### McpClient Interface
Primary interface for communicating with MCP servers.
**Key Methods:**
```java
// Tool Management
List<ToolSpecification> listTools();
String executeTool(ToolExecutionRequest request);
// Resource Management
List<McpResource> listResources();
String getResource(String uri);
List<McpResourceTemplate> listResourceTemplates();
// Prompt Management
List<Prompt> listPrompts();
String getPrompt(String name);
// Lifecycle Management
void close();
```
### DefaultMcpClient.Builder
Builder for creating MCP clients with configuration options.
**Configuration Methods:**
```java
McpClient client = new DefaultMcpClient.Builder()
.key("unique-client-id") // Unique identifier
.transport(transport) // Transport mechanism
.cacheToolList(true) // Enable tool caching
.logMessageHandler(handler) // Custom logging
.build();
```
### McpToolProvider.Builder
Builder for creating tool providers that bridge MCP servers to LangChain4j AI services.
**Configuration Methods:**
```java
McpToolProvider provider = McpToolProvider.builder()
.mcpClients(client1, client2) // Add MCP clients
.failIfOneServerFails(false) // Configure failure handling
.filterToolNames("tool1", "tool2") // Filter by names
.filter((client, tool) -> logic) // Custom filtering
.build();
```
## Transport Configuration
### StdioMcpTransport.Builder
For local process communication with npm packages or Docker containers.
```java
McpTransport transport = new StdioMcpTransport.Builder()
.command(List.of("npm", "exec", "@modelcontextprotocol/server-everything@0.6.2"))
.logEvents(true)
.build();
```
### HttpMcpTransport.Builder
For HTTP-based communication with remote MCP servers.
```java
McpTransport transport = new HttpMcpTransport.Builder()
.sseUrl("http://localhost:3001/sse")
.logRequests(true)
.logResponses(true)
.build();
```
### StreamableHttpMcpTransport.Builder
For streamable HTTP transport with enhanced performance.
```java
McpTransport transport = new StreamableHttpMcpTransport.Builder()
.url("http://localhost:3001/mcp")
.logRequests(true)
.logResponses(true)
.build();
```
## AI Service Integration
### AiServices.builder()
Create AI services integrated with MCP tool providers.
**Integration Methods:**
```java
AIAssistant assistant = AiServices.builder(AIAssistant.class)
.chatModel(chatModel)
.toolProvider(toolProvider)
.chatMemoryProvider(memoryProvider)
.build();
```
## Error Handling and Management
### Exception Handling
Handle MCP-specific exceptions gracefully:
```java
try {
String result = mcpClient.executeTool(request);
} catch (McpException e) {
log.error("MCP execution failed: {}", e.getMessage());
// Implement fallback logic
}
```
### Retry and Resilience
Implement retry logic for unreliable MCP servers:
```java
RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(3)
.exponentialBackoff(1000, 2, 10000)
.build();
String result = retryTemplate.execute(context ->
mcpClient.executeTool(request));
```
## Configuration Properties
### Application Configuration
```yaml
mcp:
fail-if-one-server-fails: false
cache-tools: true
servers:
github:
type: docker
command: ["/usr/local/bin/docker", "run", "-e", "GITHUB_TOKEN", "-i", "mcp/github"]
log-events: true
database:
type: stdio
command: ["/usr/bin/npm", "exec", "@modelcontextprotocol/server-sqlite"]
log-events: false
```
### Spring Boot Configuration
```java
@Configuration
@EnableConfigurationProperties(McpProperties.class)
public class McpConfiguration {
@Bean
public List<McpClient> mcpClients(McpProperties properties) {
return properties.getServers().entrySet().stream()
.map(entry -> createMcpClient(entry.getKey(), entry.getValue()))
.collect(Collectors.toList());
}
@Bean
public McpToolProvider mcpToolProvider(List<McpClient> mcpClients, McpProperties properties) {
return McpToolProvider.builder()
.mcpClients(mcpClients)
.failIfOneServerFails(properties.isFailIfOneServerFails())
.build();
}
}
```
## Tool Specification and Execution
### Tool Specification
Define tools with proper schema:
```java
ToolSpecification toolSpec = ToolSpecification.builder()
.name("database_query")
.description("Execute SQL queries against the database")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"sql", Map.of(
"type", "string",
"description", "SQL query to execute"
)
)
))
.build();
```
### Tool Execution
Execute tools with structured requests:
```java
ToolExecutionRequest request = ToolExecutionRequest.builder()
.name("database_query")
.arguments("{\"sql\": \"SELECT * FROM users LIMIT 10\"}")
.build();
String result = mcpClient.executeTool(request);
```
## Resource Handling
### Resource Access
Access and utilize MCP resources:
```java
// List available resources
List<McpResource> resources = mcpClient.listResources();
// Get specific resource content
String content = mcpClient.getResource("resource://schema/database");
// Work with resource templates
List<McpResourceTemplate> templates = mcpClient.listResourceTemplates();
```
### Resource as Tools
Convert MCP resources to tools automatically:
```java
DefaultMcpResourcesAsToolsPresenter presenter =
new DefaultMcpResourcesAsToolsPresenter();
mcpToolProvider.provideTools(presenter);
// Adds 'list_resources' and 'get_resource' tools automatically
```
## Security and Filtering
### Tool Filtering
Implement security-conscious tool filtering:
```java
McpToolProvider secureProvider = McpToolProvider.builder()
.mcpClients(mcpClient)
.filter((client, tool) -> {
// Check user permissions
if (tool.name().startsWith("admin_") && !currentUser.hasRole("ADMIN")) {
return false;
}
return true;
})
.build();
```
### Resource Security
Apply security controls to resource access:
```java
public boolean canAccessResource(String uri, User user) {
if (uri.contains("sensitive/") && !user.hasRole("ADMIN")) {
return false;
}
return true;
}
```
## Performance Optimization
### Caching Strategy
Implement intelligent caching:
```java
// Enable tool caching for performance
McpClient client = new DefaultMcpClient.Builder()
.transport(transport)
.cacheToolList(true)
.build();
// Periodic cache refresh
@Scheduled(fixedRate = 300000) // 5 minutes
public void refreshToolCache() {
mcpClients.forEach(client -> {
try {
client.invalidateCache();
client.listTools(); // Preload cache
} catch (Exception e) {
log.warn("Cache refresh failed: {}", e.getMessage());
}
});
}
```
### Connection Pooling
Optimize connection management:
```java
@Bean
public Executor mcpExecutor() {
return Executors.newFixedThreadPool(10); // Dedicated thread pool
}
```
## Testing and Validation
### Mock Configuration
Setup for testing:
```java
@TestConfiguration
public class MockMcpConfiguration {
@Bean
@Primary
public McpClient mockMcpClient() {
McpClient mock = Mockito.mock(McpClient.class);
when(mock.listTools()).thenReturn(List.of(
ToolSpecification.builder()
.name("test_tool")
.description("Test tool")
.build()
));
when(mock.executeTool(any(ToolExecutionRequest.class)))
.thenReturn("Mock result");
return mock;
}
}
```
### Integration Testing
Test MCP integrations:
```java
@SpringBootTest
class McpIntegrationTest {
@Autowired
private AIAssistant assistant;
@Test
void shouldExecuteToolsSuccessfully() {
String response = assistant.chat("Execute test tool");
assertThat(response).contains("Mock result");
}
}
```
## Monitoring and Observability
### Health Checks
Monitor MCP server health:
```java
@Component
public class McpHealthChecker {
@EventListener
@Async
public void checkHealth() {
mcpClients.forEach(client -> {
try {
client.listTools(); // Simple health check
healthRegistry.markHealthy(client.key());
} catch (Exception e) {
healthRegistry.markUnhealthy(client.key(), e.getMessage());
}
});
}
}
```
### Metrics Collection
Collect execution metrics:
```java
@Bean
public Counter toolExecutionCounter(MeterRegistry meterRegistry) {
return meterRegistry.counter("mcp.tool.execution", "type", "total");
}
@Bean
public Timer toolExecutionTimer(MeterRegistry meterRegistry) {
return meterRegistry.timer("mcp.tool.execution.time");
}
```
## Migration and Versioning
### Version Compatibility
Handle version compatibility:
```java
public class VersionedMcpClient {
public boolean isCompatible(String serverVersion) {
return semanticVersionChecker.isCompatible(
REQUIRED_MCP_VERSION, serverVersion);
}
public McpClient createClient(ServerConfig config) {
if (!isCompatible(config.getVersion())) {
throw new IncompatibleVersionException(
"Server version " + config.getVersion() +
" is not compatible with required " + REQUIRED_MCP_VERSION);
}
return new DefaultMcpClient.Builder()
.transport(createTransport(config))
.build();
}
}
```
This API reference provides the complete foundation for implementing MCP servers and clients with LangChain4j, covering all major aspects from basic setup to advanced enterprise patterns.

View File

@@ -0,0 +1,592 @@
# LangChain4j MCP Server Implementation Examples
This document provides comprehensive, production-ready examples for implementing MCP servers with LangChain4j.
## Basic MCP Server Setup
### Simple MCP Server Implementation
Create a basic MCP server with single tool functionality:
```java
import dev.langchain4j.mcp.MCPServer;
import dev.langchain4j.mcp.ToolProvider;
import dev.langchain4j.mcp.server.StdioServer;
public class BasicMcpServer {
public static void main(String[] args) {
MCPServer server = MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(new SimpleWeatherToolProvider())
.build();
// Start the server
server.start();
}
}
class SimpleWeatherToolProvider implements ToolProvider {
@Override
public List<ToolSpecification> listTools() {
return List.of(ToolSpecification.builder()
.name("get_weather")
.description("Get weather information for a city")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"city", Map.of(
"type", "string",
"description", "City name to get weather for"
)
),
"required", List.of("city")
))
.build());
}
@Override
public String executeTool(String name, String arguments) {
if ("get_weather".equals(name)) {
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
String city = args.get("city").getAsString();
// Simulate weather API call
return String.format("Weather in %s: Sunny, 22°C", city);
}
throw new UnsupportedOperationException("Unknown tool: " + name);
}
}
```
### Spring Boot MCP Server Integration
Integrate MCP server with Spring Boot application:
```java
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
@SpringBootApplication
public class McpSpringBootApplication {
public static void main(String[] args) {
SpringApplication.run(McpSpringBootApplication.class, args);
}
@Bean
public MCPServer mcpServer() {
return MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(new DatabaseToolProvider())
.addToolProvider(new FileToolProvider())
.build();
}
}
@Component
class DatabaseToolProvider implements ToolProvider {
@Override
public List<ToolSpecification> listTools() {
return List.of(ToolSpecification.builder()
.name("query_database")
.description("Execute SQL queries against the database")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"sql", Map.of(
"type", "string",
"description", "SQL query to execute"
)
),
"required", List.of("sql")
))
.build());
}
@Override
public String executeTool(String name, String arguments) {
if ("query_database".equals(name)) {
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
String sql = args.get("sql").getAsString();
// Execute database query
return executeDatabaseQuery(sql);
}
throw new UnsupportedOperationException("Unknown tool: " + name);
}
private String executeDatabaseQuery(String sql) {
// Implementation using Spring Data JPA
try {
return jdbcTemplate.queryForObject(sql, String.class);
} catch (Exception e) {
return "Error executing query: " + e.getMessage();
}
}
}
```
## Multi-Tool MCP Server
### Enterprise MCP Server with Multiple Tools
Create a comprehensive MCP server with multiple tool providers:
```java
@Component
public class EnterpriseMcpServer {
@Bean
public MCPServer enterpriseMcpServer(
GitHubToolProvider githubToolProvider,
DatabaseToolProvider databaseToolProvider,
FileToolProvider fileToolProvider,
EmailToolProvider emailToolProvider) {
return MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(githubToolProvider)
.addToolProvider(databaseToolProvider)
.addToolProvider(fileToolProvider)
.addToolProvider(emailToolProvider)
.enableLogging(true)
.setLogHandler(new CustomLogHandler())
.build();
}
}
@Component
class GitHubToolProvider implements ToolProvider {
@Override
public List<ToolSpecification> listTools() {
return List.of(
ToolSpecification.builder()
.name("get_issue")
.description("Get GitHub issue details")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"owner", Map.of(
"type", "string",
"description", "Repository owner"
),
"repo", Map.of(
"type", "string",
"description", "Repository name"
),
"issue_number", Map.of(
"type", "integer",
"description", "Issue number"
)
),
"required", List.of("owner", "repo", "issue_number")
))
.build(),
ToolSpecification.builder()
.name("list_issues")
.description("List GitHub issues for a repository")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"owner", Map.of(
"type", "string",
"description", "Repository owner"
),
"repo", Map.of(
"type", "string",
"description", "Repository name"
),
"state", Map.of(
"type", "string",
"description", "Issue state: open, closed, all",
"enum", List.of("open", "closed", "all")
)
),
"required", List.of("owner", "repo")
))
.build()
);
}
@Override
public String executeTool(String name, String arguments) {
switch (name) {
case "get_issue":
return getIssueDetails(arguments);
case "list_issues":
return listRepositoryIssues(arguments);
default:
throw new UnsupportedOperationException("Unknown tool: " + name);
}
}
private String getIssueDetails(String arguments) {
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
String owner = args.get("owner").getAsString();
String repo = args.get("repo").getAsString();
int issueNumber = args.get("issue_number").getAsInt();
// Call GitHub API
GitHubIssue issue = githubService.getIssue(owner, repo, issueNumber);
return "Issue #" + issueNumber + ": " + issue.getTitle() +
"\nState: " + issue.getState() +
"\nCreated: " + issue.getCreatedAt();
}
private String listRepositoryIssues(String arguments) {
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
String owner = args.get("owner").getAsString();
String repo = args.get("repo").getAsString();
String state = args.has("state") ? args.get("state").getAsString() : "open";
List<GitHubIssue> issues = githubService.listIssues(owner, repo, state);
return issues.stream()
.map(issue -> "#%d: %s (%s)".formatted(issue.getNumber(), issue.getTitle(), issue.getState()))
.collect(Collectors.joining("\n"));
}
}
```
## Resource Provider Implementation
### Static Resource Provider
Provide static resources for context enhancement:
```java
@Component
class StaticResourceProvider implements ResourceListProvider, ResourceReadHandler {
private final Map<String, String> resources = new HashMap<>();
public StaticResourceProvider() {
// Initialize with static resources
resources.put("company-policies", loadCompanyPolicies());
resources.put("api-documentation", loadApiDocumentation());
resources.put("best-practices", loadBestPractices());
}
@Override
public List<McpResource> listResources() {
return resources.keySet().stream()
.map(uri -> McpResource.builder()
.uri(uri)
.name(uri.replace("-", " "))
.description("Documentation resource")
.mimeType("text/plain")
.build())
.collect(Collectors.toList());
}
@Override
public String readResource(String uri) {
if (!resources.containsKey(uri)) {
throw new ResourceNotFoundException("Resource not found: " + uri);
}
return resources.get(uri);
}
private String loadCompanyPolicies() {
// Load company policies from file or database
return "Company Policies:\n1. Work hours: 9-5\n2. Security compliance\n3. Data privacy";
}
private String loadApiDocumentation() {
// Load API documentation
return "API Documentation:\nGET /api/users - List users\nPOST /api/users - Create user";
}
}
```
### Dynamic Resource Provider
Create dynamic resources that update in real-time:
```java
@Component
class DynamicResourceProvider implements ResourceListProvider, ResourceReadHandler {
@Autowired
private MetricsService metricsService;
@Override
public List<McpResource> listResources() {
return List.of(
McpResource.builder()
.uri("system-metrics")
.name("System Metrics")
.description("Real-time system performance metrics")
.mimeType("application/json")
.build(),
McpResource.builder()
.uri("user-analytics")
.name("User Analytics")
.description("User behavior and usage statistics")
.mimeType("application/json")
.build()
);
}
@Override
public String readResource(String uri) {
switch (uri) {
case "system-metrics":
return metricsService.getCurrentSystemMetrics();
case "user-analytics":
return metricsService.getUserAnalytics();
default:
throw new ResourceNotFoundException("Resource not found: " + uri);
}
}
}
```
## Prompt Template Provider
### Prompt Template Server
Create prompt templates for common AI tasks:
```java
@Component
class PromptTemplateProvider implements PromptListProvider, PromptGetHandler {
private final Map<String, PromptTemplate> templates = new HashMap<>();
public PromptTemplateProvider() {
templates.put("code-review", PromptTemplate.builder()
.name("Code Review")
.description("Review code for quality, security, and best practices")
.template("Review the following code for:\n" +
"1. Code quality and readability\n" +
"2. Security vulnerabilities\n" +
"3. Performance optimizations\n" +
"4. Best practices compliance\n\n" +
"Code:\n" +
"```{code}```\n\n" +
"Provide a detailed analysis with specific recommendations.")
.build());
templates.put("documentation-generation", PromptTemplate.builder()
.name("Documentation Generator")
.description("Generate technical documentation from code")
.template("Generate comprehensive documentation for the following code:\n" +
"{code}\n\n" +
"Include:\n" +
"1. Function/method signatures\n" +
"2. Parameters and return values\n" +
"3. Purpose and usage examples\n" +
"4. Dependencies and requirements")
.build());
}
@Override
public List<Prompt> listPrompts() {
return templates.values().stream()
.map(template -> Prompt.builder()
.name(template.getName())
.description(template.getDescription())
.build())
.collect(Collectors.toList());
}
@Override
public String getPrompt(String name, Map<String, String> arguments) {
PromptTemplate template = templates.get(name);
if (template == null) {
throw new PromptNotFoundException("Prompt not found: " + name);
}
// Replace template variables
String content = template.getTemplate();
for (Map.Entry<String, String> entry : arguments.entrySet()) {
content = content.replace("{" + entry.getKey() + "}", entry.getValue());
}
return content;
}
}
```
## Error Handling and Validation
### Robust Error Handling
Implement comprehensive error handling and validation:
```java
@Component
class RobustToolProvider implements ToolProvider {
@Override
public List<ToolSpecification> listTools() {
return List.of(ToolSpecification.builder()
.name("secure_data_access")
.description("Access sensitive data with proper validation")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"data_type", Map.of(
"type", "string",
"description", "Type of data to access",
"enum", List.of("user_data", "system_data", "admin_data")
),
"user_id", Map.of(
"type", "string",
"description", "User ID requesting access"
)
),
"required", List.of("data_type", "user_id")
))
.build());
}
@Override
public String executeTool(String name, String arguments) {
if ("secure_data_access".equals(name)) {
try {
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
String dataType = args.get("data_type").getAsString();
String userId = args.get("user_id").getAsString();
// Validate user permissions
if (!hasPermission(userId, dataType)) {
return "Access denied: Insufficient permissions";
}
// Get data securely
return getSecureData(dataType, userId);
} catch (JsonParseException e) {
return "Invalid JSON format: " + e.getMessage();
} catch (Exception e) {
return "Error accessing data: " + e.getMessage();
}
}
throw new UnsupportedOperationException("Unknown tool: " + name);
}
private boolean hasPermission(String userId, String dataType) {
// Implement permission checking
if ("admin_data".equals(dataType)) {
return userRepository.isAdmin(userId);
}
return true;
}
private String getSecureData(String dataType, String userId) {
// Implement secure data retrieval
if ("user_data".equals(dataType)) {
return userDataService.getUserData(userId);
}
return "Data not available";
}
}
```
## Advanced Server Configuration
### Multi-Transport Server Configuration
Configure MCP server with multiple transport options:
```java
@Configuration
public class AdvancedMcpConfiguration {
@Bean
public MCPServer advancedMcpServer(
List<ToolProvider> toolProviders,
List<ResourceListProvider> resourceProviders,
List<PromptListProvider> promptProviders) {
return MCPServer.builder()
.server(new StdioServer.Builder())
.addToolProvider(toolProviders)
.addResourceProvider(resourceProviders)
.addPromptProvider(promptProviders)
.enableLogging(true)
.setLogHandler(new StructuredLogHandler())
.enableHealthChecks(true)
.setHealthCheckInterval(30) // seconds
.setMaxConcurrentRequests(100)
.setRequestTimeout(30) // seconds
.build();
}
@Bean
public HttpMcpTransport httpTransport() {
return new HttpMcpTransport.Builder()
.sseUrl("http://localhost:8080/mcp/sse")
.logRequests(true)
.logResponses(true)
.setCorsEnabled(true)
.setAllowedOrigins(List.of("http://localhost:3000"))
.build();
}
}
```
## Client Integration Patterns
### Multi-Client MCP Integration
Integrate with multiple MCP servers for comprehensive functionality:
```java
@Service
public class MultiMcpIntegrationService {
private final List<McpClient> mcpClients;
private final ChatModel chatModel;
private final McpToolProvider toolProvider;
public MultiMcpIntegrationService(List<McpClient> mcpClients, ChatModel chatModel) {
this.mcpClients = mcpClients;
this.chatModel = chatModel;
// Create tool provider with multiple MCP clients
this.toolProvider = McpToolProvider.builder()
.mcpClients(mcpClients)
.failIfOneServerFails(false) // Continue with available servers
.filter((client, tool) -> {
// Implement cross-server tool filtering
return !tool.name().startsWith("deprecated_");
})
.build();
}
public String processUserQuery(String userId, String query) {
// Create AI service with multiple MCP integrations
AIAssistant assistant = AiServices.builder(AIAssistant.class)
.chatModel(chatModel)
.toolProvider(toolProvider)
.chatMemoryProvider(memoryProvider)
.build();
return assistant.chat(userId, query);
}
public List<ToolSpecification> getAvailableTools() {
return mcpClients.stream()
.flatMap(client -> {
try {
return client.listTools().stream();
} catch (Exception e) {
log.warn("Failed to list tools from client {}: {}", client.key(), e.getMessage());
return Stream.empty();
}
})
.distinct()
.collect(Collectors.toList());
}
}
```
These comprehensive examples provide a solid foundation for implementing MCP servers with LangChain4j, covering everything from basic setup to advanced enterprise patterns.

View File

@@ -0,0 +1,349 @@
---
name: langchain4j-rag-implementation-patterns
description: Implement Retrieval-Augmented Generation (RAG) systems with LangChain4j. Build document ingestion pipelines, embedding stores, vector search strategies, and knowledge-enhanced AI applications. Use when creating question-answering systems over document collections or AI assistants with external knowledge bases.
allowed-tools: Read, Write, Bash
category: ai-development
tags: [langchain4j, rag, retrieval-augmented-generation, embedding, vector-search, document-ingestion, java]
version: 1.1.0
---
# LangChain4j RAG Implementation Patterns
## When to Use This Skill
Use this skill when:
- Building knowledge-based AI applications requiring external document access
- Implementing question-answering systems over large document collections
- Creating AI assistants with access to company knowledge bases
- Building semantic search capabilities for document repositories
- Implementing chat systems that reference specific information sources
- Creating AI applications requiring source attribution
- Building domain-specific AI systems with curated knowledge
- Implementing hybrid search combining vector similarity with traditional search
- Creating AI applications requiring real-time document updates
- Building multi-modal RAG systems with text, images, and other content types
## Overview
Implement complete Retrieval-Augmented Generation (RAG) systems with LangChain4j. RAG enhances language models by providing relevant context from external knowledge sources, improving accuracy and reducing hallucinations.
## Instructions
### Initialize RAG Project
Create a new Spring Boot project with required dependencies:
**pom.xml**:
```xml
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring-boot-starter</artifactId>
<version>1.8.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>1.8.0</version>
</dependency>
```
### Setup Document Ingestion
Configure document loading and processing:
```java
@Configuration
public class RAGConfiguration {
@Bean
public EmbeddingModel embeddingModel() {
return OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small")
.build();
}
@Bean
public EmbeddingStore<TextSegment> embeddingStore() {
return new InMemoryEmbeddingStore<>();
}
}
```
Create document ingestion service:
```java
@Service
@RequiredArgsConstructor
public class DocumentIngestionService {
private final EmbeddingModel embeddingModel;
private final EmbeddingStore<TextSegment> embeddingStore;
public void ingestDocument(String filePath, Map<String, Object> metadata) {
Document document = FileSystemDocumentLoader.loadDocument(filePath);
document.metadata().putAll(metadata);
DocumentSplitter splitter = DocumentSplitters.recursive(
500, 50, new OpenAiTokenCountEstimator("text-embedding-3-small")
);
List<TextSegment> segments = splitter.split(document);
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
embeddingStore.addAll(embeddings, segments);
}
}
```
### Configure Content Retrieval
Setup content retrieval with filtering:
```java
@Configuration
public class ContentRetrieverConfiguration {
@Bean
public ContentRetriever contentRetriever(
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel) {
return EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.minScore(0.7)
.build();
}
}
```
### Create RAG-Enabled AI Service
Define AI service with context retrieval:
```java
interface KnowledgeAssistant {
@SystemMessage("""
You are a knowledgeable assistant with access to a comprehensive knowledge base.
When answering questions:
1. Use the provided context from the knowledge base
2. If information is not in the context, clearly state this
3. Provide accurate, helpful responses
4. When possible, reference specific sources
5. If the context is insufficient, ask for clarification
""")
String answerQuestion(String question);
}
@Service
@RequiredArgsConstructor
public class KnowledgeService {
private final KnowledgeAssistant assistant;
public KnowledgeService(ChatModel chatModel, ContentRetriever contentRetriever) {
this.assistant = AiServices.builder(KnowledgeAssistant.class)
.chatModel(chatModel)
.contentRetriever(contentRetriever)
.build();
}
public String answerQuestion(String question) {
return assistant.answerQuestion(question);
}
}
```
## Examples
### Basic Document Processing
```java
public class BasicRAGExample {
public static void main(String[] args) {
var embeddingStore = new InMemoryEmbeddingStore<TextSegment>();
var embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small")
.build();
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
ingestor.ingest(Document.from("Spring Boot is a framework for building Java applications with minimal configuration."));
var retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.build();
}
}
```
### Multi-Domain Assistant
```java
interface MultiDomainAssistant {
@SystemMessage("""
You are an expert assistant with access to multiple knowledge domains:
- Technical documentation
- Company policies
- Product information
- Customer support guides
Tailor your response based on the type of question and available context.
Always indicate which domain the information comes from.
""")
String answerQuestion(@MemoryId String userId, String question);
}
```
### Hierarchical RAG
```java
@Service
@RequiredArgsConstructor
public class HierarchicalRAGService {
private final EmbeddingStore<TextSegment> chunkStore;
private final EmbeddingStore<TextSegment> summaryStore;
private final EmbeddingModel embeddingModel;
public String performHierarchicalRetrieval(String query) {
List<EmbeddingMatch<TextSegment>> summaryMatches = searchSummaries(query);
List<TextSegment> relevantChunks = new ArrayList<>();
for (EmbeddingMatch<TextSegment> summaryMatch : summaryMatches) {
String documentId = summaryMatch.embedded().metadata().getString("documentId");
List<EmbeddingMatch<TextSegment>> chunkMatches = searchChunksInDocument(query, documentId);
chunkMatches.stream()
.map(EmbeddingMatch::embedded)
.forEach(relevantChunks::add);
}
return generateResponseWithChunks(query, relevantChunks);
}
}
```
## Best Practices
### Document Segmentation
- Use recursive splitting with 500-1000 token chunks for most applications
- Maintain 20-50 token overlap between chunks for context preservation
- Consider document structure (headings, paragraphs) when splitting
- Use token-aware splitters for optimal embedding generation
### Metadata Strategy
- Include rich metadata for filtering and attribution:
- User and tenant identifiers for multi-tenancy
- Document type and category classification
- Creation and modification timestamps
- Version and author information
- Confidentiality and access level tags
### Query Processing
- Implement query preprocessing and cleaning
- Consider query expansion for better recall
- Apply dynamic filtering based on user context
- Use re-ranking for improved result quality
### Performance Optimization
- Cache embeddings for repeated queries
- Use batch embedding generation for bulk operations
- Implement pagination for large result sets
- Consider asynchronous processing for long operations
## Common Patterns
### Simple RAG Pipeline
```java
@RequiredArgsConstructor
@Service
public class SimpleRAGPipeline {
private final EmbeddingModel embeddingModel;
private final EmbeddingStore<TextSegment> embeddingStore;
private final ChatModel chatModel;
public String answerQuestion(String question) {
Embedding queryEmbedding = embeddingModel.embed(question).content();
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(3)
.build();
List<TextSegment> segments = embeddingStore.search(request).matches().stream()
.map(EmbeddingMatch::embedded)
.collect(Collectors.toList());
String context = segments.stream()
.map(TextSegment::text)
.collect(Collectors.joining("\n\n"));
return chatModel.generate(context + "\n\nQuestion: " + question + "\nAnswer:");
}
}
```
### Hybrid Search (Vector + Keyword)
```java
@Service
@RequiredArgsConstructor
public class HybridSearchService {
private final EmbeddingStore<TextSegment> vectorStore;
private final FullTextSearchEngine keywordEngine;
private final EmbeddingModel embeddingModel;
public List<Content> hybridSearch(String query, int maxResults) {
// Vector search
List<Content> vectorResults = performVectorSearch(query, maxResults);
// Keyword search
List<Content> keywordResults = performKeywordSearch(query, maxResults);
// Combine and re-rank using RRF algorithm
return combineResults(vectorResults, keywordResults, maxResults);
}
}
```
## Troubleshooting
### Common Issues
**Poor Retrieval Results**
- Check document chunk size and overlap settings
- Verify embedding model compatibility
- Ensure metadata filters are not too restrictive
- Consider adding re-ranking step
**Slow Performance**
- Use cached embeddings for frequent queries
- Optimize database indexing for vector stores
- Implement pagination for large datasets
- Consider async processing for bulk operations
**High Memory Usage**
- Use disk-based embedding stores for large datasets
- Implement proper pagination and filtering
- Clean up unused embeddings periodically
- Monitor and optimize chunk sizes
## References
- [API Reference](references/references.md) - Complete API documentation and interfaces
- [Examples](references/examples.md) - Production-ready examples and patterns
- [Official LangChain4j Documentation](https://docs.langchain4j.dev/)

View File

@@ -0,0 +1,482 @@
# LangChain4j RAG Implementation - Practical Examples
Production-ready examples for implementing Retrieval-Augmented Generation (RAG) systems with LangChain4j.
## 1. Simple In-Memory RAG
**Scenario**: Quick RAG setup with documents in memory for development/testing.
```java
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
interface DocumentAssistant {
String answer(String question);
}
public class SimpleRagExample {
public static void main(String[] args) {
// Setup
var embeddingStore = new InMemoryEmbeddingStore<TextSegment>();
var embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small")
.build();
var chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4o-mini")
.build();
// Ingest documents
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
ingestor.ingest(Document.from("Spring Boot is a framework for building Java applications with minimal configuration."));
ingestor.ingest(Document.from("Spring Data JPA provides data access abstraction using repositories."));
ingestor.ingest(Document.from("Spring Cloud enables building distributed systems and microservices."));
// Create retriever and AI service
var contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(3)
.minScore(0.7)
.build();
var assistant = AiServices.builder(DocumentAssistant.class)
.chatModel(chatModel)
.contentRetriever(contentRetriever)
.build();
// Query with RAG
System.out.println(assistant.answer("What is Spring Boot?"));
System.out.println(assistant.answer("What does Spring Data JPA do?"));
}
}
```
## 2. Vector Database RAG (Pinecone)
**Scenario**: Production RAG with persistent vector database.
```java
import dev.langchain4j.store.embedding.pinecone.PineconeEmbeddingStore;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.Metadata;
public class PineconeRagExample {
public static void main(String[] args) {
// Production vector store
var embeddingStore = PineconeEmbeddingStore.builder()
.apiKey(System.getenv("PINECONE_API_KEY"))
.index("docs-index")
.namespace("production")
.build();
var embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
// Ingest with metadata
var ingestor = EmbeddingStoreIngestor.builder()
.documentTransformer(doc -> {
doc.metadata().put("source", "documentation");
doc.metadata().put("date", LocalDate.now().toString());
return doc;
})
.documentSplitter(DocumentSplitters.recursive(1000, 200))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
ingestor.ingest(Document.from("Your large document..."));
// Retrieve with filters
var retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.dynamicFilter(query ->
new IsEqualTo("source", "documentation")
)
.build();
}
}
```
## 3. Document Loading and Splitting
**Scenario**: Load documents from various sources and split intelligently.
```java
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.openai.OpenAiTokenCountEstimator;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
public class DocumentProcessingExample {
public static void main(String[] args) {
// Load from filesystem
Path docPath = Paths.get("documents");
List<Document> documents = FileSystemDocumentLoader.load(docPath);
// Smart recursive splitting with token counting
DocumentSplitter splitter = DocumentSplitters.recursive(
500, // Max tokens per segment
50, // Overlap tokens
new OpenAiTokenCountEstimator("gpt-4o-mini")
);
// Process documents
for (Document doc : documents) {
List<TextSegment> segments = splitter.split(doc);
System.out.println("Document split into " + segments.size() + " segments");
segments.forEach(segment -> {
System.out.println("Text: " + segment.text());
System.out.println("Metadata: " + segment.metadata());
});
}
// Alternative: Character-based splitting
DocumentSplitter charSplitter = DocumentSplitters.recursive(
1000, // Max characters
100 // Overlap characters
);
// Alternative: Paragraph-based splitting
DocumentSplitter paraSplitter = DocumentSplitters.byParagraph(500, 50);
}
}
```
## 4. Metadata Filtering in RAG
**Scenario**: Search with complex metadata filters for multi-tenant RAG.
```java
import dev.langchain4j.store.embedding.filter.comparison.*;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
public class MetadataFilteringExample {
public static void main(String[] args) {
var retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
// Single filter: user isolation
.filter(new IsEqualTo("userId", "user123"))
// Complex AND filter
.filter(new And(
new IsEqualTo("department", "engineering"),
new IsEqualTo("status", "active")
))
// OR filter: multiple categories
.filter(new Or(
new IsEqualTo("category", "tutorial"),
new IsEqualTo("category", "guide")
))
// NOT filter: exclude deprecated
.filter(new Not(
new IsEqualTo("deprecated", "true")
))
// Numeric filters
.filter(new IsGreaterThan("relevance", 0.8))
.filter(new IsLessThanOrEqualTo("createdDaysAgo", 30))
// Multiple conditions
.dynamicFilter(query -> {
String userId = extractUserFromQuery(query);
return new And(
new IsEqualTo("userId", userId),
new IsGreaterThan("score", 0.7)
);
})
.build();
}
private static String extractUserFromQuery(Object query) {
// Extract user context
return "user123";
}
}
```
## 5. Document Transformation Pipeline
**Scenario**: Transform documents with custom metadata before ingestion.
```java
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import java.time.LocalDate;
public class DocumentTransformationExample {
public static void main(String[] args) {
var ingestor = EmbeddingStoreIngestor.builder()
// Add metadata to each document
.documentTransformer(doc -> {
doc.metadata().put("ingested_date", LocalDate.now().toString());
doc.metadata().put("source_system", "internal");
doc.metadata().put("version", "1.0");
return doc;
})
// Split documents intelligently
.documentSplitter(DocumentSplitters.recursive(500, 50))
// Transform each segment (e.g., add filename)
.textSegmentTransformer(segment -> {
String fileName = segment.metadata().getString("file_name", "unknown");
String enrichedText = "File: " + fileName + "\n" + segment.text();
return TextSegment.from(enrichedText, segment.metadata());
})
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
// Ingest with tracking
IngestionResult result = ingestor.ingest(document);
System.out.println("Tokens ingested: " + result.tokenUsage().totalTokenCount());
}
}
```
## 6. Hybrid Search (Vector + Full-Text)
**Scenario**: Combine semantic search with keyword search for better recall.
```java
import dev.langchain4j.store.embedding.neo4j.Neo4jEmbeddingStore;
public class HybridSearchExample {
public static void main(String[] args) {
// Configure Neo4j for hybrid search
var embeddingStore = Neo4jEmbeddingStore.builder()
.withBasicAuth("bolt://localhost:7687", "neo4j", "password")
.dimension(1536)
// Enable full-text search
.fullTextIndexName("documents_fulltext")
.autoCreateFullText(true)
// Query for full-text context
.fullTextQuery("Spring OR Boot")
.build();
var retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.build();
// Search combines both vector similarity and full-text keywords
}
}
```
## 7. Advanced RAG with Query Transformation
**Scenario**: Transform user queries before retrieval for better results.
```java
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.query.transformer.CompressingQueryTransformer;
import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator;
import dev.langchain4j.model.cohere.CohereScoringModel;
public class AdvancedRagExample {
public static void main(String[] args) {
// Scoring model for re-ranking
var scoringModel = CohereScoringModel.builder()
.apiKey(System.getenv("COHERE_API_KEY"))
.build();
// Advanced retrieval augmentor
var augmentor = DefaultRetrievalAugmentor.builder()
// Transform query for better context
.queryTransformer(new CompressingQueryTransformer(chatModel))
// Retrieve relevant content
.contentRetriever(EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(10)
.minScore(0.6)
.build())
// Re-rank results by relevance
.contentAggregator(ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.minScore(0.8)
.build())
.build();
// Use with AI Service
var assistant = AiServices.builder(QuestionAnswering.class)
.chatModel(chatModel)
.retrievalAugmentor(augmentor)
.build();
}
}
```
## 8. Multi-User RAG with Isolation
**Scenario**: Per-user vector stores for data isolation.
```java
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import java.util.HashMap;
import java.util.Map;
public class MultiUserRagExample {
private final Map<String, EmbeddingStore<TextSegment>> userStores = new HashMap<>();
public void ingestForUser(String userId, Document document) {
var store = userStores.computeIfAbsent(userId,
k -> new InMemoryEmbeddingStore<>());
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingModel(embeddingModel)
.embeddingStore(store)
.build();
ingestor.ingest(document);
}
public String askQuestion(String userId, String question) {
var store = userStores.get(userId);
var retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(store)
.embeddingModel(embeddingModel)
.maxResults(3)
.build();
var assistant = AiServices.builder(QuestionAnswering.class)
.chatModel(chatModel)
.contentRetriever(retriever)
.build();
return assistant.answer(question);
}
}
```
## 9. Streaming RAG with Content Access
**Scenario**: Stream RAG responses while accessing retrieved content.
```java
import dev.langchain4j.service.TokenStream;
interface StreamingRagAssistant {
TokenStream streamAnswer(String question);
}
public class StreamingRagExample {
public static void main(String[] args) {
var assistant = AiServices.builder(StreamingRagAssistant.class)
.streamingChatModel(streamingModel)
.contentRetriever(contentRetriever)
.build();
assistant.streamAnswer("What is Spring Boot?")
.onRetrieved(contents -> {
System.out.println("=== Retrieved Content ===");
contents.forEach(content ->
System.out.println("Score: " + content.score() +
", Text: " + content.textSegment().text()));
})
.onNext(token -> System.out.print(token))
.onCompleteResponse(response ->
System.out.println("\n=== Complete ==="))
.onError(error -> System.err.println("Error: " + error))
.start();
try {
Thread.sleep(5000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
```
## 10. Batch Document Ingestion
**Scenario**: Efficiently ingest large document collections.
```java
import dev.langchain4j.data.document.Document;
import java.util.List;
import java.util.ArrayList;
public class BatchIngestionExample {
public static void main(String[] args) {
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.documentSplitter(DocumentSplitters.recursive(500, 50))
.build();
// Load batch of documents
List<Document> documents = new ArrayList<>();
for (int i = 1; i <= 100; i++) {
documents.add(Document.from("Content " + i));
}
// Ingest all at once
IngestionResult result = ingestor.ingest(documents);
System.out.println("Documents ingested: " + documents.size());
System.out.println("Total tokens: " + result.tokenUsage().totalTokenCount());
// Track progress
long tokensPerDoc = result.tokenUsage().totalTokenCount() / documents.size();
System.out.println("Average tokens per document: " + tokensPerDoc);
}
}
```
## Performance Considerations
1. **Batch Processing**: Ingest documents in batches to optimize embedding API calls
2. **Document Splitting**: Use recursive splitting for better semantic chunks
3. **Metadata**: Add minimal metadata to reduce embedding overhead
4. **Vector DB**: Choose appropriate vector DB based on scale (in-memory for dev, Pinecone/Weaviate for prod)
5. **Similarity Threshold**: Adjust minScore based on use case (0.7-0.85 typical)
6. **Max Results**: Return top 3-5 results unless specific needs require more
7. **Caching**: Cache frequently retrieved content to reduce API calls
8. **Async Ingestion**: Use async ingestion for large datasets
9. **Monitoring**: Track token usage and retrieval quality metrics
10. **Testing**: Use in-memory store for unit tests, external DB for integration tests

View File

@@ -0,0 +1,506 @@
# LangChain4j RAG Implementation - API References
Complete API reference for implementing RAG systems with LangChain4j.
## Document Loading
### Document Loaders
**FileSystemDocumentLoader**: Load from filesystem.
```java
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import java.nio.file.Path;
List<Document> documents = FileSystemDocumentLoader.load("documents");
List<Document> single = FileSystemDocumentLoader.load("document.pdf");
```
**ClassPathDocumentLoader**: Load from classpath resources.
```java
List<Document> resources = ClassPathDocumentLoader.load("documents");
```
**UrlDocumentLoader**: Load from web URLs.
```java
Document webDoc = UrlDocumentLoader.load("https://example.com/doc.html");
```
## Document Splitting
### DocumentSplitter Interface
```java
interface DocumentSplitter {
List<TextSegment> split(Document document);
List<TextSegment> splitAll(Collection<Document> documents);
}
```
### DocumentSplitters Factory
**Recursive Split**: Smart recursive splitting by paragraphs, sentences, words.
```java
DocumentSplitter splitter = DocumentSplitters.recursive(
500, // Max segment size (tokens or characters)
50 // Overlap size
);
// With token counting
DocumentSplitter splitter = DocumentSplitters.recursive(
500,
50,
new OpenAiTokenCountEstimator("gpt-4o-mini")
);
```
**Paragraph Split**: Split by paragraphs.
```java
DocumentSplitter splitter = DocumentSplitters.byParagraph(500, 50);
```
**Sentence Split**: Split by sentences.
```java
DocumentSplitter splitter = DocumentSplitters.bySentence(500, 50);
```
**Line Split**: Split by lines.
```java
DocumentSplitter splitter = DocumentSplitters.byLine(500, 50);
```
## Embedding Models
### EmbeddingModel Interface
```java
public interface EmbeddingModel {
// Embed single text
Response<Embedding> embed(String text);
Response<Embedding> embed(TextSegment textSegment);
// Batch embedding
Response<List<Embedding>> embedAll(List<TextSegment> textSegments);
// Model dimension
int dimension();
}
```
### OpenAI Embedding Model
```java
EmbeddingModel model = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small") // or text-embedding-3-large
.dimensions(512) // Optional: reduce dimensions
.timeout(Duration.ofSeconds(30))
.logRequests(true)
.logResponses(true)
.build();
```
### Other Embedding Models
```java
// Google Vertex AI
EmbeddingModel google = VertexAiEmbeddingModel.builder()
.project("PROJECT_ID")
.location("us-central1")
.modelName("textembedding-gecko")
.build();
// Ollama (local)
EmbeddingModel ollama = OllamaEmbeddingModel.builder()
.baseUrl("http://localhost:11434")
.modelName("all-minilm")
.build();
// AllMiniLmL6V2 (offline)
EmbeddingModel offline = new AllMiniLmL6V2EmbeddingModel();
```
## Vector Stores (EmbeddingStore)
### EmbeddingStore Interface
```java
public interface EmbeddingStore<Embedded> {
// Add embeddings
String add(Embedding embedding);
String add(String id, Embedding embedding);
String add(Embedding embedding, Embedded embedded);
List<String> addAll(List<Embedding> embeddings);
List<String> addAll(List<Embedding> embeddings, List<Embedded> embeddeds);
List<String> addAll(List<String> ids, List<Embedding> embeddings, List<Embedded> embeddeds);
// Search embeddings
EmbeddingSearchResult<Embedded> search(EmbeddingSearchRequest request);
// Remove embeddings
void remove(String id);
void removeAll(Collection<String> ids);
void removeAll(Filter filter);
void removeAll();
}
```
### In-Memory Store
```java
EmbeddingStore<TextSegment> store = new InMemoryEmbeddingStore<>();
// Merge stores
InMemoryEmbeddingStore<TextSegment> merged = InMemoryEmbeddingStore.merge(
store1, store2, store3
);
```
### Pinecone
```java
EmbeddingStore<TextSegment> store = PineconeEmbeddingStore.builder()
.apiKey(System.getenv("PINECONE_API_KEY"))
.index("my-index")
.namespace("production")
.environment("gcp-starter") // or "aws-us-east-1"
.build();
```
### Weaviate
```java
EmbeddingStore<TextSegment> store = WeaviateEmbeddingStore.builder()
.host("localhost")
.port(8080)
.scheme("http")
.collectionName("Documents")
.build();
```
### Qdrant
```java
EmbeddingStore<TextSegment> store = QdrantEmbeddingStore.builder()
.host("localhost")
.port(6333)
.collectionName("documents")
.build();
```
### Chroma
```java
EmbeddingStore<TextSegment> store = ChromaEmbeddingStore.builder()
.baseUrl("http://localhost:8000")
.collectionName("my-collection")
.build();
```
### Neo4j
```java
EmbeddingStore<TextSegment> store = Neo4jEmbeddingStore.builder()
.withBasicAuth("bolt://localhost:7687", "neo4j", "password")
.dimension(1536)
.label("Document")
.build();
```
### MongoDB Atlas
```java
EmbeddingStore<TextSegment> store = MongoDbEmbeddingStore.builder()
.databaseName("search")
.collectionName("documents")
.indexName("vector_index")
.createIndex(true)
.fromClient(mongoClient)
.build();
```
### PostgreSQL (pgvector)
```java
EmbeddingStore<TextSegment> store = PgVectorEmbeddingStore.builder()
.host("localhost")
.port(5432)
.database("embeddings")
.user("postgres")
.password("password")
.table("embeddings")
.createTableIfNotExists(true)
.build();
```
### Milvus
```java
EmbeddingStore<TextSegment> store = MilvusEmbeddingStore.builder()
.host("localhost")
.port(19530)
.collectionName("documents")
.dimension(1536)
.build();
```
## Document Ingestion
### EmbeddingStoreIngestor
```java
public class EmbeddingStoreIngestor {
public static Builder builder();
public IngestionResult ingest(Document document);
public IngestionResult ingest(Document... documents);
public IngestionResult ingest(Collection<Document> documents);
}
```
### Building an Ingestor
```java
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
// Document transformation
.documentTransformer(doc -> {
doc.metadata().put("source", "manual");
return doc;
})
// Document splitting strategy
.documentSplitter(DocumentSplitters.recursive(500, 50))
// Text segment transformation
.textSegmentTransformer(segment -> {
String enhanced = "Category: Spring\n" + segment.text();
return TextSegment.from(enhanced, segment.metadata());
})
// Embedding model (required)
.embeddingModel(embeddingModel)
// Embedding store (required)
.embeddingStore(embeddingStore)
.build();
```
### IngestionResult
```java
IngestionResult result = ingestor.ingest(documents);
// Access results
TokenUsage usage = result.tokenUsage();
long totalTokens = usage.totalTokenCount();
long inputTokens = usage.inputTokenCount();
```
## Content Retrieval
### EmbeddingSearchRequest
```java
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding) // Required
.maxResults(5) // Default: 3
.minScore(0.7) // Threshold 0-1
.filter(new IsEqualTo("category", "tutorial"))
.build();
```
### EmbeddingSearchResult
```java
EmbeddingSearchResult<TextSegment> result = store.search(request);
List<EmbeddingMatch<TextSegment>> matches = result.matches();
for (EmbeddingMatch<TextSegment> match : matches) {
double score = match.score(); // Relevance 0-1
TextSegment segment = match.embedded(); // Retrieved content
String id = match.embeddingId(); // Store ID
}
```
### ContentRetriever Interface
```java
public interface ContentRetriever {
Content retrieve(Query query);
List<Content> retrieveAll(List<Query> queries);
}
```
### EmbeddingStoreContentRetriever
```java
ContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
// Static configuration
.maxResults(5)
.minScore(0.7)
// Dynamic configuration per query
.dynamicMaxResults(query -> 10)
.dynamicMinScore(query -> 0.8)
.dynamicFilter(query ->
new IsEqualTo("userId", extractUserId(query))
)
.build();
```
## Advanced RAG
### RetrievalAugmentor
```java
public interface RetrievalAugmentor {
AugmentationResult augment(UserMessage message);
AugmentationResult augmentAll(List<UserMessage> messages);
}
```
### DefaultRetrievalAugmentor
```java
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
// Query transformation
.queryTransformer(new CompressingQueryTransformer(chatModel))
// Content retrieval
.contentRetriever(contentRetriever)
// Content aggregation and re-ranking
.contentAggregator(ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.minScore(0.8)
.build())
// Parallelization
.executor(customExecutor)
.build();
```
### Use with AI Services
```java
Assistant assistant = AiServices.builder(Assistant.class)
.chatModel(chatModel)
.retrievalAugmentor(augmentor)
.build();
```
## Metadata and Filtering
### Metadata Object
```java
// Create from map
Metadata meta = Metadata.from(Map.of(
"userId", "user123",
"category", "tutorial",
"score", 0.95
));
// Add entries
meta.put("status", "active");
meta.put("version", 2);
// Retrieve entries
String userId = meta.getString("userId");
int version = meta.getInt("version");
double score = meta.getDouble("score");
// Check existence
boolean has = meta.containsKey("userId");
// Remove entry
meta.remove("userId");
// Merge
Metadata other = Metadata.from(Map.of("source", "db"));
meta.merge(other);
```
### Filter Operations
```java
import dev.langchain4j.store.embedding.filter.comparison.*;
import dev.langchain4j.store.embedding.filter.logical.*;
// Equality
Filter filter = new IsEqualTo("status", "active");
Filter filter = new IsNotEqualTo("deprecated", "true");
// Comparison
Filter filter = new IsGreaterThan("score", 0.8);
Filter filter = new IsLessThanOrEqualTo("daysOld", 30);
Filter filter = new IsGreaterThanOrEqualTo("priority", 5);
Filter filter = new IsLessThan("errorRate", 0.01);
// Membership
Filter filter = new IsIn("category", Arrays.asList("tech", "guide"));
Filter filter = new IsNotIn("status", Arrays.asList("archived"));
// String operations
Filter filter = new ContainsString("content", "Spring");
// Logical operations
Filter filter = new And(
new IsEqualTo("userId", "123"),
new IsGreaterThan("score", 0.7)
);
Filter filter = new Or(
new IsEqualTo("type", "doc"),
new IsEqualTo("type", "guide")
);
Filter filter = new Not(new IsEqualTo("archived", "true"));
```
## TextSegment
### Creating TextSegments
```java
// Text only
TextSegment segment = TextSegment.from("This is the content");
// With metadata
Metadata metadata = Metadata.from(Map.of("source", "docs"));
TextSegment segment = TextSegment.from("Content", metadata);
// Accessing
String text = segment.text();
Metadata meta = segment.metadata();
```
## Best Practices
1. **Chunk Size**: Use 300-500 tokens per chunk for optimal balance
2. **Overlap**: Use 10-50 token overlap for semantic continuity
3. **Metadata**: Include source and timestamp for traceability
4. **Batch Processing**: Ingest documents in batches when possible
5. **Similarity Threshold**: Adjust minScore (0.7-0.85) based on precision/recall needs
6. **Vector DB Selection**: In-memory for dev/test, Pinecone/Qdrant for production
7. **Filtering**: Pre-filter by metadata to reduce search space
8. **Re-ranking**: Use scoring models for better relevance in production
9. **Monitoring**: Track retrieval quality metrics
10. **Testing**: Use small in-memory stores for unit tests
## Performance Tips
- Use recursive splitting for semantic coherence
- Enable batch processing for large datasets
- Use dynamic max results based on query complexity
- Cache embedding model for frequently accessed content
- Implement async ingestion for large document collections
- Monitor token usage for cost optimization
- Use appropriate vector DB indexes for scale

View File

@@ -0,0 +1,130 @@
---
name: langchain4j-spring-boot-integration
description: Integration patterns for LangChain4j with Spring Boot. Auto-configuration, dependency injection, and Spring ecosystem integration. Use when embedding LangChain4j into Spring Boot applications.
category: ai-development
tags: [langchain4j, spring-boot, ai, llm, rag, chatbot, integration, configuration, java]
version: 1.1.0
allowed-tools: Read, Write, Bash, Grep
---
# LangChain4j Spring Boot Integration
To accomplish integration of LangChain4j with Spring Boot applications, follow this comprehensive guidance covering auto-configuration, declarative AI Services, chat models, embedding stores, and production-ready patterns for building AI-powered applications.
## When to Use
To accomplish integration of LangChain4j with Spring Boot when:
- Integrating LangChain4j into existing Spring Boot applications
- Building AI-powered microservices with Spring Boot
- Setting up auto-configuration for AI models and services
- Creating declarative AI Services with Spring dependency injection
- Configuring multiple AI providers (OpenAI, Azure, Ollama, etc.)
- Implementing RAG systems with Spring Boot
- Setting up observability and monitoring for AI components
- Building production-ready AI applications with Spring Boot
## Overview
LangChain4j Spring Boot integration provides declarative AI Services through Spring Boot starters, enabling automatic configuration of AI components based on properties. The integration combines the power of Spring dependency injection with LangChain4j's AI capabilities, allowing developers to create AI-powered applications using interface-based definitions with annotations.
## Core Concepts
To accomplish basic setup of LangChain4j with Spring Boot:
**Add Dependencies:**
```xml
<!-- Core LangChain4j -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring-boot-starter</artifactId>
<version>1.8.0</version> // Use latest version
</dependency>
<!-- OpenAI Integration -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai-spring-boot-starter</artifactId>
<version>1.8.0</version>
</dependency>
```
**Configure Properties:**
```properties
# application.properties
langchain4j.open-ai.chat-model.api-key=${OPENAI_API_KEY}
langchain4j.open-ai.chat-model.model-name=gpt-4o-mini
langchain4j.open-ai.chat-model.temperature=0.7
```
**Create Declarative AI Service:**
```java
@AiService
interface CustomerSupportAssistant {
@SystemMessage("You are a helpful customer support agent for TechCorp.")
String handleInquiry(String customerMessage);
}
```
## Configuration
To accomplish Spring Boot configuration for LangChain4j:
**Property-Based Configuration:** Configure AI models through application properties for different providers.
**Manual Bean Configuration:** For advanced configurations, define beans manually using @Configuration.
**Multiple Providers:** Support for multiple AI providers with explicit wiring when needed.
## Declarative AI Services
To accomplish interface-based AI service definitions:
**Basic AI Service:** Create interfaces with @AiService annotation and define methods with message templates.
**Streaming AI Service:** Implement streaming responses using Reactor or Project Reactor.
**Explicit Wiring:** Specify which model to use with @AiService(wiringMode = EXPLICIT, chatModel = "modelBeanName").
## RAG Implementation
To accomplish RAG system implementation:
**Embedding Stores:** Configure various embedding stores (PostgreSQL/pgvector, Neo4j, Pinecone, etc.).
**Document Ingestion:** Implement document processing and embedding generation.
**Content Retrieval:** Set up content retrieval mechanisms for knowledge augmentation.
## Tool Integration
To accomplish AI tool integration:
**Spring Component Tools:** Define tools as Spring components with @Tool annotations.
**Database Access Tools:** Create tools for database operations and business logic.
**Tool Registration:** Automatically register tools with AI services.
## Examples
To understand implementation patterns, refer to the comprehensive examples in [references/examples.md](references/examples.md).
## Best Practices
To accomplish production-ready AI applications:
- **Use Property-Based Configuration:** External configuration over hardcoded values
- **Implement Proper Error Handling:** Graceful degradation and meaningful error responses
- **Use Profiles for Different Environments:** Separate configurations for development, testing, and production
- **Implement Proper Logging:** Debug AI service calls and monitor performance
- **Secure API Keys:** Use environment variables and never commit to version control
- **Handle Failures:** Implement retry mechanisms and fallback strategies
- **Monitor Performance:** Add metrics and health checks for observability
## References
For detailed API references, advanced configurations, and additional patterns, refer to:
- [API Reference](references/references.md) - Complete API reference and configurations
- [Examples](references/examples.md) - Comprehensive implementation examples
- [Configuration Guide](references/configuration.md) - Deep dive into configuration options

Some files were not shown because too many files have changed in this diff Show More