Initial commit
This commit is contained in:
194
skills/ai/chunking-strategy/SKILL.md
Normal file
194
skills/ai/chunking-strategy/SKILL.md
Normal file
@@ -0,0 +1,194 @@
|
||||
---
|
||||
name: chunking-strategy
|
||||
description: Implement optimal chunking strategies in RAG systems and document processing pipelines. Use when building retrieval-augmented generation systems, vector databases, or processing large documents that require breaking into semantically meaningful segments for embeddings and search.
|
||||
allowed-tools: Read, Write, Bash
|
||||
category: artificial-intelligence
|
||||
tags: [rag, chunking, vector-search, embeddings, document-processing]
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# Chunking Strategy for RAG Systems
|
||||
|
||||
## Overview
|
||||
|
||||
Implement optimal chunking strategies for Retrieval-Augmented Generation (RAG) systems and document processing pipelines. This skill provides a comprehensive framework for breaking large documents into smaller, semantically meaningful segments that preserve context while enabling efficient retrieval and search.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when building RAG systems, optimizing vector search performance, implementing document processing pipelines, handling multi-modal content, or performance-tuning existing RAG systems with poor retrieval quality.
|
||||
|
||||
## Instructions
|
||||
|
||||
### Choose Chunking Strategy
|
||||
|
||||
Select appropriate chunking strategy based on document type and use case:
|
||||
|
||||
1. **Fixed-Size Chunking** (Level 1)
|
||||
- Use for simple documents without clear structure
|
||||
- Start with 512 tokens and 10-20% overlap
|
||||
- Adjust size based on query type: 256 for factoid, 1024 for analytical
|
||||
|
||||
2. **Recursive Character Chunking** (Level 2)
|
||||
- Use for documents with clear structural boundaries
|
||||
- Implement hierarchical separators: paragraphs → sentences → words
|
||||
- Customize separators for document types (HTML, Markdown)
|
||||
|
||||
3. **Structure-Aware Chunking** (Level 3)
|
||||
- Use for structured documents (Markdown, code, tables, PDFs)
|
||||
- Preserve semantic units: functions, sections, table blocks
|
||||
- Validate structure preservation post-splitting
|
||||
|
||||
4. **Semantic Chunking** (Level 4)
|
||||
- Use for complex documents with thematic shifts
|
||||
- Implement embedding-based boundary detection
|
||||
- Configure similarity threshold (0.8) and buffer size (3-5 sentences)
|
||||
|
||||
5. **Advanced Methods** (Level 5)
|
||||
- Use Late Chunking for long-context embedding models
|
||||
- Apply Contextual Retrieval for high-precision requirements
|
||||
- Monitor computational costs vs. retrieval improvements
|
||||
|
||||
Reference detailed strategy implementations in [references/strategies.md](references/strategies.md).
|
||||
|
||||
### Implement Chunking Pipeline
|
||||
|
||||
Follow these steps to implement effective chunking:
|
||||
|
||||
1. **Pre-process documents**
|
||||
- Analyze document structure and content types
|
||||
- Identify multi-modal content (tables, images, code)
|
||||
- Assess information density and complexity
|
||||
|
||||
2. **Select strategy parameters**
|
||||
- Choose chunk size based on embedding model context window
|
||||
- Set overlap percentage (10-20% for most cases)
|
||||
- Configure strategy-specific parameters
|
||||
|
||||
3. **Process and validate**
|
||||
- Apply chosen chunking strategy
|
||||
- Validate semantic coherence of chunks
|
||||
- Test with representative documents
|
||||
|
||||
4. **Evaluate and iterate**
|
||||
- Measure retrieval precision and recall
|
||||
- Monitor processing latency and resource usage
|
||||
- Optimize based on specific use case requirements
|
||||
|
||||
Reference detailed implementation guidelines in [references/implementation.md](references/implementation.md).
|
||||
|
||||
### Evaluate Performance
|
||||
|
||||
Use these metrics to evaluate chunking effectiveness:
|
||||
|
||||
- **Retrieval Precision**: Fraction of retrieved chunks that are relevant
|
||||
- **Retrieval Recall**: Fraction of relevant chunks that are retrieved
|
||||
- **End-to-End Accuracy**: Quality of final RAG responses
|
||||
- **Processing Time**: Latency impact on overall system
|
||||
- **Resource Usage**: Memory and computational costs
|
||||
|
||||
Reference detailed evaluation framework in [references/evaluation.md](references/evaluation.md).
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Fixed-Size Chunking
|
||||
|
||||
```python
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
# Configure for factoid queries
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=256,
|
||||
chunk_overlap=25,
|
||||
length_function=len
|
||||
)
|
||||
|
||||
chunks = splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
### Structure-Aware Code Chunking
|
||||
|
||||
```python
|
||||
def chunk_python_code(code):
|
||||
"""Split Python code into semantic chunks"""
|
||||
import ast
|
||||
|
||||
tree = ast.parse(code)
|
||||
chunks = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
|
||||
chunks.append(ast.get_source_segment(code, node))
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
### Semantic Chunking with Embeddings
|
||||
|
||||
```python
|
||||
def semantic_chunk(text, similarity_threshold=0.8):
|
||||
"""Chunk text based on semantic boundaries"""
|
||||
sentences = split_into_sentences(text)
|
||||
embeddings = generate_embeddings(sentences)
|
||||
|
||||
chunks = []
|
||||
current_chunk = [sentences[0]]
|
||||
|
||||
for i in range(1, len(sentences)):
|
||||
similarity = cosine_similarity(embeddings[i-1], embeddings[i])
|
||||
|
||||
if similarity < similarity_threshold:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
current_chunk = [sentences[i]]
|
||||
else:
|
||||
current_chunk.append(sentences[i])
|
||||
|
||||
chunks.append(" ".join(current_chunk))
|
||||
return chunks
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Core Principles
|
||||
- Balance context preservation with retrieval precision
|
||||
- Maintain semantic coherence within chunks
|
||||
- Optimize for embedding model constraints
|
||||
- Preserve document structure when beneficial
|
||||
|
||||
### Implementation Guidelines
|
||||
- Start simple with fixed-size chunking (512 tokens, 10-20% overlap)
|
||||
- Test thoroughly with representative documents
|
||||
- Monitor both accuracy metrics and computational costs
|
||||
- Iterate based on specific document characteristics
|
||||
|
||||
### Common Pitfalls to Avoid
|
||||
- Over-chunking: Creating too many small, context-poor chunks
|
||||
- Under-chunking: Missing relevant information due to oversized chunks
|
||||
- Ignoring document structure and semantic boundaries
|
||||
- Using one-size-fits-all approach for diverse content types
|
||||
- Neglecting overlap for boundary-crossing information
|
||||
|
||||
## Constraints
|
||||
|
||||
### Resource Considerations
|
||||
- Semantic and contextual methods require significant computational resources
|
||||
- Late chunking needs long-context embedding models
|
||||
- Complex strategies increase processing latency
|
||||
- Monitor memory usage for large document processing
|
||||
|
||||
### Quality Requirements
|
||||
- Validate chunk semantic coherence post-processing
|
||||
- Test with domain-specific documents before deployment
|
||||
- Ensure chunks maintain standalone meaning where possible
|
||||
- Implement proper error handling for edge cases
|
||||
|
||||
## References
|
||||
|
||||
Reference detailed documentation in the [references/](references/) folder:
|
||||
- [strategies.md](references/strategies.md) - Detailed strategy implementations
|
||||
- [implementation.md](references/implementation.md) - Complete implementation guidelines
|
||||
- [evaluation.md](references/evaluation.md) - Performance evaluation framework
|
||||
- [tools.md](references/tools.md) - Recommended libraries and frameworks
|
||||
- [research.md](references/research.md) - Key research papers and findings
|
||||
- [advanced-strategies.md](references/advanced-strategies.md) - 11 comprehensive chunking methods
|
||||
- [semantic-methods.md](references/semantic-methods.md) - Semantic and contextual approaches
|
||||
- [visualization-tools.md](references/visualization-tools.md) - Evaluation and visualization tools
|
||||
1358
skills/ai/chunking-strategy/references/advanced-strategies.md
Normal file
1358
skills/ai/chunking-strategy/references/advanced-strategies.md
Normal file
File diff suppressed because it is too large
Load Diff
904
skills/ai/chunking-strategy/references/evaluation.md
Normal file
904
skills/ai/chunking-strategy/references/evaluation.md
Normal file
@@ -0,0 +1,904 @@
|
||||
# Performance Evaluation Framework
|
||||
|
||||
This document provides comprehensive methodologies for evaluating chunking strategy performance and effectiveness.
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
### Core Retrieval Metrics
|
||||
|
||||
#### Retrieval Precision
|
||||
Measures the fraction of retrieved chunks that are relevant to the query.
|
||||
|
||||
```python
|
||||
def calculate_precision(retrieved_chunks: List[Dict], relevant_chunks: List[Dict]) -> float:
|
||||
"""
|
||||
Calculate retrieval precision
|
||||
Precision = |Relevant ∩ Retrieved| / |Retrieved|
|
||||
"""
|
||||
retrieved_ids = {chunk.get('id') for chunk in retrieved_chunks}
|
||||
relevant_ids = {chunk.get('id') for chunk in relevant_chunks}
|
||||
|
||||
intersection = retrieved_ids & relevant_ids
|
||||
|
||||
if not retrieved_ids:
|
||||
return 0.0
|
||||
|
||||
return len(intersection) / len(retrieved_ids)
|
||||
```
|
||||
|
||||
#### Retrieval Recall
|
||||
Measures the fraction of relevant chunks that are successfully retrieved.
|
||||
|
||||
```python
|
||||
def calculate_recall(retrieved_chunks: List[Dict], relevant_chunks: List[Dict]) -> float:
|
||||
"""
|
||||
Calculate retrieval recall
|
||||
Recall = |Relevant ∩ Retrieved| / |Relevant|
|
||||
"""
|
||||
retrieved_ids = {chunk.get('id') for chunk in retrieved_chunks}
|
||||
relevant_ids = {chunk.get('id') for chunk in relevant_chunks}
|
||||
|
||||
intersection = retrieved_ids & relevant_ids
|
||||
|
||||
if not relevant_ids:
|
||||
return 0.0
|
||||
|
||||
return len(intersection) / len(relevant_ids)
|
||||
```
|
||||
|
||||
#### F1-Score
|
||||
Harmonic mean of precision and recall.
|
||||
|
||||
```python
|
||||
def calculate_f1_score(precision: float, recall: float) -> float:
|
||||
"""
|
||||
Calculate F1-score
|
||||
F1 = 2 * (Precision * Recall) / (Precision + Recall)
|
||||
"""
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
```
|
||||
|
||||
### Mean Reciprocal Rank (MRR)
|
||||
Measures the rank of the first relevant result.
|
||||
|
||||
```python
|
||||
def calculate_mrr(queries: List[Dict], results: List[List[Dict]]) -> float:
|
||||
"""
|
||||
Calculate Mean Reciprocal Rank
|
||||
"""
|
||||
reciprocal_ranks = []
|
||||
|
||||
for query, query_results in zip(queries, results):
|
||||
relevant_found = False
|
||||
|
||||
for rank, result in enumerate(query_results, 1):
|
||||
if result.get('is_relevant', False):
|
||||
reciprocal_ranks.append(1.0 / rank)
|
||||
relevant_found = True
|
||||
break
|
||||
|
||||
if not relevant_found:
|
||||
reciprocal_ranks.append(0.0)
|
||||
|
||||
return sum(reciprocal_ranks) / len(reciprocal_ranks)
|
||||
```
|
||||
|
||||
### Mean Average Precision (MAP)
|
||||
Considers both precision and the ranking of relevant documents.
|
||||
|
||||
```python
|
||||
def calculate_average_precision(retrieved_chunks: List[Dict], relevant_chunks: List[Dict]) -> float:
|
||||
"""
|
||||
Calculate Average Precision for a single query
|
||||
"""
|
||||
retrieved_ids = {chunk.get('id') for chunk in retrieved_chunks}
|
||||
relevant_ids = {chunk.get('id') for chunk in relevant_chunks}
|
||||
|
||||
if not relevant_ids:
|
||||
return 0.0
|
||||
|
||||
precisions = []
|
||||
relevant_count = 0
|
||||
|
||||
for rank, chunk in enumerate(retrieved_chunks, 1):
|
||||
if chunk.get('id') in relevant_ids:
|
||||
relevant_count += 1
|
||||
precision_at_rank = relevant_count / rank
|
||||
precisions.append(precision_at_rank)
|
||||
|
||||
return sum(precisions) / len(relevant_ids) if relevant_ids else 0.0
|
||||
|
||||
def calculate_map(queries: List[Dict], results: List[List[Dict]]) -> float:
|
||||
"""
|
||||
Calculate Mean Average Precision across multiple queries
|
||||
"""
|
||||
average_precisions = []
|
||||
|
||||
for query, query_results in zip(queries, results):
|
||||
ap = calculate_average_precision(query_results, query.get('relevant_chunks', []))
|
||||
average_precisions.append(ap)
|
||||
|
||||
return sum(average_precisions) / len(average_precisions) if average_precisions else 0.0
|
||||
```
|
||||
|
||||
### Normalized Discounted Cumulative Gain (NDCG)
|
||||
Measures ranking quality with emphasis on highly relevant results.
|
||||
|
||||
```python
|
||||
def calculate_dcg(retrieved_chunks: List[Dict]) -> float:
|
||||
"""
|
||||
Calculate Discounted Cumulative Gain
|
||||
"""
|
||||
dcg = 0.0
|
||||
|
||||
for rank, chunk in enumerate(retrieved_chunks, 1):
|
||||
relevance = chunk.get('relevance_score', 0)
|
||||
dcg += relevance / np.log2(rank + 1)
|
||||
|
||||
return dcg
|
||||
|
||||
def calculate_ndcg(retrieved_chunks: List[Dict], ideal_chunks: List[Dict]) -> float:
|
||||
"""
|
||||
Calculate Normalized Discounted Cumulative Gain
|
||||
"""
|
||||
dcg = calculate_dcg(retrieved_chunks)
|
||||
idcg = calculate_dcg(ideal_chunks)
|
||||
|
||||
if idcg == 0:
|
||||
return 0.0
|
||||
|
||||
return dcg / idcg
|
||||
```
|
||||
|
||||
## End-to-End RAG Evaluation
|
||||
|
||||
### Answer Quality Metrics
|
||||
|
||||
#### Factual Consistency
|
||||
Measures how well the generated answer aligns with retrieved chunks.
|
||||
|
||||
```python
|
||||
import spacy
|
||||
from transformers import pipeline
|
||||
|
||||
class FactualConsistencyEvaluator:
|
||||
def __init__(self):
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
self.nli_pipeline = pipeline("text-classification",
|
||||
model="roberta-large-mnli")
|
||||
|
||||
def evaluate_consistency(self, answer: str, retrieved_chunks: List[str]) -> float:
|
||||
"""
|
||||
Evaluate factual consistency between answer and retrieved context
|
||||
"""
|
||||
if not retrieved_chunks:
|
||||
return 0.0
|
||||
|
||||
# Combine retrieved chunks as context
|
||||
context = " ".join(retrieved_chunks[:3]) # Use top 3 chunks
|
||||
|
||||
# Use Natural Language Inference to check consistency
|
||||
result = self.nli_pipeline(f"premise: {context} hypothesis: {answer}")
|
||||
|
||||
# Extract consistency score (entailment probability)
|
||||
for item in result:
|
||||
if item['label'] == 'ENTAILMENT':
|
||||
return item['score']
|
||||
elif item['label'] == 'CONTRADICTION':
|
||||
return 1.0 - item['score']
|
||||
|
||||
return 0.5 # Neutral if NLI is inconclusive
|
||||
```
|
||||
|
||||
#### Answer Completeness
|
||||
Measures how completely the answer addresses the user's query.
|
||||
|
||||
```python
|
||||
def evaluate_completeness(answer: str, query: str, reference_answer: str = None) -> float:
|
||||
"""
|
||||
Evaluate answer completeness
|
||||
"""
|
||||
# Extract key entities from query
|
||||
query_entities = extract_entities(query)
|
||||
answer_entities = extract_entities(answer)
|
||||
|
||||
# Calculate entity coverage
|
||||
if not query_entities:
|
||||
return 0.5 # Neutral if no entities in query
|
||||
|
||||
covered_entities = query_entities & answer_entities
|
||||
entity_coverage = len(covered_entities) / len(query_entities)
|
||||
|
||||
# If reference answer is available, compare against it
|
||||
if reference_answer:
|
||||
reference_entities = extract_entities(reference_answer)
|
||||
answer_reference_overlap = len(answer_entities & reference_entities) / max(len(reference_entities), 1)
|
||||
return (entity_coverage + answer_reference_overlap) / 2
|
||||
|
||||
return entity_coverage
|
||||
|
||||
def extract_entities(text: str) -> set:
|
||||
"""
|
||||
Extract named entities from text (simplified)
|
||||
"""
|
||||
# This would use a proper NER model in practice
|
||||
import re
|
||||
|
||||
# Simple noun phrase extraction as placeholder
|
||||
noun_phrases = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
|
||||
return set(noun_phrases)
|
||||
```
|
||||
|
||||
#### Response Relevance
|
||||
Measures how relevant the answer is to the original query.
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class RelevanceEvaluator:
|
||||
def __init__(self, model_name="all-MiniLM-L6-v2"):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
|
||||
def evaluate_relevance(self, query: str, answer: str) -> float:
|
||||
"""
|
||||
Evaluate semantic relevance between query and answer
|
||||
"""
|
||||
# Generate embeddings
|
||||
query_embedding = self.model.encode([query])
|
||||
answer_embedding = self.model.encode([answer])
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = cosine_similarity(query_embedding, answer_embedding)[0][0]
|
||||
|
||||
return float(similarity)
|
||||
```
|
||||
|
||||
## Performance Metrics
|
||||
|
||||
### Processing Time
|
||||
|
||||
```python
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict
|
||||
|
||||
@dataclass
|
||||
class PerformanceMetrics:
|
||||
total_time: float
|
||||
chunking_time: float
|
||||
embedding_time: float
|
||||
search_time: float
|
||||
generation_time: float
|
||||
throughput: float # documents per second
|
||||
|
||||
class PerformanceProfiler:
|
||||
def __init__(self):
|
||||
self.timings = {}
|
||||
self.start_times = {}
|
||||
|
||||
def start_timer(self, operation: str):
|
||||
self.start_times[operation] = time.time()
|
||||
|
||||
def end_timer(self, operation: str):
|
||||
if operation in self.start_times:
|
||||
duration = time.time() - self.start_times[operation]
|
||||
if operation not in self.timings:
|
||||
self.timings[operation] = []
|
||||
self.timings[operation].append(duration)
|
||||
return duration
|
||||
return 0.0
|
||||
|
||||
def get_performance_metrics(self, document_count: int) -> PerformanceMetrics:
|
||||
total_time = sum(sum(times) for times in self.timings.values())
|
||||
|
||||
return PerformanceMetrics(
|
||||
total_time=total_time,
|
||||
chunking_time=sum(self.timings.get('chunking', [0])),
|
||||
embedding_time=sum(self.timings.get('embedding', [0])),
|
||||
search_time=sum(self.timings.get('search', [0])),
|
||||
generation_time=sum(self.timings.get('generation', [0])),
|
||||
throughput=document_count / total_time if total_time > 0 else 0
|
||||
)
|
||||
```
|
||||
|
||||
### Memory Usage
|
||||
|
||||
```python
|
||||
import psutil
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
class MemoryProfiler:
|
||||
def __init__(self):
|
||||
self.process = psutil.Process(os.getpid())
|
||||
self.memory_snapshots = []
|
||||
|
||||
def take_memory_snapshot(self, label: str):
|
||||
"""Take a snapshot of current memory usage"""
|
||||
memory_info = self.process.memory_info()
|
||||
memory_mb = memory_info.rss / 1024 / 1024 # Convert to MB
|
||||
|
||||
self.memory_snapshots.append({
|
||||
'label': label,
|
||||
'memory_mb': memory_mb,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
def get_peak_memory_usage(self) -> float:
|
||||
"""Get peak memory usage in MB"""
|
||||
if not self.memory_snapshots:
|
||||
return 0.0
|
||||
return max(snapshot['memory_mb'] for snapshot in self.memory_snapshots)
|
||||
|
||||
def get_memory_usage_by_operation(self) -> Dict[str, float]:
|
||||
"""Get memory usage breakdown by operation"""
|
||||
if not self.memory_snapshots:
|
||||
return {}
|
||||
|
||||
memory_by_op = {}
|
||||
for i in range(1, len(self.memory_snapshots)):
|
||||
prev_snapshot = self.memory_snapshots[i-1]
|
||||
curr_snapshot = self.memory_snapshots[i]
|
||||
|
||||
operation = curr_snapshot['label']
|
||||
memory_delta = curr_snapshot['memory_mb'] - prev_snapshot['memory_mb']
|
||||
|
||||
if operation not in memory_by_op:
|
||||
memory_by_op[operation] = []
|
||||
memory_by_op[operation].append(memory_delta)
|
||||
|
||||
return {op: sum(deltas) for op, deltas in memory_by_op.items()}
|
||||
```
|
||||
|
||||
## Evaluation Datasets
|
||||
|
||||
### Standardized Test Sets
|
||||
|
||||
#### Question-Answer Pairs
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
import json
|
||||
|
||||
@dataclass
|
||||
class EvaluationQuery:
|
||||
id: str
|
||||
question: str
|
||||
reference_answer: Optional[str]
|
||||
relevant_chunk_ids: List[str]
|
||||
query_type: str # factoid, analytical, comparative
|
||||
difficulty: str # easy, medium, hard
|
||||
domain: str # finance, medical, legal, technical
|
||||
|
||||
class EvaluationDataset:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.queries: List[EvaluationQuery] = []
|
||||
self.documents: Dict[str, str] = {}
|
||||
self.chunks: Dict[str, Dict] = {}
|
||||
|
||||
def add_query(self, query: EvaluationQuery):
|
||||
self.queries.append(query)
|
||||
|
||||
def add_document(self, doc_id: str, content: str):
|
||||
self.documents[doc_id] = content
|
||||
|
||||
def add_chunk(self, chunk_id: str, content: str, doc_id: str, metadata: Dict):
|
||||
self.chunks[chunk_id] = {
|
||||
'id': chunk_id,
|
||||
'content': content,
|
||||
'doc_id': doc_id,
|
||||
'metadata': metadata
|
||||
}
|
||||
|
||||
def save_to_file(self, filepath: str):
|
||||
data = {
|
||||
'name': self.name,
|
||||
'queries': [
|
||||
{
|
||||
'id': q.id,
|
||||
'question': q.question,
|
||||
'reference_answer': q.reference_answer,
|
||||
'relevant_chunk_ids': q.relevant_chunk_ids,
|
||||
'query_type': q.query_type,
|
||||
'difficulty': q.difficulty,
|
||||
'domain': q.domain
|
||||
}
|
||||
for q in self.queries
|
||||
],
|
||||
'documents': self.documents,
|
||||
'chunks': self.chunks
|
||||
}
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, filepath: str):
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
dataset = cls(data['name'])
|
||||
dataset.documents = data['documents']
|
||||
dataset.chunks = data['chunks']
|
||||
|
||||
for q_data in data['queries']:
|
||||
query = EvaluationQuery(
|
||||
id=q_data['id'],
|
||||
question=q_data['question'],
|
||||
reference_answer=q_data.get('reference_answer'),
|
||||
relevant_chunk_ids=q_data['relevant_chunk_ids'],
|
||||
query_type=q_data['query_type'],
|
||||
difficulty=q_data['difficulty'],
|
||||
domain=q_data['domain']
|
||||
)
|
||||
dataset.add_query(query)
|
||||
|
||||
return dataset
|
||||
```
|
||||
|
||||
### Dataset Generation
|
||||
|
||||
#### Synthetic Query Generation
|
||||
|
||||
```python
|
||||
import random
|
||||
from typing import List, Dict
|
||||
|
||||
class SyntheticQueryGenerator:
|
||||
def __init__(self):
|
||||
self.query_templates = {
|
||||
'factoid': [
|
||||
"What is {concept}?",
|
||||
"When did {event} occur?",
|
||||
"Who developed {technology}?",
|
||||
"How many {items} are mentioned?",
|
||||
"What is the value of {metric}?"
|
||||
],
|
||||
'analytical': [
|
||||
"Compare and contrast {concept1} and {concept2}.",
|
||||
"Analyze the impact of {concept} on {domain}.",
|
||||
"What are the advantages and disadvantages of {technology}?",
|
||||
"Explain the relationship between {concept1} and {concept2}.",
|
||||
"Evaluate the effectiveness of {approach} for {problem}."
|
||||
],
|
||||
'comparative': [
|
||||
"Which is better: {option1} or {option2}?",
|
||||
"How does {method1} differ from {method2}?",
|
||||
"Compare the performance of {system1} and {system2}.",
|
||||
"What are the key differences between {approach1} and {approach2}?"
|
||||
]
|
||||
}
|
||||
|
||||
def generate_queries_from_chunks(self, chunks: List[Dict], num_queries: int = 100) -> List[EvaluationQuery]:
|
||||
"""Generate synthetic queries from document chunks"""
|
||||
queries = []
|
||||
|
||||
# Extract entities and concepts from chunks
|
||||
entities = self._extract_entities_from_chunks(chunks)
|
||||
|
||||
for i in range(num_queries):
|
||||
query_type = random.choice(['factoid', 'analytical', 'comparative'])
|
||||
template = random.choice(self.query_templates[query_type])
|
||||
|
||||
# Fill template with extracted entities
|
||||
query_text = self._fill_template(template, entities)
|
||||
|
||||
# Find relevant chunks for this query
|
||||
relevant_chunks = self._find_relevant_chunks(query_text, chunks)
|
||||
|
||||
query = EvaluationQuery(
|
||||
id=f"synthetic_{i}",
|
||||
question=query_text,
|
||||
reference_answer=None, # Would need generation model
|
||||
relevant_chunk_ids=[chunk['id'] for chunk in relevant_chunks],
|
||||
query_type=query_type,
|
||||
difficulty=random.choice(['easy', 'medium', 'hard']),
|
||||
domain='synthetic'
|
||||
)
|
||||
|
||||
queries.append(query)
|
||||
|
||||
return queries
|
||||
|
||||
def _extract_entities_from_chunks(self, chunks: List[Dict]) -> Dict[str, List[str]]:
|
||||
"""Extract entities, concepts, and relationships from chunks"""
|
||||
# This would use proper NER in practice
|
||||
entities = {
|
||||
'concepts': [],
|
||||
'technologies': [],
|
||||
'methods': [],
|
||||
'metrics': [],
|
||||
'events': []
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
content = chunk['content']
|
||||
# Simplified entity extraction
|
||||
words = content.split()
|
||||
entities['concepts'].extend([word for word in words if len(word) > 6])
|
||||
entities['technologies'].extend([word for word in words if 'technology' in word.lower()])
|
||||
entities['methods'].extend([word for word in words if 'method' in word.lower()])
|
||||
entities['metrics'].extend([word for word in words if '%' in word or '$' in word])
|
||||
|
||||
# Remove duplicates and limit
|
||||
for key in entities:
|
||||
entities[key] = list(set(entities[key]))[:50]
|
||||
|
||||
return entities
|
||||
|
||||
def _fill_template(self, template: str, entities: Dict[str, List[str]]) -> str:
|
||||
"""Fill query template with random entities"""
|
||||
import re
|
||||
|
||||
def replace_placeholder(match):
|
||||
placeholder = match.group(1)
|
||||
|
||||
# Map placeholders to entity types
|
||||
entity_mapping = {
|
||||
'concept': 'concepts',
|
||||
'concept1': 'concepts',
|
||||
'concept2': 'concepts',
|
||||
'technology': 'technologies',
|
||||
'method': 'methods',
|
||||
'method1': 'methods',
|
||||
'method2': 'methods',
|
||||
'metric': 'metrics',
|
||||
'event': 'events',
|
||||
'items': 'concepts',
|
||||
'option1': 'concepts',
|
||||
'option2': 'concepts',
|
||||
'approach': 'methods',
|
||||
'problem': 'concepts',
|
||||
'domain': 'concepts',
|
||||
'system1': 'concepts',
|
||||
'system2': 'concepts'
|
||||
}
|
||||
|
||||
entity_type = entity_mapping.get(placeholder, 'concepts')
|
||||
available_entities = entities.get(entity_type, ['something'])
|
||||
|
||||
if available_entities:
|
||||
return random.choice(available_entities)
|
||||
else:
|
||||
return 'something'
|
||||
|
||||
return re.sub(r'\{(\w+)\}', replace_placeholder, template)
|
||||
|
||||
def _find_relevant_chunks(self, query: str, chunks: List[Dict], k: int = 3) -> List[Dict]:
|
||||
"""Find chunks most relevant to the query"""
|
||||
# Simple keyword matching for synthetic generation
|
||||
query_words = set(query.lower().split())
|
||||
|
||||
chunk_scores = []
|
||||
for chunk in chunks:
|
||||
chunk_words = set(chunk['content'].lower().split())
|
||||
overlap = len(query_words & chunk_words)
|
||||
chunk_scores.append((overlap, chunk))
|
||||
|
||||
# Sort by overlap and return top k
|
||||
chunk_scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return [chunk for _, chunk in chunk_scores[:k]]
|
||||
```
|
||||
|
||||
## A/B Testing Framework
|
||||
|
||||
### Statistical Significance Testing
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
class ABTestAnalyzer:
|
||||
def __init__(self):
|
||||
self.significance_level = 0.05
|
||||
|
||||
def compare_metrics(self, control_metrics: List[float],
|
||||
treatment_metrics: List[float],
|
||||
metric_name: str) -> Dict:
|
||||
"""
|
||||
Compare metrics between control and treatment groups
|
||||
"""
|
||||
control_mean = np.mean(control_metrics)
|
||||
treatment_mean = np.mean(treatment_metrics)
|
||||
|
||||
control_std = np.std(control_metrics)
|
||||
treatment_std = np.std(treatment_metrics)
|
||||
|
||||
# Perform t-test
|
||||
t_statistic, p_value = stats.ttest_ind(control_metrics, treatment_metrics)
|
||||
|
||||
# Calculate effect size (Cohen's d)
|
||||
pooled_std = np.sqrt(((len(control_metrics) - 1) * control_std**2 +
|
||||
(len(treatment_metrics) - 1) * treatment_std**2) /
|
||||
(len(control_metrics) + len(treatment_metrics) - 2))
|
||||
|
||||
cohens_d = (treatment_mean - control_mean) / pooled_std if pooled_std > 0 else 0
|
||||
|
||||
# Determine significance
|
||||
is_significant = p_value < self.significance_level
|
||||
|
||||
return {
|
||||
'metric_name': metric_name,
|
||||
'control_mean': control_mean,
|
||||
'treatment_mean': treatment_mean,
|
||||
'absolute_difference': treatment_mean - control_mean,
|
||||
'relative_difference': ((treatment_mean - control_mean) / control_mean * 100) if control_mean != 0 else 0,
|
||||
'control_std': control_std,
|
||||
'treatment_std': treatment_std,
|
||||
't_statistic': t_statistic,
|
||||
'p_value': p_value,
|
||||
'is_significant': is_significant,
|
||||
'effect_size': cohens_d,
|
||||
'significance_level': self.significance_level
|
||||
}
|
||||
|
||||
def analyze_ab_test_results(self,
|
||||
control_results: Dict[str, List[float]],
|
||||
treatment_results: Dict[str, List[float]]) -> Dict:
|
||||
"""
|
||||
Analyze A/B test results across multiple metrics
|
||||
"""
|
||||
analysis_results = {}
|
||||
|
||||
# Ensure both dictionaries have the same keys
|
||||
all_metrics = set(control_results.keys()) & set(treatment_results.keys())
|
||||
|
||||
for metric in all_metrics:
|
||||
if metric in control_results and metric in treatment_results:
|
||||
analysis_results[metric] = self.compare_metrics(
|
||||
control_results[metric],
|
||||
treatment_results[metric],
|
||||
metric
|
||||
)
|
||||
|
||||
# Calculate overall summary
|
||||
significant_improvements = sum(1 for result in analysis_results.values()
|
||||
if result['is_significant'] and result['relative_difference'] > 0)
|
||||
significant_degradations = sum(1 for result in analysis_results.values()
|
||||
if result['is_significant'] and result['relative_difference'] < 0)
|
||||
|
||||
analysis_results['summary'] = {
|
||||
'total_metrics_compared': len(analysis_results),
|
||||
'significant_improvements': significant_improvements,
|
||||
'significant_degradations': significant_degradations,
|
||||
'no_significant_change': len(analysis_results) - significant_improvements - significant_degradations
|
||||
}
|
||||
|
||||
return analysis_results
|
||||
```
|
||||
|
||||
## Automated Evaluation Pipeline
|
||||
|
||||
### End-to-End Evaluation
|
||||
|
||||
```python
|
||||
class ChunkingEvaluationPipeline:
|
||||
def __init__(self, strategies: Dict[str, Any], dataset: EvaluationDataset):
|
||||
self.strategies = strategies
|
||||
self.dataset = dataset
|
||||
self.results = {}
|
||||
self.profiler = PerformanceProfiler()
|
||||
self.memory_profiler = MemoryProfiler()
|
||||
|
||||
def run_evaluation(self) -> Dict:
|
||||
"""Run comprehensive evaluation of all strategies"""
|
||||
evaluation_results = {}
|
||||
|
||||
for strategy_name, strategy in self.strategies.items():
|
||||
print(f"Evaluating strategy: {strategy_name}")
|
||||
|
||||
# Reset profilers for each strategy
|
||||
self.profiler = PerformanceProfiler()
|
||||
self.memory_profiler = MemoryProfiler()
|
||||
|
||||
# Evaluate strategy
|
||||
strategy_results = self._evaluate_strategy(strategy, strategy_name)
|
||||
evaluation_results[strategy_name] = strategy_results
|
||||
|
||||
# Compare strategies
|
||||
comparison_results = self._compare_strategies(evaluation_results)
|
||||
|
||||
return {
|
||||
'individual_results': evaluation_results,
|
||||
'comparison': comparison_results,
|
||||
'recommendations': self._generate_recommendations(comparison_results)
|
||||
}
|
||||
|
||||
def _evaluate_strategy(self, strategy: Any, strategy_name: str) -> Dict:
|
||||
"""Evaluate a single chunking strategy"""
|
||||
results = {
|
||||
'strategy_name': strategy_name,
|
||||
'retrieval_metrics': {},
|
||||
'quality_metrics': {},
|
||||
'performance_metrics': {}
|
||||
}
|
||||
|
||||
# Track memory usage
|
||||
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_start")
|
||||
|
||||
# Process all documents
|
||||
self.profiler.start_timer('total_processing')
|
||||
|
||||
all_chunks = {}
|
||||
for doc_id, content in self.dataset.documents.items():
|
||||
self.profiler.start_timer('chunking')
|
||||
chunks = strategy.chunk(content)
|
||||
self.profiler.end_timer('chunking')
|
||||
|
||||
all_chunks[doc_id] = chunks
|
||||
|
||||
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_after_chunking")
|
||||
|
||||
# Generate embeddings for chunks
|
||||
self.profiler.start_timer('embedding')
|
||||
chunk_embeddings = self._generate_embeddings(all_chunks)
|
||||
self.profiler.end_timer('embedding')
|
||||
|
||||
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_after_embedding")
|
||||
|
||||
# Evaluate retrieval performance
|
||||
retrieval_results = self._evaluate_retrieval(all_chunks, chunk_embeddings)
|
||||
results['retrieval_metrics'] = retrieval_results
|
||||
|
||||
# Evaluate chunk quality
|
||||
quality_results = self._evaluate_chunk_quality(all_chunks)
|
||||
results['quality_metrics'] = quality_results
|
||||
|
||||
# Get performance metrics
|
||||
self.profiler.end_timer('total_processing')
|
||||
performance_metrics = self.profiler.get_performance_metrics(len(self.dataset.documents))
|
||||
results['performance_metrics'] = performance_metrics.__dict__
|
||||
|
||||
# Get memory metrics
|
||||
self.memory_profiler.take_memory_snapshot(f"{strategy_name}_end")
|
||||
results['memory_metrics'] = {
|
||||
'peak_memory_mb': self.memory_profiler.get_peak_memory_usage(),
|
||||
'memory_by_operation': self.memory_profiler.get_memory_usage_by_operation()
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def _evaluate_retrieval(self, all_chunks: Dict, chunk_embeddings: Dict) -> Dict:
|
||||
"""Evaluate retrieval performance"""
|
||||
retrieval_metrics = {
|
||||
'precision': [],
|
||||
'recall': [],
|
||||
'f1_score': [],
|
||||
'mrr': [],
|
||||
'map': []
|
||||
}
|
||||
|
||||
for query in self.dataset.queries:
|
||||
# Perform retrieval
|
||||
self.profiler.start_timer('search')
|
||||
retrieved_chunks = self._retrieve_chunks(query.question, chunk_embeddings, k=10)
|
||||
self.profiler.end_timer('search')
|
||||
|
||||
# Get relevant chunks for this query
|
||||
relevant_chunk_ids = set(query.relevant_chunk_ids)
|
||||
relevant_chunks = [chunk for chunk in retrieved_chunks
|
||||
if chunk.get('id') in relevant_chunk_ids]
|
||||
|
||||
# Calculate metrics
|
||||
precision = calculate_precision(retrieved_chunks, relevant_chunks)
|
||||
recall = calculate_recall(retrieved_chunks, relevant_chunks)
|
||||
f1 = calculate_f1_score(precision, recall)
|
||||
|
||||
retrieval_metrics['precision'].append(precision)
|
||||
retrieval_metrics['recall'].append(recall)
|
||||
retrieval_metrics['f1_score'].append(f1)
|
||||
|
||||
# Calculate averages
|
||||
return {metric: np.mean(values) for metric, values in retrieval_metrics.items()}
|
||||
|
||||
def _evaluate_chunk_quality(self, all_chunks: Dict) -> Dict:
|
||||
"""Evaluate quality of generated chunks"""
|
||||
quality_assessor = ChunkQualityAssessor()
|
||||
quality_scores = []
|
||||
|
||||
for doc_id, chunks in all_chunks.items():
|
||||
# Analyze document
|
||||
content = self.dataset.documents[doc_id]
|
||||
analyzer = DocumentAnalyzer()
|
||||
analysis = analyzer.analyze(content)
|
||||
|
||||
# Assess chunk quality
|
||||
scores = quality_assessor.assess_chunks(chunks, analysis)
|
||||
quality_scores.append(scores)
|
||||
|
||||
# Aggregate quality scores
|
||||
if quality_scores:
|
||||
avg_scores = {}
|
||||
for metric in quality_scores[0].keys():
|
||||
avg_scores[metric] = np.mean([scores[metric] for scores in quality_scores])
|
||||
return avg_scores
|
||||
|
||||
return {}
|
||||
|
||||
def _compare_strategies(self, evaluation_results: Dict) -> Dict:
|
||||
"""Compare performance across strategies"""
|
||||
ab_analyzer = ABTestAnalyzer()
|
||||
|
||||
comparison = {}
|
||||
|
||||
# Compare each metric across strategies
|
||||
strategy_names = list(evaluation_results.keys())
|
||||
|
||||
for i in range(len(strategy_names)):
|
||||
for j in range(i + 1, len(strategy_names)):
|
||||
strategy1 = strategy_names[i]
|
||||
strategy2 = strategy_names[j]
|
||||
|
||||
comparison_key = f"{strategy1}_vs_{strategy2}"
|
||||
comparison[comparison_key] = {}
|
||||
|
||||
# Compare retrieval metrics
|
||||
for metric in ['precision', 'recall', 'f1_score']:
|
||||
if (metric in evaluation_results[strategy1]['retrieval_metrics'] and
|
||||
metric in evaluation_results[strategy2]['retrieval_metrics']):
|
||||
|
||||
comparison[comparison_key][f"retrieval_{metric}"] = ab_analyzer.compare_metrics(
|
||||
[evaluation_results[strategy1]['retrieval_metrics'][metric]],
|
||||
[evaluation_results[strategy2]['retrieval_metrics'][metric]],
|
||||
f"retrieval_{metric}"
|
||||
)
|
||||
|
||||
return comparison
|
||||
|
||||
def _generate_recommendations(self, comparison_results: Dict) -> Dict:
|
||||
"""Generate recommendations based on evaluation results"""
|
||||
recommendations = {
|
||||
'best_overall': None,
|
||||
'best_for_precision': None,
|
||||
'best_for_recall': None,
|
||||
'best_for_performance': None,
|
||||
'trade_offs': []
|
||||
}
|
||||
|
||||
# This would analyze the comparison results and generate specific recommendations
|
||||
# Implementation depends on specific use case requirements
|
||||
|
||||
return recommendations
|
||||
|
||||
def _generate_embeddings(self, all_chunks: Dict) -> Dict:
|
||||
"""Generate embeddings for all chunks"""
|
||||
# This would use the actual embedding model
|
||||
# Placeholder implementation
|
||||
embeddings = {}
|
||||
|
||||
for doc_id, chunks in all_chunks.items():
|
||||
embeddings[doc_id] = []
|
||||
for chunk in chunks:
|
||||
# Generate embedding for chunk content
|
||||
embedding = np.random.rand(384) # Placeholder
|
||||
embeddings[doc_id].append({
|
||||
'chunk': chunk,
|
||||
'embedding': embedding
|
||||
})
|
||||
|
||||
return embeddings
|
||||
|
||||
def _retrieve_chunks(self, query: str, chunk_embeddings: Dict, k: int = 10) -> List[Dict]:
|
||||
"""Retrieve most relevant chunks for a query"""
|
||||
# This would use actual similarity search
|
||||
# Placeholder implementation
|
||||
all_chunks = []
|
||||
|
||||
for doc_embeddings in chunk_embeddings.values():
|
||||
for chunk_data in doc_embeddings:
|
||||
all_chunks.append(chunk_data['chunk'])
|
||||
|
||||
# Simple random selection as placeholder
|
||||
selected = random.sample(all_chunks, min(k, len(all_chunks)))
|
||||
|
||||
return selected
|
||||
```
|
||||
|
||||
This comprehensive evaluation framework provides the tools needed to thoroughly assess chunking strategies across multiple dimensions: retrieval effectiveness, answer quality, system performance, and statistical significance. The modular design allows for easy extension and customization based on specific requirements and use cases.
|
||||
709
skills/ai/chunking-strategy/references/implementation.md
Normal file
709
skills/ai/chunking-strategy/references/implementation.md
Normal file
@@ -0,0 +1,709 @@
|
||||
# Complete Implementation Guidelines
|
||||
|
||||
This document provides comprehensive implementation guidance for building effective chunking systems.
|
||||
|
||||
## System Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
```
|
||||
Document Processor
|
||||
├── Ingestion Layer
|
||||
│ ├── Document Type Detection
|
||||
│ ├── Format Parsing (PDF, HTML, Markdown, etc.)
|
||||
│ └── Content Extraction
|
||||
├── Analysis Layer
|
||||
│ ├── Structure Analysis
|
||||
│ ├── Content Type Identification
|
||||
│ └── Complexity Assessment
|
||||
├── Strategy Selection Layer
|
||||
│ ├── Rule-based Selection
|
||||
│ ├── ML-based Prediction
|
||||
│ └── Adaptive Configuration
|
||||
├── Chunking Layer
|
||||
│ ├── Strategy Implementation
|
||||
│ ├── Parameter Optimization
|
||||
│ └── Quality Validation
|
||||
└── Output Layer
|
||||
├── Chunk Metadata Generation
|
||||
├── Embedding Integration
|
||||
└── Storage Preparation
|
||||
```
|
||||
|
||||
## Pre-processing Pipeline
|
||||
|
||||
### Document Analysis Framework
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
|
||||
@dataclass
|
||||
class DocumentAnalysis:
|
||||
doc_type: str
|
||||
structure_score: float # 0-1, higher means more structured
|
||||
complexity_score: float # 0-1, higher means more complex
|
||||
content_types: List[str]
|
||||
language: str
|
||||
estimated_tokens: int
|
||||
has_multimodal: bool
|
||||
|
||||
class DocumentAnalyzer:
|
||||
def __init__(self):
|
||||
self.structure_patterns = {
|
||||
'markdown': [r'^#+\s', r'^\*\*.*\*\*$', r'^\* ', r'^\d+\. '],
|
||||
'html': [r'<h[1-6]>', r'<p>', r'<div>', r'<table>'],
|
||||
'latex': [r'\\section', r'\\subsection', r'\\begin\{', r'\\end\{'],
|
||||
'academic': [r'^\d+\.', r'^\d+\.\d+', r'^[A-Z]\.', r'^Figure \d+']
|
||||
}
|
||||
|
||||
def analyze(self, content: str) -> DocumentAnalysis:
|
||||
doc_type = self.detect_document_type(content)
|
||||
structure_score = self.calculate_structure_score(content, doc_type)
|
||||
complexity_score = self.calculate_complexity_score(content)
|
||||
content_types = self.identify_content_types(content)
|
||||
language = self.detect_language(content)
|
||||
estimated_tokens = self.estimate_tokens(content)
|
||||
has_multimodal = self.detect_multimodal_content(content)
|
||||
|
||||
return DocumentAnalysis(
|
||||
doc_type=doc_type,
|
||||
structure_score=structure_score,
|
||||
complexity_score=complexity_score,
|
||||
content_types=content_types,
|
||||
language=language,
|
||||
estimated_tokens=estimated_tokens,
|
||||
has_multimodal=has_multimodal
|
||||
)
|
||||
|
||||
def detect_document_type(self, content: str) -> str:
|
||||
content_lower = content.lower()
|
||||
|
||||
if '<html' in content_lower or '<body' in content_lower:
|
||||
return 'html'
|
||||
elif '#' in content and '##' in content:
|
||||
return 'markdown'
|
||||
elif '\\documentclass' in content_lower or '\\begin{' in content_lower:
|
||||
return 'latex'
|
||||
elif any(keyword in content_lower for keyword in ['abstract', 'introduction', 'conclusion', 'references']):
|
||||
return 'academic'
|
||||
elif 'def ' in content or 'class ' in content or 'function ' in content_lower:
|
||||
return 'code'
|
||||
else:
|
||||
return 'plain'
|
||||
|
||||
def calculate_structure_score(self, content: str, doc_type: str) -> float:
|
||||
patterns = self.structure_patterns.get(doc_type, [])
|
||||
if not patterns:
|
||||
return 0.5 # Default for unstructured content
|
||||
|
||||
line_count = len(content.split('\n'))
|
||||
structured_lines = 0
|
||||
|
||||
for line in content.split('\n'):
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, line.strip()):
|
||||
structured_lines += 1
|
||||
break
|
||||
|
||||
return min(structured_lines / max(line_count, 1), 1.0)
|
||||
|
||||
def calculate_complexity_score(self, content: str) -> float:
|
||||
# Factors that increase complexity
|
||||
avg_sentence_length = self.calculate_avg_sentence_length(content)
|
||||
vocabulary_richness = self.calculate_vocabulary_richness(content)
|
||||
nested_structure = self.detect_nested_structure(content)
|
||||
|
||||
# Normalize and combine
|
||||
complexity = (
|
||||
min(avg_sentence_length / 30, 1.0) * 0.3 +
|
||||
vocabulary_richness * 0.4 +
|
||||
nested_structure * 0.3
|
||||
)
|
||||
|
||||
return min(complexity, 1.0)
|
||||
|
||||
def identify_content_types(self, content: str) -> List[str]:
|
||||
types = []
|
||||
|
||||
if '```' in content or 'def ' in content or 'function ' in content.lower():
|
||||
types.append('code')
|
||||
if '|' in content and '\n' in content:
|
||||
types.append('tables')
|
||||
if re.search(r'\!\[.*\]\(.*\)', content):
|
||||
types.append('images')
|
||||
if re.search(r'http[s]?://', content):
|
||||
types.append('links')
|
||||
if re.search(r'\d+\.\d+', content) or re.search(r'\$\d', content):
|
||||
types.append('numbers')
|
||||
|
||||
return types if types else ['text']
|
||||
|
||||
def detect_language(self, content: str) -> str:
|
||||
# Simple language detection - can be enhanced with proper language detection libraries
|
||||
if re.search(r'[\u4e00-\u9fff]', content):
|
||||
return 'chinese'
|
||||
elif re.search(r'[u0600-\u06ff]', content):
|
||||
return 'arabic'
|
||||
elif re.search(r'[u0400-\u04ff]', content):
|
||||
return 'russian'
|
||||
else:
|
||||
return 'english' # Default assumption
|
||||
|
||||
def estimate_tokens(self, content: str) -> int:
|
||||
# Rough estimation - actual tokenization varies by model
|
||||
word_count = len(content.split())
|
||||
return int(word_count * 1.3) # Average tokens per word
|
||||
|
||||
def detect_multimodal_content(self, content: str) -> bool:
|
||||
multimodal_indicators = [
|
||||
r'\!\[.*\]\(.*\)', # Images
|
||||
r'<iframe', # Embedded content
|
||||
r'<object', # Embedded objects
|
||||
r'<embed', # Embedded media
|
||||
]
|
||||
|
||||
return any(re.search(pattern, content) for pattern in multimodal_indicators)
|
||||
|
||||
def calculate_avg_sentence_length(self, content: str) -> float:
|
||||
sentences = re.split(r'[.!?]+', content)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
if not sentences:
|
||||
return 0
|
||||
return sum(len(s.split()) for s in sentences) / len(sentences)
|
||||
|
||||
def calculate_vocabulary_richness(self, content: str) -> float:
|
||||
words = content.lower().split()
|
||||
if not words:
|
||||
return 0
|
||||
unique_words = set(words)
|
||||
return len(unique_words) / len(words)
|
||||
|
||||
def detect_nested_structure(self, content: str) -> float:
|
||||
# Detect nested lists, indented content, etc.
|
||||
lines = content.split('\n')
|
||||
indented_lines = 0
|
||||
|
||||
for line in lines:
|
||||
if line.strip() and line.startswith(' '):
|
||||
indented_lines += 1
|
||||
|
||||
return indented_lines / max(len(lines), 1)
|
||||
```
|
||||
|
||||
### Strategy Selection Engine
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
|
||||
class ChunkingStrategy(ABC):
|
||||
@abstractmethod
|
||||
def chunk(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
class StrategySelector:
|
||||
def __init__(self):
|
||||
self.strategies = {
|
||||
'fixed_size': FixedSizeStrategy(),
|
||||
'recursive': RecursiveStrategy(),
|
||||
'structure_aware': StructureAwareStrategy(),
|
||||
'semantic': SemanticStrategy(),
|
||||
'adaptive': AdaptiveStrategy()
|
||||
}
|
||||
|
||||
def select_strategy(self, analysis: DocumentAnalysis) -> str:
|
||||
# Rule-based selection logic
|
||||
if analysis.structure_score > 0.8 and analysis.doc_type in ['markdown', 'html', 'latex']:
|
||||
return 'structure_aware'
|
||||
elif analysis.complexity_score > 0.7 and analysis.estimated_tokens < 10000:
|
||||
return 'semantic'
|
||||
elif analysis.doc_type == 'code':
|
||||
return 'structure_aware'
|
||||
elif analysis.structure_score < 0.3:
|
||||
return 'fixed_size'
|
||||
elif analysis.complexity_score > 0.5:
|
||||
return 'recursive'
|
||||
else:
|
||||
return 'adaptive'
|
||||
|
||||
def get_strategy(self, analysis: DocumentAnalysis) -> ChunkingStrategy:
|
||||
strategy_name = self.select_strategy(analysis)
|
||||
return self.strategies[strategy_name]
|
||||
|
||||
# Example strategy implementations
|
||||
class FixedSizeStrategy(ChunkingStrategy):
|
||||
def __init__(self, default_size=512, default_overlap=50):
|
||||
self.default_size = default_size
|
||||
self.default_overlap = default_overlap
|
||||
|
||||
def chunk(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
|
||||
# Adjust parameters based on analysis
|
||||
if analysis.complexity_score > 0.7:
|
||||
chunk_size = 1024
|
||||
elif analysis.complexity_score < 0.3:
|
||||
chunk_size = 256
|
||||
else:
|
||||
chunk_size = self.default_size
|
||||
|
||||
overlap = int(chunk_size * 0.1) # 10% overlap
|
||||
|
||||
# Implementation here...
|
||||
return self._fixed_size_chunk(content, chunk_size, overlap)
|
||||
|
||||
def _fixed_size_chunk(self, content: str, chunk_size: int, overlap: int) -> List[Dict[str, Any]]:
|
||||
# Implementation using RecursiveCharacterTextSplitter or custom logic
|
||||
pass
|
||||
|
||||
class AdaptiveStrategy(ChunkingStrategy):
|
||||
def chunk(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
|
||||
# Combine multiple strategies based on content characteristics
|
||||
if analysis.structure_score > 0.6:
|
||||
# Use structure-aware for structured parts
|
||||
structured_chunks = self._chunk_structured_parts(content, analysis)
|
||||
else:
|
||||
# Use fixed-size for unstructured parts
|
||||
unstructured_chunks = self._chunk_unstructured_parts(content, analysis)
|
||||
|
||||
# Merge and optimize
|
||||
return self._merge_chunks(structured_chunks + unstructured_chunks)
|
||||
|
||||
def _chunk_structured_parts(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
|
||||
# Implementation for structured content
|
||||
pass
|
||||
|
||||
def _chunk_unstructured_parts(self, content: str, analysis: DocumentAnalysis) -> List[Dict[str, Any]]:
|
||||
# Implementation for unstructured content
|
||||
pass
|
||||
|
||||
def _merge_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# Implementation for merging and optimizing chunks
|
||||
pass
|
||||
```
|
||||
|
||||
## Quality Assurance Framework
|
||||
|
||||
### Chunk Quality Metrics
|
||||
|
||||
```python
|
||||
from typing import List, Dict, Any
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class ChunkQualityAssessor:
|
||||
def __init__(self):
|
||||
self.quality_weights = {
|
||||
'coherence': 0.3,
|
||||
'completeness': 0.25,
|
||||
'size_appropriateness': 0.2,
|
||||
'semantic_similarity': 0.15,
|
||||
'boundary_quality': 0.1
|
||||
}
|
||||
|
||||
def assess_chunks(self, chunks: List[Dict[str, Any]], analysis: DocumentAnalysis) -> Dict[str, float]:
|
||||
scores = {}
|
||||
|
||||
# Coherence: Do chunks make sense on their own?
|
||||
scores['coherence'] = self._assess_coherence(chunks)
|
||||
|
||||
# Completeness: Do chunks preserve important information?
|
||||
scores['completeness'] = self._assess_completeness(chunks, analysis)
|
||||
|
||||
# Size appropriateness: Are chunks within optimal size range?
|
||||
scores['size_appropriateness'] = self._assess_size(chunks)
|
||||
|
||||
# Semantic similarity: Are chunks thematically consistent?
|
||||
scores['semantic_similarity'] = self._assess_semantic_consistency(chunks)
|
||||
|
||||
# Boundary quality: Are chunk boundaries placed well?
|
||||
scores['boundary_quality'] = self._assess_boundary_quality(chunks)
|
||||
|
||||
# Calculate overall quality score
|
||||
overall_score = sum(
|
||||
score * self.quality_weights[metric]
|
||||
for metric, score in scores.items()
|
||||
)
|
||||
|
||||
scores['overall'] = overall_score
|
||||
return scores
|
||||
|
||||
def _assess_coherence(self, chunks: List[Dict[str, Any]]) -> float:
|
||||
# Simple heuristic-based coherence assessment
|
||||
coherence_scores = []
|
||||
|
||||
for chunk in chunks:
|
||||
content = chunk['content']
|
||||
|
||||
# Check for complete sentences
|
||||
sentences = re.split(r'[.!?]+', content)
|
||||
complete_sentences = sum(1 for s in sentences if s.strip())
|
||||
coherence = complete_sentences / max(len(sentences), 1)
|
||||
|
||||
coherence_scores.append(coherence)
|
||||
|
||||
return np.mean(coherence_scores)
|
||||
|
||||
def _assess_completeness(self, chunks: List[Dict[str, Any]], analysis: DocumentAnalysis) -> float:
|
||||
# Check if important structural elements are preserved
|
||||
if analysis.doc_type in ['markdown', 'html']:
|
||||
return self._assess_structure_preservation(chunks, analysis)
|
||||
else:
|
||||
return self._assess_content_preservation(chunks)
|
||||
|
||||
def _assess_structure_preservation(self, chunks: List[Dict[str, Any]], analysis: DocumentAnalysis) -> float:
|
||||
# Check if headings, lists, and other structural elements are preserved
|
||||
preserved_elements = 0
|
||||
total_elements = 0
|
||||
|
||||
for chunk in chunks:
|
||||
content = chunk['content']
|
||||
|
||||
# Count preserved structural elements
|
||||
headings = len(re.findall(r'^#+\s', content, re.MULTILINE))
|
||||
lists = len(re.findall(r'^\s*[-*+]\s', content, re.MULTILINE))
|
||||
|
||||
preserved_elements += headings + lists
|
||||
total_elements += 1 # Simplified count
|
||||
|
||||
return preserved_elements / max(total_elements, 1)
|
||||
|
||||
def _assess_content_preservation(self, chunks: List[Dict[str, Any]]) -> float:
|
||||
# Simple check based on content ratio
|
||||
total_content = ''.join(chunk['content'] for chunk in chunks)
|
||||
# This would need comparison with original content
|
||||
return 0.8 # Placeholder
|
||||
|
||||
def _assess_size(self, chunks: List[Dict[str, Any]]) -> float:
|
||||
optimal_min = 100 # tokens
|
||||
optimal_max = 1000 # tokens
|
||||
|
||||
size_scores = []
|
||||
for chunk in chunks:
|
||||
token_count = self._estimate_tokens(chunk['content'])
|
||||
if optimal_min <= token_count <= optimal_max:
|
||||
score = 1.0
|
||||
elif token_count < optimal_min:
|
||||
score = token_count / optimal_min
|
||||
else:
|
||||
score = max(0, 1 - (token_count - optimal_max) / optimal_max)
|
||||
|
||||
size_scores.append(score)
|
||||
|
||||
return np.mean(size_scores)
|
||||
|
||||
def _assess_semantic_consistency(self, chunks: List[Dict[str, Any]]) -> float:
|
||||
# This would require embedding models for actual implementation
|
||||
# Placeholder implementation
|
||||
return 0.7
|
||||
|
||||
def _assess_boundary_quality(self, chunks: List[Dict[str, Any]]) -> float:
|
||||
# Check if boundaries don't split important content
|
||||
boundary_scores = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
content = chunk['content']
|
||||
|
||||
# Check for incomplete sentences at boundaries
|
||||
if not content.strip().endswith(('.', '!', '?', '>', '}')):
|
||||
boundary_scores.append(0.5)
|
||||
else:
|
||||
boundary_scores.append(1.0)
|
||||
|
||||
return np.mean(boundary_scores)
|
||||
|
||||
def _estimate_tokens(self, content: str) -> int:
|
||||
# Simple token estimation
|
||||
return len(content.split()) * 4 // 3 # Rough approximation
|
||||
```
|
||||
|
||||
## Error Handling and Edge Cases
|
||||
|
||||
### Robust Error Handling
|
||||
|
||||
```python
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ChunkingError:
|
||||
error_type: str
|
||||
message: str
|
||||
chunk_index: Optional[int] = None
|
||||
recovery_action: Optional[str] = None
|
||||
|
||||
class ChunkingErrorHandler:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.error_handlers = {
|
||||
'empty_content': self._handle_empty_content,
|
||||
'oversized_chunk': self._handle_oversized_chunk,
|
||||
'encoding_error': self._handle_encoding_error,
|
||||
'memory_error': self._handle_memory_error,
|
||||
'structure_parsing_error': self._handle_structure_parsing_error
|
||||
}
|
||||
|
||||
def handle_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
error_type = self._classify_error(error)
|
||||
handler = self.error_handlers.get(error_type, self._handle_generic_error)
|
||||
return handler(error, context)
|
||||
|
||||
def _classify_error(self, error: Exception) -> str:
|
||||
if isinstance(error, ValueError) and 'empty' in str(error).lower():
|
||||
return 'empty_content'
|
||||
elif isinstance(error, MemoryError):
|
||||
return 'memory_error'
|
||||
elif isinstance(error, UnicodeError):
|
||||
return 'encoding_error'
|
||||
elif 'too large' in str(error).lower():
|
||||
return 'oversized_chunk'
|
||||
elif 'parsing' in str(error).lower():
|
||||
return 'structure_parsing_error'
|
||||
else:
|
||||
return 'generic_error'
|
||||
|
||||
def _handle_empty_content(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
self.logger.warning(f"Empty content encountered: {error}")
|
||||
return ChunkingError(
|
||||
error_type='empty_content',
|
||||
message=str(error),
|
||||
recovery_action='skip_empty_content'
|
||||
)
|
||||
|
||||
def _handle_oversized_chunk(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
self.logger.warning(f"Oversized chunk detected: {error}")
|
||||
return ChunkingError(
|
||||
error_type='oversized_chunk',
|
||||
message=str(error),
|
||||
chunk_index=context.get('chunk_index'),
|
||||
recovery_action='reduce_chunk_size'
|
||||
)
|
||||
|
||||
def _handle_encoding_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
self.logger.error(f"Encoding error: {error}")
|
||||
return ChunkingError(
|
||||
error_type='encoding_error',
|
||||
message=str(error),
|
||||
recovery_action='fallback_encoding'
|
||||
)
|
||||
|
||||
def _handle_memory_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
self.logger.error(f"Memory error during chunking: {error}")
|
||||
return ChunkingError(
|
||||
error_type='memory_error',
|
||||
message=str(error),
|
||||
recovery_action='process_in_batches'
|
||||
)
|
||||
|
||||
def _handle_structure_parsing_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
self.logger.warning(f"Structure parsing failed: {error}")
|
||||
return ChunkingError(
|
||||
error_type='structure_parsing_error',
|
||||
message=str(error),
|
||||
recovery_action='fallback_to_fixed_size'
|
||||
)
|
||||
|
||||
def _handle_generic_error(self, error: Exception, context: Dict[str, Any]) -> ChunkingError:
|
||||
self.logger.error(f"Unexpected error during chunking: {error}")
|
||||
return ChunkingError(
|
||||
error_type='generic_error',
|
||||
message=str(error),
|
||||
recovery_action='skip_and_continue'
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Caching and Memoization
|
||||
|
||||
```python
|
||||
import hashlib
|
||||
import pickle
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Any, Optional
|
||||
import redis
|
||||
import json
|
||||
|
||||
class ChunkingCache:
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
if redis_url:
|
||||
self.redis_client = redis.from_url(redis_url)
|
||||
else:
|
||||
self.redis_client = None
|
||||
self.local_cache = {}
|
||||
|
||||
def _generate_cache_key(self, content: str, strategy: str, params: Dict[str, Any]) -> str:
|
||||
content_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
params_str = json.dumps(params, sort_keys=True)
|
||||
params_hash = hashlib.md5(params_str.encode()).hexdigest()
|
||||
return f"chunking:{strategy}:{content_hash}:{params_hash}"
|
||||
|
||||
def get(self, content: str, strategy: str, params: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
cache_key = self._generate_cache_key(content, strategy, params)
|
||||
|
||||
# Try local cache first
|
||||
if cache_key in self.local_cache:
|
||||
return self.local_cache[cache_key]
|
||||
|
||||
# Try Redis cache
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_data = self.redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
chunks = pickle.loads(cached_data)
|
||||
self.local_cache[cache_key] = chunks # Cache locally too
|
||||
return chunks
|
||||
except Exception as e:
|
||||
logging.warning(f"Redis cache error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def set(self, content: str, strategy: str, params: Dict[str, Any], chunks: List[Dict[str, Any]]) -> None:
|
||||
cache_key = self._generate_cache_key(content, strategy, params)
|
||||
|
||||
# Store in local cache
|
||||
self.local_cache[cache_key] = chunks
|
||||
|
||||
# Store in Redis cache
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_data = pickle.dumps(chunks)
|
||||
self.redis_client.setex(cache_key, 3600, cached_data) # 1 hour TTL
|
||||
except Exception as e:
|
||||
logging.warning(f"Redis cache set error: {e}")
|
||||
|
||||
def clear_local_cache(self):
|
||||
self.local_cache.clear()
|
||||
|
||||
def clear_redis_cache(self):
|
||||
if self.redis_client:
|
||||
pattern = "chunking:*"
|
||||
keys = self.redis_client.keys(pattern)
|
||||
if keys:
|
||||
self.redis_client.delete(*keys)
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from typing import List, Callable, Any
|
||||
|
||||
class BatchChunkingProcessor:
|
||||
def __init__(self, max_workers: int = 4, batch_size: int = 10):
|
||||
self.max_workers = max_workers
|
||||
self.batch_size = batch_size
|
||||
|
||||
def process_documents_batch(self, documents: List[str],
|
||||
chunking_function: Callable[[str], List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
|
||||
"""Process multiple documents in parallel"""
|
||||
results = []
|
||||
|
||||
# Process in batches to avoid memory issues
|
||||
for i in range(0, len(documents), self.batch_size):
|
||||
batch = documents[i:i + self.batch_size]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
future_to_doc = {
|
||||
executor.submit(chunking_function, doc): doc
|
||||
for doc in batch
|
||||
}
|
||||
|
||||
batch_results = []
|
||||
for future in concurrent.futures.as_completed(future_to_doc):
|
||||
try:
|
||||
chunks = future.result()
|
||||
batch_results.append(chunks)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing document: {e}")
|
||||
batch_results.append([]) # Empty result for failed processing
|
||||
|
||||
results.extend(batch_results)
|
||||
|
||||
return results
|
||||
|
||||
async def process_documents_async(self, documents: List[str],
|
||||
chunking_function: Callable[[str], List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
|
||||
"""Process documents asynchronously"""
|
||||
semaphore = asyncio.Semaphore(self.max_workers)
|
||||
|
||||
async def process_single_document(doc: str) -> List[Dict[str, Any]]:
|
||||
async with semaphore:
|
||||
# Run the synchronous chunking function in an executor
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, chunking_function, doc)
|
||||
|
||||
tasks = [process_single_document(doc) for doc in documents]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
```
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
```python
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List
|
||||
from collections import defaultdict
|
||||
|
||||
@dataclass
|
||||
class ChunkingMetrics:
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
avg_chunk_size: float
|
||||
processing_time: float
|
||||
memory_usage: float
|
||||
error_count: int
|
||||
strategy_distribution: Dict[str, int]
|
||||
|
||||
class MetricsCollector:
|
||||
def __init__(self):
|
||||
self.metrics = defaultdict(list)
|
||||
self.start_time = None
|
||||
|
||||
def start_timing(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def end_timing(self) -> float:
|
||||
if self.start_time:
|
||||
duration = time.time() - self.start_time
|
||||
self.metrics['processing_time'].append(duration)
|
||||
self.start_time = None
|
||||
return duration
|
||||
return 0.0
|
||||
|
||||
def record_chunk_count(self, count: int):
|
||||
self.metrics['chunk_count'].append(count)
|
||||
|
||||
def record_chunk_size(self, size: int):
|
||||
self.metrics['chunk_size'].append(size)
|
||||
|
||||
def record_strategy_usage(self, strategy: str):
|
||||
self.metrics['strategy'][strategy] = self.metrics['strategy'].get(strategy, 0) + 1
|
||||
|
||||
def record_error(self, error_type: str):
|
||||
self.metrics['errors'].append(error_type)
|
||||
|
||||
def record_memory_usage(self, memory_mb: float):
|
||||
self.metrics['memory_usage'].append(memory_mb)
|
||||
|
||||
def get_summary(self) -> ChunkingMetrics:
|
||||
return ChunkingMetrics(
|
||||
total_documents=len(self.metrics['processing_time']),
|
||||
total_chunks=sum(self.metrics['chunk_count']),
|
||||
avg_chunk_size=sum(self.metrics['chunk_size']) / max(len(self.metrics['chunk_size']), 1),
|
||||
processing_time=sum(self.metrics['processing_time']),
|
||||
memory_usage=sum(self.metrics['memory_usage']) / max(len(self.metrics['memory_usage']), 1),
|
||||
error_count=len(self.metrics['errors']),
|
||||
strategy_distribution=dict(self.metrics['strategy'])
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.metrics.clear()
|
||||
self.start_time = None
|
||||
```
|
||||
|
||||
This implementation guide provides a comprehensive foundation for building robust, scalable chunking systems that can handle various document types and use cases while maintaining high quality and performance.
|
||||
366
skills/ai/chunking-strategy/references/research.md
Normal file
366
skills/ai/chunking-strategy/references/research.md
Normal file
@@ -0,0 +1,366 @@
|
||||
# Key Research Papers and Findings
|
||||
|
||||
This document summarizes important research papers and findings related to chunking strategies for RAG systems.
|
||||
|
||||
## Seminal Papers
|
||||
|
||||
### "Reconstructing Context: Evaluating Advanced Chunking Strategies for RAG" (arXiv:2504.19754)
|
||||
|
||||
**Key Findings**:
|
||||
- Page-level chunking achieved highest average accuracy (0.648) with lowest variance across different query types
|
||||
- Optimal chunk size varies significantly by document type and query complexity
|
||||
- Factoid queries perform better with smaller chunks (256-512 tokens)
|
||||
- Complex analytical queries benefit from larger chunks (1024+ tokens)
|
||||
|
||||
**Methodology**:
|
||||
- Evaluated 7 different chunking strategies across multiple document types
|
||||
- Tested with both factoid and analytical queries
|
||||
- Measured end-to-end RAG performance
|
||||
|
||||
**Practical Implications**:
|
||||
- Start with page-level chunking for general-purpose RAG systems
|
||||
- Adapt chunk size based on expected query patterns
|
||||
- Consider hybrid approaches for mixed query types
|
||||
|
||||
### "Lost in the Middle: How Language Models Use Long Contexts"
|
||||
|
||||
**Key Findings**:
|
||||
- Language models tend to pay more attention to information at the beginning and end of context
|
||||
- Information in the middle of long contexts is often ignored
|
||||
- Performance degradation is most severe for centrally located information
|
||||
|
||||
**Practical Implications**:
|
||||
- Place most important information at chunk boundaries
|
||||
- Consider chunk overlap to ensure important context appears multiple times
|
||||
- Use ranking to prioritize relevant chunks for inclusion in context
|
||||
|
||||
### "Grounded Language Learning in a Simulated 3D World"
|
||||
|
||||
**Related Concepts**:
|
||||
- Importance of grounding text in visual/contextual information
|
||||
- Multi-modal learning approaches for better understanding
|
||||
|
||||
**Relevance to Chunking**:
|
||||
- Supports contextual chunking approaches that preserve visual/contextual relationships
|
||||
- Validates importance of maintaining document structure and relationships
|
||||
|
||||
## Industry Research
|
||||
|
||||
### NVIDIA Research: "Finding the Best Chunking Strategy for Accurate AI Responses"
|
||||
|
||||
**Key Findings**:
|
||||
- Page-level chunking outperformed sentence and paragraph-level approaches
|
||||
- Fixed-size chunking showed consistent but suboptimal performance
|
||||
- Semantic chunking provided improvements for complex documents
|
||||
|
||||
**Technical Details**:
|
||||
- Tested chunk sizes from 128 to 2048 tokens
|
||||
- Evaluated across financial, technical, and legal documents
|
||||
- Measured both retrieval accuracy and generation quality
|
||||
|
||||
**Recommendations**:
|
||||
- Use 512-1024 token chunks as starting point
|
||||
- Implement adaptive chunking based on document complexity
|
||||
- Consider page boundaries as natural chunk separators
|
||||
|
||||
### Cohere Research: "Effective Chunking Strategies for RAG"
|
||||
|
||||
**Key Findings**:
|
||||
- Recursive character splitting provides good balance of performance and simplicity
|
||||
- Document structure awareness improves retrieval by 15-20%
|
||||
- Overlap of 10-20% provides optimal context preservation
|
||||
|
||||
**Methodology**:
|
||||
- Compared 12 chunking strategies across 6 document types
|
||||
- Measured retrieval precision, recall, and F1-score
|
||||
- Tested with both dense and sparse retrieval
|
||||
|
||||
**Best Practices Identified**:
|
||||
- Start with recursive character splitting with 10-20% overlap
|
||||
- Preserve document structure (headings, lists, tables)
|
||||
- Customize chunk size based on embedding model context window
|
||||
|
||||
### Anthropic: "Contextual Retrieval"
|
||||
|
||||
**Key Innovation**:
|
||||
- Enhance each chunk with LLM-generated contextual information before embedding
|
||||
- Improves retrieval precision by 25-30% for complex documents
|
||||
- Particularly effective for technical and academic content
|
||||
|
||||
**Implementation Approach**:
|
||||
1. Split document using traditional methods
|
||||
2. For each chunk, generate contextual information using LLM
|
||||
3. Prepend context to chunk before embedding
|
||||
4. Use hybrid search (dense + sparse) with weighted ranking
|
||||
|
||||
**Trade-offs**:
|
||||
- Significant computational overhead (2-3x processing time)
|
||||
- Higher embedding storage requirements
|
||||
- Improved retrieval precision justifies cost for high-value applications
|
||||
|
||||
## Algorithmic Advances
|
||||
|
||||
### Semantic Chunking Algorithms
|
||||
|
||||
#### "Semantic Segmentation of Text Documents"
|
||||
|
||||
**Core Idea**: Use cosine similarity between consecutive sentence embeddings to identify natural boundaries.
|
||||
|
||||
**Algorithm**:
|
||||
1. Split document into sentences
|
||||
2. Generate embeddings for each sentence
|
||||
3. Calculate similarity between consecutive sentences
|
||||
4. Create boundaries where similarity drops below threshold
|
||||
5. Merge short segments with neighbors
|
||||
|
||||
**Performance**: 20-30% improvement in retrieval relevance over fixed-size chunking for technical documents.
|
||||
|
||||
#### "Hierarchical Semantic Chunking"
|
||||
|
||||
**Core Idea**: Multi-level semantic segmentation for document organization.
|
||||
|
||||
**Algorithm**:
|
||||
1. Document-level semantic analysis
|
||||
2. Section-level boundary detection
|
||||
3. Paragraph-level segmentation
|
||||
4. Sentence-level refinement
|
||||
|
||||
**Benefits**: Maintains document hierarchy while adapting to semantic structure.
|
||||
|
||||
### Advanced Embedding Techniques
|
||||
|
||||
#### "Late Chunking: Contextual Chunk Embeddings"
|
||||
|
||||
**Core Innovation**: Generate embeddings for entire document first, then create chunk embeddings from token-level embeddings.
|
||||
|
||||
**Advantages**:
|
||||
- Preserves global document context
|
||||
- Reduces context fragmentation
|
||||
- Better for documents with complex inter-relationships
|
||||
|
||||
**Requirements**:
|
||||
- Long-context embedding models (8k+ tokens)
|
||||
- Significant computational resources
|
||||
- Specialized implementation
|
||||
|
||||
#### "Hierarchical Embedding Retrieval"
|
||||
|
||||
**Approach**: Create embeddings at multiple granularities (document, section, paragraph, sentence).
|
||||
|
||||
**Implementation**:
|
||||
1. Generate embeddings at each level
|
||||
2. Store in hierarchical vector database
|
||||
3. Query at appropriate granularity based on information needs
|
||||
|
||||
**Performance**: 15-25% improvement in precision for complex queries.
|
||||
|
||||
## Evaluation Methodologies
|
||||
|
||||
### Retrieval-Augmented Generation Assessment Frameworks
|
||||
|
||||
#### RAGAS Framework
|
||||
|
||||
**Metrics**:
|
||||
- **Faithfulness**: Consistency between generated answer and retrieved context
|
||||
- **Answer Relevancy**: Relevance of generated answer to the question
|
||||
- **Context Relevancy**: Relevance of retrieved context to the question
|
||||
- **Context Recall**: Coverage of relevant information in retrieved context
|
||||
|
||||
**Evaluation Process**:
|
||||
1. Generate questions from document corpus
|
||||
2. Retrieve relevant chunks using different strategies
|
||||
3. Generate answers using retrieved chunks
|
||||
4. Evaluate using automated metrics and human judgment
|
||||
|
||||
#### ARES Framework
|
||||
|
||||
**Innovation**: Automated evaluation using synthetic questions and LLM-based assessment.
|
||||
|
||||
**Key Features**:
|
||||
- Generates diverse question types (factoid, analytical, comparative)
|
||||
- Uses LLMs to evaluate answer quality
|
||||
- Provides scalable evaluation without human annotation
|
||||
|
||||
### Benchmark Datasets
|
||||
|
||||
#### Natural Questions (NQ)
|
||||
|
||||
**Description**: Real user questions from Google Search with relevant Wikipedia passages.
|
||||
|
||||
**Relevance**: Natural language queries with authentic relevance judgments.
|
||||
|
||||
#### MS MARCO
|
||||
|
||||
**Description**: Large-scale passage ranking dataset with real search queries.
|
||||
|
||||
**Relevance**: High-quality relevance judgments for passage retrieval.
|
||||
|
||||
#### HotpotQA
|
||||
|
||||
**Description**: Multi-hop question answering requiring information from multiple documents.
|
||||
|
||||
**Relevance**: Tests ability to retrieve and synthesize information from multiple chunks.
|
||||
|
||||
## Domain-Specific Research
|
||||
|
||||
### Medical Documents
|
||||
|
||||
#### "Optimal Chunking for Medical Question Answering"
|
||||
|
||||
**Key Findings**:
|
||||
- Medical terminology requires specialized handling
|
||||
- Section-based chunking (History, Diagnosis, Treatment) most effective
|
||||
- Preserving doctor-patient dialogue context crucial
|
||||
|
||||
**Recommendations**:
|
||||
- Use medical-specific tokenizers
|
||||
- Preserve section headers and structure
|
||||
- Maintain temporal relationships in medical histories
|
||||
|
||||
### Legal Documents
|
||||
|
||||
#### "Chunking Strategies for Legal Document Analysis"
|
||||
|
||||
**Key Findings**:
|
||||
- Legal citations and cross-references require special handling
|
||||
- Contract clause boundaries serve as natural chunk separators
|
||||
- Case law benefits from hierarchical chunking
|
||||
|
||||
**Best Practices**:
|
||||
- Preserve legal citation structure
|
||||
- Use clause and section boundaries
|
||||
- Maintain context for legal definitions and references
|
||||
|
||||
### Financial Documents
|
||||
|
||||
#### "SEC Filing Chunking for Financial Analysis"
|
||||
|
||||
**Key Findings**:
|
||||
- Table preservation critical for financial data
|
||||
- XBRL tagging provides natural segmentation
|
||||
- Risk factors sections benefit from specialized treatment
|
||||
|
||||
**Approach**:
|
||||
- Preserve complete tables when possible
|
||||
- Use XBRL tags for structured data
|
||||
- Create specialized chunks for risk sections
|
||||
|
||||
## Emerging Trends
|
||||
|
||||
### Multi-Modal Chunking
|
||||
|
||||
#### "Integrating Text, Tables, and Images in RAG Systems"
|
||||
|
||||
**Innovation**: Unified chunking approach for mixed-modal content.
|
||||
|
||||
**Approach**:
|
||||
- Extract and describe images using vision models
|
||||
- Preserve table structure and relationships
|
||||
- Create unified embeddings for mixed content
|
||||
|
||||
**Results**: 35% improvement in complex document understanding.
|
||||
|
||||
### Adaptive Chunking
|
||||
|
||||
#### "Machine Learning-Based Chunk Size Optimization"
|
||||
|
||||
**Core Idea**: Use ML models to predict optimal chunking parameters.
|
||||
|
||||
**Features**:
|
||||
- Document length and complexity
|
||||
- Query type distribution
|
||||
- Embedding model characteristics
|
||||
- Performance requirements
|
||||
|
||||
**Benefits**: Dynamic optimization based on use case and content.
|
||||
|
||||
### Real-time Chunking
|
||||
|
||||
#### "Streaming Chunking for Live Document Processing"
|
||||
|
||||
**Innovation**: Process documents as they become available.
|
||||
|
||||
**Techniques**:
|
||||
- Incremental boundary detection
|
||||
- Dynamic chunk size adjustment
|
||||
- Context preservation across chunks
|
||||
|
||||
**Applications**: Live news feeds, social media analysis, meeting transcripts.
|
||||
|
||||
## Implementation Challenges
|
||||
|
||||
### Computational Efficiency
|
||||
|
||||
#### "Scalable Chunking for Large Document Collections"
|
||||
|
||||
**Challenges**:
|
||||
- Processing millions of documents efficiently
|
||||
- Memory usage optimization
|
||||
- Distributed processing requirements
|
||||
|
||||
**Solutions**:
|
||||
- Batch processing with parallel execution
|
||||
- Streaming approaches for large documents
|
||||
- Distributed chunking with load balancing
|
||||
|
||||
### Quality Assurance
|
||||
|
||||
#### "Evaluating Chunk Quality at Scale"
|
||||
|
||||
**Challenges**:
|
||||
- Automated quality assessment
|
||||
- Detecting poor chunk boundaries
|
||||
- Maintaining consistency across document types
|
||||
|
||||
**Approaches**:
|
||||
- Heuristic-based quality metrics
|
||||
- LLM-based evaluation
|
||||
- Human-in-the-loop validation
|
||||
|
||||
## Future Research Directions
|
||||
|
||||
### Context-Aware Chunking
|
||||
|
||||
**Open Questions**:
|
||||
- How to optimally preserve cross-chunk relationships?
|
||||
- Can we predict chunk quality without human evaluation?
|
||||
- What is the optimal balance between size and context?
|
||||
|
||||
### Domain Adaptation
|
||||
|
||||
**Research Areas**:
|
||||
- Automatic domain detection and adaptation
|
||||
- Transfer learning across domains
|
||||
- Zero-shot chunking for new document types
|
||||
|
||||
### Evaluation Standards
|
||||
|
||||
**Needs**:
|
||||
- Standardized evaluation benchmarks
|
||||
- Cross-paper comparison methodologies
|
||||
- Real-world performance metrics
|
||||
|
||||
## Practical Recommendations Based on Research
|
||||
|
||||
### Starting Points
|
||||
|
||||
1. **For General RAG Systems**: Page-level or recursive character chunking with 512-1024 tokens and 10-20% overlap
|
||||
2. **For Technical Documents**: Structure-aware chunking with semantic boundary detection
|
||||
3. **For High-Value Applications**: Contextual retrieval with LLM-generated context
|
||||
|
||||
### Evolution Strategy
|
||||
|
||||
1. **Begin**: Simple fixed-size chunking (512 tokens)
|
||||
2. **Improve**: Add document structure awareness
|
||||
3. **Optimize**: Implement semantic boundaries
|
||||
4. **Advanced**: Consider contextual retrieval for critical use cases
|
||||
|
||||
### Key Success Factors
|
||||
|
||||
1. **Match strategy to document type and query patterns**
|
||||
2. **Preserve document structure when beneficial**
|
||||
3. **Use overlap to maintain context across boundaries**
|
||||
4. **Monitor both accuracy and computational costs**
|
||||
5. **Iterate based on specific use case requirements**
|
||||
|
||||
This research foundation provides evidence-based guidance for implementing effective chunking strategies across various domains and use cases.
|
||||
1315
skills/ai/chunking-strategy/references/semantic-methods.md
Normal file
1315
skills/ai/chunking-strategy/references/semantic-methods.md
Normal file
File diff suppressed because it is too large
Load Diff
423
skills/ai/chunking-strategy/references/strategies.md
Normal file
423
skills/ai/chunking-strategy/references/strategies.md
Normal file
@@ -0,0 +1,423 @@
|
||||
# Detailed Chunking Strategies
|
||||
|
||||
This document provides comprehensive implementation details for all chunking strategies mentioned in the main skill.
|
||||
|
||||
## Level 1: Fixed-Size Chunking
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
class FixedSizeChunker:
|
||||
def __init__(self, chunk_size=512, chunk_overlap=50):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", " ", ""]
|
||||
)
|
||||
|
||||
def chunk(self, documents):
|
||||
return self.splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
### Parameter Recommendations
|
||||
|
||||
| Use Case | Chunk Size | Overlap | Rationale |
|
||||
|----------|------------|---------|-----------|
|
||||
| Factoid Queries | 256 | 25 | Small chunks for precise answers |
|
||||
| General Q&A | 512 | 50 | Balanced approach for most cases |
|
||||
| Analytical Queries | 1024 | 100 | Larger context for complex analysis |
|
||||
| Code Documentation | 300 | 30 | Preserve code context while maintaining focus |
|
||||
|
||||
### Best Practices
|
||||
|
||||
- Start with 512 tokens and 10-20% overlap
|
||||
- Adjust based on embedding model context window
|
||||
- Use overlap for queries where context might span boundaries
|
||||
- Monitor token count vs. character count based on model
|
||||
|
||||
## Level 2: Recursive Character Chunking
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
class RecursiveChunker:
|
||||
def __init__(self, chunk_size=512, separators=None):
|
||||
self.chunk_size = chunk_size
|
||||
self.separators = separators or ["\n\n", "\n", " ", ""]
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=0,
|
||||
length_function=len,
|
||||
separators=self.separators
|
||||
)
|
||||
|
||||
def chunk(self, text):
|
||||
return self.splitter.create_documents([text])
|
||||
|
||||
# Document-specific configurations
|
||||
def get_chunker_for_document_type(doc_type):
|
||||
configurations = {
|
||||
"markdown": ["\n## ", "\n### ", "\n\n", "\n", " ", ""],
|
||||
"html": ["</div>", "</p>", "\n\n", "\n", " ", ""],
|
||||
"code": ["\n\n", "\n", " ", ""],
|
||||
"plain": ["\n\n", "\n", " ", ""]
|
||||
}
|
||||
return RecursiveChunker(separators=configurations.get(doc_type, ["\n\n", "\n", " ", ""]))
|
||||
```
|
||||
|
||||
### Customization Guidelines
|
||||
|
||||
- **Markdown**: Use headings as primary separators
|
||||
- **HTML**: Use block-level tags as separators
|
||||
- **Code**: Preserve function and class boundaries
|
||||
- **Academic papers**: Prioritize paragraph and section breaks
|
||||
|
||||
## Level 3: Structure-Aware Chunking
|
||||
|
||||
### Markdown Documents
|
||||
|
||||
```python
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
class MarkdownChunker:
|
||||
def __init__(self, max_chunk_size=512):
|
||||
self.max_chunk_size = max_chunk_size
|
||||
|
||||
def chunk(self, markdown_text):
|
||||
html = markdown.markdown(markdown_text)
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_heading = "Introduction"
|
||||
|
||||
for element in soup.find_all(['h1', 'h2', 'h3', 'p', 'pre', 'table']):
|
||||
if element.name.startswith('h'):
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"content": current_chunk.strip(),
|
||||
"heading": current_heading
|
||||
})
|
||||
current_heading = element.get_text().strip()
|
||||
current_chunk = f"{element}\n"
|
||||
elif element.name in ['pre', 'table']:
|
||||
# Preserve code blocks and tables intact
|
||||
if len(current_chunk) + len(str(element)) > self.max_chunk_size:
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"content": current_chunk.strip(),
|
||||
"heading": current_heading
|
||||
})
|
||||
current_chunk = f"{element}\n"
|
||||
else:
|
||||
current_chunk += f"{element}\n"
|
||||
else:
|
||||
current_chunk += str(element)
|
||||
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"content": current_chunk.strip(),
|
||||
"heading": current_heading
|
||||
})
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
### Code Documents
|
||||
|
||||
```python
|
||||
import ast
|
||||
import re
|
||||
|
||||
class CodeChunker:
|
||||
def __init__(self, language='python'):
|
||||
self.language = language
|
||||
|
||||
def chunk_python(self, code):
|
||||
tree = ast.parse(code)
|
||||
chunks = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
|
||||
start_line = node.lineno - 1
|
||||
end_line = node.end_lineno if hasattr(node, 'end_lineno') else start_line + 10
|
||||
lines = code.split('\n')
|
||||
chunk_lines = lines[start_line:end_line]
|
||||
chunks.append('\n'.join(chunk_lines))
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_javascript(self, code):
|
||||
# Use regex for languages without AST parsers
|
||||
function_pattern = r'(function\s+\w+\s*\([^)]*\)\s*\{[^}]*\})'
|
||||
class_pattern = r'(class\s+\w+\s*\{[^}]*\})'
|
||||
|
||||
patterns = [function_pattern, class_pattern]
|
||||
chunks = []
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.finditer(pattern, code, re.MULTILINE | re.DOTALL)
|
||||
for match in matches:
|
||||
chunks.append(match.group(1))
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk(self, code):
|
||||
if self.language == 'python':
|
||||
return self.chunk_python(code)
|
||||
elif self.language == 'javascript':
|
||||
return self.chunk_javascript(code)
|
||||
else:
|
||||
# Fallback to line-based chunking
|
||||
return self.chunk_by_lines(code)
|
||||
|
||||
def chunk_by_lines(self, code, max_lines=50):
|
||||
lines = code.split('\n')
|
||||
chunks = []
|
||||
|
||||
for i in range(0, len(lines), max_lines):
|
||||
chunk = '\n'.join(lines[i:i+max_lines])
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
### Tabular Data
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
class TableChunker:
|
||||
def __init__(self, max_rows=100, summary_rows=5):
|
||||
self.max_rows = max_rows
|
||||
self.summary_rows = summary_rows
|
||||
|
||||
def chunk(self, table_data):
|
||||
if isinstance(table_data, str):
|
||||
df = pd.read_csv(StringIO(table_data))
|
||||
else:
|
||||
df = table_data
|
||||
|
||||
chunks = []
|
||||
|
||||
if len(df) <= self.max_rows:
|
||||
# Small table - keep intact
|
||||
chunks.append({
|
||||
"type": "full_table",
|
||||
"content": df.to_string(),
|
||||
"metadata": {
|
||||
"rows": len(df),
|
||||
"columns": len(df.columns)
|
||||
}
|
||||
})
|
||||
else:
|
||||
# Large table - create summary + chunks
|
||||
summary = df.head(self.summary_rows)
|
||||
chunks.append({
|
||||
"type": "table_summary",
|
||||
"content": f"Table Summary ({len(df)} rows, {len(df.columns)} columns):\n{summary.to_string()}",
|
||||
"metadata": {
|
||||
"total_rows": len(df),
|
||||
"summary_rows": self.summary_rows,
|
||||
"columns": list(df.columns)
|
||||
}
|
||||
})
|
||||
|
||||
# Chunk the remaining data
|
||||
for i in range(self.summary_rows, len(df), self.max_rows):
|
||||
chunk_df = df.iloc[i:i+self.max_rows]
|
||||
chunks.append({
|
||||
"type": "table_chunk",
|
||||
"content": f"Rows {i+1}-{min(i+self.max_rows, len(df))}:\n{chunk_df.to_string()}",
|
||||
"metadata": {
|
||||
"start_row": i + 1,
|
||||
"end_row": min(i + self.max_rows, len(df)),
|
||||
"columns": list(df.columns)
|
||||
}
|
||||
})
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
## Level 4: Semantic Chunking
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class SemanticChunker:
|
||||
def __init__(self, model_name="all-MiniLM-L6-v2", similarity_threshold=0.8, buffer_size=3):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
def split_into_sentences(self, text):
|
||||
# Simple sentence splitting - can be enhanced with nltk/spacy
|
||||
sentences = re.split(r'[.!?]+', text)
|
||||
return [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
def chunk(self, text):
|
||||
sentences = self.split_into_sentences(text)
|
||||
|
||||
if len(sentences) <= self.buffer_size:
|
||||
return [text]
|
||||
|
||||
# Create embeddings
|
||||
embeddings = self.model.encode(sentences)
|
||||
|
||||
chunks = []
|
||||
current_chunk_sentences = []
|
||||
|
||||
for i in range(len(sentences)):
|
||||
current_chunk_sentences.append(sentences[i])
|
||||
|
||||
# Check if we should create a boundary
|
||||
if i < len(sentences) - 1:
|
||||
similarity = cosine_similarity(
|
||||
[embeddings[i]],
|
||||
[embeddings[i + 1]]
|
||||
)[0][0]
|
||||
|
||||
if similarity < self.similarity_threshold and len(current_chunk_sentences) >= 2:
|
||||
chunks.append(' '.join(current_chunk_sentences))
|
||||
current_chunk_sentences = []
|
||||
|
||||
# Add remaining sentences
|
||||
if current_chunk_sentences:
|
||||
chunks.append(' '.join(current_chunk_sentences))
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
### Parameter Tuning
|
||||
|
||||
| Parameter | Range | Effect |
|
||||
|-----------|-------|--------|
|
||||
| similarity_threshold | 0.5-0.9 | Higher values create more chunks |
|
||||
| buffer_size | 1-10 | Larger buffers provide more context |
|
||||
| model_name | Various | Different models for different domains |
|
||||
|
||||
### Optimization Tips
|
||||
|
||||
- Use domain-specific models for specialized content
|
||||
- Adjust threshold based on content complexity
|
||||
- Cache embeddings for repeated processing
|
||||
- Consider batch processing for large documents
|
||||
|
||||
## Level 5: Advanced Contextual Methods
|
||||
|
||||
### Late Chunking
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
class LateChunker:
|
||||
def __init__(self, model_name="microsoft/DialoGPT-medium"):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name)
|
||||
|
||||
def chunk(self, text, chunk_size=512):
|
||||
# Tokenize entire document
|
||||
tokens = self.tokenizer(text, return_tensors="pt", truncation=False)
|
||||
|
||||
# Get token-level embeddings
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**tokens, output_hidden_states=True)
|
||||
token_embeddings = outputs.last_hidden_state[0]
|
||||
|
||||
# Create chunk embeddings from token embeddings
|
||||
chunks = []
|
||||
for i in range(0, len(token_embeddings), chunk_size):
|
||||
chunk_tokens = token_embeddings[i:i+chunk_size]
|
||||
chunk_embedding = torch.mean(chunk_tokens, dim=0)
|
||||
chunks.append({
|
||||
"content": self.tokenizer.decode(tokens["input_ids"][0][i:i+chunk_size]),
|
||||
"embedding": chunk_embedding.numpy()
|
||||
})
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
### Contextual Retrieval
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
class ContextualChunker:
|
||||
def __init__(self, api_key):
|
||||
self.client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
def generate_context(self, chunk, full_document):
|
||||
prompt = f"""
|
||||
Given the following document and a chunk from it, provide a brief context
|
||||
that helps understand the chunk's meaning within the full document.
|
||||
|
||||
Document:
|
||||
{full_document[:2000]}...
|
||||
|
||||
Chunk:
|
||||
{chunk}
|
||||
|
||||
Context (max 50 words):
|
||||
"""
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=100,
|
||||
temperature=0
|
||||
)
|
||||
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def chunk_with_context(self, text, base_chunker):
|
||||
# First create base chunks
|
||||
base_chunks = base_chunker.chunk(text)
|
||||
|
||||
# Then add context to each chunk
|
||||
contextualized_chunks = []
|
||||
for chunk in base_chunks:
|
||||
context = self.generate_context(chunk.page_content, text)
|
||||
contextualized_content = f"Context: {context}\n\nContent: {chunk.page_content}"
|
||||
|
||||
contextualized_chunks.append({
|
||||
"content": contextualized_content,
|
||||
"original_content": chunk.page_content,
|
||||
"context": context
|
||||
})
|
||||
|
||||
return contextualized_chunks
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Computational Cost Analysis
|
||||
|
||||
| Strategy | Time Complexity | Space Complexity | Relative Cost |
|
||||
|----------|-----------------|------------------|---------------|
|
||||
| Fixed-Size | O(n) | O(n) | Low |
|
||||
| Recursive | O(n) | O(n) | Low |
|
||||
| Structure-Aware | O(n log n) | O(n) | Medium |
|
||||
| Semantic | O(n²) | O(n²) | High |
|
||||
| Late Chunking | O(n) | O(n) | Very High |
|
||||
| Contextual | O(n²) | O(n²) | Very High |
|
||||
|
||||
### Optimization Strategies
|
||||
|
||||
1. **Parallel Processing**: Process chunks concurrently when possible
|
||||
2. **Caching**: Store embeddings and intermediate results
|
||||
3. **Batch Operations**: Group similar operations together
|
||||
4. **Progressive Loading**: Process large documents in streaming fashion
|
||||
5. **Model Selection**: Choose appropriate models for task complexity
|
||||
867
skills/ai/chunking-strategy/references/tools.md
Normal file
867
skills/ai/chunking-strategy/references/tools.md
Normal file
@@ -0,0 +1,867 @@
|
||||
# Recommended Libraries and Frameworks
|
||||
|
||||
This document provides a comprehensive guide to tools, libraries, and frameworks for implementing chunking strategies.
|
||||
|
||||
## Core Chunking Libraries
|
||||
|
||||
### LangChain
|
||||
|
||||
**Overview**: Comprehensive framework for building applications with large language models, includes robust text splitting utilities.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install langchain langchain-text-splitters
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Multiple text splitting strategies
|
||||
- Integration with various document loaders
|
||||
- Support for different content types (code, markdown, etc.)
|
||||
- Customizable separators and parameters
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
CharacterTextSplitter,
|
||||
TokenTextSplitter,
|
||||
MarkdownTextSplitter,
|
||||
PythonCodeTextSplitter
|
||||
)
|
||||
|
||||
# Basic recursive splitting
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", " ", ""]
|
||||
)
|
||||
|
||||
chunks = splitter.split_text(large_text)
|
||||
|
||||
# Markdown-specific splitting
|
||||
markdown_splitter = MarkdownTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=100
|
||||
)
|
||||
|
||||
# Code-specific splitting
|
||||
code_splitter = PythonCodeTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=100
|
||||
)
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Well-maintained and actively developed
|
||||
- Extensive documentation and examples
|
||||
- Integrates well with other LangChain components
|
||||
- Supports multiple document types
|
||||
|
||||
**Cons**:
|
||||
- Can be heavy dependency for simple use cases
|
||||
- Some advanced features require LangChain ecosystem
|
||||
|
||||
### LlamaIndex
|
||||
|
||||
**Overview**: Data framework for LLM applications with advanced indexing and retrieval capabilities.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install llama-index
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Advanced semantic chunking
|
||||
- Hierarchical indexing
|
||||
- Context-aware retrieval
|
||||
- Integration with vector databases
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from llama_index.core.node_parser import (
|
||||
SentenceSplitter,
|
||||
SemanticSplitterNodeParser
|
||||
)
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
# Basic sentence splitting
|
||||
splitter = SentenceSplitter(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=20
|
||||
)
|
||||
|
||||
# Semantic chunking with embeddings
|
||||
embed_model = OpenAIEmbedding()
|
||||
semantic_splitter = SemanticSplitterNodeParser(
|
||||
buffer_size=1,
|
||||
breakpoint_percentile_threshold=95,
|
||||
embed_model=embed_model
|
||||
)
|
||||
|
||||
# Load and process documents
|
||||
documents = SimpleDirectoryReader("./data").load_data()
|
||||
nodes = semantic_splitter.get_nodes_from_documents(documents)
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Excellent semantic chunking capabilities
|
||||
- Built for production RAG systems
|
||||
- Strong vector database integration
|
||||
- Active community support
|
||||
|
||||
**Cons**:
|
||||
- More complex setup for basic use cases
|
||||
- Semantic chunking requires embedding model setup
|
||||
|
||||
### Unstructured
|
||||
|
||||
**Overview**: Open-source library for processing unstructured documents, especially strong with multi-modal content.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install "unstructured[pdf,png,jpg]"
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Multi-modal document processing
|
||||
- Support for PDFs, images, and various formats
|
||||
- Structure preservation
|
||||
- Table extraction and processing
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from unstructured.partition.auto import partition
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
# Partition document by type
|
||||
elements = partition(filename="document.pdf")
|
||||
|
||||
# Chunk by title/heading structure
|
||||
chunks = chunk_by_title(
|
||||
elements,
|
||||
combine_text_under_n_chars=2000,
|
||||
max_characters=10000,
|
||||
new_after_n_chars=1500,
|
||||
multipage_sections=True
|
||||
)
|
||||
|
||||
# Access chunked content
|
||||
for chunk in chunks:
|
||||
print(f"Category: {chunk.category}")
|
||||
print(f"Content: {chunk.text[:200]}...")
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Excellent for PDF and image processing
|
||||
- Preserves document structure
|
||||
- Handles tables and figures well
|
||||
- Strong multi-modal capabilities
|
||||
|
||||
**Cons**:
|
||||
- Can be slower for large documents
|
||||
- Requires additional dependencies for some formats
|
||||
|
||||
## Text Processing Libraries
|
||||
|
||||
### NLTK (Natural Language Toolkit)
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install nltk
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Sentence tokenization
|
||||
- Language detection
|
||||
- Text preprocessing
|
||||
- Linguistic analysis
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import nltk
|
||||
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
# Download required data
|
||||
nltk.download('punkt')
|
||||
nltk.download('stopwords')
|
||||
|
||||
# Sentence and word tokenization
|
||||
text = "This is a sample sentence. This is another sentence."
|
||||
sentences = sent_tokenize(text)
|
||||
words = word_tokenize(text)
|
||||
|
||||
# Stop words removal
|
||||
stop_words = set(stopwords.words('english'))
|
||||
filtered_words = [word for word in words if word.lower() not in stop_words]
|
||||
```
|
||||
|
||||
### spaCy
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install spacy
|
||||
python -m spacy download en_core_web_sm
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Industrial-strength NLP
|
||||
- Named entity recognition
|
||||
- Dependency parsing
|
||||
- Sentence boundary detection
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import spacy
|
||||
|
||||
# Load language model
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
# Process text
|
||||
doc = nlp("This is a sample sentence. This is another sentence.")
|
||||
|
||||
# Extract sentences
|
||||
sentences = [sent.text for sent in doc.sents]
|
||||
|
||||
# Named entities
|
||||
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
||||
|
||||
# Dependency parsing for better chunking
|
||||
for token in doc:
|
||||
print(f"{token.text}: {token.dep_} (head: {token.head.text})")
|
||||
```
|
||||
|
||||
### Sentence Transformers
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install sentence-transformers
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Pre-trained sentence embeddings
|
||||
- Semantic similarity calculation
|
||||
- Multi-lingual support
|
||||
- Custom model training
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
import numpy as np
|
||||
|
||||
# Load pre-trained model
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Generate embeddings
|
||||
sentences = ["This is a sentence.", "This is another sentence."]
|
||||
embeddings = model.encode(sentences)
|
||||
|
||||
# Calculate semantic similarity
|
||||
similarity = util.cos_sim(embeddings[0], embeddings[1])
|
||||
|
||||
# Find semantic boundaries for chunking
|
||||
def find_semantic_boundaries(text, model, threshold=0.8):
|
||||
sentences = [s.strip() for s in text.split('.') if s.strip()]
|
||||
embeddings = model.encode(sentences)
|
||||
|
||||
boundaries = [0]
|
||||
for i in range(1, len(sentences)):
|
||||
similarity = util.cos_sim(embeddings[i-1], embeddings[i])
|
||||
if similarity < threshold:
|
||||
boundaries.append(i)
|
||||
|
||||
return boundaries
|
||||
```
|
||||
|
||||
## Vector Databases and Search
|
||||
|
||||
### ChromaDB
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install chromadb
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- In-memory and persistent storage
|
||||
- Built-in embedding functions
|
||||
- Similarity search
|
||||
- Metadata filtering
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
# Initialize client
|
||||
client = chromadb.Client()
|
||||
|
||||
# Create collection
|
||||
collection = client.create_collection(
|
||||
name="document_chunks",
|
||||
embedding_function=embedding_functions.DefaultEmbeddingFunction()
|
||||
)
|
||||
|
||||
# Add chunks
|
||||
collection.add(
|
||||
documents=[chunk["content"] for chunk in chunks],
|
||||
metadatas=[chunk.get("metadata", {}) for chunk in chunks],
|
||||
ids=[chunk["id"] for chunk in chunks]
|
||||
)
|
||||
|
||||
# Search
|
||||
results = collection.query(
|
||||
query_texts=["What is chunking?"],
|
||||
n_results=5
|
||||
)
|
||||
```
|
||||
|
||||
### Pinecone
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install pinecone-client
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Managed vector database service
|
||||
- High-performance similarity search
|
||||
- Metadata filtering
|
||||
- Scalable infrastructure
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import pinecone
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Initialize
|
||||
pinecone.init(api_key="your-api-key", environment="your-environment")
|
||||
index_name = "document-chunks"
|
||||
|
||||
# Create index if it doesn't exist
|
||||
if index_name not in pinecone.list_indexes():
|
||||
pinecone.create_index(
|
||||
name=index_name,
|
||||
dimension=384, # Match embedding model
|
||||
metric="cosine"
|
||||
)
|
||||
|
||||
index = pinecone.Index(index_name)
|
||||
|
||||
# Generate embeddings and upsert
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
for chunk in chunks:
|
||||
embedding = model.encode(chunk["content"])
|
||||
index.upsert(
|
||||
vectors=[{
|
||||
"id": chunk["id"],
|
||||
"values": embedding.tolist(),
|
||||
"metadata": chunk.get("metadata", {})
|
||||
}]
|
||||
)
|
||||
|
||||
# Search
|
||||
query_embedding = model.encode("search query")
|
||||
results = index.query(
|
||||
vector=query_embedding.tolist(),
|
||||
top_k=5,
|
||||
include_metadata=True
|
||||
)
|
||||
```
|
||||
|
||||
### Weaviate
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install weaviate-client
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- GraphQL API
|
||||
- Hybrid search (dense + sparse)
|
||||
- Real-time updates
|
||||
- Schema validation
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import weaviate
|
||||
|
||||
# Connect to Weaviate
|
||||
client = weaviate.Client("http://localhost:8080")
|
||||
|
||||
# Define schema
|
||||
client.schema.create_class({
|
||||
"class": "DocumentChunk",
|
||||
"description": "A chunk of document content",
|
||||
"properties": [
|
||||
{
|
||||
"name": "content",
|
||||
"dataType": ["text"]
|
||||
},
|
||||
{
|
||||
"name": "source",
|
||||
"dataType": ["string"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Add data
|
||||
for chunk in chunks:
|
||||
client.data_object.create(
|
||||
data_object={
|
||||
"content": chunk["content"],
|
||||
"source": chunk.get("source", "unknown")
|
||||
},
|
||||
class_name="DocumentChunk"
|
||||
)
|
||||
|
||||
# Search
|
||||
results = client.query.get(
|
||||
"DocumentChunk",
|
||||
["content", "source"]
|
||||
).with_near_text({
|
||||
"concepts": ["search query"]
|
||||
}).with_limit(5).do()
|
||||
```
|
||||
|
||||
## Evaluation and Testing
|
||||
|
||||
### RAGAS
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install ragas
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- RAG evaluation metrics
|
||||
- Answer quality assessment
|
||||
- Context relevance measurement
|
||||
- Faithfulness evaluation
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from ragas import evaluate
|
||||
from ragas.metrics import (
|
||||
faithfulness,
|
||||
answer_relevancy,
|
||||
context_relevancy,
|
||||
context_recall
|
||||
)
|
||||
from datasets import Dataset
|
||||
|
||||
# Prepare evaluation data
|
||||
dataset = Dataset.from_dict({
|
||||
"question": ["What is chunking?"],
|
||||
"answer": ["Chunking is the process of breaking large documents into smaller segments"],
|
||||
"contexts": [["Chunking involves dividing text into manageable pieces for better processing"]],
|
||||
"ground_truth": ["Chunking is a document processing technique"]
|
||||
})
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(
|
||||
dataset=dataset,
|
||||
metrics=[
|
||||
faithfulness,
|
||||
answer_relevancy,
|
||||
context_relevancy,
|
||||
context_recall
|
||||
]
|
||||
)
|
||||
|
||||
print(result)
|
||||
```
|
||||
|
||||
### TruEra (TruLens)
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install trulens trulens-apps
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- LLM application evaluation
|
||||
- Feedback functions
|
||||
- Hallucination detection
|
||||
- Performance monitoring
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from trulens.core import TruSession
|
||||
from trulens.apps.custom import instrument
|
||||
from trulens.feedback import GroundTruthAgreement
|
||||
|
||||
# Initialize session
|
||||
session = TruSession()
|
||||
|
||||
# Define feedback functions
|
||||
f_groundedness = GroundTruthAgreement(ground_truth)
|
||||
|
||||
# Evaluate chunks
|
||||
@instrument
|
||||
def chunk_and_query(text, query):
|
||||
chunks = chunk_function(text)
|
||||
relevant_chunks = search_function(chunks, query)
|
||||
answer = generate_function(relevant_chunks, query)
|
||||
return answer
|
||||
|
||||
# Record evaluation
|
||||
with session:
|
||||
chunk_and_query("large document text", "what is the main topic?")
|
||||
```
|
||||
|
||||
## Document Processing
|
||||
|
||||
### PyPDF2
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install PyPDF2
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- PDF text extraction
|
||||
- Page manipulation
|
||||
- Metadata extraction
|
||||
- Form field processing
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import PyPDF2
|
||||
|
||||
def extract_text_from_pdf(pdf_path):
|
||||
text = ""
|
||||
with open(pdf_path, 'rb') as file:
|
||||
reader = PyPDF2.PdfReader(file)
|
||||
for page in reader.pages:
|
||||
text += page.extract_text()
|
||||
return text
|
||||
|
||||
# Extract text by page for better chunking
|
||||
def extract_pages(pdf_path):
|
||||
pages = []
|
||||
with open(pdf_path, 'rb') as file:
|
||||
reader = PyPDF2.PdfReader(file)
|
||||
for i, page in enumerate(reader.pages):
|
||||
pages.append({
|
||||
"page_number": i + 1,
|
||||
"content": page.extract_text()
|
||||
})
|
||||
return pages
|
||||
```
|
||||
|
||||
### python-docx
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install python-docx
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Microsoft Word document processing
|
||||
- Paragraph and table extraction
|
||||
- Style preservation
|
||||
- Metadata access
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from docx import Document
|
||||
|
||||
def extract_from_docx(docx_path):
|
||||
doc = Document(docx_path)
|
||||
content = []
|
||||
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
content.append({
|
||||
"type": "paragraph",
|
||||
"text": paragraph.text,
|
||||
"style": paragraph.style.name
|
||||
})
|
||||
|
||||
for table in doc.tables:
|
||||
table_text = []
|
||||
for row in table.rows:
|
||||
row_text = [cell.text for cell in row.cells]
|
||||
table_text.append(" | ".join(row_text))
|
||||
|
||||
content.append({
|
||||
"type": "table",
|
||||
"text": "\n".join(table_text)
|
||||
})
|
||||
|
||||
return content
|
||||
```
|
||||
|
||||
## Specialized Libraries
|
||||
|
||||
### tiktoken (OpenAI)
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install tiktoken
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Accurate token counting for OpenAI models
|
||||
- Fast encoding/decoding
|
||||
- Multiple model support
|
||||
- Language model specific tokenization
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import tiktoken
|
||||
|
||||
# Get encoding for specific model
|
||||
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
# Encode text
|
||||
tokens = encoding.encode("This is a sample text")
|
||||
print(f"Token count: {len(tokens)}")
|
||||
|
||||
# Decode tokens
|
||||
text = encoding.decode(tokens)
|
||||
|
||||
# Count tokens without full encoding
|
||||
def count_tokens(text, model="gpt-3.5-turbo"):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(text))
|
||||
|
||||
# Use in chunking
|
||||
def chunk_by_tokens(text, max_tokens=1000):
|
||||
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
tokens = encoding.encode(text)
|
||||
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), max_tokens):
|
||||
chunk_tokens = tokens[i:i + max_tokens]
|
||||
chunk_text = encoding.decode(chunk_tokens)
|
||||
chunks.append(chunk_text)
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
### PDFMiner
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install pdfminer.six
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Detailed PDF analysis
|
||||
- Layout preservation
|
||||
- Font and style information
|
||||
- High-precision text extraction
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from pdfminer.high_level import extract_pages
|
||||
from pdfminer.layout import LTTextContainer
|
||||
|
||||
def extract_structured_text(pdf_path):
|
||||
structured_content = []
|
||||
|
||||
for page_layout in extract_pages(pdf_path):
|
||||
page_content = []
|
||||
|
||||
for element in page_layout:
|
||||
if isinstance(element, LTTextContainer):
|
||||
text = element.get_text()
|
||||
font_info = {
|
||||
"font_size": element.height,
|
||||
"is_bold": "Bold" in element.fontname,
|
||||
"x0": element.x0,
|
||||
"y0": element.y0
|
||||
}
|
||||
page_content.append({
|
||||
"text": text.strip(),
|
||||
"font_info": font_info
|
||||
})
|
||||
|
||||
structured_content.append({
|
||||
"page_number": page_layout.pageid,
|
||||
"content": page_content
|
||||
})
|
||||
|
||||
return structured_content
|
||||
```
|
||||
|
||||
## Performance and Optimization
|
||||
|
||||
### Dask
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install dask[complete]
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Parallel processing
|
||||
- Out-of-core computation
|
||||
- Distributed computing
|
||||
- Integration with pandas
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import dask.bag as db
|
||||
from dask.distributed import Client
|
||||
|
||||
# Setup distributed client
|
||||
client = Client(n_workers=4)
|
||||
|
||||
# Parallel chunking of multiple documents
|
||||
def chunk_document(document):
|
||||
# Your chunking logic here
|
||||
return chunk_function(document)
|
||||
|
||||
# Process documents in parallel
|
||||
documents = ["doc1", "doc2", "doc3", ...] # List of document contents
|
||||
document_bag = db.from_sequence(documents)
|
||||
|
||||
# Apply chunking function in parallel
|
||||
chunked_documents = document_bag.map(chunk_document)
|
||||
|
||||
# Compute results
|
||||
results = chunked_documents.compute()
|
||||
```
|
||||
|
||||
### Ray
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install ray
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Distributed computing
|
||||
- Actor model
|
||||
- Autoscaling
|
||||
- ML pipeline integration
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
# Initialize Ray
|
||||
ray.init()
|
||||
|
||||
@ray.remote
|
||||
class ChunkingWorker:
|
||||
def __init__(self, strategy):
|
||||
self.strategy = strategy
|
||||
|
||||
def chunk_documents(self, documents):
|
||||
results = []
|
||||
for doc in documents:
|
||||
chunks = self.strategy.chunk(doc)
|
||||
results.append(chunks)
|
||||
return results
|
||||
|
||||
# Create workers
|
||||
workers = [ChunkingWorker.remote(strategy) for _ in range(4)]
|
||||
|
||||
# Distribute work
|
||||
documents_batch = [documents[i::4] for i in range(4)]
|
||||
futures = [worker.chunk_documents.remote(batch)
|
||||
for worker, batch in zip(workers, documents_batch)]
|
||||
|
||||
# Get results
|
||||
results = ray.get(futures)
|
||||
```
|
||||
|
||||
## Development and Testing
|
||||
|
||||
### pytest
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install pytest pytest-asyncio
|
||||
```
|
||||
|
||||
**Example Tests**:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from your_chunking_module import FixedSizeChunker, SemanticChunker
|
||||
|
||||
class TestFixedSizeChunker:
|
||||
def test_chunk_size_respect(self):
|
||||
chunker = FixedSizeChunker(chunk_size=100, chunk_overlap=10)
|
||||
text = "word " * 50 # 50 words
|
||||
|
||||
chunks = chunker.chunk(text)
|
||||
|
||||
for chunk in chunks:
|
||||
assert len(chunk.split()) <= 100 # Account for word boundaries
|
||||
|
||||
def test_overlap_consistency(self):
|
||||
chunker = FixedSizeChunker(chunk_size=50, chunk_overlap=10)
|
||||
text = "word " * 30
|
||||
|
||||
chunks = chunker.chunk(text)
|
||||
|
||||
# Check overlap between consecutive chunks
|
||||
for i in range(1, len(chunks)):
|
||||
chunk1_words = set(chunks[i-1].split()[-10:])
|
||||
chunk2_words = set(chunks[i].split()[:10])
|
||||
overlap = len(chunk1_words & chunk2_words)
|
||||
assert overlap >= 5 # Allow some tolerance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_chunker():
|
||||
chunker = SemanticChunker()
|
||||
text = "First topic sentence. Another sentence about first topic. " \
|
||||
"Now switching to second topic. More about second topic."
|
||||
|
||||
chunks = await chunker.chunk_async(text)
|
||||
|
||||
# Should detect topic change and create boundary
|
||||
assert len(chunks) >= 2
|
||||
assert "first topic" in chunks[0].lower()
|
||||
assert "second topic" in chunks[1].lower()
|
||||
```
|
||||
|
||||
### Memory Profiler
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install memory-profiler
|
||||
```
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
from memory_profiler import profile
|
||||
|
||||
@profile
|
||||
def chunk_large_document():
|
||||
chunker = FixedSizeChunker(chunk_size=1000)
|
||||
large_text = "word " * 100000 # Large document
|
||||
|
||||
chunks = chunker.chunk(large_text)
|
||||
return chunks
|
||||
|
||||
# Run with: python -m memory_profiler your_script.py
|
||||
```
|
||||
|
||||
This comprehensive toolset provides everything needed to implement, test, and optimize chunking strategies for various use cases, from simple text processing to production-grade RAG systems.
|
||||
1403
skills/ai/chunking-strategy/references/visualization-tools.md
Normal file
1403
skills/ai/chunking-strategy/references/visualization-tools.md
Normal file
File diff suppressed because it is too large
Load Diff
302
skills/ai/prompt-engineering/SKILL.md
Normal file
302
skills/ai/prompt-engineering/SKILL.md
Normal file
@@ -0,0 +1,302 @@
|
||||
---
|
||||
name: prompt-engineering
|
||||
category: backend
|
||||
tags: [prompt-engineering, few-shot-learning, chain-of-thought, optimization, templates, system-prompts, llm-performance, ai-patterns]
|
||||
version: 1.0.0
|
||||
description: This skill should be used when creating, optimizing, or implementing advanced prompt patterns including few-shot learning, chain-of-thought reasoning, prompt optimization workflows, template systems, and system prompt design. It provides comprehensive frameworks for building production-ready prompts with measurable performance improvements.
|
||||
---
|
||||
|
||||
# Prompt Engineering
|
||||
|
||||
This skill provides comprehensive frameworks for creating, optimizing, and implementing advanced prompt patterns that significantly improve LLM performance across various tasks and models.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Creating new prompts for complex reasoning or analytical tasks
|
||||
- Optimizing existing prompts for better accuracy or efficiency
|
||||
- Implementing few-shot learning with strategic example selection
|
||||
- Designing chain-of-thought reasoning for multi-step problems
|
||||
- Building reusable prompt templates and systems
|
||||
- Developing system prompts for consistent model behavior
|
||||
- Troubleshooting poor prompt performance or failure modes
|
||||
- Scaling prompt systems for production use cases
|
||||
|
||||
## Core Prompt Engineering Patterns
|
||||
|
||||
### 1. Few-Shot Learning Implementation
|
||||
|
||||
Select examples using semantic similarity and diversity sampling to maximize learning within context window constraints.
|
||||
|
||||
#### Example Selection Strategy
|
||||
- Use `references/few-shot-patterns.md` for comprehensive selection frameworks
|
||||
- Balance example count (3-5 optimal) with context window limitations
|
||||
- Include edge cases and boundary conditions in example sets
|
||||
- Prioritize diverse examples that cover problem space variations
|
||||
- Order examples from simple to complex for progressive learning
|
||||
|
||||
#### Few-Shot Template Structure
|
||||
```
|
||||
Example 1 (Basic case):
|
||||
Input: {representative_input}
|
||||
Output: {expected_output}
|
||||
|
||||
Example 2 (Edge case):
|
||||
Input: {challenging_input}
|
||||
Output: {robust_output}
|
||||
|
||||
Example 3 (Error case):
|
||||
Input: {problematic_input}
|
||||
Output: {corrected_output}
|
||||
|
||||
Now handle: {target_input}
|
||||
```
|
||||
|
||||
### 2. Chain-of-Thought Reasoning
|
||||
|
||||
Elicit step-by-step reasoning for complex problem-solving through structured thinking patterns.
|
||||
|
||||
#### Implementation Patterns
|
||||
- Reference `references/cot-patterns.md` for detailed reasoning frameworks
|
||||
- Use "Let's think step by step" for zero-shot CoT initiation
|
||||
- Provide complete reasoning traces for few-shot CoT demonstrations
|
||||
- Implement self-consistency by sampling multiple reasoning paths
|
||||
- Include verification and validation steps in reasoning chains
|
||||
|
||||
#### CoT Template Structure
|
||||
```
|
||||
Let's approach this step-by-step:
|
||||
|
||||
Step 1: {break_down_the_problem}
|
||||
Analysis: {detailed_reasoning}
|
||||
|
||||
Step 2: {identify_key_components}
|
||||
Analysis: {component_analysis}
|
||||
|
||||
Step 3: {synthesize_solution}
|
||||
Analysis: {solution_justification}
|
||||
|
||||
Final Answer: {conclusion_with_confidence}
|
||||
```
|
||||
|
||||
### 3. Prompt Optimization Workflows
|
||||
|
||||
Implement iterative refinement processes with measurable performance metrics and systematic A/B testing.
|
||||
|
||||
#### Optimization Process
|
||||
- Use `references/optimization-frameworks.md` for comprehensive optimization strategies
|
||||
- Measure baseline performance before optimization attempts
|
||||
- Implement single-variable changes for accurate attribution
|
||||
- Track metrics: accuracy, consistency, latency, token efficiency
|
||||
- Use statistical significance testing for A/B validation
|
||||
- Document optimization iterations and their impacts
|
||||
|
||||
#### Performance Metrics Framework
|
||||
- **Accuracy**: Task completion rate and output correctness
|
||||
- **Consistency**: Response stability across multiple runs
|
||||
- **Efficiency**: Token usage and response time optimization
|
||||
- **Robustness**: Performance across edge cases and variations
|
||||
- **Safety**: Adherence to guidelines and harm prevention
|
||||
|
||||
### 4. Template Systems Architecture
|
||||
|
||||
Build modular, reusable prompt components with variable interpolation and conditional sections.
|
||||
|
||||
#### Template Design Principles
|
||||
- Reference `references/template-systems.md` for modular template frameworks
|
||||
- Use clear variable naming conventions (e.g., `{user_input}`, `{context}`)
|
||||
- Implement conditional sections for different scenario handling
|
||||
- Design role-based templates for specific use cases
|
||||
- Create hierarchical template composition patterns
|
||||
|
||||
#### Template Structure Example
|
||||
```
|
||||
# System Context
|
||||
You are a {role} with {expertise_level} expertise in {domain}.
|
||||
|
||||
# Task Context
|
||||
{if background_information}
|
||||
Background: {background_information}
|
||||
{endif}
|
||||
|
||||
# Instructions
|
||||
{task_instructions}
|
||||
|
||||
# Examples
|
||||
{example_count}
|
||||
|
||||
# Output Format
|
||||
{output_specification}
|
||||
|
||||
# Input
|
||||
{user_query}
|
||||
```
|
||||
|
||||
### 5. System Prompt Design
|
||||
|
||||
Design comprehensive system prompts that establish consistent model behavior, output formats, and safety constraints.
|
||||
|
||||
#### System Prompt Components
|
||||
- Use `references/system-prompt-design.md` for detailed design guidelines
|
||||
- Define clear role specification and expertise boundaries
|
||||
- Establish output format requirements and structural constraints
|
||||
- Include safety guidelines and content policy adherence
|
||||
- Set context for background information and domain knowledge
|
||||
|
||||
#### System Prompt Framework
|
||||
```
|
||||
You are an expert {role} specializing in {domain} with {experience_level} of experience.
|
||||
|
||||
## Core Capabilities
|
||||
- List specific capabilities and expertise areas
|
||||
- Define scope of knowledge and limitations
|
||||
|
||||
## Behavioral Guidelines
|
||||
- Specify interaction style and communication approach
|
||||
- Define error handling and uncertainty protocols
|
||||
- Establish quality standards and verification requirements
|
||||
|
||||
## Output Requirements
|
||||
- Specify format expectations and structural requirements
|
||||
- Define content inclusion and exclusion criteria
|
||||
- Establish consistency and validation requirements
|
||||
|
||||
## Safety and Ethics
|
||||
- Include content policy adherence
|
||||
- Specify bias mitigation requirements
|
||||
- Define harm prevention protocols
|
||||
```
|
||||
|
||||
## Implementation Workflows
|
||||
|
||||
### Workflow 1: Create New Prompt from Requirements
|
||||
|
||||
1. **Analyze Requirements**
|
||||
- Identify task complexity and reasoning requirements
|
||||
- Determine target model capabilities and limitations
|
||||
- Define success criteria and evaluation metrics
|
||||
- Assess need for few-shot learning or CoT reasoning
|
||||
|
||||
2. **Select Pattern Strategy**
|
||||
- Use few-shot learning for classification or transformation tasks
|
||||
- Apply CoT for complex reasoning or multi-step problems
|
||||
- Implement template systems for reusable prompt architecture
|
||||
- Design system prompts for consistent behavior requirements
|
||||
|
||||
3. **Draft Initial Prompt**
|
||||
- Structure prompt with clear sections and logical flow
|
||||
- Include relevant examples or reasoning demonstrations
|
||||
- Specify output format and quality requirements
|
||||
- Incorporate safety guidelines and constraints
|
||||
|
||||
4. **Validate and Test**
|
||||
- Test with diverse input scenarios including edge cases
|
||||
- Measure performance against defined success criteria
|
||||
- Iterate refinement based on testing results
|
||||
- Document optimization decisions and their rationale
|
||||
|
||||
### Workflow 2: Optimize Existing Prompt
|
||||
|
||||
1. **Performance Analysis**
|
||||
- Measure current prompt performance metrics
|
||||
- Identify failure modes and error patterns
|
||||
- Analyze token efficiency and response latency
|
||||
- Assess consistency across multiple runs
|
||||
|
||||
2. **Optimization Strategy**
|
||||
- Apply systematic A/B testing with single-variable changes
|
||||
- Use few-shot learning to improve task adherence
|
||||
- Implement CoT reasoning for complex task components
|
||||
- Refine template structure for better clarity
|
||||
|
||||
3. **Implementation and Testing**
|
||||
- Deploy optimized prompts with controlled rollout
|
||||
- Monitor performance metrics in production environment
|
||||
- Compare against baseline using statistical significance
|
||||
- Document improvements and lessons learned
|
||||
|
||||
### Workflow 3: Scale Prompt Systems
|
||||
|
||||
1. **Modular Architecture Design**
|
||||
- Decompose complex prompts into reusable components
|
||||
- Create template inheritance hierarchies
|
||||
- Implement dynamic example selection systems
|
||||
- Build automated quality assurance frameworks
|
||||
|
||||
2. **Production Integration**
|
||||
- Implement prompt versioning and rollback capabilities
|
||||
- Create performance monitoring and alerting systems
|
||||
- Build automated testing frameworks for prompt validation
|
||||
- Establish update and deployment workflows
|
||||
|
||||
## Quality Assurance
|
||||
|
||||
### Validation Requirements
|
||||
- Test prompts with at least 10 diverse scenarios
|
||||
- Include edge cases, boundary conditions, and failure modes
|
||||
- Verify output format compliance and structural consistency
|
||||
- Validate safety guideline adherence and harm prevention
|
||||
- Measure performance across multiple model runs
|
||||
|
||||
### Performance Standards
|
||||
- Achieve >90% task completion for well-defined use cases
|
||||
- Maintain <5% variance across multiple runs for consistency
|
||||
- Optimize token usage without sacrificing accuracy
|
||||
- Ensure response latency meets application requirements
|
||||
- Demonstrate robust handling of edge cases and unexpected inputs
|
||||
|
||||
## Integration with Other Skills
|
||||
|
||||
This skill integrates seamlessly with:
|
||||
- **langchain4j-ai-services-patterns**: Interface-based prompt design
|
||||
- **langchain4j-rag-implementation-patterns**: Context-enhanced prompting
|
||||
- **langchain4j-testing-strategies**: Prompt validation frameworks
|
||||
- **unit-test-parameterized**: Systematic prompt testing approaches
|
||||
|
||||
## Resources and References
|
||||
|
||||
- `references/few-shot-patterns.md`: Comprehensive few-shot learning frameworks
|
||||
- `references/cot-patterns.md`: Chain-of-thought reasoning patterns and examples
|
||||
- `references/optimization-frameworks.md`: Systematic prompt optimization methodologies
|
||||
- `references/template-systems.md`: Modular template design and implementation
|
||||
- `references/system-prompt-design.md`: System prompt architecture and best practices
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Example 1: Classification Task with Few-Shot Learning
|
||||
```
|
||||
Classify customer feedback into categories using semantic similarity for example selection and diversity sampling for edge case coverage.
|
||||
```
|
||||
|
||||
### Example 2: Complex Reasoning with Chain-of-Thought
|
||||
```
|
||||
Implement step-by-step reasoning for financial analysis with verification steps and confidence scoring.
|
||||
```
|
||||
|
||||
### Example 3: Template System for Customer Service
|
||||
```
|
||||
Create modular templates with role-based components and conditional sections for different inquiry types.
|
||||
```
|
||||
|
||||
### Example 4: System Prompt for Code Generation
|
||||
```
|
||||
Design comprehensive system prompt with behavioral guidelines, output requirements, and safety constraints.
|
||||
```
|
||||
|
||||
## Common Pitfalls and Solutions
|
||||
|
||||
- **Overfitting examples**: Use diverse example sets with semantic variety
|
||||
- **Context window overflow**: Implement strategic example selection and compression
|
||||
- **Inconsistent outputs**: Specify clear output formats and validation requirements
|
||||
- **Poor generalization**: Include edge cases and boundary conditions in training examples
|
||||
- **Safety violations**: Incorporate comprehensive content policies and harm prevention
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
- Monitor token usage and implement compression strategies
|
||||
- Use caching for repeated prompt components
|
||||
- Optimize example selection for maximum learning efficiency
|
||||
- Implement progressive disclosure for complex prompt systems
|
||||
- Balance prompt complexity with response quality requirements
|
||||
|
||||
This skill provides the foundational patterns and methodologies for building production-ready prompt systems that consistently deliver high performance across diverse use cases and model types.
|
||||
426
skills/ai/prompt-engineering/references/cot-patterns.md
Normal file
426
skills/ai/prompt-engineering/references/cot-patterns.md
Normal file
@@ -0,0 +1,426 @@
|
||||
# Chain-of-Thought Reasoning Patterns
|
||||
|
||||
This reference provides comprehensive frameworks for implementing effective chain-of-thought (CoT) reasoning that improves model performance on complex, multi-step problems.
|
||||
|
||||
## Core Principles
|
||||
|
||||
### Step-by-Step Reasoning Elicitation
|
||||
|
||||
#### Problem Decomposition Strategy
|
||||
- Break complex problems into manageable sub-problems
|
||||
- Identify dependencies and relationships between components
|
||||
- Establish logical flow and sequence of reasoning steps
|
||||
- Define clear decision points and validation criteria
|
||||
|
||||
#### Verification and Validation Integration
|
||||
- Include self-checking mechanisms at critical junctures
|
||||
- Implement consistency checks across reasoning steps
|
||||
- Add confidence scoring for uncertain conclusions
|
||||
- Provide fallback strategies for ambiguous situations
|
||||
|
||||
## Zero-Shot Chain-of-Thought Patterns
|
||||
|
||||
### Basic CoT Initiation
|
||||
```
|
||||
Let's think step by step to solve this problem:
|
||||
|
||||
1. First, I need to understand what the question is asking for
|
||||
2. Then, I'll identify the key information and constraints
|
||||
3. Next, I'll consider different approaches to solve it
|
||||
4. I'll work through the solution methodically
|
||||
5. Finally, I'll verify my answer makes sense
|
||||
|
||||
Problem: {problem_statement}
|
||||
|
||||
Step 1: Understanding the question
|
||||
{analysis}
|
||||
|
||||
Step 2: Key information and constraints
|
||||
{information_analysis}
|
||||
|
||||
Step 3: Solution approach
|
||||
{approach_analysis}
|
||||
|
||||
Step 4: Working through the solution
|
||||
{detailed_solution}
|
||||
|
||||
Step 5: Verification
|
||||
{verification}
|
||||
|
||||
Final Answer: {conclusion}
|
||||
```
|
||||
|
||||
### Enhanced CoT with Confidence
|
||||
```
|
||||
Let me think through this systematically, breaking down the problem and checking my reasoning at each step.
|
||||
|
||||
**Problem**: {problem_description}
|
||||
|
||||
**Step 1: Problem Analysis**
|
||||
- What am I being asked to solve?
|
||||
- What information is provided?
|
||||
- What are the constraints?
|
||||
- My confidence in understanding: {score}/10
|
||||
|
||||
**Step 2: Strategy Selection**
|
||||
- Possible approaches:
|
||||
1. {approach_1}
|
||||
2. {approach_2}
|
||||
3. {approach_3}
|
||||
- Selected approach: {chosen_approach}
|
||||
- Rationale: {reasoning_for_choice}
|
||||
|
||||
**Step 3: Execution**
|
||||
- {detailed_step_by_step_solution}
|
||||
|
||||
**Step 4: Verification**
|
||||
- Does the answer make sense?
|
||||
- Have I addressed all parts of the question?
|
||||
- Confidence in final answer: {score}/10
|
||||
|
||||
**Final Answer**: {solution_with_confidence_score}
|
||||
```
|
||||
|
||||
## Few-Shot Chain-of-Thought Patterns
|
||||
|
||||
### Mathematical Reasoning Template
|
||||
```
|
||||
Solve the following math problem step by step.
|
||||
|
||||
Example 1:
|
||||
Problem: A store sells apples for $2 each and oranges for $3 each. If John buys 4 apples and 2 oranges, and spends exactly $14, how much does each fruit cost?
|
||||
|
||||
Step 1: Set up the equation
|
||||
Let a = cost of apples, o = cost of oranges
|
||||
4a + 2o = 14
|
||||
|
||||
Step 2: Simplify the equation
|
||||
Divide both sides by 2: 2a + o = 7
|
||||
|
||||
Step 3: Test integer solutions
|
||||
If a = 2, then 2(2) + o = 7 → o = 3
|
||||
If a = 3, then 2(3) + o = 7 → o = 1
|
||||
|
||||
Step 4: Verify the original problem
|
||||
For a=2, o=3: 4(2) + 2(3) = 8 + 6 = 14 ✓
|
||||
For a=3, o=1: 4(3) + 2(1) = 12 + 2 = 14 ✓
|
||||
|
||||
Both solutions work, so apples could cost $2 and oranges $3, OR apples $3 and oranges $1.
|
||||
|
||||
Example 2:
|
||||
Problem: Sarah is twice as old as Tom was 3 years ago. In 5 years, Sarah will be 1.5 times Tom's current age. How old are they now?
|
||||
|
||||
Step 1: Define variables
|
||||
Let S = Sarah's current age, T = Tom's current age
|
||||
|
||||
Step 2: Set up equations
|
||||
"Sarah is twice as old as Tom was 3 years ago": S = 2(T - 3)
|
||||
"In 5 years, Sarah will be 1.5 times Tom's current age": S + 5 = 1.5T
|
||||
|
||||
Step 3: Solve the system
|
||||
From first equation: S = 2T - 6
|
||||
Substitute into second: (2T - 6) + 5 = 1.5T
|
||||
2T - 1 = 1.5T
|
||||
0.5T = 1
|
||||
T = 2
|
||||
|
||||
Step 4: Find Sarah's age
|
||||
S = 2(2) - 6 = 4 - 6 = -2
|
||||
|
||||
Step 5: Check for validity
|
||||
Negative age is impossible, so I must have made an error.
|
||||
Let me re-examine the problem...
|
||||
|
||||
Actually, let me try again:
|
||||
S = 2(T - 3) = 2T - 6
|
||||
S + 5 = 1.5T
|
||||
Substitute: (2T - 6) + 5 = 1.5T
|
||||
2T - 1 = 1.5T
|
||||
0.5T = 1
|
||||
T = 2 years old
|
||||
S = 2(2 - 3) = 2(-1) = -2
|
||||
|
||||
This gives a negative age. The problem might have no solution with integer ages.
|
||||
|
||||
Now solve:
|
||||
Problem: {math_problem}
|
||||
|
||||
Step 1: Define variables
|
||||
{variable_definitions}
|
||||
|
||||
Step 2: Set up equations
|
||||
{equation_setup}
|
||||
|
||||
Step 3: Solve the system
|
||||
{solution_process}
|
||||
|
||||
Step 4: Verify the solution
|
||||
{verification}
|
||||
|
||||
Final Answer: {answer}
|
||||
```
|
||||
|
||||
### Logical Reasoning Template
|
||||
```
|
||||
Analyze the logical argument and determine if it's valid.
|
||||
|
||||
Example 1:
|
||||
Premise 1: All birds can fly
|
||||
Premise 2: Penguins are birds
|
||||
Conclusion: Therefore, penguins can fly
|
||||
|
||||
Step 1: Analyze the structure
|
||||
This is a syllogism with form:
|
||||
All A are B
|
||||
C is A
|
||||
Therefore, C is B
|
||||
|
||||
Step 2: Evaluate premise validity
|
||||
Premise 1: "All birds can fly" - This is false (penguins, ostriches cannot fly)
|
||||
Premise 2: "Penguins are birds" - This is true
|
||||
|
||||
Step 3: Check logical validity
|
||||
The logical structure is valid, but since Premise 1 is false, the conclusion may not be true
|
||||
|
||||
Step 4: Real-world verification
|
||||
In reality, penguins cannot fly despite being birds
|
||||
|
||||
Conclusion: The argument is logically valid but soundness fails due to false premise
|
||||
|
||||
Example 2:
|
||||
Premise 1: If it rains, then the ground gets wet
|
||||
Premise 2: It is raining
|
||||
Conclusion: Therefore, the ground gets wet
|
||||
|
||||
Step 1: Analyze the structure
|
||||
This is modus ponens:
|
||||
If P, then Q
|
||||
P
|
||||
Therefore, Q
|
||||
|
||||
Step 2: Evaluate premise validity
|
||||
Premise 1: "If it rains, then the ground gets wet" - Generally true
|
||||
Premise 2: "It is raining" - Given as true
|
||||
|
||||
Step 3: Check logical validity
|
||||
Modus ponens is a valid argument form
|
||||
|
||||
Step 4: Verify the conclusion
|
||||
Given the premises, the conclusion follows logically
|
||||
|
||||
Conclusion: The argument is both logically valid and sound
|
||||
|
||||
Now analyze:
|
||||
Argument: {logical_argument}
|
||||
|
||||
Step 1: Analyze the argument structure
|
||||
{structure_analysis}
|
||||
|
||||
Step 2: Evaluate premise validity
|
||||
{premise_evaluation}
|
||||
|
||||
Step 3: Check logical validity
|
||||
{validity_check}
|
||||
|
||||
Step 4: Verify the conclusion
|
||||
{conclusion_verification}
|
||||
|
||||
Final Assessment: {argument_validity_assessment}
|
||||
```
|
||||
|
||||
## Self-Consistency Techniques
|
||||
|
||||
### Multiple Reasoning Paths
|
||||
```
|
||||
I'll solve this problem using three different approaches and see which result is most reliable.
|
||||
|
||||
**Problem**: {complex_problem}
|
||||
|
||||
**Approach 1: Direct Calculation**
|
||||
{first_approach_reasoning}
|
||||
Result 1: {result_1}
|
||||
|
||||
**Approach 2: Logical Deduction**
|
||||
{second_approach_reasoning}
|
||||
Result 2: {result_2}
|
||||
|
||||
**Approach 3: Pattern Recognition**
|
||||
{third_approach_reasoning}
|
||||
Result 3: {result_3}
|
||||
|
||||
**Consistency Analysis:**
|
||||
- Approach 1 and 2 agree: {yes/no}
|
||||
- Approach 1 and 3 agree: {yes/no}
|
||||
- Approach 2 and 3 agree: {yes/no}
|
||||
|
||||
**Final Decision:**
|
||||
{majority_result} appears in {count} out of 3 approaches.
|
||||
Confidence: {high/medium/low}
|
||||
|
||||
Most Likely Answer: {final_answer_with_confidence}
|
||||
```
|
||||
|
||||
### Verification Loop Pattern
|
||||
```
|
||||
Let me solve this step by step and verify each step.
|
||||
|
||||
**Problem**: {problem_description}
|
||||
|
||||
**Step 1: Initial Analysis**
|
||||
{initial_analysis}
|
||||
|
||||
Verification: Does this make sense? {verification_1}
|
||||
|
||||
**Step 2: Solution Development**
|
||||
{solution_development}
|
||||
|
||||
Verification: Does this logically follow from step 1? {verification_2}
|
||||
|
||||
**Step 3: Result Calculation**
|
||||
{result_calculation}
|
||||
|
||||
Verification: Does this answer the original question? {verification_3}
|
||||
|
||||
**Step 4: Cross-Check**
|
||||
Let me try a different approach to confirm:
|
||||
{alternative_approach}
|
||||
|
||||
Results comparison: {comparison_analysis}
|
||||
|
||||
**Final Answer:**
|
||||
{conclusion_with_verification_status}
|
||||
```
|
||||
|
||||
## Specialized CoT Patterns
|
||||
|
||||
### Code Debugging CoT
|
||||
```
|
||||
Debug the following code by analyzing it step by step.
|
||||
|
||||
**Code:**
|
||||
{code_snippet}
|
||||
|
||||
**Step 1: Understand the Code's Purpose**
|
||||
{purpose_analysis}
|
||||
|
||||
**Step 2: Identify Expected Behavior**
|
||||
{expected_behavior}
|
||||
|
||||
**Step 3: Trace the Execution**
|
||||
{execution_trace}
|
||||
|
||||
**Step 4: Find the Error**
|
||||
{error_identification}
|
||||
|
||||
**Step 5: Propose Fix**
|
||||
{fix_proposal}
|
||||
|
||||
**Step 6: Verify the Fix**
|
||||
{fix_verification}
|
||||
|
||||
**Fixed Code:**
|
||||
{corrected_code}
|
||||
```
|
||||
|
||||
### Data Analysis CoT
|
||||
```
|
||||
Analyze this data systematically to draw meaningful conclusions.
|
||||
|
||||
**Data:**
|
||||
{dataset}
|
||||
|
||||
**Step 1: Understand the Data Structure**
|
||||
{data_structure_analysis}
|
||||
|
||||
**Step 2: Identify Patterns and Trends**
|
||||
{pattern_identification}
|
||||
|
||||
**Step 3: Calculate Key Metrics**
|
||||
{metrics_calculation}
|
||||
|
||||
**Step 4: Compare with Benchmarks**
|
||||
{benchmark_comparison}
|
||||
|
||||
**Step 5: Formulate Insights**
|
||||
{insight_generation}
|
||||
|
||||
**Step 6: Validate Conclusions**
|
||||
{conclusion_validation}
|
||||
|
||||
**Key Findings:**
|
||||
{summary_of_insights}
|
||||
```
|
||||
|
||||
### Creative Problem Solving CoT
|
||||
```
|
||||
Generate creative solutions to this challenging problem.
|
||||
|
||||
**Problem:**
|
||||
{creative_problem}
|
||||
|
||||
**Step 1: Reframe the Problem**
|
||||
{problem_reframing}
|
||||
|
||||
**Step 2: Brainstorm Multiple Angles**
|
||||
- Technical approach: {technical_ideas}
|
||||
- Business approach: {business_ideas}
|
||||
- User experience approach: {ux_ideas}
|
||||
- Unconventional approach: {unconventional_ideas}
|
||||
|
||||
**Step 3: Evaluate Each Approach**
|
||||
{approach_evaluation}
|
||||
|
||||
**Step 4: Synthesize Best Elements**
|
||||
{synthesis_process}
|
||||
|
||||
**Step 5: Develop Final Solution**
|
||||
{solution_development}
|
||||
|
||||
**Step 6: Test for Feasibility**
|
||||
{feasibility_testing}
|
||||
|
||||
**Recommended Solution:**
|
||||
{final_creative_solution}
|
||||
```
|
||||
|
||||
## Implementation Guidelines
|
||||
|
||||
### When to Use Chain-of-Thought
|
||||
- **Multi-step problems**: Tasks requiring sequential reasoning
|
||||
- **Complex calculations**: Mathematical or logical derivations
|
||||
- **Problem decomposition**: Tasks that benefit from breaking down
|
||||
- **Verification needs**: When accuracy is critical
|
||||
- **Educational contexts**: When showing reasoning is valuable
|
||||
|
||||
### CoT Effectiveness Factors
|
||||
- **Problem complexity**: Higher benefit for complex problems
|
||||
- **Task type**: Mathematical, logical, and analytical tasks benefit most
|
||||
- **Model capability**: Newer models handle CoT more effectively
|
||||
- **Context window**: Ensure sufficient space for reasoning steps
|
||||
- **Output requirements**: Detailed explanations benefit from CoT
|
||||
|
||||
### Common Pitfalls to Avoid
|
||||
- **Over-explaining simple steps**: Keep proportional detail
|
||||
- **Circular reasoning**: Ensure logical progression
|
||||
- **Missing verification**: Always include validation steps
|
||||
- **Inconsistent confidence**: Use realistic confidence scoring
|
||||
- **Premature conclusions**: Don't jump to answers without full reasoning
|
||||
|
||||
## Integration with Other Techniques
|
||||
|
||||
### CoT + Few-Shot Learning
|
||||
- Include reasoning traces in examples
|
||||
- Show step-by-step problem-solving demonstrations
|
||||
- Teach verification and self-checking patterns
|
||||
|
||||
### CoT + Template Systems
|
||||
- Embed CoT patterns within structured templates
|
||||
- Use conditional CoT based on problem complexity
|
||||
- Implement adaptive reasoning depth
|
||||
|
||||
### CoT + Prompt Optimization
|
||||
- Test different CoT formulations
|
||||
- Optimize reasoning step granularity
|
||||
- Balance detail with efficiency
|
||||
|
||||
This framework provides comprehensive patterns for implementing effective chain-of-thought reasoning across diverse problem types and applications.
|
||||
273
skills/ai/prompt-engineering/references/few-shot-patterns.md
Normal file
273
skills/ai/prompt-engineering/references/few-shot-patterns.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Few-Shot Learning Patterns
|
||||
|
||||
This reference provides comprehensive frameworks for implementing effective few-shot learning strategies that maximize model performance within context window constraints.
|
||||
|
||||
## Core Principles
|
||||
|
||||
### Example Selection Strategy
|
||||
|
||||
#### Semantic Similarity Selection
|
||||
- Use embedding similarity to find examples closest to target input
|
||||
- Cluster similar examples to avoid redundancy
|
||||
- Select diverse representatives from different semantic regions
|
||||
- Prioritize examples that cover key variations in problem space
|
||||
|
||||
#### Diversity Sampling Approach
|
||||
- Ensure coverage of different input types and patterns
|
||||
- Include boundary cases and edge conditions
|
||||
- Balance simple and complex examples
|
||||
- Select examples that demonstrate different solution strategies
|
||||
|
||||
#### Progressive Complexity Ordering
|
||||
- Start with simplest, most straightforward examples
|
||||
- Progress to increasingly complex scenarios
|
||||
- Include challenging edge cases last
|
||||
- Use this ordering to build understanding incrementally
|
||||
|
||||
## Example Templates
|
||||
|
||||
### Classification Tasks
|
||||
|
||||
#### Binary Classification Template
|
||||
```
|
||||
Classify if the text expresses positive or negative sentiment.
|
||||
|
||||
Example 1:
|
||||
Text: "I love this product! It works exactly as advertised and exceeded my expectations."
|
||||
Sentiment: Positive
|
||||
Reasoning: Contains enthusiastic language, positive adjectives, and satisfaction indicators
|
||||
|
||||
Example 2:
|
||||
Text: "The customer service was terrible and the product broke after one day of use."
|
||||
Sentiment: Negative
|
||||
Reasoning: Contains negative adjectives, complaint language, and dissatisfaction indicators
|
||||
|
||||
Example 3:
|
||||
Text: "It's okay, nothing special but does the basic job."
|
||||
Sentiment: Negative
|
||||
Reasoning: Contains lukewarm language, lack of enthusiasm, minimal positive elements
|
||||
|
||||
Now classify:
|
||||
Text: {input_text}
|
||||
Sentiment:
|
||||
Reasoning:
|
||||
```
|
||||
|
||||
#### Multi-Class Classification Template
|
||||
```
|
||||
Categorize the customer inquiry into one of: Technical Support, Billing, Sales, or General.
|
||||
|
||||
Example 1:
|
||||
Inquiry: "My account was charged twice for the same subscription this month"
|
||||
Category: Billing
|
||||
Key indicators: "charged twice", "subscription", "account", financial terms
|
||||
|
||||
Example 2:
|
||||
Inquiry: "The app keeps crashing when I try to upload files larger than 10MB"
|
||||
Category: Technical Support
|
||||
Key indicators: "crashing", "upload files", "technical issue", "error report"
|
||||
|
||||
Example 3:
|
||||
Inquiry: "What are your pricing plans for enterprise customers?"
|
||||
Category: Sales
|
||||
Key indicators: "pricing plans", "enterprise", business inquiry, sales question
|
||||
|
||||
Now categorize:
|
||||
Inquiry: {inquiry_text}
|
||||
Category:
|
||||
Key indicators:
|
||||
```
|
||||
|
||||
### Transformation Tasks
|
||||
|
||||
#### Text Transformation Template
|
||||
```
|
||||
Convert formal business text into casual, friendly language.
|
||||
|
||||
Example 1:
|
||||
Formal: "We regret to inform you that your request cannot be processed at this time due to insufficient documentation."
|
||||
Casual: "Sorry, but we can't process your request right now because some documents are missing."
|
||||
|
||||
Example 2:
|
||||
Formal: "The aforementioned individual has demonstrated exceptional proficiency in the designated responsibilities."
|
||||
Casual: "They've done a great job with their tasks and really know what they're doing."
|
||||
|
||||
Example 3:
|
||||
Formal: "Please be advised that the scheduled meeting has been postponed pending further notice."
|
||||
Casual: "Hey, just letting you know that we've put off the meeting for now and will let you know when it's rescheduled."
|
||||
|
||||
Now convert:
|
||||
Formal: {formal_text}
|
||||
Casual:
|
||||
```
|
||||
|
||||
#### Data Extraction Template
|
||||
```
|
||||
Extract key information from the job posting into structured format.
|
||||
|
||||
Example 1:
|
||||
Job Posting: "We are seeking a Senior Software Engineer with 5+ years of experience in Python and cloud technologies. This is a remote position offering $120k-$150k salary plus equity."
|
||||
|
||||
Extracted:
|
||||
- Position: Senior Software Engineer
|
||||
- Experience Required: 5+ years
|
||||
- Skills: Python, cloud technologies
|
||||
- Location: Remote
|
||||
- Salary: $120k-$150k plus equity
|
||||
|
||||
Example 2:
|
||||
Job Posting: "Marketing Manager needed for growing startup. Must have 3 years experience in digital marketing, social media management, and content creation. San Francisco office, competitive compensation."
|
||||
|
||||
Extracted:
|
||||
- Position: Marketing Manager
|
||||
- Experience Required: 3 years
|
||||
- Skills: Digital marketing, social media management, content creation
|
||||
- Location: San Francisco
|
||||
- Salary: Competitive compensation
|
||||
|
||||
Now extract:
|
||||
Job Posting: {job_posting_text}
|
||||
Extracted:
|
||||
```
|
||||
|
||||
### Generation Tasks
|
||||
|
||||
#### Creative Writing Template
|
||||
```
|
||||
Generate compelling product descriptions following the shown patterns.
|
||||
|
||||
Example 1:
|
||||
Product: Wireless headphones with noise cancellation
|
||||
Description: "Immerse yourself in crystal-clear audio with our premium wireless headphones. Advanced noise cancellation technology blocks out distractions while 30-hour battery life keeps you connected all day long."
|
||||
|
||||
Example 2:
|
||||
Product: Smart home security camera
|
||||
Description: "Protect what matters most with intelligent monitoring that alerts you to activity instantly. AI-powered detection distinguishes between people, pets, and vehicles for truly smart security."
|
||||
|
||||
Example 3:
|
||||
Product: Portable espresso maker
|
||||
Description: "Barista-quality espresso anywhere, anytime. Compact design meets professional-grade extraction in this revolutionary portable machine that delivers perfect shots in under 30 seconds."
|
||||
|
||||
Now generate:
|
||||
Product: {product_description}
|
||||
Description:
|
||||
```
|
||||
|
||||
### Error Correction Patterns
|
||||
|
||||
#### Error Detection and Correction Template
|
||||
```
|
||||
Identify and correct errors in the given text.
|
||||
|
||||
Example 1:
|
||||
Text with errors: "Their going to the park to play there new game with they're friends."
|
||||
Correction: "They're going to the park to play their new game with their friends."
|
||||
Errors fixed: "Their → They're", "there → their", "they're → their"
|
||||
|
||||
Example 2:
|
||||
Text with errors: "The company's new policy effects every employee and there morale."
|
||||
Correction: "The company's new policy affects every employee and their morale."
|
||||
Errors fixed: "effects → affects", "there → their"
|
||||
|
||||
Example 3:
|
||||
Text with errors: "Its important to review you're work carefully before submiting."
|
||||
Correction: "It's important to review your work carefully before submitting."
|
||||
Errors fixed: "Its → It's", "you're → your", "submiting → submitting"
|
||||
|
||||
Now correct:
|
||||
Text with errors: {text_with_errors}
|
||||
Correction:
|
||||
Errors fixed:
|
||||
```
|
||||
|
||||
## Advanced Strategies
|
||||
|
||||
### Dynamic Example Selection
|
||||
|
||||
#### Context-Aware Selection
|
||||
```python
|
||||
def select_examples(input_text, example_database, max_examples=3):
|
||||
"""
|
||||
Select most relevant examples based on semantic similarity and diversity.
|
||||
"""
|
||||
# 1. Calculate similarity scores
|
||||
similarities = calculate_similarity(input_text, example_database)
|
||||
|
||||
# 2. Sort by similarity
|
||||
sorted_examples = sort_by_similarity(similarities)
|
||||
|
||||
# 3. Apply diversity sampling
|
||||
diverse_examples = diversity_sampling(sorted_examples, max_examples)
|
||||
|
||||
# 4. Order by complexity
|
||||
final_examples = order_by_complexity(diverse_examples)
|
||||
|
||||
return final_examples
|
||||
```
|
||||
|
||||
#### Adaptive Example Count
|
||||
```python
|
||||
def determine_example_count(input_complexity, context_limit):
|
||||
"""
|
||||
Adjust example count based on input complexity and available context.
|
||||
"""
|
||||
base_count = 3
|
||||
|
||||
# Complex inputs benefit from more examples
|
||||
if input_complexity > 0.8:
|
||||
return min(base_count + 2, context_limit)
|
||||
elif input_complexity > 0.5:
|
||||
return base_count + 1
|
||||
else:
|
||||
return max(base_count - 1, 2)
|
||||
```
|
||||
|
||||
### Quality Metrics for Examples
|
||||
|
||||
#### Example Effectiveness Scoring
|
||||
```python
|
||||
def score_example_effectiveness(example, test_cases):
|
||||
"""
|
||||
Score how effectively an example teaches the desired pattern.
|
||||
"""
|
||||
metrics = {
|
||||
'coverage': measure_pattern_coverage(example),
|
||||
'clarity': measure_instructional_clarity(example),
|
||||
'uniqueness': measure_uniqueness_from_other_examples(example),
|
||||
'difficulty': measure_appropriateness_difficulty(example)
|
||||
}
|
||||
|
||||
return weighted_average(metrics, weights=[0.3, 0.3, 0.2, 0.2])
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Example Quality Guidelines
|
||||
- **Clarity**: Examples should clearly demonstrate the desired pattern
|
||||
- **Accuracy**: Input-output pairs must be correct and consistent
|
||||
- **Relevance**: Examples should be representative of target task
|
||||
- **Diversity**: Include variation in input types and complexity levels
|
||||
- **Completeness**: Cover edge cases and boundary conditions
|
||||
|
||||
### Context Management
|
||||
- **Token Efficiency**: Optimize example length while maintaining clarity
|
||||
- **Progressive Disclosure**: Start simple, increase complexity gradually
|
||||
- **Redundancy Elimination**: Remove overlapping or duplicate examples
|
||||
- **Compression**: Use concise representations where possible
|
||||
|
||||
### Common Pitfalls to Avoid
|
||||
- **Overfitting**: Don't include too many examples from same pattern
|
||||
- **Under-representation**: Ensure coverage of important variations
|
||||
- **Ambiguity**: Examples should have clear, unambiguous solutions
|
||||
- **Context Overflow**: Balance example count with window limitations
|
||||
- **Poor Ordering**: Place examples in logical progression order
|
||||
|
||||
## Integration with Other Patterns
|
||||
|
||||
Few-shot learning combines effectively with:
|
||||
- **Chain-of-Thought**: Add reasoning steps to examples
|
||||
- **Template Systems**: Use few-shot within structured templates
|
||||
- **Prompt Optimization**: Test different example selections
|
||||
- **System Prompts**: Establish few-shot learning expectations in system prompts
|
||||
|
||||
This framework provides the foundation for implementing effective few-shot learning across diverse tasks and model types.
|
||||
@@ -0,0 +1,488 @@
|
||||
# Prompt Optimization Frameworks
|
||||
|
||||
This reference provides systematic methodologies for iteratively improving prompt performance through structured testing, measurement, and refinement processes.
|
||||
|
||||
## Optimization Process Overview
|
||||
|
||||
### Iterative Improvement Cycle
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Baseline Measurement] --> B[Hypothesis Generation]
|
||||
B --> C[Controlled Test]
|
||||
C --> D[Performance Analysis]
|
||||
D --> E[Statistical Validation]
|
||||
E --> F[Implementation Decision]
|
||||
F --> G[Monitor Impact]
|
||||
G --> H[Learn & Iterate]
|
||||
H --> B
|
||||
```
|
||||
|
||||
### Core Optimization Principles
|
||||
- **Single Variable Testing**: Change one element at a time for accurate attribution
|
||||
- **Measurable Metrics**: Define quantitative success criteria
|
||||
- **Statistical Significance**: Use proper sample sizes and validation methods
|
||||
- **Controlled Environment**: Test conditions must be consistent
|
||||
- **Baseline Comparison**: Always measure against established baseline
|
||||
|
||||
## Performance Metrics Framework
|
||||
|
||||
### Primary Metrics
|
||||
|
||||
#### Task Success Rate
|
||||
```python
|
||||
def calculate_success_rate(results, expected_outputs):
|
||||
"""
|
||||
Measure percentage of tasks completed correctly.
|
||||
"""
|
||||
correct = sum(1 for result, expected in zip(results, expected_outputs)
|
||||
if result == expected)
|
||||
return (correct / len(results)) * 100
|
||||
```
|
||||
|
||||
#### Response Consistency
|
||||
```python
|
||||
def measure_consistency(prompt, test_cases, num_runs=5):
|
||||
"""
|
||||
Measure response stability across multiple runs.
|
||||
"""
|
||||
responses = {}
|
||||
for test_case in test_cases:
|
||||
test_responses = []
|
||||
for _ in range(num_runs):
|
||||
response = execute_prompt(prompt, test_case)
|
||||
test_responses.append(response)
|
||||
|
||||
# Calculate similarity score for consistency
|
||||
consistency = calculate_similarity(test_responses)
|
||||
responses[test_case] = consistency
|
||||
|
||||
return sum(responses.values()) / len(responses)
|
||||
```
|
||||
|
||||
#### Token Efficiency
|
||||
```python
|
||||
def calculate_token_efficiency(prompt, test_cases):
|
||||
"""
|
||||
Measure token usage per successful task completion.
|
||||
"""
|
||||
total_tokens = 0
|
||||
successful_tasks = 0
|
||||
|
||||
for test_case in test_cases:
|
||||
response = execute_prompt_with_metrics(prompt, test_case)
|
||||
total_tokens += response.token_count
|
||||
if response.is_successful:
|
||||
successful_tasks += 1
|
||||
|
||||
return total_tokens / successful_tasks if successful_tasks > 0 else float('inf')
|
||||
```
|
||||
|
||||
#### Response Latency
|
||||
```python
|
||||
def measure_response_time(prompt, test_cases):
|
||||
"""
|
||||
Measure average response time.
|
||||
"""
|
||||
times = []
|
||||
for test_case in test_cases:
|
||||
start_time = time.time()
|
||||
execute_prompt(prompt, test_case)
|
||||
end_time = time.time()
|
||||
times.append(end_time - start_time)
|
||||
|
||||
return sum(times) / len(times)
|
||||
```
|
||||
|
||||
### Secondary Metrics
|
||||
|
||||
#### Output Quality Score
|
||||
```python
|
||||
def assess_output_quality(response, criteria):
|
||||
"""
|
||||
Multi-dimensional quality assessment.
|
||||
"""
|
||||
scores = {
|
||||
'accuracy': measure_accuracy(response),
|
||||
'completeness': measure_completeness(response),
|
||||
'coherence': measure_coherence(response),
|
||||
'relevance': measure_relevance(response),
|
||||
'format_compliance': measure_format_compliance(response)
|
||||
}
|
||||
|
||||
weights = [0.3, 0.2, 0.2, 0.2, 0.1]
|
||||
return sum(score * weight for score, weight in zip(scores.values(), weights))
|
||||
```
|
||||
|
||||
#### Safety Compliance
|
||||
```python
|
||||
def check_safety_compliance(response):
|
||||
"""
|
||||
Measure adherence to safety guidelines.
|
||||
"""
|
||||
violations = []
|
||||
|
||||
# Check for various safety issues
|
||||
if contains_harmful_content(response):
|
||||
violations.append('harmful_content')
|
||||
if has_bias(response):
|
||||
violations.append('bias')
|
||||
if violates_privacy(response):
|
||||
violations.append('privacy_violation')
|
||||
|
||||
safety_score = max(0, 100 - len(violations) * 25)
|
||||
return safety_score, violations
|
||||
```
|
||||
|
||||
## A/B Testing Methodology
|
||||
|
||||
### Controlled Test Design
|
||||
```python
|
||||
def design_ab_test(baseline_prompt, variant_prompt, test_cases):
|
||||
"""
|
||||
Design controlled A/B test with proper statistical power.
|
||||
"""
|
||||
# Calculate required sample size
|
||||
effect_size = estimate_effect_size(baseline_prompt, variant_prompt)
|
||||
sample_size = calculate_sample_size(effect_size, power=0.8, alpha=0.05)
|
||||
|
||||
# Random assignment
|
||||
randomized_cases = random.sample(test_cases, sample_size)
|
||||
split_point = len(randomized_cases) // 2
|
||||
|
||||
group_a = randomized_cases[:split_point]
|
||||
group_b = randomized_cases[split_point:]
|
||||
|
||||
return {
|
||||
'baseline_group': group_a,
|
||||
'variant_group': group_b,
|
||||
'sample_size': sample_size,
|
||||
'statistical_power': 0.8,
|
||||
'significance_level': 0.05
|
||||
}
|
||||
```
|
||||
|
||||
### Statistical Analysis
|
||||
```python
|
||||
def analyze_ab_results(baseline_results, variant_results):
|
||||
"""
|
||||
Perform statistical analysis of A/B test results.
|
||||
"""
|
||||
# Calculate means and standard deviations
|
||||
baseline_mean = np.mean(baseline_results)
|
||||
variant_mean = np.mean(variant_results)
|
||||
baseline_std = np.std(baseline_results)
|
||||
variant_std = np.std(variant_results)
|
||||
|
||||
# Perform t-test
|
||||
t_statistic, p_value = stats.ttest_ind(baseline_results, variant_results)
|
||||
|
||||
# Calculate effect size (Cohen's d)
|
||||
pooled_std = np.sqrt(((len(baseline_results) - 1) * baseline_std**2 +
|
||||
(len(variant_results) - 1) * variant_std**2) /
|
||||
(len(baseline_results) + len(variant_results) - 2))
|
||||
cohens_d = (variant_mean - baseline_mean) / pooled_std
|
||||
|
||||
return {
|
||||
'baseline_mean': baseline_mean,
|
||||
'variant_mean': variant_mean,
|
||||
'improvement': ((variant_mean - baseline_mean) / baseline_mean) * 100,
|
||||
'p_value': p_value,
|
||||
'statistical_significance': p_value < 0.05,
|
||||
'effect_size': cohens_d,
|
||||
'recommendation': 'implement_variant' if p_value < 0.05 and cohens_d > 0.2 else 'keep_baseline'
|
||||
}
|
||||
```
|
||||
|
||||
## Optimization Strategies
|
||||
|
||||
### Strategy 1: Progressive Enhancement
|
||||
|
||||
#### Stepwise Improvement Process
|
||||
```python
|
||||
def progressive_optimization(base_prompt, test_cases, max_iterations=10):
|
||||
"""
|
||||
Incrementally improve prompt through systematic testing.
|
||||
"""
|
||||
current_prompt = base_prompt
|
||||
current_performance = evaluate_prompt(current_prompt, test_cases)
|
||||
optimization_history = []
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
# Generate improvement hypotheses
|
||||
hypotheses = generate_improvement_hypotheses(current_prompt, current_performance)
|
||||
|
||||
best_improvement = None
|
||||
best_performance = current_performance
|
||||
|
||||
for hypothesis in hypotheses:
|
||||
# Test hypothesis
|
||||
test_prompt = apply_hypothesis(current_prompt, hypothesis)
|
||||
test_performance = evaluate_prompt(test_prompt, test_cases)
|
||||
|
||||
# Validate improvement
|
||||
if is_statistically_significant(current_performance, test_performance):
|
||||
if test_performance.overall_score > best_performance.overall_score:
|
||||
best_improvement = hypothesis
|
||||
best_performance = test_performance
|
||||
|
||||
# Apply best improvement if found
|
||||
if best_improvement:
|
||||
current_prompt = apply_hypothesis(current_prompt, best_improvement)
|
||||
optimization_history.append({
|
||||
'iteration': iteration,
|
||||
'hypothesis': best_improvement,
|
||||
'performance_before': current_performance,
|
||||
'performance_after': best_performance,
|
||||
'improvement': best_performance.overall_score - current_performance.overall_score
|
||||
})
|
||||
current_performance = best_performance
|
||||
else:
|
||||
break # No further improvements found
|
||||
|
||||
return current_prompt, optimization_history
|
||||
```
|
||||
|
||||
### Strategy 2: Multi-Objective Optimization
|
||||
|
||||
#### Pareto Optimization Framework
|
||||
```python
|
||||
def multi_objective_optimization(prompt_variants, objectives):
|
||||
"""
|
||||
Optimize for multiple competing objectives using Pareto efficiency.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for variant in prompt_variants:
|
||||
scores = {}
|
||||
for objective in objectives:
|
||||
scores[objective] = evaluate_objective(variant, objective)
|
||||
|
||||
results.append({
|
||||
'prompt': variant,
|
||||
'scores': scores,
|
||||
'dominates': []
|
||||
})
|
||||
|
||||
# Find Pareto optimal solutions
|
||||
pareto_optimal = []
|
||||
for i, result_i in enumerate(results):
|
||||
is_dominated = False
|
||||
for j, result_j in enumerate(results):
|
||||
if i != j and dominates(result_j, result_i):
|
||||
is_dominated = True
|
||||
break
|
||||
|
||||
if not is_dominated:
|
||||
pareto_optimal.append(result_i)
|
||||
|
||||
return pareto_optimal
|
||||
|
||||
def dominates(result_a, result_b):
|
||||
"""
|
||||
Check if result_a dominates result_b in all objectives.
|
||||
"""
|
||||
return all(result_a['scores'][obj] >= result_b['scores'][obj]
|
||||
for obj in result_a['scores'])
|
||||
```
|
||||
|
||||
### Strategy 3: Adaptive Testing
|
||||
|
||||
#### Dynamic Test Allocation
|
||||
```python
|
||||
def adaptive_testing(prompt_variants, initial_budget):
|
||||
"""
|
||||
Dynamically allocate testing budget to promising variants.
|
||||
"""
|
||||
# Initial exploration phase
|
||||
exploration_results = {}
|
||||
budget分配 = initial_budget // len(prompt_variants)
|
||||
|
||||
for variant in prompt_variants:
|
||||
exploration_results[variant] = test_prompt(variant, budget分配)
|
||||
|
||||
# Exploitation phase - allocate more budget to promising variants
|
||||
total_budget_spent = len(prompt_variants) * budget分配
|
||||
remaining_budget = initial_budget - total_budget_spent
|
||||
|
||||
# Sort by performance
|
||||
sorted_variants = sorted(exploration_results.items(),
|
||||
key=lambda x: x[1].overall_score, reverse=True)
|
||||
|
||||
# Allocate remaining budget proportionally to performance
|
||||
final_results = {}
|
||||
for i, (variant, initial_result) in enumerate(sorted_variants):
|
||||
if remaining_budget > 0:
|
||||
additional_budget = max(1, remaining_budget // (len(sorted_variants) - i))
|
||||
final_results[variant] = test_prompt(variant, additional_budget)
|
||||
remaining_budget -= additional_budget
|
||||
else:
|
||||
final_results[variant] = initial_result
|
||||
|
||||
return final_results
|
||||
```
|
||||
|
||||
## Optimization Hypotheses
|
||||
|
||||
### Common Optimization Areas
|
||||
|
||||
#### Instruction Clarity
|
||||
```python
|
||||
instruction_clarity_hypotheses = [
|
||||
"Add numbered steps to instructions",
|
||||
"Include specific output format examples",
|
||||
"Clarify role and expertise level",
|
||||
"Add context and background information",
|
||||
"Specify constraints and boundaries",
|
||||
"Include success criteria and evaluation standards"
|
||||
]
|
||||
```
|
||||
|
||||
#### Example Quality
|
||||
```python
|
||||
example_optimization_hypotheses = [
|
||||
"Increase number of examples from 3 to 5",
|
||||
"Add edge case examples",
|
||||
"Reorder examples by complexity",
|
||||
"Include negative examples",
|
||||
"Add reasoning traces to examples",
|
||||
"Improve example diversity and coverage"
|
||||
]
|
||||
```
|
||||
|
||||
#### Structure Optimization
|
||||
```python
|
||||
structure_hypotheses = [
|
||||
"Add clear section headings",
|
||||
"Reorganize content flow",
|
||||
"Include summary at the beginning",
|
||||
"Add checklist for verification",
|
||||
"Separate instructions from examples",
|
||||
"Add troubleshooting section"
|
||||
]
|
||||
```
|
||||
|
||||
#### Model-Specific Optimization
|
||||
```python
|
||||
model_specific_hypotheses = {
|
||||
'claude': [
|
||||
"Use XML tags for structure",
|
||||
"Add <thinking> sections for reasoning",
|
||||
"Include constitutional AI principles",
|
||||
"Use system message format",
|
||||
"Add safety guidelines and constraints"
|
||||
],
|
||||
'gpt-4': [
|
||||
"Use numbered sections with ### headers",
|
||||
"Include JSON format specifications",
|
||||
"Add function calling patterns",
|
||||
"Use bullet points for clarity",
|
||||
"Include error handling instructions"
|
||||
],
|
||||
'gemini': [
|
||||
"Use bold headers with ** formatting",
|
||||
"Include step-by-step process descriptions",
|
||||
"Add validation checkpoints",
|
||||
"Use conversational tone",
|
||||
"Include confidence scoring"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Continuous Monitoring
|
||||
|
||||
### Production Performance Tracking
|
||||
```python
|
||||
def setup_monitoring(prompt, alert_thresholds):
|
||||
"""
|
||||
Set up continuous monitoring for deployed prompts.
|
||||
"""
|
||||
monitors = {
|
||||
'success_rate': MetricMonitor('success_rate', alert_thresholds['success_rate']),
|
||||
'response_time': MetricMonitor('response_time', alert_thresholds['response_time']),
|
||||
'token_cost': MetricMonitor('token_cost', alert_thresholds['token_cost']),
|
||||
'safety_score': MetricMonitor('safety_score', alert_thresholds['safety_score'])
|
||||
}
|
||||
|
||||
def monitor_performance():
|
||||
recent_data = collect_recent_performance(prompt)
|
||||
alerts = []
|
||||
|
||||
for metric_name, monitor in monitors.items():
|
||||
if metric_name in recent_data:
|
||||
alert = monitor.check(recent_data[metric_name])
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
|
||||
return alerts
|
||||
|
||||
return monitor_performance
|
||||
```
|
||||
|
||||
### Automated Rollback System
|
||||
```python
|
||||
def automated_rollback_system(prompts, monitoring_data):
|
||||
"""
|
||||
Automatically rollback to previous version if performance degrades.
|
||||
"""
|
||||
def check_and_rollback(current_prompt, baseline_prompt):
|
||||
current_metrics = monitoring_data.get_metrics(current_prompt)
|
||||
baseline_metrics = monitoring_data.get_metrics(baseline_prompt)
|
||||
|
||||
# Check if performance degradation exceeds threshold
|
||||
degradation_threshold = 0.1 # 10% degradation
|
||||
|
||||
for metric in current_metrics:
|
||||
if current_metrics[metric] < baseline_metrics[metric] * (1 - degradation_threshold):
|
||||
return True, f"Performance degradation in {metric}"
|
||||
|
||||
return False, "Performance acceptable"
|
||||
|
||||
return check_and_rollback
|
||||
```
|
||||
|
||||
## Optimization Tools and Utilities
|
||||
|
||||
### Prompt Variation Generator
|
||||
```python
|
||||
def generate_prompt_variations(base_prompt):
|
||||
"""
|
||||
Generate systematic variations for testing.
|
||||
"""
|
||||
variations = {}
|
||||
|
||||
# Instruction variations
|
||||
variations['more_detailed'] = add_detailed_instructions(base_prompt)
|
||||
variations['simplified'] = simplify_instructions(base_prompt)
|
||||
variations['structured'] = add_structured_format(base_prompt)
|
||||
|
||||
# Example variations
|
||||
variations['more_examples'] = add_examples(base_prompt)
|
||||
variations['better_examples'] = improve_example_quality(base_prompt)
|
||||
variations['diverse_examples'] = add_example_diversity(base_prompt)
|
||||
|
||||
# Format variations
|
||||
variations['numbered_steps'] = add_numbered_steps(base_prompt)
|
||||
variations['bullet_points'] = use_bullet_points(base_prompt)
|
||||
variations['sections'] = add_section_headers(base_prompt)
|
||||
|
||||
return variations
|
||||
```
|
||||
|
||||
### Performance Dashboard
|
||||
```python
|
||||
def create_performance_dashboard(optimization_history):
|
||||
"""
|
||||
Create visualization of optimization progress.
|
||||
"""
|
||||
# Generate performance metrics over time
|
||||
metrics_over_time = {
|
||||
'iterations': [h['iteration'] for h in optimization_history],
|
||||
'success_rates': [h['performance_after'].success_rate for h in optimization_history],
|
||||
'token_efficiency': [h['performance_after'].token_efficiency for h in optimization_history],
|
||||
'response_times': [h['performance_after'].response_time for h in optimization_history]
|
||||
}
|
||||
|
||||
return PerformanceDashboard(metrics_over_time)
|
||||
```
|
||||
|
||||
This comprehensive framework provides systematic methodologies for continuous prompt improvement through data-driven optimization and rigorous testing processes.
|
||||
494
skills/ai/prompt-engineering/references/system-prompt-design.md
Normal file
494
skills/ai/prompt-engineering/references/system-prompt-design.md
Normal file
@@ -0,0 +1,494 @@
|
||||
# System Prompt Design
|
||||
|
||||
This reference provides comprehensive frameworks for designing effective system prompts that establish consistent model behavior, define clear boundaries, and ensure reliable performance across diverse applications.
|
||||
|
||||
## System Prompt Architecture
|
||||
|
||||
### Core Components Structure
|
||||
```
|
||||
1. Role Definition & Expertise
|
||||
2. Behavioral Guidelines & Constraints
|
||||
3. Interaction Protocols
|
||||
4. Output Format Specifications
|
||||
5. Safety & Ethical Guidelines
|
||||
6. Context & Background Information
|
||||
7. Quality Standards & Verification
|
||||
8. Error Handling & Uncertainty Protocols
|
||||
```
|
||||
|
||||
## Component Design Patterns
|
||||
|
||||
### 1. Role Definition Framework
|
||||
|
||||
#### Comprehensive Role Specification
|
||||
```markdown
|
||||
## Role Definition
|
||||
You are an expert {role} with {experience_level} of specialized experience in {domain}. Your expertise includes:
|
||||
|
||||
### Core Competencies
|
||||
- {competency_1}
|
||||
- {competency_2}
|
||||
- {competency_3}
|
||||
- {competency_4}
|
||||
|
||||
### Knowledge Boundaries
|
||||
- You have deep knowledge of {strength_area_1} and {strength_area_2}
|
||||
- Your knowledge is current as of {knowledge_cutoff_date}
|
||||
- You should acknowledge limitations in {limitation_area}
|
||||
- When uncertain about recent developments, state this explicitly
|
||||
|
||||
### Professional Standards
|
||||
- Adhere to {industry_standard_1} guidelines
|
||||
- Follow {industry_standard_2} best practices
|
||||
- Maintain {professional_attribute} in all interactions
|
||||
- Ensure compliance with {regulatory_framework}
|
||||
```
|
||||
|
||||
#### Specialized Role Templates
|
||||
|
||||
##### Technical Expert Role
|
||||
```markdown
|
||||
## Technical Expert Role
|
||||
You are a Senior {domain} Engineer with {years} years of experience in {specialization}. Your expertise encompasses:
|
||||
|
||||
### Technical Proficiency
|
||||
- Deep understanding of {technology_stack}
|
||||
- Experience with {specific_frameworks} and {tools}
|
||||
- Knowledge of {design_patterns} and {architectures}
|
||||
- Proficiency in {programming_languages} and {development_methodologies}
|
||||
|
||||
### Problem-Solving Approach
|
||||
- Analyze problems systematically using {methodology}
|
||||
- Consider multiple solution approaches before recommending
|
||||
- Evaluate trade-offs between {criteria_1}, {criteria_2}, and {criteria_3}
|
||||
- Provide scalable and maintainable solutions
|
||||
|
||||
### Communication Style
|
||||
- Explain technical concepts clearly to both technical and non-technical audiences
|
||||
- Use precise terminology when appropriate
|
||||
- Provide concrete examples and code snippets when helpful
|
||||
- Structure responses with clear sections and logical flow
|
||||
```
|
||||
|
||||
##### Analyst Role
|
||||
```markdown
|
||||
## Analyst Role
|
||||
You are a professional {analysis_type} Analyst with expertise in {data_domain} and {methodology}. Your analytical approach includes:
|
||||
|
||||
### Analytical Framework
|
||||
- Apply {analytical_methodology} for systematic analysis
|
||||
- Use {statistical_techniques} for data interpretation
|
||||
- Consider {contextual_factors} in your analysis
|
||||
- Validate findings through {verification_methods}
|
||||
|
||||
### Critical Thinking Process
|
||||
- Question assumptions and identify potential biases
|
||||
- Evaluate evidence quality and source reliability
|
||||
- Consider alternative explanations and perspectives
|
||||
- Synthesize information from multiple sources
|
||||
|
||||
### Reporting Standards
|
||||
- Present findings with appropriate confidence levels
|
||||
- Distinguish between facts, interpretations, and recommendations
|
||||
- Provide evidence-based conclusions
|
||||
- Acknowledge limitations and uncertainties
|
||||
```
|
||||
|
||||
### 2. Behavioral Guidelines Design
|
||||
|
||||
#### Comprehensive Behavior Framework
|
||||
```markdown
|
||||
## Behavioral Guidelines
|
||||
|
||||
### Interaction Style
|
||||
- Maintain {tone} tone throughout all interactions
|
||||
- Use {communication_approach} when explaining complex concepts
|
||||
- Be {responsiveness_level} in addressing user questions
|
||||
- Demonstrate {empathy_level} when dealing with user challenges
|
||||
|
||||
### Response Standards
|
||||
- Provide responses that are {length_preference} and {detail_preference}
|
||||
- Structure information using {organization_pattern}
|
||||
- Include {frequency} examples and illustrations
|
||||
- Use {format_preference} formatting for clarity
|
||||
|
||||
### Quality Expectations
|
||||
- Ensure all information is {accuracy_standard}
|
||||
- Provide citations for {information_type} when available
|
||||
- Cross-verify information using {verification_method}
|
||||
- Update knowledge based on {update_criteria}
|
||||
```
|
||||
|
||||
#### Model-Specific Behavior Patterns
|
||||
|
||||
##### Claude 3.5/4 Specific Guidelines
|
||||
```markdown
|
||||
## Claude-Specific Behavioral Guidelines
|
||||
|
||||
### Constitutional Alignment
|
||||
- Follow constitutional AI principles in all responses
|
||||
- Prioritize helpfulness while maintaining safety
|
||||
- Consider multiple perspectives before concluding
|
||||
- Avoid harmful content while remaining useful
|
||||
|
||||
### Output Formatting
|
||||
- Use XML tags for structured information: <tag>content</tag>
|
||||
- Include thinking blocks for complex reasoning: <thinking>...</thinking>
|
||||
- Provide clear section headers with proper hierarchy
|
||||
- Use markdown formatting for improved readability
|
||||
|
||||
### Safety Protocols
|
||||
- Apply content policies consistently
|
||||
- Identify and flag potentially harmful requests
|
||||
- Provide safe alternatives when appropriate
|
||||
- Maintain transparency about limitations
|
||||
```
|
||||
|
||||
##### GPT-4 Specific Guidelines
|
||||
```markdown
|
||||
## GPT-4 Specific Behavioral Guidelines
|
||||
|
||||
### Structured Response Patterns
|
||||
- Use numbered lists for step-by-step processes
|
||||
- Implement clear section boundaries with ### headers
|
||||
- Provide JSON formatted outputs when specified
|
||||
- Use consistent indentation and formatting
|
||||
|
||||
### Function Calling Integration
|
||||
- Recognize when function calling would be appropriate
|
||||
- Structure responses to facilitate tool usage
|
||||
- Provide clear parameter specifications
|
||||
- Handle function results systematically
|
||||
|
||||
### Optimization Behaviors
|
||||
- Balance conciseness with comprehensiveness
|
||||
- Prioritize information relevance and importance
|
||||
- Use efficient language patterns
|
||||
- Minimize redundancy while maintaining clarity
|
||||
```
|
||||
|
||||
### 3. Output Format Specifications
|
||||
|
||||
#### Comprehensive Format Framework
|
||||
```markdown
|
||||
## Output Format Requirements
|
||||
|
||||
### Structure Standards
|
||||
- Begin responses with {opening_pattern}
|
||||
- Use {section_pattern} for major sections
|
||||
- Implement {hierarchy_pattern} for information organization
|
||||
- Include {closing_pattern} for response completion
|
||||
|
||||
### Content Organization
|
||||
- Present information in {presentation_order}
|
||||
- Group related information using {grouping_method}
|
||||
- Use {transition_pattern} between sections
|
||||
- Include {summary_element} for complex responses
|
||||
|
||||
### Format Specifications
|
||||
{if json_format_required}
|
||||
- Provide responses in valid JSON format
|
||||
- Use consistent key naming conventions
|
||||
- Include all required fields
|
||||
- Validate JSON syntax before output
|
||||
{endif}
|
||||
|
||||
{if markdown_format_required}
|
||||
- Use markdown for formatting and emphasis
|
||||
- Include appropriate heading levels
|
||||
- Use code blocks for technical content
|
||||
- Implement tables for structured data
|
||||
{endif}
|
||||
```
|
||||
|
||||
### 4. Safety and Ethical Guidelines
|
||||
|
||||
#### Comprehensive Safety Framework
|
||||
```markdown
|
||||
## Safety and Ethical Guidelines
|
||||
|
||||
### Content Policies
|
||||
- Avoid generating {prohibited_content_1}
|
||||
- Do not provide {prohibited_content_2}
|
||||
- Flag {sensitive_topics} for human review
|
||||
- Provide {safe_alternatives} when appropriate
|
||||
|
||||
### Ethical Considerations
|
||||
- Consider {ethical_principle_1} in all responses
|
||||
- Evaluate potential {ethical_impact} of provided information
|
||||
- Balance helpfulness with {safety_consideration}
|
||||
- Maintain {transparency_standard} about limitations
|
||||
|
||||
### Bias Mitigation
|
||||
- Actively identify and mitigate {bias_type_1}
|
||||
- Present information {neutrality_standard}
|
||||
- Include {diverse_perspectives} when appropriate
|
||||
- Avoid {stereotype_patterns}
|
||||
|
||||
### Harm Prevention
|
||||
- Identify potential {harm_type_1} in responses
|
||||
- Implement {prevention_mechanism} for harmful content
|
||||
- Provide {warning_system} for sensitive topics
|
||||
- Include {escalation_protocol} for concerning requests
|
||||
```
|
||||
|
||||
### 5. Error Handling and Uncertainty
|
||||
|
||||
#### Comprehensive Error Management
|
||||
```markdown
|
||||
## Error Handling and Uncertainty Protocols
|
||||
|
||||
### Uncertainty Management
|
||||
- Explicitly state confidence levels for uncertain information
|
||||
- Use phrases like "I believe," "It appears that," "Based on available information"
|
||||
- Acknowledge when information may be {uncertainty_type}
|
||||
- Provide {verification_method} for uncertain claims
|
||||
|
||||
### Error Recognition
|
||||
- Identify when {error_pattern} might have occurred
|
||||
- Implement {self_checking_mechanism} for accuracy
|
||||
- Use {validation_process} for important information
|
||||
- Provide {correction_protocol} when errors are identified
|
||||
|
||||
### Limitation Acknowledgment
|
||||
- Clearly state {knowledge_limitation} when relevant
|
||||
- Explain {limitation_reason} when unable to provide complete information
|
||||
- Suggest {alternative_approach} when direct assistance isn't possible
|
||||
- Provide {escalation_option} for complex scenarios
|
||||
|
||||
### Correction Procedures
|
||||
- Implement {correction_workflow} for identified errors
|
||||
- Provide {explanation_format} for corrections
|
||||
- Use {acknowledgment_pattern} for mistakes
|
||||
- Include {improvement_commitment} for future accuracy
|
||||
```
|
||||
|
||||
## Specialized System Prompt Templates
|
||||
|
||||
### 1. Educational Assistant System Prompt
|
||||
```markdown
|
||||
# Educational Assistant System Prompt
|
||||
|
||||
## Role Definition
|
||||
You are an expert educational assistant specializing in {subject_area} with {experience_level} of teaching experience. Your pedagogical approach emphasizes {teaching_philosophy} and adapts to different learning styles.
|
||||
|
||||
## Educational Philosophy
|
||||
- Create inclusive and supportive learning environments
|
||||
- Adapt explanations to match learner's comprehension level
|
||||
- Use scaffolding techniques to build understanding progressively
|
||||
- Encourage critical thinking and independent learning
|
||||
|
||||
## Teaching Standards
|
||||
- Provide accurate, up-to-date information verified through {verification_sources}
|
||||
- Use clear, accessible language appropriate for the target audience
|
||||
- Include relevant examples and analogies to enhance understanding
|
||||
- Structure learning objectives with clear progression
|
||||
|
||||
## Interaction Protocols
|
||||
- Assess learner's current understanding before providing explanations
|
||||
- Ask clarifying questions to tailor responses appropriately
|
||||
- Provide opportunities for learner questions and feedback
|
||||
- Offer additional resources for extended learning
|
||||
|
||||
## Output Format
|
||||
- Begin with brief assessment of learner's needs
|
||||
- Use clear headings and organized structure
|
||||
- Include summary points for key takeaways
|
||||
- Provide practice exercises when appropriate
|
||||
- End with suggestions for further learning
|
||||
|
||||
## Safety Guidelines
|
||||
- Create psychologically safe learning environments
|
||||
- Avoid language that might discourage or intimidate learners
|
||||
- Be patient and supportive when learners struggle with concepts
|
||||
- Respect diverse backgrounds and learning abilities
|
||||
|
||||
## Uncertainty Handling
|
||||
- Acknowledge when topics are beyond current expertise
|
||||
- Suggest reliable resources for additional information
|
||||
- Be transparent about the limits of available knowledge
|
||||
- Encourage critical thinking and independent verification
|
||||
```
|
||||
|
||||
### 2. Technical Documentation Generator System Prompt
|
||||
```markdown
|
||||
# Technical Documentation System Prompt
|
||||
|
||||
## Role Definition
|
||||
You are a Senior Technical Writer with {years} of experience creating documentation for {technology_domain}. Your expertise encompasses {documentation_types} and you follow {industry_standards} for technical communication.
|
||||
|
||||
## Documentation Standards
|
||||
- Follow {style_guide} for consistent formatting and terminology
|
||||
- Ensure clarity and accuracy in all technical explanations
|
||||
- Include practical examples and code snippets when helpful
|
||||
- Structure content with clear hierarchy and logical flow
|
||||
|
||||
## Quality Requirements
|
||||
- Maintain technical accuracy verified through {review_process}
|
||||
- Use consistent terminology throughout documentation
|
||||
- Provide comprehensive coverage of topics without overwhelming detail
|
||||
- Include troubleshooting information for common issues
|
||||
|
||||
## Audience Considerations
|
||||
- Target documentation at {audience_level} technical proficiency
|
||||
- Define technical terms and concepts appropriately
|
||||
- Provide progressive disclosure of complex information
|
||||
- Include context and motivation for technical decisions
|
||||
|
||||
## Format Specifications
|
||||
- Use markdown formatting for clear structure and readability
|
||||
- Include code blocks with syntax highlighting
|
||||
- Implement consistent section headings and numbering
|
||||
- Provide navigation aids and cross-references
|
||||
|
||||
## Review Process
|
||||
- Verify technical accuracy through {verification_method}
|
||||
- Test all code examples and procedures
|
||||
- Ensure completeness of coverage for documented features
|
||||
- Validate clarity and comprehensibility with target audience
|
||||
|
||||
## Safety and Compliance
|
||||
- Include security considerations where relevant
|
||||
- Document potential risks and mitigation strategies
|
||||
- Follow industry compliance requirements
|
||||
- Maintain confidentiality for sensitive information
|
||||
```
|
||||
|
||||
### 3. Data Analysis System Prompt
|
||||
```markdown
|
||||
# Data Analysis System Prompt
|
||||
|
||||
## Role Definition
|
||||
You are an expert Data Analyst specializing in {data_domain} with {years} of experience in {analysis_methodologies}. Your analytical approach combines {technical_skills} with {business_acumen} to deliver actionable insights.
|
||||
|
||||
## Analytical Framework
|
||||
- Apply {statistical_methods} for rigorous data analysis
|
||||
- Use {visualization_techniques} for effective data communication
|
||||
- Implement {quality_assurance} processes for data validation
|
||||
- Follow {ethical_guidelines} for responsible data handling
|
||||
|
||||
## Analysis Standards
|
||||
- Ensure methodological soundness in all analyses
|
||||
- Provide clear documentation of analytical processes
|
||||
- Include appropriate statistical measures and confidence intervals
|
||||
- Validate findings through {validation_methods}
|
||||
|
||||
## Communication Requirements
|
||||
- Present findings with appropriate technical depth for the audience
|
||||
- Use clear visualizations and narrative explanations
|
||||
- Highlight actionable insights and recommendations
|
||||
- Acknowledge limitations and uncertainties in analyses
|
||||
|
||||
## Output Structure
|
||||
```json
|
||||
{
|
||||
"executive_summary": "High-level overview of key findings",
|
||||
"methodology": "Description of analytical approach and methods used",
|
||||
"data_overview": "Summary of data sources, quality, and limitations",
|
||||
"key_findings": [
|
||||
{
|
||||
"finding": "Specific discovery or insight",
|
||||
"evidence": "Supporting data and statistical measures",
|
||||
"confidence": "Confidence level in the finding",
|
||||
"implications": "Business or operational implications"
|
||||
}
|
||||
],
|
||||
"recommendations": [
|
||||
{
|
||||
"action": "Recommended action",
|
||||
"priority": "High/Medium/Low",
|
||||
"expected_impact": "Anticipated outcome",
|
||||
"implementation_considerations": "Factors to consider"
|
||||
}
|
||||
],
|
||||
"limitations": "Constraints and limitations of the analysis",
|
||||
"next_steps": "Suggested follow-up analyses or actions"
|
||||
}
|
||||
```
|
||||
|
||||
## Ethical Considerations
|
||||
- Protect privacy and confidentiality of data subjects
|
||||
- Ensure unbiased analysis and interpretation
|
||||
- Consider potential impact of findings on stakeholders
|
||||
- Maintain transparency about analytical limitations
|
||||
```
|
||||
|
||||
## System Prompt Testing and Validation
|
||||
|
||||
### Validation Framework
|
||||
```python
|
||||
class SystemPromptValidator:
|
||||
def __init__(self):
|
||||
self.validation_criteria = {
|
||||
'role_clarity': 0.2,
|
||||
'instruction_specificity': 0.2,
|
||||
'safety_completeness': 0.15,
|
||||
'output_format_clarity': 0.15,
|
||||
'error_handling_coverage': 0.1,
|
||||
'behavioral_consistency': 0.1,
|
||||
'ethical_considerations': 0.1
|
||||
}
|
||||
|
||||
def validate_prompt(self, system_prompt):
|
||||
"""Validate system prompt against quality criteria."""
|
||||
scores = {}
|
||||
|
||||
scores['role_clarity'] = self.assess_role_clarity(system_prompt)
|
||||
scores['instruction_specificity'] = self.assess_instruction_specificity(system_prompt)
|
||||
scores['safety_completeness'] = self.assess_safety_completeness(system_prompt)
|
||||
scores['output_format_clarity'] = self.assess_output_format_clarity(system_prompt)
|
||||
scores['error_handling_coverage'] = self.assess_error_handling(system_prompt)
|
||||
scores['behavioral_consistency'] = self.assess_behavioral_consistency(system_prompt)
|
||||
scores['ethical_considerations'] = self.assess_ethical_considerations(system_prompt)
|
||||
|
||||
# Calculate overall score
|
||||
overall_score = sum(score * weight for score, weight in
|
||||
zip(scores.values(), self.validation_criteria.values()))
|
||||
|
||||
return {
|
||||
'overall_score': overall_score,
|
||||
'individual_scores': scores,
|
||||
'recommendations': self.generate_recommendations(scores)
|
||||
}
|
||||
|
||||
def test_prompt_consistency(self, system_prompt, test_scenarios):
|
||||
"""Test prompt behavior consistency across different scenarios."""
|
||||
results = []
|
||||
|
||||
for scenario in test_scenarios:
|
||||
response = execute_with_system_prompt(system_prompt, scenario)
|
||||
|
||||
# Analyze response consistency
|
||||
consistency_score = self.analyze_response_consistency(response, system_prompt)
|
||||
results.append({
|
||||
'scenario': scenario,
|
||||
'response': response,
|
||||
'consistency_score': consistency_score
|
||||
})
|
||||
|
||||
average_consistency = sum(r['consistency_score'] for r in results) / len(results)
|
||||
|
||||
return {
|
||||
'average_consistency': average_consistency,
|
||||
'scenario_results': results,
|
||||
'recommendations': self.generate_consistency_recommendations(results)
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices Summary
|
||||
|
||||
### Design Principles
|
||||
- **Clarity First**: Ensure role and instructions are unambiguous
|
||||
- **Comprehensive Coverage**: Address all aspects of model behavior
|
||||
- **Consistency Focus**: Maintain consistent behavior across scenarios
|
||||
- **Safety Priority**: Include robust safety guidelines and constraints
|
||||
- **Flexibility Built-in**: Allow for adaptation to different contexts
|
||||
|
||||
### Common Pitfalls to Avoid
|
||||
- **Vague Instructions**: Be specific about expected behaviors
|
||||
- **Over-constraining**: Allow room for intelligent adaptation
|
||||
- **Missing Safety Guidelines**: Always include comprehensive safety measures
|
||||
- **Inconsistent Formatting**: Use consistent structure throughout
|
||||
- **Ignoring Model Capabilities**: Design prompts that leverage model strengths
|
||||
|
||||
This comprehensive system prompt design framework provides the foundation for creating effective, reliable, and safe AI system behaviors across diverse applications and use cases.
|
||||
599
skills/ai/prompt-engineering/references/template-systems.md
Normal file
599
skills/ai/prompt-engineering/references/template-systems.md
Normal file
@@ -0,0 +1,599 @@
|
||||
# Template Systems Architecture
|
||||
|
||||
This reference provides comprehensive frameworks for building modular, reusable prompt templates with variable interpolation, conditional sections, and hierarchical composition.
|
||||
|
||||
## Template Design Principles
|
||||
|
||||
### Modularity and Reusability
|
||||
- **Single Responsibility**: Each template handles one specific type of task
|
||||
- **Composability**: Templates can be combined to create complex prompts
|
||||
- **Parameterization**: Variables allow customization without core changes
|
||||
- **Inheritance**: Base templates can be extended for specific use cases
|
||||
|
||||
### Clear Variable Naming Conventions
|
||||
```
|
||||
{user_input} - Direct input from user
|
||||
{context} - Background information
|
||||
{examples} - Few-shot learning examples
|
||||
{constraints} - Task limitations and requirements
|
||||
{output_format} - Desired output structure
|
||||
{role} - AI role or persona
|
||||
{expertise_level} - Level of expertise for the role
|
||||
{domain} - Specific domain or field
|
||||
{difficulty} - Task complexity level
|
||||
{language} - Output language specification
|
||||
```
|
||||
|
||||
## Core Template Components
|
||||
|
||||
### 1. Base Template Structure
|
||||
```
|
||||
# Template: Universal Task Framework
|
||||
# Purpose: Base template for most task types
|
||||
# Variables: {role}, {task_description}, {context}, {examples}, {output_format}
|
||||
|
||||
## System Instructions
|
||||
You are a {role} with {expertise_level} expertise in {domain}.
|
||||
|
||||
## Context Information
|
||||
{if context}
|
||||
Background and relevant context:
|
||||
{context}
|
||||
{endif}
|
||||
|
||||
## Task Description
|
||||
{task_description}
|
||||
|
||||
## Examples
|
||||
{if examples}
|
||||
Here are some examples to guide your response:
|
||||
|
||||
{examples}
|
||||
{endif}
|
||||
|
||||
## Output Requirements
|
||||
{output_format}
|
||||
|
||||
## Constraints and Guidelines
|
||||
{constraints}
|
||||
|
||||
## User Input
|
||||
{user_input}
|
||||
```
|
||||
|
||||
### 2. Conditional Sections Framework
|
||||
```python
|
||||
def process_conditional_template(template, variables):
|
||||
"""
|
||||
Process template with conditional sections.
|
||||
"""
|
||||
# Process if/endif blocks
|
||||
while '{if ' in template:
|
||||
start = template.find('{if ')
|
||||
end_condition = template.find('}', start)
|
||||
condition = template[start+4:end_condition].strip()
|
||||
|
||||
start_endif = template.find('{endif}', end_condition)
|
||||
if_content = template[end_condition+1:start_endif].strip()
|
||||
|
||||
# Evaluate condition
|
||||
if evaluate_condition(condition, variables):
|
||||
template = template[:start] + if_content + template[start_endif+6:]
|
||||
else:
|
||||
template = template[:start] + template[start_endif+6:]
|
||||
|
||||
# Replace variables
|
||||
for key, value in variables.items():
|
||||
template = template.replace(f'{{{key}}}', str(value))
|
||||
|
||||
return template
|
||||
```
|
||||
|
||||
### 3. Variable Interpolation System
|
||||
```python
|
||||
class TemplateEngine:
|
||||
def __init__(self):
|
||||
self.variables = {}
|
||||
self.functions = {
|
||||
'upper': str.upper,
|
||||
'lower': str.lower,
|
||||
'capitalize': str.capitalize,
|
||||
'pluralize': self.pluralize,
|
||||
'format_date': self.format_date,
|
||||
'truncate': self.truncate
|
||||
}
|
||||
|
||||
def set_variable(self, name, value):
|
||||
"""Set a template variable."""
|
||||
self.variables[name] = value
|
||||
|
||||
def render(self, template):
|
||||
"""Render template with variable substitution."""
|
||||
# Process function calls {variable|function}
|
||||
template = self.process_functions(template)
|
||||
|
||||
# Replace variables
|
||||
for key, value in self.variables.items():
|
||||
template = template.replace(f'{{{key}}}', str(value))
|
||||
|
||||
return template
|
||||
|
||||
def process_functions(self, template):
|
||||
"""Process template functions."""
|
||||
import re
|
||||
pattern = r'\{(\w+)\|(\w+)\}'
|
||||
|
||||
def replace_function(match):
|
||||
var_name, func_name = match.groups()
|
||||
value = self.variables.get(var_name, '')
|
||||
if func_name in self.functions:
|
||||
return self.functions[func_name](str(value))
|
||||
return value
|
||||
|
||||
return re.sub(pattern, replace_function, template)
|
||||
```
|
||||
|
||||
## Specialized Template Types
|
||||
|
||||
### 1. Classification Template
|
||||
```
|
||||
# Template: Multi-Class Classification
|
||||
# Purpose: Classify inputs into predefined categories
|
||||
# Required Variables: {input_text}, {categories}, {role}
|
||||
|
||||
## Classification Framework
|
||||
You are a {role} specializing in accurate text classification.
|
||||
|
||||
## Classification Categories
|
||||
{categories}
|
||||
|
||||
## Classification Process
|
||||
1. Analyze the input text carefully
|
||||
2. Identify key indicators and features
|
||||
3. Match against category definitions
|
||||
4. Select the most appropriate category
|
||||
5. Provide confidence score
|
||||
|
||||
## Input to Classify
|
||||
{input_text}
|
||||
|
||||
## Output Format
|
||||
```json
|
||||
{{
|
||||
"category": "selected_category",
|
||||
"confidence": 0.95,
|
||||
"reasoning": "Brief explanation of classification logic",
|
||||
"key_indicators": ["indicator1", "indicator2"]
|
||||
}}
|
||||
```
|
||||
```
|
||||
|
||||
### 2. Transformation Template
|
||||
```
|
||||
# Template: Text Transformation
|
||||
# Purpose: Transform text from one format/style to another
|
||||
# Required Variables: {source_text}, {target_format}, {transformation_rules}
|
||||
|
||||
## Transformation Task
|
||||
Transform the given {source_format} text into {target_format} following these rules:
|
||||
{transformation_rules}
|
||||
|
||||
## Source Text
|
||||
{source_text}
|
||||
|
||||
## Transformation Process
|
||||
1. Analyze the structure and content of the source text
|
||||
2. Apply the specified transformation rules
|
||||
3. Maintain the core meaning and intent
|
||||
4. Ensure proper {target_format} formatting
|
||||
5. Verify completeness and accuracy
|
||||
|
||||
## Transformed Output
|
||||
```
|
||||
|
||||
### 3. Generation Template
|
||||
```
|
||||
# Template: Creative Generation
|
||||
# Purpose: Generate creative content based on specifications
|
||||
# Required Variables: {content_type}, {specifications}, {style_guidelines}
|
||||
|
||||
## Creative Generation Task
|
||||
Generate {content_type} that meets the following specifications:
|
||||
|
||||
## Content Specifications
|
||||
{specifications}
|
||||
|
||||
## Style Guidelines
|
||||
{style_guidelines}
|
||||
|
||||
## Quality Requirements
|
||||
- Originality and creativity
|
||||
- Adherence to specifications
|
||||
- Appropriate tone and style
|
||||
- Clear structure and coherence
|
||||
- Audience-appropriate language
|
||||
|
||||
## Generated Content
|
||||
```
|
||||
|
||||
### 4. Analysis Template
|
||||
```
|
||||
# Template: Comprehensive Analysis
|
||||
# Purpose: Perform detailed analysis of given input
|
||||
# Required Variables: {input_data}, {analysis_framework}, {focus_areas}
|
||||
|
||||
## Analysis Framework
|
||||
You are an expert analyst with deep expertise in {domain}.
|
||||
|
||||
## Analysis Scope
|
||||
Focus on these key areas:
|
||||
{focus_areas}
|
||||
|
||||
## Analysis Methodology
|
||||
{analysis_framework}
|
||||
|
||||
## Input Data for Analysis
|
||||
{input_data}
|
||||
|
||||
## Analysis Process
|
||||
1. Initial assessment and context understanding
|
||||
2. Detailed examination of each focus area
|
||||
3. Pattern and trend identification
|
||||
4. Comparative analysis with benchmarks
|
||||
5. Insight generation and recommendation formulation
|
||||
|
||||
## Analysis Output Structure
|
||||
```yaml
|
||||
executive_summary:
|
||||
key_findings: []
|
||||
overall_assessment: ""
|
||||
|
||||
detailed_analysis:
|
||||
{focus_area_1}:
|
||||
observations: []
|
||||
patterns: []
|
||||
insights: []
|
||||
{focus_area_2}:
|
||||
observations: []
|
||||
patterns: []
|
||||
insights: []
|
||||
|
||||
recommendations:
|
||||
immediate: []
|
||||
short_term: []
|
||||
long_term: []
|
||||
```
|
||||
|
||||
## Advanced Template Patterns
|
||||
|
||||
### 1. Hierarchical Template Composition
|
||||
```python
|
||||
class HierarchicalTemplate:
|
||||
def __init__(self, name, content, parent=None):
|
||||
self.name = name
|
||||
self.content = content
|
||||
self.parent = parent
|
||||
self.children = []
|
||||
self.variables = {}
|
||||
|
||||
def add_child(self, child_template):
|
||||
"""Add a child template."""
|
||||
child_template.parent = self
|
||||
self.children.append(child_template)
|
||||
|
||||
def render(self, variables=None):
|
||||
"""Render template with inherited variables."""
|
||||
# Combine variables from parent hierarchy
|
||||
combined_vars = {}
|
||||
|
||||
# Collect variables from parents
|
||||
current = self.parent
|
||||
while current:
|
||||
combined_vars.update(current.variables)
|
||||
current = current.parent
|
||||
|
||||
# Add current variables
|
||||
combined_vars.update(self.variables)
|
||||
|
||||
# Override with provided variables
|
||||
if variables:
|
||||
combined_vars.update(variables)
|
||||
|
||||
# Render content
|
||||
rendered_content = self.render_content(self.content, combined_vars)
|
||||
|
||||
# Render children
|
||||
for child in self.children:
|
||||
child_rendered = child.render(combined_vars)
|
||||
rendered_content = rendered_content.replace(
|
||||
f'{{child:{child.name}}}', child_rendered
|
||||
)
|
||||
|
||||
return rendered_content
|
||||
```
|
||||
|
||||
### 2. Role-Based Template System
|
||||
```python
|
||||
class RoleBasedTemplate:
|
||||
def __init__(self):
|
||||
self.roles = {
|
||||
'analyst': {
|
||||
'persona': 'You are a professional analyst with expertise in data interpretation and pattern recognition.',
|
||||
'approach': 'systematic',
|
||||
'output_style': 'detailed and evidence-based',
|
||||
'verification': 'Always cross-check findings and cite sources'
|
||||
},
|
||||
'creative_writer': {
|
||||
'persona': 'You are a creative writer with a talent for engaging storytelling and vivid descriptions.',
|
||||
'approach': 'imaginative',
|
||||
'output_style': 'descriptive and engaging',
|
||||
'verification': 'Ensure narrative consistency and flow'
|
||||
},
|
||||
'technical_expert': {
|
||||
'persona': 'You are a technical expert with deep knowledge of {domain} and practical implementation experience.',
|
||||
'approach': 'methodical',
|
||||
'output_style': 'precise and technical',
|
||||
'verification': 'Include technical accuracy and best practices'
|
||||
}
|
||||
}
|
||||
|
||||
def create_prompt(self, role, task, domain=None):
|
||||
"""Create role-specific prompt template."""
|
||||
role_config = self.roles.get(role, self.roles['analyst'])
|
||||
|
||||
template = f"""
|
||||
## Role Definition
|
||||
{role_config['persona']}
|
||||
|
||||
## Approach
|
||||
Use a {role_config['approach']} approach to this task.
|
||||
|
||||
## Task
|
||||
{task}
|
||||
|
||||
## Output Style
|
||||
{role_config['output_style']}
|
||||
|
||||
## Verification
|
||||
{role_config['verification']}
|
||||
"""
|
||||
|
||||
if domain and '{domain}' in role_config['persona']:
|
||||
template = template.replace('{domain}', domain)
|
||||
|
||||
return template
|
||||
```
|
||||
|
||||
### 3. Dynamic Template Selection
|
||||
```python
|
||||
class DynamicTemplateSelector:
|
||||
def __init__(self):
|
||||
self.templates = {}
|
||||
self.selection_rules = {}
|
||||
|
||||
def register_template(self, name, template, selection_criteria):
|
||||
"""Register a template with selection criteria."""
|
||||
self.templates[name] = template
|
||||
self.selection_rules[name] = selection_criteria
|
||||
|
||||
def select_template(self, task_characteristics):
|
||||
"""Select the most appropriate template based on task characteristics."""
|
||||
best_template = None
|
||||
best_score = 0
|
||||
|
||||
for name, criteria in self.selection_rules.items():
|
||||
score = self.calculate_match_score(task_characteristics, criteria)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_template = name
|
||||
|
||||
return self.templates[best_template] if best_template else None
|
||||
|
||||
def calculate_match_score(self, task_characteristics, criteria):
|
||||
"""Calculate how well task matches template criteria."""
|
||||
score = 0
|
||||
total_weight = 0
|
||||
|
||||
for characteristic, weight in criteria.items():
|
||||
if characteristic in task_characteristics:
|
||||
if task_characteristics[characteristic] == weight['value']:
|
||||
score += weight['weight']
|
||||
total_weight += weight['weight']
|
||||
|
||||
return score / total_weight if total_weight > 0 else 0
|
||||
```
|
||||
|
||||
## Template Implementation Examples
|
||||
|
||||
### Example 1: Customer Service Template
|
||||
```python
|
||||
customer_service_template = """
|
||||
# Customer Service Response Template
|
||||
|
||||
## Role Definition
|
||||
You are a {customer_service_role} with {experience_level} of customer service experience in {industry}.
|
||||
|
||||
## Context
|
||||
{if customer_history}
|
||||
Customer History:
|
||||
{customer_history}
|
||||
{endif}
|
||||
|
||||
{if issue_context}
|
||||
Issue Context:
|
||||
{issue_context}
|
||||
{endif}
|
||||
|
||||
## Response Guidelines
|
||||
- Maintain {tone} tone throughout
|
||||
- Address all aspects of the customer's inquiry
|
||||
- Provide {level_of_detail} explanation
|
||||
- Include {additional_elements}
|
||||
- Follow company {communication_style} style
|
||||
|
||||
## Customer Inquiry
|
||||
{customer_inquiry}
|
||||
|
||||
## Response Structure
|
||||
1. Greeting and acknowledgment
|
||||
2. Understanding and empathy
|
||||
3. Solution or explanation
|
||||
4. Additional assistance offered
|
||||
5. Professional closing
|
||||
|
||||
## Response
|
||||
"""
|
||||
```
|
||||
|
||||
### Example 2: Technical Documentation Template
|
||||
```python
|
||||
documentation_template = """
|
||||
# Technical Documentation Generator
|
||||
|
||||
## Role Definition
|
||||
You are a {technical_writer_role} specializing in {technology} documentation with {experience_level} of experience.
|
||||
|
||||
## Documentation Standards
|
||||
- Target audience: {audience_level}
|
||||
- Technical depth: {technical_depth}
|
||||
- Include examples: {include_examples}
|
||||
- Add troubleshooting: {add_troubleshooting}
|
||||
- Version: {version}
|
||||
|
||||
## Content to Document
|
||||
{content_to_document}
|
||||
|
||||
## Documentation Structure
|
||||
```markdown
|
||||
# {title}
|
||||
|
||||
## Overview
|
||||
{overview}
|
||||
|
||||
## Prerequisites
|
||||
{prerequisites}
|
||||
|
||||
## {main_sections}
|
||||
|
||||
## Examples
|
||||
{if include_examples}
|
||||
{examples}
|
||||
{endif}
|
||||
|
||||
## Troubleshooting
|
||||
{if add_troubleshooting}
|
||||
{troubleshooting}
|
||||
{endif}
|
||||
|
||||
## Additional Resources
|
||||
{additional_resources}
|
||||
```
|
||||
|
||||
## Generated Documentation
|
||||
"""
|
||||
```
|
||||
|
||||
## Template Management System
|
||||
|
||||
### Version Control Integration
|
||||
```python
|
||||
class TemplateVersionManager:
|
||||
def __init__(self):
|
||||
self.versions = {}
|
||||
self.current_versions = {}
|
||||
|
||||
def create_version(self, template_name, template_content, author, description):
|
||||
"""Create a new version of a template."""
|
||||
import datetime
|
||||
import hashlib
|
||||
|
||||
version_id = hashlib.md5(template_content.encode()).hexdigest()[:8]
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
|
||||
version_info = {
|
||||
'version_id': version_id,
|
||||
'content': template_content,
|
||||
'author': author,
|
||||
'description': description,
|
||||
'timestamp': timestamp,
|
||||
'parent_version': self.current_versions.get(template_name)
|
||||
}
|
||||
|
||||
if template_name not in self.versions:
|
||||
self.versions[template_name] = []
|
||||
|
||||
self.versions[template_name].append(version_info)
|
||||
self.current_versions[template_name] = version_id
|
||||
|
||||
return version_id
|
||||
|
||||
def rollback(self, template_name, version_id):
|
||||
"""Rollback to a specific version."""
|
||||
if template_name in self.versions:
|
||||
for version in self.versions[template_name]:
|
||||
if version['version_id'] == version_id:
|
||||
self.current_versions[template_name] = version_id
|
||||
return version['content']
|
||||
return None
|
||||
```
|
||||
|
||||
### Performance Monitoring
|
||||
```python
|
||||
class TemplatePerformanceMonitor:
|
||||
def __init__(self):
|
||||
self.usage_stats = {}
|
||||
self.performance_metrics = {}
|
||||
|
||||
def track_usage(self, template_name, execution_time, success):
|
||||
"""Track template usage and performance."""
|
||||
if template_name not in self.usage_stats:
|
||||
self.usage_stats[template_name] = {
|
||||
'usage_count': 0,
|
||||
'total_time': 0,
|
||||
'success_count': 0,
|
||||
'failure_count': 0
|
||||
}
|
||||
|
||||
stats = self.usage_stats[template_name]
|
||||
stats['usage_count'] += 1
|
||||
stats['total_time'] += execution_time
|
||||
|
||||
if success:
|
||||
stats['success_count'] += 1
|
||||
else:
|
||||
stats['failure_count'] += 1
|
||||
|
||||
def get_performance_report(self, template_name):
|
||||
"""Generate performance report for a template."""
|
||||
if template_name not in self.usage_stats:
|
||||
return None
|
||||
|
||||
stats = self.usage_stats[template_name]
|
||||
avg_time = stats['total_time'] / stats['usage_count']
|
||||
success_rate = stats['success_count'] / stats['usage_count']
|
||||
|
||||
return {
|
||||
'template_name': template_name,
|
||||
'total_usage': stats['usage_count'],
|
||||
'average_execution_time': avg_time,
|
||||
'success_rate': success_rate,
|
||||
'failure_rate': 1 - success_rate
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Template Quality Guidelines
|
||||
- **Clear Documentation**: Include purpose, variables, and usage examples
|
||||
- **Consistent Naming**: Use standardized variable naming conventions
|
||||
- **Error Handling**: Include fallback mechanisms for missing variables
|
||||
- **Performance Optimization**: Minimize template complexity and rendering time
|
||||
- **Testing**: Implement comprehensive template testing frameworks
|
||||
|
||||
### Security Considerations
|
||||
- **Input Validation**: Sanitize all template variables
|
||||
- **Injection Prevention**: Prevent code injection in template rendering
|
||||
- **Access Control**: Implement proper authorization for template modifications
|
||||
- **Audit Trail**: Track template changes and usage
|
||||
|
||||
This comprehensive template system architecture provides the foundation for building scalable, maintainable prompt templates that can be efficiently managed and optimized across diverse use cases.
|
||||
286
skills/ai/rag/SKILL.md
Normal file
286
skills/ai/rag/SKILL.md
Normal file
@@ -0,0 +1,286 @@
|
||||
---
|
||||
name: rag-implementation
|
||||
description: Build Retrieval-Augmented Generation (RAG) systems for AI applications with vector databases and semantic search. Use when implementing knowledge-grounded AI, building document Q&A systems, or integrating LLMs with external knowledge bases.
|
||||
allowed-tools: Read, Write, Bash
|
||||
category: ai-engineering
|
||||
tags: [rag, vector-databases, embeddings, retrieval, semantic-search]
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# RAG Implementation
|
||||
|
||||
Build Retrieval-Augmented Generation systems that extend AI capabilities with external knowledge sources.
|
||||
|
||||
## Overview
|
||||
|
||||
RAG (Retrieval-Augmented Generation) enhances AI applications by retrieving relevant information from knowledge bases and incorporating it into AI responses, reducing hallucinations and providing accurate, grounded answers.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
|
||||
- Building Q&A systems over proprietary documents
|
||||
- Creating chatbots with current, factual information
|
||||
- Implementing semantic search with natural language queries
|
||||
- Reducing hallucinations with grounded responses
|
||||
- Enabling AI systems to access domain-specific knowledge
|
||||
- Building documentation assistants
|
||||
- Creating research tools with source citation
|
||||
- Developing knowledge management systems
|
||||
|
||||
## Core Components
|
||||
|
||||
### Vector Databases
|
||||
Store and efficiently retrieve document embeddings for semantic search.
|
||||
|
||||
**Key Options:**
|
||||
- **Pinecone**: Managed, scalable, production-ready
|
||||
- **Weaviate**: Open-source, hybrid search capabilities
|
||||
- **Milvus**: High performance, on-premise deployment
|
||||
- **Chroma**: Lightweight, easy local development
|
||||
- **Qdrant**: Fast, advanced filtering
|
||||
- **FAISS**: Meta's library, full control
|
||||
|
||||
### Embedding Models
|
||||
Convert text to numerical vectors for similarity search.
|
||||
|
||||
**Popular Models:**
|
||||
- **text-embedding-ada-002** (OpenAI): General purpose, 1536 dimensions
|
||||
- **all-MiniLM-L6-v2**: Fast, lightweight, 384 dimensions
|
||||
- **e5-large-v2**: High quality, multilingual
|
||||
- **bge-large-en-v1.5**: State-of-the-art performance
|
||||
|
||||
### Retrieval Strategies
|
||||
Find relevant content based on user queries.
|
||||
|
||||
**Approaches:**
|
||||
- **Dense Retrieval**: Semantic similarity via embeddings
|
||||
- **Sparse Retrieval**: Keyword matching (BM25, TF-IDF)
|
||||
- **Hybrid Search**: Combine dense + sparse for best results
|
||||
- **Multi-Query**: Generate multiple query variations
|
||||
- **Contextual Compression**: Extract only relevant parts
|
||||
|
||||
## Quick Implementation
|
||||
|
||||
### Basic RAG Setup
|
||||
|
||||
```java
|
||||
// Load documents from file system
|
||||
List<Document> documents = FileSystemDocumentLoader.loadDocuments("/path/to/docs");
|
||||
|
||||
// Create embedding store
|
||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
||||
|
||||
// Ingest documents into the store
|
||||
EmbeddingStoreIngestor.ingest(documents, embeddingStore);
|
||||
|
||||
// Create AI service with RAG capability
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
.chatModel(chatModel)
|
||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||
.contentRetriever(EmbeddingStoreContentRetriever.from(embeddingStore))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Document Processing Pipeline
|
||||
|
||||
```java
|
||||
// Split documents into chunks
|
||||
DocumentSplitter splitter = new RecursiveCharacterTextSplitter(
|
||||
500, // chunk size
|
||||
100 // overlap
|
||||
);
|
||||
|
||||
// Create embedding model
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey("your-api-key")
|
||||
.build();
|
||||
|
||||
// Create embedding store
|
||||
EmbeddingStore<TextSegment> embeddingStore = PgVectorEmbeddingStore.builder()
|
||||
.host("localhost")
|
||||
.database("postgres")
|
||||
.user("postgres")
|
||||
.password("password")
|
||||
.table("embeddings")
|
||||
.dimension(1536)
|
||||
.build();
|
||||
|
||||
// Process and store documents
|
||||
for (Document document : documents) {
|
||||
List<TextSegment> segments = splitter.split(document);
|
||||
for (TextSegment segment : segments) {
|
||||
Embedding embedding = embeddingModel.embed(segment).content();
|
||||
embeddingStore.add(embedding, segment);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Patterns
|
||||
|
||||
### Pattern 1: Simple Document Q&A
|
||||
|
||||
Create a basic Q&A system over your documents.
|
||||
|
||||
```java
|
||||
public interface DocumentAssistant {
|
||||
String answer(String question);
|
||||
}
|
||||
|
||||
DocumentAssistant assistant = AiServices.builder(DocumentAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.contentRetriever(retriever)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Pattern 2: Metadata-Filtered Retrieval
|
||||
|
||||
Filter results based on document metadata.
|
||||
|
||||
```java
|
||||
// Add metadata during document loading
|
||||
Document document = Document.builder()
|
||||
.text("Content here")
|
||||
.metadata("source", "technical-manual.pdf")
|
||||
.metadata("category", "technical")
|
||||
.metadata("date", "2024-01-15")
|
||||
.build();
|
||||
|
||||
// Filter during retrieval
|
||||
EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(5)
|
||||
.minScore(0.7)
|
||||
.filter(metadataKey("category").isEqualTo("technical"))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Source Retrieval
|
||||
|
||||
Combine results from multiple knowledge sources.
|
||||
|
||||
```java
|
||||
ContentRetriever webRetriever = EmbeddingStoreContentRetriever.from(webStore);
|
||||
ContentRetriever documentRetriever = EmbeddingStoreContentRetriever.from(documentStore);
|
||||
ContentRetriever databaseRetriever = EmbeddingStoreContentRetriever.from(databaseStore);
|
||||
|
||||
// Combine results
|
||||
List<Content> allResults = new ArrayList<>();
|
||||
allResults.addAll(webRetriever.retrieve(query));
|
||||
allResults.addAll(documentRetriever.retrieve(query));
|
||||
allResults.addAll(databaseRetriever.retrieve(query));
|
||||
|
||||
// Rerank combined results
|
||||
List<Content> rerankedResults = reranker.reorder(query, allResults);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Document Preparation
|
||||
- Clean and preprocess documents before ingestion
|
||||
- Remove irrelevant content and formatting artifacts
|
||||
- Standardize document structure for consistent processing
|
||||
- Add relevant metadata for filtering and context
|
||||
|
||||
### Chunking Strategy
|
||||
- Use 500-1000 tokens per chunk for optimal balance
|
||||
- Include 10-20% overlap to preserve context at boundaries
|
||||
- Consider document structure when determining chunk boundaries
|
||||
- Test different chunk sizes for your specific use case
|
||||
|
||||
### Retrieval Optimization
|
||||
- Start with high k values (10-20) then filter/rerank
|
||||
- Use metadata filtering to improve relevance
|
||||
- Combine multiple retrieval strategies for better coverage
|
||||
- Monitor retrieval quality and user feedback
|
||||
|
||||
### Performance Considerations
|
||||
- Cache embeddings for frequently accessed content
|
||||
- Use batch processing for document ingestion
|
||||
- Optimize vector store configuration for your scale
|
||||
- Monitor query performance and system resources
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Poor Retrieval Quality
|
||||
**Problem**: Retrieved documents don't match user queries
|
||||
**Solutions**:
|
||||
- Improve document preprocessing and cleaning
|
||||
- Adjust chunk size and overlap parameters
|
||||
- Try different embedding models
|
||||
- Use hybrid search combining semantic and keyword matching
|
||||
|
||||
### Irrelevant Results
|
||||
**Problem**: Retrieved documents contain relevant information but are not specific enough
|
||||
**Solutions**:
|
||||
- Add metadata filtering for domain-specific constraints
|
||||
- Implement reranking with cross-encoder models
|
||||
- Use contextual compression to extract relevant parts
|
||||
- Fine-tune retrieval parameters (k values, similarity thresholds)
|
||||
|
||||
### Performance Issues
|
||||
**Problem**: Slow response times during retrieval
|
||||
**Solutions**:
|
||||
- Optimize vector store configuration and indexing
|
||||
- Implement caching for frequently retrieved content
|
||||
- Use smaller embedding models for faster inference
|
||||
- Consider approximate nearest neighbor algorithms
|
||||
|
||||
### Hallucination Prevention
|
||||
**Problem**: AI generates information not present in retrieved documents
|
||||
**Solutions**:
|
||||
- Improve prompt engineering to emphasize grounding
|
||||
- Add verification steps to check answer alignment
|
||||
- Include confidence scoring for responses
|
||||
- Implement fact-checking mechanisms
|
||||
|
||||
## Evaluation Framework
|
||||
|
||||
### Retrieval Metrics
|
||||
- **Precision@k**: Percentage of relevant documents in top-k results
|
||||
- **Recall@k**: Percentage of all relevant documents found in top-k results
|
||||
- **Mean Reciprocal Rank (MRR)**: Average rank of first relevant result
|
||||
- **Normalized Discounted Cumulative Gain (nDCG)**: Ranking quality metric
|
||||
|
||||
### Answer Quality Metrics
|
||||
- **Faithfulness**: Degree to which answers are grounded in retrieved documents
|
||||
- **Answer Relevance**: How well answers address user questions
|
||||
- **Context Recall**: Percentage of relevant context used in answers
|
||||
- **Context Precision**: Percentage of retrieved context that is relevant
|
||||
|
||||
### User Experience Metrics
|
||||
- **Response Time**: Time from query to answer
|
||||
- **User Satisfaction**: Feedback ratings on answer quality
|
||||
- **Task Completion**: Rate of successful task completion
|
||||
- **Engagement**: User interaction patterns with the system
|
||||
|
||||
## Resources
|
||||
|
||||
### Reference Documentation
|
||||
- [Vector Database Comparison](references/vector-databases.md) - Detailed comparison of vector database options
|
||||
- [Embedding Models Guide](references/embedding-models.md) - Model selection and optimization
|
||||
- [Retrieval Strategies](references/retrieval-strategies.md) - Advanced retrieval techniques
|
||||
- [Document Chunking](references/document-chunking.md) - Chunking strategies and best practices
|
||||
- [LangChain4j RAG Guide](references/langchain4j-rag-guide.md) - Official implementation patterns
|
||||
|
||||
### Assets
|
||||
- `assets/vector-store-config.yaml` - Configuration templates for different vector stores
|
||||
- `assets/retriever-pipeline.java` - Complete RAG pipeline implementation
|
||||
- `assets/evaluation-metrics.java` - Evaluation framework code
|
||||
|
||||
## Constraints and Limitations
|
||||
|
||||
1. **Token Limits**: Respect model context window limitations
|
||||
2. **API Rate Limits**: Manage external API rate limits and costs
|
||||
3. **Data Privacy**: Ensure compliance with data protection regulations
|
||||
4. **Resource Requirements**: Consider memory and computational requirements
|
||||
5. **Maintenance**: Plan for regular updates and system monitoring
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Secure access to vector databases and embedding services
|
||||
- Implement proper authentication and authorization
|
||||
- Validate and sanitize user inputs
|
||||
- Monitor for abuse and unusual usage patterns
|
||||
- Regular security audits and penetration testing
|
||||
307
skills/ai/rag/assets/retriever-pipeline.java
Normal file
307
skills/ai/rag/assets/retriever-pipeline.java
Normal file
@@ -0,0 +1,307 @@
|
||||
package com.example.rag;
|
||||
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.DocumentSplitter;
|
||||
import dev.langchain4j.data.document.parser.TextDocumentParser;
|
||||
import dev.langchain4j.data.document.splitter.RecursiveCharacterTextSplitter;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.pinecone.PineconeEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
|
||||
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* Complete RAG Pipeline Implementation
|
||||
*
|
||||
* This class provides a comprehensive implementation of a RAG (Retrieval-Augmented Generation)
|
||||
* system with support for multiple vector stores and advanced retrieval strategies.
|
||||
*/
|
||||
public class RAGPipeline {
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final EmbeddingStore<TextSegment> embeddingStore;
|
||||
private final DocumentSplitter documentSplitter;
|
||||
private final RAGConfig config;
|
||||
|
||||
/**
|
||||
* Configuration class for RAG pipeline
|
||||
*/
|
||||
public static class RAGConfig {
|
||||
private String vectorStoreType = "chroma";
|
||||
private String openAiApiKey;
|
||||
private String pineconeApiKey;
|
||||
private String pineconeEnvironment;
|
||||
private String pineconeIndex = "rag-documents";
|
||||
private String chromaCollection = "rag-documents";
|
||||
private String chromaPersistPath = "./chroma_db";
|
||||
private String qdrantHost = "localhost";
|
||||
private int qdrantPort = 6333;
|
||||
private String qdrantCollection = "rag-documents";
|
||||
private int chunkSize = 1000;
|
||||
private int chunkOverlap = 200;
|
||||
private int embeddingDimension = 1536;
|
||||
|
||||
// Getters and setters
|
||||
public String getVectorStoreType() { return vectorStoreType; }
|
||||
public void setVectorStoreType(String vectorStoreType) { this.vectorStoreType = vectorStoreType; }
|
||||
public String getOpenAiApiKey() { return openAiApiKey; }
|
||||
public void setOpenAiApiKey(String openAiApiKey) { this.openAiApiKey = openAiApiKey; }
|
||||
public String getPineconeApiKey() { return pineconeApiKey; }
|
||||
public void setPineconeApiKey(String pineconeApiKey) { this.pineconeApiKey = pineconeApiKey; }
|
||||
public String getPineconeEnvironment() { return pineconeEnvironment; }
|
||||
public void setPineconeEnvironment(String pineconeEnvironment) { this.pineconeEnvironment = pineconeEnvironment; }
|
||||
public String getPineconeIndex() { return pineconeIndex; }
|
||||
public void setPineconeIndex(String pineconeIndex) { this.pineconeIndex = pineconeIndex; }
|
||||
public String getChromaCollection() { return chromaCollection; }
|
||||
public void setChromaCollection(String chromaCollection) { this.chromaCollection = chromaCollection; }
|
||||
public String getChromaPersistPath() { return chromaPersistPath; }
|
||||
public void setChromaPersistPath(String chromaPersistPath) { this.chromaPersistPath = chromaPersistPath; }
|
||||
public String getQdrantHost() { return qdrantHost; }
|
||||
public void setQdrantHost(String qdrantHost) { this.qdrantHost = qdrantHost; }
|
||||
public int getQdrantPort() { return qdrantPort; }
|
||||
public void setQdrantPort(int qdrantPort) { this.qdrantPort = qdrantPort; }
|
||||
public String getQdrantCollection() { return qdrantCollection; }
|
||||
public void setQdrantCollection(String qdrantCollection) { this.qdrantCollection = qdrantCollection; }
|
||||
public int getChunkSize() { return chunkSize; }
|
||||
public void setChunkSize(int chunkSize) { this.chunkSize = chunkSize; }
|
||||
public int getChunkOverlap() { return chunkOverlap; }
|
||||
public void setChunkOverlap(int chunkOverlap) { this.chunkOverlap = chunkOverlap; }
|
||||
public int getEmbeddingDimension() { return embeddingDimension; }
|
||||
public void setEmbeddingDimension(int embeddingDimension) { this.embeddingDimension = embeddingDimension; }
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public RAGPipeline(RAGConfig config) {
|
||||
this.config = config;
|
||||
this.embeddingModel = createEmbeddingModel();
|
||||
this.embeddingStore = createEmbeddingStore();
|
||||
this.documentSplitter = createDocumentSplitter();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create embedding model based on configuration
|
||||
*/
|
||||
private EmbeddingModel createEmbeddingModel() {
|
||||
return OpenAiEmbeddingModel.builder()
|
||||
.apiKey(config.getOpenAiApiKey())
|
||||
.modelName("text-embedding-ada-002")
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create embedding store based on configuration
|
||||
*/
|
||||
private EmbeddingStore<TextSegment> createEmbeddingStore() {
|
||||
switch (config.getVectorStoreType().toLowerCase()) {
|
||||
case "pinecone":
|
||||
return PineconeEmbeddingStore.builder()
|
||||
.apiKey(config.getPineconeApiKey())
|
||||
.environment(config.getPineconeEnvironment())
|
||||
.index(config.getPineconeIndex())
|
||||
.dimension(config.getEmbeddingDimension())
|
||||
.build();
|
||||
|
||||
case "chroma":
|
||||
return ChromaEmbeddingStore.builder()
|
||||
.collectionName(config.getChromaCollection())
|
||||
.persistDirectory(config.getChromaPersistPath())
|
||||
.build();
|
||||
|
||||
case "qdrant":
|
||||
return QdrantEmbeddingStore.builder()
|
||||
.host(config.getQdrantHost())
|
||||
.port(config.getQdrantPort())
|
||||
.collectionName(config.getQdrantCollection())
|
||||
.dimension(config.getEmbeddingDimension())
|
||||
.build();
|
||||
|
||||
case "memory":
|
||||
default:
|
||||
return new InMemoryEmbeddingStore<>();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create document splitter
|
||||
*/
|
||||
private DocumentSplitter createDocumentSplitter() {
|
||||
return new RecursiveCharacterTextSplitter(
|
||||
config.getChunkSize(),
|
||||
config.getChunkOverlap()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load documents from directory
|
||||
*/
|
||||
public List<Document> loadDocuments(String directoryPath) {
|
||||
try {
|
||||
Path directory = Paths.get(directoryPath);
|
||||
List<Document> documents = FileSystemDocumentLoader.loadDocuments(directory);
|
||||
|
||||
// Add metadata to documents
|
||||
for (Document document : documents) {
|
||||
Map<String, Object> metadata = new HashMap<>(document.metadata().toMap());
|
||||
metadata.put("loaded_at", System.currentTimeMillis());
|
||||
metadata.put("source_directory", directoryPath);
|
||||
|
||||
// Update document metadata
|
||||
document = Document.from(document.text(), metadata);
|
||||
}
|
||||
|
||||
return documents;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to load documents from " + directoryPath, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process and ingest documents
|
||||
*/
|
||||
public void ingestDocuments(List<Document> documents) {
|
||||
// Split documents into segments
|
||||
List<TextSegment> segments = documentSplitter.split(documents);
|
||||
|
||||
// Add additional metadata to segments
|
||||
for (int i = 0; i < segments.size(); i++) {
|
||||
TextSegment segment = segments.get(i);
|
||||
Map<String, Object> metadata = new HashMap<>(segment.metadata().toMap());
|
||||
metadata.put("segment_index", i);
|
||||
metadata.put("total_segments", segments.size());
|
||||
metadata.put("processed_at", System.currentTimeMillis());
|
||||
|
||||
segments.set(i, TextSegment.from(segment.text(), metadata));
|
||||
}
|
||||
|
||||
// Ingest into embedding store
|
||||
EmbeddingStoreIngestor.ingest(segments, embeddingStore);
|
||||
|
||||
System.out.println("Ingested " + documents.size() + " documents into " +
|
||||
segments.size() + " segments");
|
||||
}
|
||||
|
||||
/**
|
||||
* Search documents with optional filtering
|
||||
*/
|
||||
public List<TextSegment> search(String query, int maxResults, Filter filter) {
|
||||
Embedding queryEmbedding = embeddingModel.embed(query).content();
|
||||
|
||||
return embeddingStore.findRelevant(queryEmbedding, maxResults, filter);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search documents with metadata filtering
|
||||
*/
|
||||
public List<TextSegment> searchWithMetadataFilter(String query, int maxResults,
|
||||
Map<String, Object> metadataFilters) {
|
||||
Filter filter = null;
|
||||
|
||||
if (metadataFilters != null && !metadataFilters.isEmpty()) {
|
||||
MetadataFilterBuilder filterBuilder = new MetadataFilterBuilder();
|
||||
|
||||
for (Map.Entry<String, Object> entry : metadataFilters.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
Object value = entry.getValue();
|
||||
|
||||
if (value instanceof String) {
|
||||
filterBuilder = filterBuilder.metadata(key).isEqualTo((String) value);
|
||||
} else if (value instanceof Number) {
|
||||
filterBuilder = filterBuilder.metadata(key).isEqualTo(((Number) value).doubleValue());
|
||||
}
|
||||
// Add more type handling as needed
|
||||
}
|
||||
|
||||
filter = filterBuilder.build();
|
||||
}
|
||||
|
||||
return search(query, maxResults, filter);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get statistics about the stored documents
|
||||
*/
|
||||
public RAGStatistics getStatistics() {
|
||||
// This is a simplified implementation
|
||||
// In practice, you might want to track more detailed statistics
|
||||
return new RAGStatistics(
|
||||
embeddingStore.getClass().getSimpleName(),
|
||||
config.getVectorStoreType()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Statistics holder class
|
||||
*/
|
||||
public static class RAGStatistics {
|
||||
private final String storeType;
|
||||
private final String implementation;
|
||||
|
||||
public RAGStatistics(String storeType, String implementation) {
|
||||
this.storeType = storeType;
|
||||
this.implementation = implementation;
|
||||
}
|
||||
|
||||
public String getStoreType() { return storeType; }
|
||||
public String getImplementation() { return implementation; }
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "RAGStatistics{" +
|
||||
"storeType='" + storeType + '\'' +
|
||||
", implementation='" + implementation + '\'' +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Example usage
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
// Configure the pipeline
|
||||
RAGConfig config = new RAGConfig();
|
||||
config.setVectorStoreType("chroma"); // or "pinecone", "qdrant", "memory"
|
||||
config.setOpenAiApiKey("your-openai-api-key");
|
||||
config.setChunkSize(1000);
|
||||
config.setChunkOverlap(200);
|
||||
|
||||
// Create pipeline
|
||||
RAGPipeline pipeline = new RAGPipeline(config);
|
||||
|
||||
// Load documents
|
||||
List<Document> documents = pipeline.loadDocuments("./documents");
|
||||
|
||||
// Ingest documents
|
||||
pipeline.ingestDocuments(documents);
|
||||
|
||||
// Search for relevant content
|
||||
List<TextSegment> results = pipeline.search("What is machine learning?", 5, null);
|
||||
|
||||
// Print results
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
TextSegment segment = results.get(i);
|
||||
System.out.println("Result " + (i + 1) + ":");
|
||||
System.out.println("Content: " + segment.text().substring(0, Math.min(200, segment.text().length())) + "...");
|
||||
System.out.println("Metadata: " + segment.metadata());
|
||||
System.out.println();
|
||||
}
|
||||
|
||||
// Print statistics
|
||||
System.out.println("Pipeline Statistics: " + pipeline.getStatistics());
|
||||
}
|
||||
}
|
||||
127
skills/ai/rag/assets/vector-store-config.yaml
Normal file
127
skills/ai/rag/assets/vector-store-config.yaml
Normal file
@@ -0,0 +1,127 @@
|
||||
# Vector Store Configuration Templates
|
||||
# This file contains configuration templates for different vector databases
|
||||
|
||||
# Chroma (Local/Development)
|
||||
chroma:
|
||||
type: chroma
|
||||
settings:
|
||||
persist_directory: "./chroma_db"
|
||||
collection_name: "rag_documents"
|
||||
host: "localhost"
|
||||
port: 8000
|
||||
|
||||
# Recommended for: Development, small-scale applications
|
||||
# Pros: Easy setup, local deployment, free
|
||||
# Cons: Limited scalability, single-node only
|
||||
|
||||
# Pinecone (Cloud/Production)
|
||||
pinecone:
|
||||
type: pinecone
|
||||
settings:
|
||||
api_key: "${PINECONE_API_KEY}"
|
||||
environment: "us-west1-gcp"
|
||||
index_name: "rag-documents"
|
||||
dimension: 1536
|
||||
metric: "cosine"
|
||||
pods: 1
|
||||
pod_type: "p1.x1"
|
||||
|
||||
# Recommended for: Production applications, large-scale
|
||||
# Pros: Managed service, scalable, fast
|
||||
# Cons: Cost, requires internet connection
|
||||
|
||||
# Weaviate (Open-source/Cloud)
|
||||
weaviate:
|
||||
type: weaviate
|
||||
settings:
|
||||
url: "http://localhost:8080"
|
||||
api_key: "${WEAVIATE_API_KEY}"
|
||||
class_name: "Document"
|
||||
text_key: "content"
|
||||
vectorizer: "text2vec-openai"
|
||||
module_config:
|
||||
text2vec-openai:
|
||||
model: "ada"
|
||||
modelVersion: "002"
|
||||
type: "text"
|
||||
baseUrl: "https://api.openai.com/v1"
|
||||
|
||||
# Recommended for: Hybrid search, GraphQL API
|
||||
# Pros: Open-source, hybrid search, flexible
|
||||
# Cons: More complex setup
|
||||
|
||||
# Qdrant (Performance-focused)
|
||||
qdrant:
|
||||
type: qdrant
|
||||
settings:
|
||||
host: "localhost"
|
||||
port: 6333
|
||||
collection_name: "rag_documents"
|
||||
vector_size: 1536
|
||||
distance: "Cosine"
|
||||
api_key: "${QDRANT_API_KEY}"
|
||||
|
||||
# Recommended for: Performance, advanced filtering
|
||||
# Pros: Fast, good filtering, open-source
|
||||
# Cons: Newer project, smaller community
|
||||
|
||||
# Milvus (Enterprise/Scale)
|
||||
milvus:
|
||||
type: milvus
|
||||
settings:
|
||||
host: "localhost"
|
||||
port: 19530
|
||||
collection_name: "rag_documents"
|
||||
dimension: 1536
|
||||
index_type: "IVF_FLAT"
|
||||
metric_type: "COSINE"
|
||||
nlist: 1024
|
||||
|
||||
# Recommended for: Enterprise, large-scale deployments
|
||||
# Pros: High performance, distributed
|
||||
# Cons: Complex setup, resource intensive
|
||||
|
||||
# FAISS (Local/Research)
|
||||
faiss:
|
||||
type: faiss
|
||||
settings:
|
||||
index_type: "IndexFlatL2"
|
||||
dimension: 1536
|
||||
save_path: "./faiss_index"
|
||||
|
||||
# Recommended for: Research, local processing
|
||||
# Pros: Fast, local, no dependencies
|
||||
# Cons: No persistence, limited features
|
||||
|
||||
# Common Configuration Parameters
|
||||
common:
|
||||
chunking:
|
||||
chunk_size: 1000
|
||||
chunk_overlap: 200
|
||||
separators: ["\n\n", "\n", " ", ""]
|
||||
|
||||
embedding:
|
||||
model: "text-embedding-ada-002"
|
||||
batch_size: 100
|
||||
max_retries: 3
|
||||
timeout: 30
|
||||
|
||||
retrieval:
|
||||
default_k: 5
|
||||
similarity_threshold: 0.7
|
||||
max_results: 20
|
||||
|
||||
performance:
|
||||
cache_embeddings: true
|
||||
cache_size: 1000
|
||||
parallel_processing: true
|
||||
batch_size: 50
|
||||
|
||||
# Environment Variables Template
|
||||
# Copy this to .env file and fill in your values
|
||||
environment:
|
||||
OPENAI_API_KEY: "your-openai-api-key-here"
|
||||
PINECONE_API_KEY: "your-pinecone-api-key-here"
|
||||
PINECONE_ENVIRONMENT: "us-west1-gcp"
|
||||
WEAVIATE_API_KEY: "your-weaviate-api-key-here"
|
||||
QDRANT_API_KEY: "your-qdrant-api-key-here"
|
||||
137
skills/ai/rag/references/document-chunking.md
Normal file
137
skills/ai/rag/references/document-chunking.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# Document Chunking Strategies
|
||||
|
||||
## Overview
|
||||
Document chunking is the process of breaking large documents into smaller, manageable pieces that can be effectively embedded and retrieved.
|
||||
|
||||
## Chunking Strategies
|
||||
|
||||
### 1. Recursive Character Text Splitter
|
||||
**Method**: Split text based on character count, trying separators in order
|
||||
**Use Case**: General purpose text splitting
|
||||
**Advantages**: Preserves sentence and paragraph boundaries when possible
|
||||
|
||||
```python
|
||||
from langchain.text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", " ", ""] # Try these in order
|
||||
)
|
||||
chunks = splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
### 2. Token-Based Splitting
|
||||
**Method**: Split based on token count rather than characters
|
||||
**Use Case**: When working with token limits of language models
|
||||
**Advantages**: Better control over context window usage
|
||||
|
||||
```python
|
||||
from langchain.text_splitters import TokenTextSplitter
|
||||
|
||||
splitter = TokenTextSplitter(
|
||||
chunk_size=512,
|
||||
chunk_overlap=50
|
||||
)
|
||||
chunks = splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
### 3. Semantic Chunking
|
||||
**Method**: Split based on semantic similarity
|
||||
**Use Case**: When maintaining semantic coherence is important
|
||||
**Advantages**: Chunks are more semantically meaningful
|
||||
|
||||
```python
|
||||
from langchain.text_splitters import SemanticChunker
|
||||
|
||||
splitter = SemanticChunker(
|
||||
embeddings=OpenAIEmbeddings(),
|
||||
breakpoint_threshold_type="percentile"
|
||||
)
|
||||
chunks = splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
### 4. Markdown Header Splitter
|
||||
**Method**: Split based on markdown headers
|
||||
**Use Case**: Structured documents with clear hierarchical organization
|
||||
**Advantages**: Maintains document structure and context
|
||||
|
||||
```python
|
||||
from langchain.text_splitters import MarkdownHeaderTextSplitter
|
||||
|
||||
headers_to_split_on = [
|
||||
("#", "Header 1"),
|
||||
("##", "Header 2"),
|
||||
("###", "Header 3"),
|
||||
]
|
||||
|
||||
splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
|
||||
chunks = splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
### 5. HTML Splitter
|
||||
**Method**: Split based on HTML tags
|
||||
**Use Case**: Web pages and HTML documents
|
||||
**Advantages**: Preserves HTML structure and metadata
|
||||
|
||||
```python
|
||||
from langchain.text_splitters import HTMLHeaderTextSplitter
|
||||
|
||||
headers_to_split_on = [
|
||||
("h1", "Header 1"),
|
||||
("h2", "Header 2"),
|
||||
("h3", "Header 3"),
|
||||
]
|
||||
|
||||
splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
|
||||
chunks = splitter.split_documents(documents)
|
||||
```
|
||||
|
||||
## Parameter Tuning
|
||||
|
||||
### Chunk Size
|
||||
- **Small chunks (200-400 tokens)**: More precise retrieval, but may lose context
|
||||
- **Medium chunks (500-1000 tokens)**: Good balance of precision and context
|
||||
- **Large chunks (1000-2000 tokens)**: More context, but less precise retrieval
|
||||
|
||||
### Chunk Overlap
|
||||
- **Purpose**: Preserve context at chunk boundaries
|
||||
- **Typical range**: 10-20% of chunk size
|
||||
- **Higher overlap**: Better context preservation, but more redundancy
|
||||
- **Lower overlap**: Less redundancy, but may lose important context
|
||||
|
||||
### Separators
|
||||
- **Hierarchical separators**: Start with larger boundaries (paragraphs), then smaller (sentences)
|
||||
- **Custom separators**: Add domain-specific separators for better results
|
||||
- **Language-specific**: Adjust for different languages and writing styles
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Preserve Context**: Ensure chunks contain enough surrounding context
|
||||
2. **Maintain Coherence**: Keep semantically related content together
|
||||
3. **Respect Boundaries**: Avoid breaking sentences or important phrases
|
||||
4. **Consider Query Types**: Adapt chunking strategy to typical user queries
|
||||
5. **Test and Iterate**: Evaluate different chunking strategies for your specific use case
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
1. **Retrieval Quality**: How well chunks answer user queries
|
||||
2. **Context Preservation**: Whether important context is maintained
|
||||
3. **Chunk Distribution**: Evenness of chunk sizes
|
||||
4. **Boundary Quality**: How natural chunk boundaries are
|
||||
5. **Retrieval Efficiency**: Impact on retrieval speed and accuracy
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Adaptive Chunking
|
||||
Adjust chunk size based on document structure and content density.
|
||||
|
||||
### Hierarchical Chunking
|
||||
Create multiple levels of chunks for different retrieval scenarios.
|
||||
|
||||
### Query-Aware Chunking
|
||||
Optimize chunk boundaries based on typical query patterns.
|
||||
|
||||
### Domain-Specific Splitting
|
||||
Use specialized splitters for specific document types (legal, medical, technical).
|
||||
88
skills/ai/rag/references/embedding-models.md
Normal file
88
skills/ai/rag/references/embedding-models.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Embedding Models Guide
|
||||
|
||||
## Overview
|
||||
Embedding models convert text into numerical vectors that capture semantic meaning for similarity search in RAG systems.
|
||||
|
||||
## Popular Embedding Models
|
||||
|
||||
### 1. text-embedding-ada-002 (OpenAI)
|
||||
- **Dimensions**: 1536
|
||||
- **Type**: General purpose
|
||||
- **Use Case**: Most applications requiring high quality embeddings
|
||||
- **Performance**: Excellent balance of quality and speed
|
||||
|
||||
### 2. all-MiniLM-L6-v2 (Sentence Transformers)
|
||||
- **Dimensions**: 384
|
||||
- **Type**: Lightweight
|
||||
- **Use Case**: Applications requiring fast inference
|
||||
- **Performance**: Good quality, very fast
|
||||
|
||||
### 3. e5-large-v2
|
||||
- **Dimensions**: 1024
|
||||
- **Type**: High quality
|
||||
- **Use Case**: Applications needing superior performance
|
||||
- **Performance**: Excellent quality, multilingual support
|
||||
|
||||
### 4. Instructor
|
||||
- **Dimensions**: Variable (768)
|
||||
- **Type**: Task-specific
|
||||
- **Use Case**: Domain-specific applications
|
||||
- **Performance**: Can be fine-tuned for specific tasks
|
||||
|
||||
### 5. bge-large-en-v1.5
|
||||
- **Dimensions**: 1024
|
||||
- **Type**: State-of-the-art
|
||||
- **Use Case**: Applications requiring best possible quality
|
||||
- **Performance**: SOTA performance on benchmarks
|
||||
|
||||
## Selection Criteria
|
||||
|
||||
1. **Quality vs Speed**: Balance between embedding quality and inference speed
|
||||
2. **Dimension Size**: Impact on storage and retrieval performance
|
||||
3. **Domain**: Specific language or domain requirements
|
||||
4. **Cost**: API costs vs local deployment
|
||||
5. **Batch Size**: Throughput requirements
|
||||
6. **Language**: Multilingual support needs
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### OpenAI Embeddings
|
||||
```python
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vector = embeddings.embed_query("Your text here")
|
||||
```
|
||||
|
||||
### Sentence Transformers
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
vector = model.encode("Your text here")
|
||||
```
|
||||
|
||||
### Hugging Face Models
|
||||
```python
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
```
|
||||
|
||||
## Optimization Tips
|
||||
|
||||
1. **Batch Processing**: Process multiple texts together for efficiency
|
||||
2. **Model Quantization**: Reduce model size for faster inference
|
||||
3. **Caching**: Cache embeddings for frequently used texts
|
||||
4. **GPU Acceleration**: Use GPU for faster processing when available
|
||||
5. **Model Selection**: Choose appropriate model size for your use case
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
1. **Semantic Similarity**: How well embeddings capture meaning
|
||||
2. **Retrieval Performance**: Quality of retrieved documents
|
||||
3. **Speed**: Inference time per document
|
||||
4. **Memory Usage**: RAM requirements for the model
|
||||
5. **Cost**: API costs or infrastructure requirements
|
||||
94
skills/ai/rag/references/langchain4j-rag-guide.md
Normal file
94
skills/ai/rag/references/langchain4j-rag-guide.md
Normal file
@@ -0,0 +1,94 @@
|
||||
# LangChain4j RAG Implementation Guide
|
||||
|
||||
## Overview
|
||||
RAG (Retrieval-Augmented Generation) extends LLM knowledge by finding and injecting relevant information from your data into prompts before sending to the LLM.
|
||||
|
||||
## What is RAG?
|
||||
RAG helps LLMs answer questions using domain-specific knowledge by retrieving relevant information to reduce hallucinations.
|
||||
|
||||
## RAG Flavors in LangChain4j
|
||||
|
||||
### 1. Easy RAG
|
||||
Simplest way to start with minimal setup. Handles document loading, splitting, and embedding automatically.
|
||||
|
||||
### 2. Core RAG APIs
|
||||
Modular components including:
|
||||
- Document
|
||||
- TextSegment
|
||||
- EmbeddingModel
|
||||
- EmbeddingStore
|
||||
- DocumentSplitter
|
||||
|
||||
### 3. Advanced RAG
|
||||
Complex pipelines supporting:
|
||||
- Query transformation
|
||||
- Multi-source retrieval
|
||||
- Re-ranking with components like QueryTransformer and ContentRetriever
|
||||
|
||||
## RAG Stages
|
||||
|
||||
### 1. Indexing
|
||||
Pre-process documents for efficient search
|
||||
|
||||
### 2. Retrieval
|
||||
Find relevant content based on user queries
|
||||
|
||||
## Core Components
|
||||
|
||||
### Documents with metadata
|
||||
Structured representation of your content with associated metadata for filtering and context.
|
||||
|
||||
### Text segments (chunks)
|
||||
Smaller, manageable pieces of documents that are embedded and stored in vector databases.
|
||||
|
||||
### Embedding models
|
||||
Convert text segments into numerical vectors for similarity search.
|
||||
|
||||
### Embedding stores (vector databases)
|
||||
Store and efficiently retrieve embedded text segments.
|
||||
|
||||
### Content retrievers
|
||||
Find relevant content based on user queries.
|
||||
|
||||
### Query transformers
|
||||
Transform and optimize user queries for better retrieval.
|
||||
|
||||
### Content aggregators
|
||||
Combine and rank retrieved content.
|
||||
|
||||
## Advanced Features
|
||||
|
||||
- Query transformation and routing
|
||||
- Multiple retrievers for different data sources
|
||||
- Re-ranking models for improved relevance
|
||||
- Metadata filtering for targeted retrieval
|
||||
- Parallel processing for performance
|
||||
|
||||
## Implementation Example (Easy RAG)
|
||||
|
||||
```java
|
||||
// Load documents
|
||||
List<Document> documents = FileSystemDocumentLoader.loadDocuments("/path/to/docs");
|
||||
|
||||
// Create embedding store
|
||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
||||
|
||||
// Ingest documents
|
||||
EmbeddingStoreIngestor.ingest(documents, embeddingStore);
|
||||
|
||||
// Create AI service
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
.chatModel(chatModel)
|
||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||
.contentRetriever(EmbeddingStoreContentRetriever.from(embeddingStore))
|
||||
.build();
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Document Preparation**: Clean and structure documents before ingestion
|
||||
2. **Chunk Size**: Balance between context preservation and retrieval precision
|
||||
3. **Metadata Strategy**: Include relevant metadata for filtering and context
|
||||
4. **Embedding Model Selection**: Choose models appropriate for your domain
|
||||
5. **Retrieval Strategy**: Select appropriate k values and filtering criteria
|
||||
6. **Evaluation**: Continuously evaluate retrieval quality and answer accuracy
|
||||
161
skills/ai/rag/references/retrieval-strategies.md
Normal file
161
skills/ai/rag/references/retrieval-strategies.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# Advanced Retrieval Strategies
|
||||
|
||||
## Overview
|
||||
Different retrieval approaches for finding relevant documents in RAG systems, each with specific strengths and use cases.
|
||||
|
||||
## Retrieval Approaches
|
||||
|
||||
### 1. Dense Retrieval
|
||||
**Method**: Semantic similarity via embeddings
|
||||
**Use Case**: Understanding meaning and context
|
||||
**Example**: Finding documents about "machine learning" when query is "AI algorithms"
|
||||
|
||||
```python
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
vectorstore = Chroma.from_documents(chunks, embeddings)
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
### 2. Sparse Retrieval
|
||||
**Method**: Keyword matching (BM25, TF-IDF)
|
||||
**Use Case**: Exact term matching and keyword-specific queries
|
||||
**Example**: Finding documents containing specific technical terms
|
||||
|
||||
```python
|
||||
from langchain.retrievers import BM25Retriever
|
||||
|
||||
bm25_retriever = BM25Retriever.from_documents(chunks)
|
||||
bm25_retriever.k = 5
|
||||
results = bm25_retriever.get_relevant_documents("query")
|
||||
```
|
||||
|
||||
### 3. Hybrid Search
|
||||
**Method**: Combine dense + sparse retrieval
|
||||
**Use Case**: Balance between semantic understanding and keyword matching
|
||||
|
||||
```python
|
||||
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
||||
|
||||
# Sparse retriever (BM25)
|
||||
bm25_retriever = BM25Retriever.from_documents(chunks)
|
||||
bm25_retriever.k = 5
|
||||
|
||||
# Dense retriever (embeddings)
|
||||
embedding_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
||||
|
||||
# Combine with weights
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, embedding_retriever],
|
||||
weights=[0.3, 0.7]
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Multi-Query Retrieval
|
||||
**Method**: Generate multiple query variations
|
||||
**Use Case**: Complex queries that can be interpreted in multiple ways
|
||||
|
||||
```python
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
|
||||
# Generate multiple query perspectives
|
||||
retriever = MultiQueryRetriever.from_llm(
|
||||
retriever=vectorstore.as_retriever(),
|
||||
llm=OpenAI()
|
||||
)
|
||||
|
||||
# Single query → multiple variations → combined results
|
||||
results = retriever.get_relevant_documents("What is the main topic?")
|
||||
```
|
||||
|
||||
### 5. HyDE (Hypothetical Document Embeddings)
|
||||
**Method**: Generate hypothetical documents for better retrieval
|
||||
**Use Case**: When queries are very different from document style
|
||||
|
||||
```python
|
||||
# Generate hypothetical document based on query
|
||||
hypothetical_doc = llm.generate(f"Write a document about: {query}")
|
||||
# Use hypothetical doc for retrieval
|
||||
results = vectorstore.similarity_search(hypothetical_doc, k=5)
|
||||
```
|
||||
|
||||
## Advanced Retrieval Patterns
|
||||
|
||||
### Contextual Compression
|
||||
Compress retrieved documents to only include relevant parts
|
||||
|
||||
```python
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
|
||||
compressor = LLMChainExtractor.from_llm(llm)
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=compressor,
|
||||
base_retriever=vectorstore.as_retriever()
|
||||
)
|
||||
```
|
||||
|
||||
### Parent Document Retriever
|
||||
Store small chunks for retrieval, return larger chunks for context
|
||||
|
||||
```python
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import InMemoryStore
|
||||
|
||||
store = InMemoryStore()
|
||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
||||
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter
|
||||
)
|
||||
```
|
||||
|
||||
## Retrieval Optimization Techniques
|
||||
|
||||
### 1. Metadata Filtering
|
||||
Filter results based on document metadata
|
||||
|
||||
```python
|
||||
results = vectorstore.similarity_search(
|
||||
"query",
|
||||
filter={"category": "technical", "date": {"$gte": "2023-01-01"}},
|
||||
k=5
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Maximal Marginal Relevance (MMR)
|
||||
Balance relevance with diversity
|
||||
|
||||
```python
|
||||
results = vectorstore.max_marginal_relevance_search(
|
||||
"query",
|
||||
k=5,
|
||||
fetch_k=20,
|
||||
lambda_mult=0.5 # 0=max diversity, 1=max relevance
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Reranking
|
||||
Improve top results with cross-encoder
|
||||
|
||||
```python
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
||||
candidates = vectorstore.similarity_search("query", k=20)
|
||||
pairs = [[query, doc.page_content] for doc in candidates]
|
||||
scores = reranker.predict(pairs)
|
||||
reranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)[:5]
|
||||
```
|
||||
|
||||
## Selection Guidelines
|
||||
|
||||
1. **Query Type**: Choose strategy based on typical query patterns
|
||||
2. **Document Type**: Consider document structure and content
|
||||
3. **Performance Requirements**: Balance quality vs speed
|
||||
4. **Domain Knowledge**: Leverage domain-specific patterns
|
||||
5. **User Expectations**: Match retrieval behavior to user expectations
|
||||
86
skills/ai/rag/references/vector-databases.md
Normal file
86
skills/ai/rag/references/vector-databases.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Vector Database Comparison and Configuration
|
||||
|
||||
## Overview
|
||||
Vector databases store and efficiently retrieve document embeddings for semantic search in RAG systems.
|
||||
|
||||
## Popular Vector Database Options
|
||||
|
||||
### 1. Pinecone
|
||||
- **Type**: Managed cloud service
|
||||
- **Features**: Scalable, fast queries, managed infrastructure
|
||||
- **Use Case**: Production applications requiring high availability
|
||||
|
||||
### 2. Weaviate
|
||||
- **Type**: Open-source, hybrid search
|
||||
- **Features**: Combines vector and keyword search, GraphQL API
|
||||
- **Use Case**: Applications needing both semantic and traditional search
|
||||
|
||||
### 3. Milvus
|
||||
- **Type**: High performance, on-premise
|
||||
- **Features**: Distributed architecture, GPU acceleration
|
||||
- **Use Case**: Large-scale deployments with custom infrastructure
|
||||
|
||||
### 4. Chroma
|
||||
- **Type**: Lightweight, easy to use
|
||||
- **Features**: Local deployment, simple API
|
||||
- **Use Case**: Development and small-scale applications
|
||||
|
||||
### 5. Qdrant
|
||||
- **Type**: Fast, filtered search
|
||||
- **Features**: Advanced filtering, payload support
|
||||
- **Use Case**: Applications requiring complex metadata filtering
|
||||
|
||||
### 6. FAISS
|
||||
- **Type**: Meta's library, local deployment
|
||||
- **Features**: High performance, CPU/GPU optimized
|
||||
- **Use Case**: Research and applications needing full control
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Pinecone Setup
|
||||
```python
|
||||
import pinecone
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
pinecone.init(api_key="your-api-key", environment="us-west1-gcp")
|
||||
index = pinecone.Index("your-index-name")
|
||||
vectorstore = Pinecone(index, embeddings.embed_query, "text")
|
||||
```
|
||||
|
||||
### Weaviate Setup
|
||||
```python
|
||||
import weaviate
|
||||
from langchain.vectorstores import Weaviate
|
||||
|
||||
client = weaviate.Client("http://localhost:8080")
|
||||
vectorstore = Weaviate(client, "Document", "content", embeddings)
|
||||
```
|
||||
|
||||
### Chroma Local Setup
|
||||
```python
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
vectorstore = Chroma(
|
||||
collection_name="my_collection",
|
||||
embedding_function=embeddings,
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
```
|
||||
|
||||
## Selection Criteria
|
||||
|
||||
1. **Scale**: Number of documents and expected query volume
|
||||
2. **Performance**: Latency requirements and throughput needs
|
||||
3. **Deployment**: Cloud vs on-premise preferences
|
||||
4. **Features**: Filtering, hybrid search, metadata support
|
||||
5. **Cost**: Budget constraints and operational overhead
|
||||
6. **Maintenance**: Team expertise and available resources
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Indexing Strategy**: Choose appropriate distance metrics (cosine, euclidean)
|
||||
2. **Sharding**: Distribute data for large-scale deployments
|
||||
3. **Monitoring**: Track query performance and system health
|
||||
4. **Backups**: Implement regular backup procedures
|
||||
5. **Security**: Secure access to sensitive data
|
||||
6. **Optimization**: Tune parameters for your specific use case
|
||||
396
skills/aws-java/aws-rds-spring-boot-integration/SKILL.md
Normal file
396
skills/aws-java/aws-rds-spring-boot-integration/SKILL.md
Normal file
@@ -0,0 +1,396 @@
|
||||
---
|
||||
name: aws-rds-spring-boot-integration
|
||||
description: Configure AWS RDS (Aurora, MySQL, PostgreSQL) with Spring Boot applications. Use when setting up datasources, connection pooling, security, and production-ready database configuration.
|
||||
category: aws
|
||||
tags: [aws, rds, aurora, spring-boot, spring-data-jpa, datasource, configuration, hikari, mysql, postgresql]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash, Glob
|
||||
---
|
||||
|
||||
# AWS RDS Spring Boot Integration
|
||||
|
||||
Configure AWS RDS databases (Aurora, MySQL, PostgreSQL) with Spring Boot applications for production-ready connectivity.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Setting up AWS RDS Aurora with Spring Data JPA
|
||||
- Configuring datasource properties for Aurora, MySQL, or PostgreSQL endpoints
|
||||
- Implementing HikariCP connection pooling for RDS
|
||||
- Setting up environment-specific configurations (dev/prod)
|
||||
- Configuring SSL connections to AWS RDS
|
||||
- Troubleshooting RDS connection issues
|
||||
- Setting up database migrations with Flyway
|
||||
- Integrating with AWS Secrets Manager for credential management
|
||||
- Optimizing connection pool settings for RDS workloads
|
||||
- Implementing read/write split with Aurora
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting AWS RDS Spring Boot integration:
|
||||
1. AWS account with RDS access
|
||||
2. Spring Boot project (3.x)
|
||||
3. RDS instance created and running (Aurora/MySQL/PostgreSQL)
|
||||
4. Security group configured for database access
|
||||
5. Database endpoint information available
|
||||
6. Database credentials secured (environment variables or Secrets Manager)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Step 1: Add Dependencies
|
||||
|
||||
**Maven (pom.xml):**
|
||||
```xml
|
||||
<dependencies>
|
||||
<!-- Spring Data JPA -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-data-jpa</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Aurora MySQL Driver -->
|
||||
<dependency>
|
||||
<groupId>com.mysql</groupId>
|
||||
<artifactId>mysql-connector-j</artifactId>
|
||||
<version>8.2.0</version>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Aurora PostgreSQL Driver (alternative) -->
|
||||
<dependency>
|
||||
<groupId>org.postgresql</groupId>
|
||||
<artifactId>postgresql</artifactId>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Flyway for database migrations -->
|
||||
<dependency>
|
||||
<groupId>org.flywaydb</groupId>
|
||||
<artifactId>flyway-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Validation -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-validation</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
```
|
||||
|
||||
**Gradle (build.gradle):**
|
||||
```gradle
|
||||
dependencies {
|
||||
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
|
||||
implementation 'org.springframework.boot:spring-boot-starter-validation'
|
||||
|
||||
// Aurora MySQL
|
||||
runtimeOnly 'com.mysql:mysql-connector-j:8.2.0'
|
||||
|
||||
// Aurora PostgreSQL (alternative)
|
||||
runtimeOnly 'org.postgresql:postgresql'
|
||||
|
||||
// Flyway
|
||||
implementation 'org.flywaydb:flyway-core'
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Basic Datasource Configuration
|
||||
|
||||
**application.properties (Aurora MySQL):**
|
||||
```properties
|
||||
# Aurora MySQL Datasource - Cluster Endpoint
|
||||
spring.datasource.url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
|
||||
spring.datasource.username=admin
|
||||
spring.datasource.password=${DB_PASSWORD}
|
||||
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver
|
||||
|
||||
# JPA/Hibernate Configuration
|
||||
spring.jpa.hibernate.ddl-auto=validate
|
||||
spring.jpa.show-sql=false
|
||||
spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.MySQL8Dialect
|
||||
spring.jpa.properties.hibernate.format_sql=true
|
||||
spring.jpa.open-in-view=false
|
||||
|
||||
# HikariCP Connection Pool
|
||||
spring.datasource.hikari.maximum-pool-size=20
|
||||
spring.datasource.hikari.minimum-idle=5
|
||||
spring.datasource.hikari.connection-timeout=20000
|
||||
spring.datasource.hikari.idle-timeout=300000
|
||||
spring.datasource.hikari.max-lifetime=1200000
|
||||
|
||||
# Flyway Configuration
|
||||
spring.flyway.enabled=true
|
||||
spring.flyway.baseline-on-migrate=true
|
||||
spring.flyway.locations=classpath:db/migration
|
||||
```
|
||||
|
||||
**application.properties (Aurora PostgreSQL):**
|
||||
```properties
|
||||
# Aurora PostgreSQL Datasource
|
||||
spring.datasource.url=jdbc:postgresql://myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:5432/devops
|
||||
spring.datasource.username=admin
|
||||
spring.datasource.password=${DB_PASSWORD}
|
||||
spring.datasource.driver-class-name=org.postgresql.Driver
|
||||
|
||||
# JPA/Hibernate Configuration
|
||||
spring.jpa.hibernate.ddl-auto=validate
|
||||
spring.jpa.show-sql=false
|
||||
spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.PostgreSQLDialect
|
||||
spring.jpa.properties.hibernate.jdbc.lob.non_contextual_creation=true
|
||||
spring.jpa.open-in-view=false
|
||||
```
|
||||
|
||||
### Step 3: Set Up Environment Variables
|
||||
|
||||
```bash
|
||||
# Production environment variables
|
||||
export DB_PASSWORD=YourStrongPassword123!
|
||||
export SPRING_PROFILES_ACTIVE=prod
|
||||
|
||||
# For development
|
||||
export SPRING_PROFILES_ACTIVE=dev
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Simple Aurora Cluster (MySQL)
|
||||
|
||||
**application.yml:**
|
||||
```yaml
|
||||
spring:
|
||||
application:
|
||||
name: DevOps
|
||||
|
||||
datasource:
|
||||
url: jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
|
||||
username: admin
|
||||
password: ${DB_PASSWORD}
|
||||
driver-class-name: com.mysql.cj.jdbc.Driver
|
||||
|
||||
hikari:
|
||||
pool-name: AuroraHikariPool
|
||||
maximum-pool-size: 20
|
||||
minimum-idle: 5
|
||||
connection-timeout: 20000
|
||||
idle-timeout: 300000
|
||||
max-lifetime: 1200000
|
||||
leak-detection-threshold: 60000
|
||||
connection-test-query: SELECT 1
|
||||
|
||||
jpa:
|
||||
hibernate:
|
||||
ddl-auto: validate
|
||||
show-sql: false
|
||||
open-in-view: false
|
||||
properties:
|
||||
hibernate:
|
||||
dialect: org.hibernate.dialect.MySQL8Dialect
|
||||
format_sql: true
|
||||
jdbc:
|
||||
batch_size: 20
|
||||
order_inserts: true
|
||||
order_updates: true
|
||||
|
||||
flyway:
|
||||
enabled: true
|
||||
baseline-on-migrate: true
|
||||
locations: classpath:db/migration
|
||||
validate-on-migrate: true
|
||||
|
||||
logging:
|
||||
level:
|
||||
org.hibernate.SQL: WARN
|
||||
com.zaxxer.hikari: INFO
|
||||
```
|
||||
|
||||
### Read/Write Split Configuration
|
||||
|
||||
For read-heavy workloads, use separate writer and reader datasources:
|
||||
|
||||
**application.properties:**
|
||||
```properties
|
||||
# Aurora MySQL - Writer Endpoint
|
||||
spring.datasource.writer.jdbc-url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
|
||||
spring.datasource.writer.username=admin
|
||||
spring.datasource.writer.password=${DB_PASSWORD}
|
||||
spring.datasource.writer.driver-class-name=com.mysql.cj.jdbc.Driver
|
||||
|
||||
# Aurora MySQL - Reader Endpoint (Read Replicas)
|
||||
spring.datasource.reader.jdbc-url=jdbc:mysql://myapp-aurora-cluster.cluster-ro-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops
|
||||
spring.datasource.reader.username=admin
|
||||
spring.datasource.reader.password=${DB_PASSWORD}
|
||||
spring.datasource.reader.driver-class-name=com.mysql.cj.jdbc.Driver
|
||||
|
||||
# HikariCP for Writer
|
||||
spring.datasource.writer.hikari.maximum-pool-size=15
|
||||
spring.datasource.writer.hikari.minimum-idle=5
|
||||
|
||||
# HikariCP for Reader
|
||||
spring.datasource.reader.hikari.maximum-pool-size=25
|
||||
spring.datasource.reader.hikari.minimum-idle=10
|
||||
```
|
||||
|
||||
### SSL Configuration
|
||||
|
||||
**Aurora MySQL with SSL:**
|
||||
```properties
|
||||
spring.datasource.url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops?useSSL=true&requireSSL=true&verifyServerCertificate=true
|
||||
```
|
||||
|
||||
**Aurora PostgreSQL with SSL:**
|
||||
```properties
|
||||
spring.datasource.url=jdbc:postgresql://myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:5432/devops?ssl=true&sslmode=require
|
||||
```
|
||||
|
||||
## Environment-Specific Configuration
|
||||
|
||||
### Development Profile
|
||||
|
||||
**application-dev.properties:**
|
||||
```properties
|
||||
# Local MySQL for development
|
||||
spring.datasource.url=jdbc:mysql://localhost:3306/devops_dev
|
||||
spring.datasource.username=root
|
||||
spring.datasource.password=root
|
||||
|
||||
# Enable DDL auto-update in development
|
||||
spring.jpa.hibernate.ddl-auto=update
|
||||
spring.jpa.show-sql=true
|
||||
|
||||
# Smaller connection pool for local dev
|
||||
spring.datasource.hikari.maximum-pool-size=5
|
||||
spring.datasource.hikari.minimum-idle=2
|
||||
```
|
||||
|
||||
### Production Profile
|
||||
|
||||
**application-prod.properties:**
|
||||
```properties
|
||||
# Aurora Cluster Endpoint (Production)
|
||||
spring.datasource.url=jdbc:mysql://${AURORA_ENDPOINT}:3306/${DB_NAME}
|
||||
spring.datasource.username=${DB_USERNAME}
|
||||
spring.datasource.password=${DB_PASSWORD}
|
||||
|
||||
# Validate schema only in production
|
||||
spring.jpa.hibernate.ddl-auto=validate
|
||||
spring.jpa.show-sql=false
|
||||
spring.jpa.open-in-view=false
|
||||
|
||||
# Production-optimized connection pool
|
||||
spring.datasource.hikari.maximum-pool-size=30
|
||||
spring.datasource.hikari.minimum-idle=10
|
||||
spring.datasource.hikari.connection-timeout=20000
|
||||
spring.datasource.hikari.idle-timeout=300000
|
||||
spring.datasource.hikari.max-lifetime=1200000
|
||||
|
||||
# Enable Flyway migrations
|
||||
spring.flyway.enabled=true
|
||||
spring.flyway.validate-on-migrate=true
|
||||
```
|
||||
|
||||
## Database Migration Setup
|
||||
|
||||
Create migration files for Flyway:
|
||||
|
||||
```
|
||||
src/main/resources/db/migration/
|
||||
├── V1__create_users_table.sql
|
||||
├── V2__add_phone_column.sql
|
||||
└── V3__create_orders_table.sql
|
||||
```
|
||||
|
||||
**V1__create_users_table.sql:**
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_email (email)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
For advanced configuration, see the reference documents:
|
||||
|
||||
- [Multi-datasource, SSL, Secrets Manager integration](references/advanced-configuration.md)
|
||||
- [Common issues and solutions](references/troubleshooting.md)
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Connection Pool Optimization
|
||||
|
||||
- Use HikariCP with Aurora-optimized settings
|
||||
- Set appropriate pool sizes based on Aurora instance capacity
|
||||
- Configure connection timeouts for failover handling
|
||||
- Enable leak detection
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
- Never hardcode credentials in configuration files
|
||||
- Use environment variables or AWS Secrets Manager
|
||||
- Enable SSL/TLS connections
|
||||
- Configure proper security group rules
|
||||
- Use IAM Database Authentication when possible
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
- Enable batch operations for bulk data operations
|
||||
- Disable open-in-view pattern to prevent lazy loading issues
|
||||
- Use appropriate indexing for Aurora queries
|
||||
- Configure connection pooling for high availability
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Enable Spring Boot Actuator for database metrics
|
||||
- Monitor connection pool metrics
|
||||
- Set up proper logging for debugging
|
||||
- Configure health checks for database connectivity
|
||||
|
||||
## Testing
|
||||
|
||||
Create a health check endpoint to test database connectivity:
|
||||
|
||||
```java
|
||||
@RestController
|
||||
@RequestMapping("/api/health")
|
||||
public class DatabaseHealthController {
|
||||
|
||||
@Autowired
|
||||
private DataSource dataSource;
|
||||
|
||||
@GetMapping("/db-connection")
|
||||
public ResponseEntity<Map<String, Object>> testDatabaseConnection() {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
|
||||
try (Connection connection = dataSource.getConnection()) {
|
||||
response.put("status", "success");
|
||||
response.put("database", connection.getCatalog());
|
||||
response.put("url", connection.getMetaData().getURL());
|
||||
response.put("connected", true);
|
||||
return ResponseEntity.ok(response);
|
||||
} catch (Exception e) {
|
||||
response.put("status", "failed");
|
||||
response.put("error", e.getMessage());
|
||||
response.put("connected", false);
|
||||
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE).body(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Test with cURL:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/health/db-connection
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For detailed troubleshooting and advanced configuration, refer to:
|
||||
|
||||
- [AWS RDS Aurora Advanced Configuration](references/advanced-configuration.md)
|
||||
- [AWS RDS Aurora Troubleshooting Guide](references/troubleshooting.md)
|
||||
- [AWS RDS Aurora documentation](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/java_aurora_code_examples.html)
|
||||
- [Spring Boot Data RDS Aurora documentation](https://www.baeldung.com/aws-aurora-rds-java)
|
||||
@@ -0,0 +1,279 @@
|
||||
# AWS RDS Aurora Advanced Configuration
|
||||
|
||||
## Read/Write Split Configuration
|
||||
|
||||
For applications with heavy read operations, configure separate datasources:
|
||||
|
||||
**Multi-Datasource Configuration Class:**
|
||||
```java
|
||||
@Configuration
|
||||
public class AuroraDataSourceConfig {
|
||||
|
||||
@Primary
|
||||
@Bean(name = "writerDataSource")
|
||||
@ConfigurationProperties("spring.datasource.writer")
|
||||
public DataSource writerDataSource() {
|
||||
return DataSourceBuilder.create().build();
|
||||
}
|
||||
|
||||
@Bean(name = "readerDataSource")
|
||||
@ConfigurationProperties("spring.datasource.reader")
|
||||
public DataSource readerDataSource() {
|
||||
return DataSourceBuilder.create().build();
|
||||
}
|
||||
|
||||
@Primary
|
||||
@Bean(name = "writerEntityManagerFactory")
|
||||
public LocalContainerEntityManagerFactoryBean writerEntityManagerFactory(
|
||||
EntityManagerFactoryBuilder builder,
|
||||
@Qualifier("writerDataSource") DataSource dataSource) {
|
||||
return builder
|
||||
.dataSource(dataSource)
|
||||
.packages("com.example.domain")
|
||||
.persistenceUnit("writer")
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean(name = "readerEntityManagerFactory")
|
||||
public LocalContainerEntityManagerFactoryBean readerEntityManagerFactory(
|
||||
EntityManagerFactoryBuilder builder,
|
||||
@Qualifier("readerDataSource") DataSource dataSource) {
|
||||
return builder
|
||||
.dataSource(dataSource)
|
||||
.packages("com.example.domain")
|
||||
.persistenceUnit("reader")
|
||||
.build();
|
||||
}
|
||||
|
||||
@Primary
|
||||
@Bean(name = "writerTransactionManager")
|
||||
public PlatformTransactionManager writerTransactionManager(
|
||||
@Qualifier("writerEntityManagerFactory") EntityManagerFactory entityManagerFactory) {
|
||||
return new JpaTransactionManager(entityManagerFactory);
|
||||
}
|
||||
|
||||
@Bean(name = "readerTransactionManager")
|
||||
public PlatformTransactionManager readerTransactionManager(
|
||||
@Qualifier("readerEntityManagerFactory") EntityManagerFactory entityManagerFactory) {
|
||||
return new JpaTransactionManager(entityManagerFactory);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Usage in Repository:**
|
||||
```java
|
||||
@Repository
|
||||
public interface UserReadRepository extends JpaRepository<User, Long> {
|
||||
// Read operations automatically use reader endpoint
|
||||
}
|
||||
|
||||
@Repository
|
||||
public interface UserWriteRepository extends JpaRepository<User, Long> {
|
||||
// Write operations use writer endpoint
|
||||
}
|
||||
```
|
||||
|
||||
## SSL/TLS Configuration
|
||||
|
||||
Enable SSL for secure connections to Aurora:
|
||||
|
||||
**Aurora MySQL with SSL:**
|
||||
```properties
|
||||
spring.datasource.url=jdbc:mysql://myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:3306/devops?useSSL=true&requireSSL=true&verifyServerCertificate=true
|
||||
```
|
||||
|
||||
**Aurora PostgreSQL with SSL:**
|
||||
```properties
|
||||
spring.datasource.url=jdbc:postgresql://myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com:5432/devops?ssl=true&sslmode=require
|
||||
```
|
||||
|
||||
**Download RDS Certificate:**
|
||||
```bash
|
||||
# Download RDS CA certificate
|
||||
wget https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem
|
||||
|
||||
# Configure in application
|
||||
spring.datasource.url=jdbc:mysql://...?useSSL=true&trustCertificateKeyStoreUrl=file:///path/to/global-bundle.pem
|
||||
```
|
||||
|
||||
## AWS Secrets Manager Integration
|
||||
|
||||
**Add AWS SDK Dependency:**
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>secretsmanager</artifactId>
|
||||
<version>2.20.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
**Secrets Manager Configuration:**
|
||||
```java
|
||||
@Configuration
|
||||
public class AuroraDataSourceConfig {
|
||||
|
||||
@Value("${aws.secretsmanager.secret-name}")
|
||||
private String secretName;
|
||||
|
||||
@Value("${aws.region}")
|
||||
private String region;
|
||||
|
||||
@Bean
|
||||
public DataSource dataSource() {
|
||||
Map<String, String> credentials = getAuroraCredentials();
|
||||
|
||||
HikariConfig config = new HikariConfig();
|
||||
config.setJdbcUrl(credentials.get("url"));
|
||||
config.setUsername(credentials.get("username"));
|
||||
config.setPassword(credentials.get("password"));
|
||||
config.setMaximumPoolSize(20);
|
||||
config.setMinimumIdle(5);
|
||||
config.setConnectionTimeout(20000);
|
||||
|
||||
return new HikariDataSource(config);
|
||||
}
|
||||
|
||||
private Map<String, String> getAuroraCredentials() {
|
||||
SecretsManagerClient client = SecretsManagerClient.builder()
|
||||
.region(Region.of(region))
|
||||
.build();
|
||||
|
||||
GetSecretValueRequest request = GetSecretValueRequest.builder()
|
||||
.secretId(secretName)
|
||||
.build();
|
||||
|
||||
GetSecretValueResponse response = client.getSecretValue(request);
|
||||
String secretString = response.secretString();
|
||||
|
||||
// Parse JSON secret
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
return mapper.readValue(secretString, Map.class);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to parse secret", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**application.properties (Secrets Manager):**
|
||||
```properties
|
||||
aws.secretsmanager.secret-name=prod/aurora/credentials
|
||||
aws.region=us-east-1
|
||||
```
|
||||
|
||||
## Database Migration with Flyway
|
||||
|
||||
### Setup Flyway
|
||||
|
||||
**Create Migration Directory:**
|
||||
```
|
||||
src/main/resources/db/migration/
|
||||
├── V1__create_users_table.sql
|
||||
├── V2__add_phone_column.sql
|
||||
└── V3__create_orders_table.sql
|
||||
```
|
||||
|
||||
**V1__create_users_table.sql:**
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_email (email)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
```
|
||||
|
||||
**V2__add_phone_column.sql:**
|
||||
```sql
|
||||
ALTER TABLE users ADD COLUMN phone VARCHAR(20);
|
||||
```
|
||||
|
||||
**Flyway Configuration:**
|
||||
```properties
|
||||
spring.jpa.hibernate.ddl-auto=validate
|
||||
spring.flyway.enabled=true
|
||||
spring.flyway.baseline-on-migrate=true
|
||||
spring.flyway.locations=classpath:db/migration
|
||||
spring.flyway.validate-on-migrate=true
|
||||
```
|
||||
|
||||
## Connection Pool Optimization for Aurora
|
||||
|
||||
**Recommended HikariCP Settings:**
|
||||
```properties
|
||||
# Aurora-optimized connection pool
|
||||
spring.datasource.hikari.maximum-pool-size=20
|
||||
spring.datasource.hikari.minimum-idle=5
|
||||
spring.datasource.hikari.connection-timeout=20000
|
||||
spring.datasource.hikari.idle-timeout=300000
|
||||
spring.datasource.hikari.max-lifetime=1200000
|
||||
spring.datasource.hikari.leak-detection-threshold=60000
|
||||
spring.datasource.hikari.connection-test-query=SELECT 1
|
||||
```
|
||||
|
||||
**Formula for Pool Size:**
|
||||
```
|
||||
connections = ((core_count * 2) + effective_spindle_count)
|
||||
For Aurora: Use 20-30 connections per application instance
|
||||
```
|
||||
|
||||
## Failover Handling
|
||||
|
||||
Aurora automatically handles failover between instances. Configure connection retry:
|
||||
|
||||
```properties
|
||||
# Connection retry configuration
|
||||
spring.datasource.hikari.connection-timeout=30000
|
||||
spring.datasource.url=jdbc:mysql://cluster-endpoint:3306/db?failOverReadOnly=false&maxReconnects=3&connectTimeout=30000
|
||||
```
|
||||
|
||||
## Read Replica Load Balancing
|
||||
|
||||
Use reader endpoint for distributing read traffic across replicas:
|
||||
|
||||
```properties
|
||||
# Reader endpoint for read-heavy workloads
|
||||
spring.datasource.reader.url=jdbc:mysql://cluster-ro-endpoint:3306/db
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
**Enable batch operations:**
|
||||
```properties
|
||||
spring.jpa.properties.hibernate.jdbc.batch_size=20
|
||||
spring.jpa.properties.hibernate.order_inserts=true
|
||||
spring.jpa.properties.hibernate.order_updates=true
|
||||
spring.jpa.properties.hibernate.batch_versioned_data=true
|
||||
```
|
||||
|
||||
**Disable open-in-view pattern:**
|
||||
```properties
|
||||
spring.jpa.open-in-view=false
|
||||
```
|
||||
|
||||
**Production logging configuration:**
|
||||
```properties
|
||||
# Disable SQL logging in production
|
||||
logging.level.org.hibernate.SQL=WARN
|
||||
logging.level.org.springframework.jdbc=WARN
|
||||
|
||||
# Enable HikariCP metrics
|
||||
logging.level.com.zaxxer.hikari=INFO
|
||||
logging.level.com.zaxxer.hikari.pool=DEBUG
|
||||
```
|
||||
|
||||
**Enable Spring Boot Actuator for metrics:**
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-actuator</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
```properties
|
||||
management.endpoints.web.exposure.include=health,metrics,info
|
||||
management.endpoint.health.show-details=always
|
||||
```
|
||||
@@ -0,0 +1,180 @@
|
||||
# AWS RDS Aurora Troubleshooting Guide
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Connection Timeout to Aurora Cluster
|
||||
**Error:** `Communications link failure` or `Connection timed out`
|
||||
|
||||
**Solutions:**
|
||||
- Verify security group inbound rules allow traffic on port 3306 (MySQL) or 5432 (PostgreSQL)
|
||||
- Check Aurora cluster endpoint is correct (cluster vs instance endpoint)
|
||||
- Ensure your IP/CIDR is whitelisted in security group
|
||||
- Verify VPC and subnet configuration
|
||||
- Check if Aurora cluster is in the same VPC or VPC peering is configured
|
||||
|
||||
```bash
|
||||
# Test connection from EC2/local machine
|
||||
telnet myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com 3306
|
||||
```
|
||||
|
||||
### Access Denied for User
|
||||
**Error:** `Access denied for user 'admin'@'...'`
|
||||
|
||||
**Solutions:**
|
||||
- Verify master username and password are correct
|
||||
- Check if IAM authentication is required but not configured
|
||||
- Reset master password in Aurora console if needed
|
||||
- Verify user permissions in database
|
||||
|
||||
```sql
|
||||
-- Check user permissions
|
||||
SHOW GRANTS FOR 'admin'@'%';
|
||||
```
|
||||
|
||||
### Database Not Found
|
||||
**Error:** `Unknown database 'devops'`
|
||||
|
||||
**Solutions:**
|
||||
- Verify initial database name was created with cluster
|
||||
- Create database manually using MySQL/PostgreSQL client
|
||||
- Check database name in JDBC URL matches existing database
|
||||
|
||||
```sql
|
||||
-- Connect to Aurora and create database
|
||||
CREATE DATABASE devops CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
```
|
||||
|
||||
### SSL Connection Issues
|
||||
**Error:** `SSL connection error` or `Certificate validation failed`
|
||||
|
||||
**Solutions:**
|
||||
```properties
|
||||
# Option 1: Disable SSL verification (NOT recommended for production)
|
||||
spring.datasource.url=jdbc:mysql://...?useSSL=false
|
||||
|
||||
# Option 2: Properly configure SSL with RDS certificate
|
||||
spring.datasource.url=jdbc:mysql://...?useSSL=true&requireSSL=true&verifyServerCertificate=true&trustCertificateKeyStoreUrl=file:///path/to/global-bundle.pem
|
||||
|
||||
# Option 3: Trust all certificates (NOT recommended for production)
|
||||
spring.datasource.url=jdbc:mysql://...?useSSL=true&requireSSL=true&verifyServerCertificate=false
|
||||
```
|
||||
|
||||
### Too Many Connections
|
||||
**Error:** `Too many connections` or `Connection pool exhausted`
|
||||
|
||||
**Solutions:**
|
||||
- Review Aurora instance max_connections parameter
|
||||
- Optimize HikariCP pool size
|
||||
- Check for connection leaks in application code
|
||||
|
||||
```properties
|
||||
# Reduce pool size
|
||||
spring.datasource.hikari.maximum-pool-size=15
|
||||
spring.datasource.hikari.minimum-idle=5
|
||||
|
||||
# Enable leak detection
|
||||
spring.datasource.hikari.leak-detection-threshold=60000
|
||||
```
|
||||
|
||||
**Check Aurora max_connections:**
|
||||
```sql
|
||||
SHOW VARIABLES LIKE 'max_connections';
|
||||
-- Default for Aurora: depends on instance class
|
||||
-- db.r6g.large: ~1000 connections
|
||||
```
|
||||
|
||||
### Slow Query Performance
|
||||
**Error:** Queries taking longer than expected
|
||||
|
||||
**Solutions:**
|
||||
- Enable slow query log in Aurora parameter group
|
||||
- Review connection pool settings
|
||||
- Check Aurora instance metrics in CloudWatch
|
||||
- Optimize queries and add indexes
|
||||
|
||||
```properties
|
||||
# Enable query logging (development only)
|
||||
logging.level.org.hibernate.SQL=DEBUG
|
||||
logging.level.org.hibernate.type.descriptor.sql.BasicBinder=TRACE
|
||||
```
|
||||
|
||||
### Failover Delays
|
||||
**Error:** Application freezes during Aurora failover
|
||||
|
||||
**Solutions:**
|
||||
- Configure connection timeout appropriately
|
||||
- Use cluster endpoint (not instance endpoint)
|
||||
- Implement connection retry logic
|
||||
|
||||
```properties
|
||||
spring.datasource.hikari.connection-timeout=20000
|
||||
spring.datasource.hikari.validation-timeout=5000
|
||||
spring.datasource.url=jdbc:mysql://...?failOverReadOnly=false&maxReconnects=3
|
||||
```
|
||||
|
||||
## Testing Aurora Connection
|
||||
|
||||
### Connection Test with Spring Boot Application
|
||||
|
||||
**Create a Simple Test Endpoint:**
|
||||
```java
|
||||
@RestController
|
||||
@RequestMapping("/api/health")
|
||||
public class DatabaseHealthController {
|
||||
|
||||
@Autowired
|
||||
private DataSource dataSource;
|
||||
|
||||
@GetMapping("/db-connection")
|
||||
public ResponseEntity<Map<String, Object>> testDatabaseConnection() {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
|
||||
try (Connection connection = dataSource.getConnection()) {
|
||||
response.put("status", "success");
|
||||
response.put("database", connection.getCatalog());
|
||||
response.put("url", connection.getMetaData().getURL());
|
||||
response.put("connected", true);
|
||||
return ResponseEntity.ok(response);
|
||||
} catch (Exception e) {
|
||||
response.put("status", "failed");
|
||||
response.put("error", e.getMessage());
|
||||
response.put("connected", false);
|
||||
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE).body(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Test with cURL:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/health/db-connection
|
||||
```
|
||||
|
||||
### Verify Aurora Connection with MySQL/PostgreSQL Client
|
||||
|
||||
**MySQL Client Connection:**
|
||||
```bash
|
||||
# Connect to Aurora MySQL cluster
|
||||
mysql -h myapp-aurora-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com \
|
||||
-P 3306 \
|
||||
-u admin \
|
||||
-p devops
|
||||
|
||||
# Verify connection
|
||||
SHOW DATABASES;
|
||||
SELECT @@version;
|
||||
SHOW VARIABLES LIKE 'aurora_version';
|
||||
```
|
||||
|
||||
**PostgreSQL Client Connection:**
|
||||
```bash
|
||||
# Connect to Aurora PostgreSQL
|
||||
psql -h myapp-aurora-pg-cluster.cluster-abc123xyz.us-east-1.rds.amazonaws.com \
|
||||
-p 5432 \
|
||||
-U admin \
|
||||
-d devops
|
||||
|
||||
# Verify connection
|
||||
\l
|
||||
SELECT version();
|
||||
```
|
||||
377
skills/aws-java/aws-sdk-java-v2-bedrock/SKILL.md
Normal file
377
skills/aws-java/aws-sdk-java-v2-bedrock/SKILL.md
Normal file
@@ -0,0 +1,377 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-bedrock
|
||||
description: Amazon Bedrock patterns using AWS SDK for Java 2.x. Use when working with foundation models (listing, invoking), text generation, image generation, embeddings, streaming responses, or integrating generative AI with Spring Boot applications.
|
||||
category: aws
|
||||
tags: [aws, bedrock, java, sdk, generative-ai, foundation-models]
|
||||
version: 2.0.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - Amazon Bedrock
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Listing and inspecting foundation models on Amazon Bedrock
|
||||
- Invoking foundation models for text generation (Claude, Llama, Titan)
|
||||
- Generating images with AI models (Stable Diffusion)
|
||||
- Creating text embeddings for RAG applications
|
||||
- Implementing streaming responses for real-time generation
|
||||
- Working with multiple AI providers through unified API
|
||||
- Integrating generative AI into Spring Boot applications
|
||||
- Building AI-powered chatbots and assistants
|
||||
|
||||
## Overview
|
||||
|
||||
Amazon Bedrock provides access to foundation models from leading AI providers through a unified API. This skill covers patterns for working with various models including Claude, Llama, Titan, and Stability Diffusion using AWS SDK for Java 2.x.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Dependencies
|
||||
|
||||
```xml
|
||||
<!-- Bedrock (model management) -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bedrock</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Bedrock Runtime (model invocation) -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bedrockruntime</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- For JSON processing -->
|
||||
<dependency>
|
||||
<groupId>org.json</groupId>
|
||||
<artifactId>json</artifactId>
|
||||
<version>20231013</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Basic Client Setup
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.bedrock.BedrockClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||
|
||||
// Model management client
|
||||
BedrockClient bedrockClient = BedrockClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Model invocation client
|
||||
BedrockRuntimeClient bedrockRuntimeClient = BedrockRuntimeClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Core Patterns
|
||||
|
||||
### Model Discovery
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.bedrock.model.*;
|
||||
import java.util.List;
|
||||
|
||||
public List<FoundationModelSummary> listFoundationModels(BedrockClient bedrockClient) {
|
||||
return bedrockClient.listFoundationModels().modelSummaries();
|
||||
}
|
||||
```
|
||||
|
||||
### Model Invocation
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.*;
|
||||
import org.json.JSONObject;
|
||||
|
||||
public String invokeModel(BedrockRuntimeClient client, String modelId, String prompt) {
|
||||
JSONObject payload = createPayload(modelId, prompt);
|
||||
|
||||
InvokeModelResponse response = client.invokeModel(request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload.toString())));
|
||||
|
||||
return extractTextFromResponse(modelId, response.body().asUtf8String());
|
||||
}
|
||||
|
||||
private JSONObject createPayload(String modelId, String prompt) {
|
||||
if (modelId.startsWith("anthropic.claude")) {
|
||||
return new JSONObject()
|
||||
.put("anthropic_version", "bedrock-2023-05-31")
|
||||
.put("max_tokens", 1000)
|
||||
.put("messages", new JSONObject[]{
|
||||
new JSONObject().put("role", "user").put("content", prompt)
|
||||
});
|
||||
} else if (modelId.startsWith("amazon.titan")) {
|
||||
return new JSONObject()
|
||||
.put("inputText", prompt)
|
||||
.put("textGenerationConfig", new JSONObject()
|
||||
.put("maxTokenCount", 512)
|
||||
.put("temperature", 0.7));
|
||||
} else if (modelId.startsWith("meta.llama")) {
|
||||
return new JSONObject()
|
||||
.put("prompt", "[INST] " + prompt + " [/INST]")
|
||||
.put("max_gen_len", 512)
|
||||
.put("temperature", 0.7);
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported model: " + modelId);
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
```java
|
||||
public void streamResponse(BedrockRuntimeClient client, String modelId, String prompt) {
|
||||
JSONObject payload = createPayload(modelId, prompt);
|
||||
|
||||
InvokeModelWithResponseStreamRequest streamRequest =
|
||||
InvokeModelWithResponseStreamRequest.builder()
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload.toString()))
|
||||
.build();
|
||||
|
||||
client.invokeModelWithResponseStream(streamRequest,
|
||||
InvokeModelWithResponseStreamResponseHandler.builder()
|
||||
.onEventStream(stream -> {
|
||||
stream.forEach(event -> {
|
||||
if (event instanceof PayloadPart) {
|
||||
PayloadPart payloadPart = (PayloadPart) event;
|
||||
String chunk = payloadPart.bytes().asUtf8String();
|
||||
processChunk(modelId, chunk);
|
||||
}
|
||||
});
|
||||
})
|
||||
.build());
|
||||
}
|
||||
```
|
||||
|
||||
### Text Embeddings
|
||||
|
||||
```java
|
||||
public double[] createEmbeddings(BedrockRuntimeClient client, String text) {
|
||||
String modelId = "amazon.titan-embed-text-v1";
|
||||
|
||||
JSONObject payload = new JSONObject().put("inputText", text);
|
||||
|
||||
InvokeModelResponse response = client.invokeModel(request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload.toString())));
|
||||
|
||||
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
|
||||
JSONArray embeddingArray = responseBody.getJSONArray("embedding");
|
||||
|
||||
double[] embeddings = new double[embeddingArray.length()];
|
||||
for (int i = 0; i < embeddingArray.length(); i++) {
|
||||
embeddings[i] = embeddingArray.getDouble(i);
|
||||
}
|
||||
|
||||
return embeddings;
|
||||
}
|
||||
```
|
||||
|
||||
### Spring Boot Integration
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class BedrockConfiguration {
|
||||
|
||||
@Bean
|
||||
public BedrockClient bedrockClient() {
|
||||
return BedrockClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockRuntimeClient bedrockRuntimeClient() {
|
||||
return BedrockRuntimeClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@Service
|
||||
public class BedrockAIService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
|
||||
@Value("${bedrock.default-model-id:anthropic.claude-sonnet-4-5-20250929-v1:0}")
|
||||
private String defaultModelId;
|
||||
|
||||
public BedrockAIService(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
}
|
||||
|
||||
public String generateText(String prompt) {
|
||||
return generateText(prompt, defaultModelId);
|
||||
}
|
||||
|
||||
public String generateText(String prompt, String modelId) {
|
||||
Map<String, Object> payload = createPayload(modelId, prompt);
|
||||
String payloadJson = new ObjectMapper().writeValueAsString(payload);
|
||||
|
||||
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(
|
||||
request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payloadJson)));
|
||||
|
||||
return extractTextFromResponse(modelId, response.body().asUtf8String());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Usage Example
|
||||
|
||||
```java
|
||||
BedrockRuntimeClient client = BedrockRuntimeClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
String prompt = "Explain quantum computing in simple terms";
|
||||
String response = invokeModel(client, "anthropic.claude-sonnet-4-5-20250929-v1:0", prompt);
|
||||
System.out.println(response);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Model Selection
|
||||
- **Claude 4.5 Sonnet**: Best for complex reasoning, analysis, and creative tasks
|
||||
- **Claude 4.5 Haiku**: Fast and affordable for real-time applications
|
||||
- **Claude 3.7 Sonnet**: Most advanced reasoning capabilities
|
||||
- **Llama 3.1**: Latest generation open-source alternative, good for general tasks
|
||||
- **Titan**: AWS native, cost-effective for simple text generation
|
||||
|
||||
### Performance Optimization
|
||||
- Reuse client instances (don't create new clients for each request)
|
||||
- Use async clients for I/O operations
|
||||
- Implement streaming for long responses
|
||||
- Cache foundation model lists
|
||||
|
||||
### Security
|
||||
- Never log sensitive prompt data
|
||||
- Use IAM roles for authentication (never access keys)
|
||||
- Implement rate limiting for public applications
|
||||
- Sanitize user inputs to prevent prompt injection
|
||||
|
||||
### Error Handling
|
||||
- Implement retry logic for throttling (exponential backoff)
|
||||
- Handle model-specific validation errors
|
||||
- Validate responses before processing
|
||||
- Use proper exception handling for different error types
|
||||
|
||||
### Cost Optimization
|
||||
- Use appropriate max_tokens limits
|
||||
- Choose cost-effective models for simple tasks
|
||||
- Cache embeddings when possible
|
||||
- Monitor usage and set budget alerts
|
||||
|
||||
## Common Model IDs
|
||||
|
||||
```java
|
||||
// Claude Models
|
||||
public static final String CLAUDE_SONNET_4_5 = "anthropic.claude-sonnet-4-5-20250929-v1:0";
|
||||
public static final String CLAUDE_HAIKU_4_5 = "anthropic.claude-haiku-4-5-20251001-v1:0";
|
||||
public static final String CLAUDE_OPUS_4_1 = "anthropic.claude-opus-4-1-20250805-v1:0";
|
||||
public static final String CLAUDE_3_7_SONNET = "anthropic.claude-3-7-sonnet-20250219-v1:0";
|
||||
public static final String CLAUDE_OPUS_4 = "anthropic.claude-opus-4-20250514-v1:0";
|
||||
public static final String CLAUDE_SONNET_4 = "anthropic.claude-sonnet-4-20250514-v1:0";
|
||||
public static final String CLAUDE_3_5_SONNET_V2 = "anthropic.claude-3-5-sonnet-20241022-v2:0";
|
||||
public static final String CLAUDE_3_5_HAIKU = "anthropic.claude-3-5-haiku-20241022-v1:0";
|
||||
public static final String CLAUDE_3_OPUS = "anthropic.claude-3-opus-20240229-v1:0";
|
||||
|
||||
// Llama Models
|
||||
public static final String LLAMA_3_3_70B = "meta.llama3-3-70b-instruct-v1:0";
|
||||
public static final String LLAMA_3_2_90B = "meta.llama3-2-90b-instruct-v1:0";
|
||||
public static final String LLAMA_3_2_11B = "meta.llama3-2-11b-instruct-v1:0";
|
||||
public static final String LLAMA_3_2_3B = "meta.llama3-2-3b-instruct-v1:0";
|
||||
public static final String LLAMA_3_2_1B = "meta.llama3-2-1b-instruct-v1:0";
|
||||
public static final String LLAMA_4_MAV_17B = "meta.llama4-maverick-17b-instruct-v1:0";
|
||||
public static final String LLAMA_4_SCOUT_17B = "meta.llama4-scout-17b-instruct-v1:0";
|
||||
public static final String LLAMA_3_1_405B = "meta.llama3-1-405b-instruct-v1:0";
|
||||
public static final String LLAMA_3_1_70B = "meta.llama3-1-70b-instruct-v1:0";
|
||||
public static final String LLAMA_3_1_8B = "meta.llama3-1-8b-instruct-v1:0";
|
||||
public static final String LLAMA_3_70B = "meta.llama3-70b-instruct-v1:0";
|
||||
public static final String LLAMA_3_8B = "meta.llama3-8b-instruct-v1:0";
|
||||
|
||||
// Amazon Titan Models
|
||||
public static final String TITAN_TEXT_EXPRESS = "amazon.titan-text-express-v1";
|
||||
public static final String TITAN_TEXT_LITE = "amazon.titan-text-lite-v1";
|
||||
public static final String TITAN_EMBEDDINGS = "amazon.titan-embed-text-v1";
|
||||
public static final String TITAN_IMAGE_GENERATOR = "amazon.titan-image-generator-v1";
|
||||
|
||||
// Stable Diffusion
|
||||
public static final String STABLE_DIFFUSION_XL = "stability.stable-diffusion-xl-v1";
|
||||
|
||||
// Mistral AI Models
|
||||
public static final String MISTRAL_LARGE_2407 = "mistral.mistral-large-2407-v1:0";
|
||||
public static final String MISTRAL_LARGE_2402 = "mistral.mistral-large-2402-v1:0";
|
||||
public static final String MISTRAL_SMALL_2402 = "mistral.mistral-small-2402-v1:0";
|
||||
public static final String MISTRAL_PIXTRAL_2502 = "mistral.pixtral-large-2502-v1:0";
|
||||
public static final String MISTRAL_MIXTRAL_8X7B = "mistral.mixtral-8x7b-instruct-v0:1";
|
||||
public static final String MISTRAL_7B = "mistral.mistral-7b-instruct-v0:2";
|
||||
|
||||
// Amazon Nova Models
|
||||
public static final String NOVA_PREMIER = "amazon.nova-premier-v1:0";
|
||||
public static final String NOVA_PRO = "amazon.nova-pro-v1:0";
|
||||
public static final String NOVA_LITE = "amazon.nova-lite-v1:0";
|
||||
public static final String NOVA_MICRO = "amazon.nova-micro-v1:0";
|
||||
public static final String NOVA_CANVAS = "amazon.nova-canvas-v1:0";
|
||||
public static final String NOVA_REEL = "amazon.nova-reel-v1:1";
|
||||
|
||||
// Other Models
|
||||
public static final String COHERE_COMMAND = "cohere.command-text-v14";
|
||||
public static final String DEEPSEEK_R1 = "deepseek.r1-v1:0";
|
||||
public static final String DEEPSEEK_V3_1 = "deepseek.v3-v1:0";
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the [examples directory](examples/) for comprehensive usage patterns.
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
See the [Advanced Topics](references/advanced-topics.md) for:
|
||||
- Multi-model service patterns
|
||||
- Advanced error handling with retries
|
||||
- Batch processing strategies
|
||||
- Performance optimization techniques
|
||||
- Custom response parsing
|
||||
|
||||
## Model Reference
|
||||
|
||||
See the [Model Reference](references/model-reference.md) for:
|
||||
- Detailed model specifications
|
||||
- Payload/response formats for each provider
|
||||
- Performance characteristics
|
||||
- Model selection guidelines
|
||||
- Configuration templates
|
||||
|
||||
## Testing Strategies
|
||||
|
||||
See the [Testing Strategies](references/testing-strategies.md) for:
|
||||
- Unit testing with mocked clients
|
||||
- Integration testing with LocalStack
|
||||
- Performance testing
|
||||
- Streaming response testing
|
||||
- Test data management
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `aws-sdk-java-v2-core` - Core AWS SDK patterns
|
||||
- `langchain4j-ai-services-patterns` - LangChain4j integration
|
||||
- `spring-boot-dependency-injection` - Spring DI patterns
|
||||
- `spring-boot-test-patterns` - Spring testing patterns
|
||||
|
||||
## References
|
||||
|
||||
- [AWS Bedrock User Guide](references/aws-bedrock-user-guide.md)
|
||||
- [AWS SDK for Java 2.x Documentation](references/aws-sdk-java-bedrock-api.md)
|
||||
- [Bedrock API Reference](references/aws-bedrock-api-reference.md)
|
||||
- [AWS SDK Examples](references/aws-sdk-examples.md)
|
||||
- [Official AWS Examples](bedrock_code_examples.md)
|
||||
- [Supported Models](bedrock_models_supported.md)
|
||||
- [Runtime Examples](bedrock_runtime_code_examples.md)
|
||||
249
skills/aws-java/aws-sdk-java-v2-bedrock/bedrock_code_examples.md
Normal file
249
skills/aws-java/aws-sdk-java-v2-bedrock/bedrock_code_examples.md
Normal file
@@ -0,0 +1,249 @@
|
||||
Amazon Bedrock examples using SDK for Java 2.x - AWS SDK for Java 2.x
|
||||
|
||||
Amazon Bedrock examples using SDK for Java 2.x - AWS SDK for Java 2.x
|
||||
|
||||
[Open PDF](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/pdfs/sdk-for-java/latest/developer-guide/aws-sdk-java-dg-v2.pdf#java_bedrock_code_examples "Open PDF")
|
||||
|
||||
[Documentation](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/index.html) [AWS SDK for Java](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/sdk-for-java/index.html) [Developer Guide for version 2.x](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/home.html)
|
||||
|
||||
[Actions](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html#actions)
|
||||
|
||||
# Amazon Bedrock examples using SDK for Java 2.x
|
||||
|
||||
The following code examples show you how to perform actions and implement common scenarios by using
|
||||
the AWS SDK for Java 2.x with Amazon Bedrock.
|
||||
|
||||
_Actions_ are code excerpts from larger programs and must be run in context. While actions show you how to call individual service functions, you can see actions in context in their related scenarios.
|
||||
|
||||
Each example includes a link to the complete source code, where you can find
|
||||
instructions on how to set up and run the code in context.
|
||||
|
||||
###### Topics
|
||||
|
||||
- [Actions](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html#actions)
|
||||
|
||||
|
||||
## Actions
|
||||
|
||||
The following code example shows how to use `GetFoundationModel`.
|
||||
|
||||
**SDK for Java 2.x**
|
||||
|
||||
###### Note
|
||||
|
||||
There's more on GitHub. Find the complete example and learn how to set up and run in the
|
||||
[AWS Code\
|
||||
Examples Repository](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/bedrock#code-examples).
|
||||
|
||||
|
||||
Get details about a foundation model using the synchronous Amazon Bedrock client.
|
||||
|
||||
```java
|
||||
|
||||
/**
|
||||
* Get details about an Amazon Bedrock foundation model.
|
||||
*
|
||||
* @param bedrockClient The service client for accessing Amazon Bedrock.
|
||||
* @param modelIdentifier The model identifier.
|
||||
* @return An object containing the foundation model's details.
|
||||
*/
|
||||
public static FoundationModelDetails getFoundationModel(BedrockClient bedrockClient, String modelIdentifier) {
|
||||
try {
|
||||
GetFoundationModelResponse response = bedrockClient.getFoundationModel(
|
||||
r -> r.modelIdentifier(modelIdentifier)
|
||||
);
|
||||
|
||||
FoundationModelDetails model = response.modelDetails();
|
||||
|
||||
System.out.println(" Model ID: " + model.modelId());
|
||||
System.out.println(" Model ARN: " + model.modelArn());
|
||||
System.out.println(" Model Name: " + model.modelName());
|
||||
System.out.println(" Provider Name: " + model.providerName());
|
||||
System.out.println(" Lifecycle status: " + model.modelLifecycle().statusAsString());
|
||||
System.out.println(" Input modalities: " + model.inputModalities());
|
||||
System.out.println(" Output modalities: " + model.outputModalities());
|
||||
System.out.println(" Supported customizations: " + model.customizationsSupported());
|
||||
System.out.println(" Supported inference types: " + model.inferenceTypesSupported());
|
||||
System.out.println(" Response streaming supported: " + model.responseStreamingSupported());
|
||||
|
||||
return model;
|
||||
|
||||
} catch (ValidationException e) {
|
||||
throw new IllegalArgumentException(e.getMessage());
|
||||
} catch (SdkException e) {
|
||||
System.err.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
Get details about a foundation model using the asynchronous Amazon Bedrock client.
|
||||
|
||||
```java
|
||||
|
||||
/**
|
||||
* Get details about an Amazon Bedrock foundation model.
|
||||
*
|
||||
* @param bedrockClient The async service client for accessing Amazon Bedrock.
|
||||
* @param modelIdentifier The model identifier.
|
||||
* @return An object containing the foundation model's details.
|
||||
*/
|
||||
public static FoundationModelDetails getFoundationModel(BedrockAsyncClient bedrockClient, String modelIdentifier) {
|
||||
try {
|
||||
CompletableFuture<GetFoundationModelResponse> future = bedrockClient.getFoundationModel(
|
||||
r -> r.modelIdentifier(modelIdentifier)
|
||||
);
|
||||
|
||||
FoundationModelDetails model = future.get().modelDetails();
|
||||
|
||||
System.out.println(" Model ID: " + model.modelId());
|
||||
System.out.println(" Model ARN: " + model.modelArn());
|
||||
System.out.println(" Model Name: " + model.modelName());
|
||||
System.out.println(" Provider Name: " + model.providerName());
|
||||
System.out.println(" Lifecycle status: " + model.modelLifecycle().statusAsString());
|
||||
System.out.println(" Input modalities: " + model.inputModalities());
|
||||
System.out.println(" Output modalities: " + model.outputModalities());
|
||||
System.out.println(" Supported customizations: " + model.customizationsSupported());
|
||||
System.out.println(" Supported inference types: " + model.inferenceTypesSupported());
|
||||
System.out.println(" Response streaming supported: " + model.responseStreamingSupported());
|
||||
|
||||
return model;
|
||||
|
||||
} catch (ExecutionException e) {
|
||||
if (e.getMessage().contains("ValidationException")) {
|
||||
throw new IllegalArgumentException(e.getMessage());
|
||||
} else {
|
||||
System.err.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
System.err.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
- For API details, see
|
||||
[GetFoundationModel](https://docs.aws.amazon.com/goto/SdkForJavaV2/bedrock-2023-04-20/GetFoundationModel)
|
||||
in _AWS SDK for Java 2.x API Reference_.
|
||||
|
||||
|
||||
|
||||
The following code example shows how to use `ListFoundationModels`.
|
||||
|
||||
**SDK for Java 2.x**
|
||||
|
||||
###### Note
|
||||
|
||||
There's more on GitHub. Find the complete example and learn how to set up and run in the
|
||||
[AWS Code\
|
||||
Examples Repository](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/bedrock#code-examples).
|
||||
|
||||
|
||||
List the available Amazon Bedrock foundation models using the synchronous Amazon Bedrock client.
|
||||
|
||||
```java
|
||||
|
||||
/**
|
||||
* Lists Amazon Bedrock foundation models that you can use.
|
||||
* You can filter the results with the request parameters.
|
||||
*
|
||||
* @param bedrockClient The service client for accessing Amazon Bedrock.
|
||||
* @return A list of objects containing the foundation models' details
|
||||
*/
|
||||
public static List<FoundationModelSummary> listFoundationModels(BedrockClient bedrockClient) {
|
||||
|
||||
try {
|
||||
ListFoundationModelsResponse response = bedrockClient.listFoundationModels(r -> {});
|
||||
|
||||
List<FoundationModelSummary> models = response.modelSummaries();
|
||||
|
||||
if (models.isEmpty()) {
|
||||
System.out.println("No available foundation models in " + region.toString());
|
||||
} else {
|
||||
for (FoundationModelSummary model : models) {
|
||||
System.out.println("Model ID: " + model.modelId());
|
||||
System.out.println("Provider: " + model.providerName());
|
||||
System.out.println("Name: " + model.modelName());
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
|
||||
} catch (SdkClientException e) {
|
||||
System.err.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
List the available Amazon Bedrock foundation models using the asynchronous Amazon Bedrock client.
|
||||
|
||||
```java
|
||||
|
||||
/**
|
||||
* Lists Amazon Bedrock foundation models that you can use.
|
||||
* You can filter the results with the request parameters.
|
||||
*
|
||||
* @param bedrockClient The async service client for accessing Amazon Bedrock.
|
||||
* @return A list of objects containing the foundation models' details
|
||||
*/
|
||||
public static List<FoundationModelSummary> listFoundationModels(BedrockAsyncClient bedrockClient) {
|
||||
try {
|
||||
CompletableFuture<ListFoundationModelsResponse> future = bedrockClient.listFoundationModels(r -> {});
|
||||
|
||||
List<FoundationModelSummary> models = future.get().modelSummaries();
|
||||
|
||||
if (models.isEmpty()) {
|
||||
System.out.println("No available foundation models in " + region.toString());
|
||||
} else {
|
||||
for (FoundationModelSummary model : models) {
|
||||
System.out.println("Model ID: " + model.modelId());
|
||||
System.out.println("Provider: " + model.providerName());
|
||||
System.out.println("Name: " + model.modelName());
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
System.err.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
} catch (ExecutionException e) {
|
||||
System.err.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
- For API details, see
|
||||
[ListFoundationModels](https://docs.aws.amazon.com/goto/SdkForJavaV2/bedrock-2023-04-20/ListFoundationModels)
|
||||
in _AWS SDK for Java 2.x API Reference_.
|
||||
|
||||
|
||||
|
||||
[Document Conventions](http://https:%2F%2Fdocs.aws.amazon.com%2Fsdk-for-java%2Flatest%2Fdeveloper-guide%2Fjava_bedrock_code_examples.html/general/latest/gr/docconventions.html)
|
||||
|
||||
AWS Batch
|
||||
|
||||
Amazon Bedrock Runtime
|
||||
|
||||
Did this page help you? - Yes
|
||||
|
||||
Thanks for letting us know we're doing a good job!
|
||||
|
||||
If you've got a moment, please tell us what we did right so we can do more of it.
|
||||
|
||||
Did this page help you? - No
|
||||
|
||||
Thanks for letting us know this page needs work. We're sorry we let you down.
|
||||
|
||||
If you've got a moment, please tell us how we can make the documentation better.
|
||||
1323
skills/aws-java/aws-sdk-java-v2-bedrock/bedrock_models_supported.md
Normal file
1323
skills/aws-java/aws-sdk-java-v2-bedrock/bedrock_models_supported.md
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,274 @@
|
||||
# Advanced Model Patterns
|
||||
|
||||
## Model-Specific Configuration
|
||||
|
||||
### Claude Models Configuration
|
||||
|
||||
```java
|
||||
// Claude 3 Sonnet
|
||||
public String invokeClaude3Sonnet(BedrockRuntimeClient client, String prompt) {
|
||||
String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
|
||||
JSONObject payload = new JSONObject()
|
||||
.put("anthropic_version", "bedrock-2023-05-31")
|
||||
.put("max_tokens", 1000)
|
||||
.put("temperature", 0.7)
|
||||
.put("top_p", 1.0)
|
||||
.put("messages", new JSONObject[]{
|
||||
new JSONObject()
|
||||
.put("role", "user")
|
||||
.put("content", prompt)
|
||||
});
|
||||
|
||||
InvokeModelResponse response = client.invokeModel(request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload.toString())));
|
||||
|
||||
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
|
||||
return responseBody.getJSONArray("content")
|
||||
.getJSONObject(0)
|
||||
.getString("text");
|
||||
}
|
||||
|
||||
// Claude 3 Haiku (faster, cheaper)
|
||||
public String invokeClaude3Haiku(BedrockRuntimeClient client, String prompt) {
|
||||
String modelId = "anthropic.claude-3-haiku-20240307-v1:0";
|
||||
|
||||
JSONObject payload = new JSONObject()
|
||||
.put("anthropic_version", "bedrock-2023-05-31")
|
||||
.put("max_tokens", 400)
|
||||
.put("messages", new JSONObject[]{
|
||||
new JSONObject()
|
||||
.put("role", "user")
|
||||
.put("content", prompt)
|
||||
});
|
||||
|
||||
// Similar invocation pattern as above
|
||||
}
|
||||
```
|
||||
|
||||
### Llama Models Configuration
|
||||
|
||||
```java
|
||||
// Llama 3 70B
|
||||
public String invokeLlama3_70B(BedrockRuntimeClient client, String prompt) {
|
||||
String modelId = "meta.llama3-70b-instruct-v1:0";
|
||||
|
||||
JSONObject payload = new JSONObject()
|
||||
.put("prompt", prompt)
|
||||
.put("max_gen_len", 512)
|
||||
.put("temperature", 0.7)
|
||||
.put("top_p", 0.9)
|
||||
.put("stop", new String[]{"[INST]", "[/INST]"}); // Custom stop tokens
|
||||
|
||||
InvokeModelResponse response = client.invokeModel(request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload.toString())));
|
||||
|
||||
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
|
||||
return responseBody.getString("generation");
|
||||
}
|
||||
```
|
||||
|
||||
## Multi-Model Service Layer
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class MultiModelService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
public MultiModelService(BedrockRuntimeClient bedrockRuntimeClient,
|
||||
ObjectMapper objectMapper) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
public String invokeModel(String modelId, String prompt, Map<String, Object> additionalParams) {
|
||||
Map<String, Object> payload = createModelPayload(modelId, prompt, additionalParams);
|
||||
|
||||
try {
|
||||
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(
|
||||
request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(objectMapper.writeValueAsString(payload))));
|
||||
|
||||
return extractResponseContent(modelId, response.body().asUtf8String());
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Model invocation failed: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Object> createModelPayload(String modelId, String prompt,
|
||||
Map<String, Object> additionalParams) {
|
||||
Map<String, Object> payload = new HashMap<>();
|
||||
|
||||
if (modelId.startsWith("anthropic.claude")) {
|
||||
payload.put("anthropic_version", "bedrock-2023-05-31");
|
||||
payload.put("messages", List.of(Map.of("role", "user", "content", prompt)));
|
||||
|
||||
// Add common parameters with defaults
|
||||
payload.putIfAbsent("max_tokens", 1000);
|
||||
payload.putIfAbsent("temperature", 0.7);
|
||||
|
||||
} else if (modelId.startsWith("meta.llama")) {
|
||||
payload.put("prompt", prompt);
|
||||
payload.putIfAbsent("max_gen_len", 512);
|
||||
payload.putIfAbsent("temperature", 0.7);
|
||||
|
||||
} else if (modelId.startsWith("amazon.titan")) {
|
||||
payload.put("inputText", prompt);
|
||||
payload.putIfAbsent("textGenerationConfig",
|
||||
Map.of("maxTokenCount", 512, "temperature", 0.7));
|
||||
}
|
||||
|
||||
// Add additional parameters
|
||||
if (additionalParams != null) {
|
||||
payload.putAll(additionalParams);
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Error Handling
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class BedrockErrorHandler {
|
||||
|
||||
@Retryable(value = {SdkClientException.class}, maxAttempts = 3, backoff = @Backoff(delay = 1000))
|
||||
public String invokeWithRetry(BedrockRuntimeClient client, String modelId,
|
||||
String payloadJson) {
|
||||
try {
|
||||
InvokeModelResponse response = client.invokeModel(request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payloadJson)));
|
||||
return response.body().asUtf8String();
|
||||
|
||||
} catch (ThrottlingException e) {
|
||||
// Exponential backoff for throttling
|
||||
throw new RuntimeException("Rate limit exceeded, please try again later", e);
|
||||
} catch (ValidationException e) {
|
||||
throw new IllegalArgumentException("Invalid request: " + e.getMessage(), e);
|
||||
} catch (SdkException e) {
|
||||
throw new RuntimeException("AWS SDK error: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class BedrockBatchService {
|
||||
|
||||
public List<String> processBatch(BedrockRuntimeClient client, String modelId,
|
||||
List<String> prompts) {
|
||||
return prompts.parallelStream()
|
||||
.map(prompt -> invokeModelWithTimeout(client, modelId, prompt, 30))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private String invokeModelWithTimeout(BedrockRuntimeClient client, String modelId,
|
||||
String prompt, int timeoutSeconds) {
|
||||
ExecutorService executor = Executors.newSingleThreadExecutor();
|
||||
Future<String> future = executor.submit(() -> {
|
||||
JSONObject payload = new JSONObject()
|
||||
.put("prompt", prompt)
|
||||
.put("max_tokens", 500);
|
||||
|
||||
InvokeModelResponse response = client.invokeModel(request -> request
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload.toString())));
|
||||
|
||||
return response.body().asUtf8String();
|
||||
});
|
||||
|
||||
try {
|
||||
return future.get(timeoutSeconds, TimeUnit.SECONDS);
|
||||
} catch (TimeoutException e) {
|
||||
future.cancel(true);
|
||||
throw new RuntimeException("Model invocation timed out");
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Batch processing error", e);
|
||||
} finally {
|
||||
executor.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Model Performance Optimization
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class BedrockOptimizationConfig {
|
||||
|
||||
@Bean
|
||||
public BedrockRuntimeClient optimizedBedrockRuntimeClient() {
|
||||
return BedrockRuntimeClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.overrideConfiguration(ClientOverrideConfiguration.builder()
|
||||
.apiCallTimeout(Duration.ofSeconds(30))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(20))
|
||||
.build())
|
||||
.httpClient(ApacheHttpClient.builder()
|
||||
.connectionTimeout(Duration.ofSeconds(10))
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Response Parsing
|
||||
|
||||
```java
|
||||
public class BedrockResponseParser {
|
||||
|
||||
public static TextResponse parseTextResponse(String modelId, String responseBody) {
|
||||
try {
|
||||
switch (getModelProvider(modelId)) {
|
||||
case ANTHROPIC:
|
||||
return parseAnthropicResponse(responseBody);
|
||||
case META:
|
||||
return parseMetaResponse(responseBody);
|
||||
case AMAZON:
|
||||
return parseAmazonResponse(responseBody);
|
||||
default:
|
||||
throw new IllegalArgumentException("Unsupported model: " + modelId);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new ResponseParsingException("Failed to parse response for model: " + modelId, e);
|
||||
}
|
||||
}
|
||||
|
||||
private static TextResponse parseAnthropicResponse(String responseBody) throws JSONException {
|
||||
JSONObject json = new JSONObject(responseBody);
|
||||
JSONArray content = json.getJSONArray("content");
|
||||
String text = content.getJSONObject(0).getString("text");
|
||||
int usage = json.getJSONObject("usage").getInt("input_tokens");
|
||||
|
||||
return new TextResponse(text, usage, "anthropic");
|
||||
}
|
||||
|
||||
private static TextResponse parseMetaResponse(String responseBody) throws JSONException {
|
||||
JSONObject json = new JSONObject(responseBody);
|
||||
String text = json.getString("generation");
|
||||
// Note: Meta doesn't provide token usage in basic response
|
||||
|
||||
return new TextResponse(text, 0, "meta");
|
||||
}
|
||||
|
||||
private enum ModelProvider {
|
||||
ANTHROPIC, META, AMAZON
|
||||
}
|
||||
|
||||
public record TextResponse(String content, int tokensUsed, String provider) {}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,372 @@
|
||||
# Advanced Amazon Bedrock Topics
|
||||
|
||||
This document covers advanced patterns and topics for working with Amazon Bedrock using AWS SDK for Java 2.x.
|
||||
|
||||
## Multi-Model Service Pattern
|
||||
|
||||
Create a service that can handle multiple foundation models with unified interfaces.
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class MultiModelAIService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
|
||||
public MultiModelAIService(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
}
|
||||
|
||||
public GenerationResult generate(GenerationRequest request) {
|
||||
String modelId = request.getModelId();
|
||||
String prompt = request.getPrompt();
|
||||
|
||||
switch (getModelProvider(modelId)) {
|
||||
case ANTHROPIC:
|
||||
return generateWithAnthropic(modelId, prompt, request.getConfig());
|
||||
case AMAZON:
|
||||
return generateWithAmazon(modelId, prompt, request.getConfig());
|
||||
case META:
|
||||
return generateWithMeta(modelId, prompt, request.getConfig());
|
||||
default:
|
||||
throw new IllegalArgumentException("Unsupported model provider: " + modelId);
|
||||
}
|
||||
}
|
||||
|
||||
private GenerationProvider getModelProvider(String modelId) {
|
||||
if (modelId.startsWith("anthropic.")) return GenerationProvider.ANTHROPIC;
|
||||
if (modelId.startsWith("amazon.")) return GenerationProvider.AMazon;
|
||||
if (modelId.startsWith("meta.")) return GenerationProvider.META;
|
||||
throw new IllegalArgumentException("Unknown provider for model: " + modelId);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Error Handling with Retries
|
||||
|
||||
Implement robust error handling with exponential backoff:
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.retry.RetryPolicy;
|
||||
import software.amazon.awssdk.core.retry.backoff.BackoffStrategy;
|
||||
import software.amazon.awssdk.core.retry.conditions.RetryCondition;
|
||||
import software.amazon.awssdk.core.retry.predicates.RetryExceptionPredicates;
|
||||
|
||||
public class BedrockWithRetry {
|
||||
|
||||
private final BedrockRuntimeClient client;
|
||||
private final RetryPolicy retryPolicy;
|
||||
|
||||
public BedrockWithRetry(BedrockRuntimeClient client) {
|
||||
this.client = client;
|
||||
this.retryPolicy = RetryPolicy.builder()
|
||||
.numRetries(3)
|
||||
.retryCondition(RetryExceptionPredicates.equalTo(
|
||||
ThrottlingException.class))
|
||||
.backoffStrategy(BackoffStrategy.defaultStrategy())
|
||||
.build();
|
||||
}
|
||||
|
||||
public String invokeModelWithRetry(String modelId, String payload) {
|
||||
try {
|
||||
InvokeModelRequest request = InvokeModelRequest.builder()
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeModelResponse response = client.invokeModel(request);
|
||||
return response.body().asUtf8String();
|
||||
|
||||
} catch (ThrottlingException e) {
|
||||
throw new BedrockThrottledException("Rate limit exceeded for model: " + modelId, e);
|
||||
} catch (ValidationException e) {
|
||||
throw new BedrockValidationException("Invalid request for model: " + modelId, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Batch Processing Strategies
|
||||
|
||||
Process multiple requests efficiently:
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class BatchGenerationService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
|
||||
public BatchGenerationService(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
}
|
||||
|
||||
public List<BatchResult> processBatch(List<BatchRequest> requests) {
|
||||
// Process in parallel
|
||||
return requests.parallelStream()
|
||||
.map(this::processSingleRequest)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private BatchResult processSingleRequest(BatchRequest request) {
|
||||
try {
|
||||
InvokeModelRequest modelRequest = InvokeModelRequest.builder()
|
||||
.modelId(request.getModelId())
|
||||
.body(SdkBytes.fromUtf8String(request.getPayload()))
|
||||
.build();
|
||||
|
||||
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(modelRequest);
|
||||
|
||||
return BatchResult.success(
|
||||
request.getRequestId(),
|
||||
response.body().asUtf8String()
|
||||
);
|
||||
|
||||
} catch (Exception e) {
|
||||
return BatchResult.failure(request.getRequestId(), e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimization Techniques
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.http.nio.apache.ApacheHttpClient;
|
||||
import software.amazon.awssdk.http.apache.ProxyConfiguration;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
|
||||
public class BedrockClientFactory {
|
||||
|
||||
public static BedrockRuntimeClient createOptimizedClient() {
|
||||
ApacheHttpClient httpClient = ApacheHttpClient.builder()
|
||||
.connectionPoolMaxConnections(50)
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.connectionTimeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
|
||||
return BedrockRuntimeClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.httpClient(httpClient)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response Caching
|
||||
|
||||
```java
|
||||
import com.github.benmanes.caffeine.cache.Cache;
|
||||
import com.github.benmanes.caffeine.cache.Caffeine;
|
||||
|
||||
@Service
|
||||
public class CachedAIService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
private final Cache<String, String> responseCache;
|
||||
|
||||
public CachedAIService(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
this.responseCache = Caffeine.newBuilder()
|
||||
.maximumSize(1000)
|
||||
.expireAfterWrite(1, TimeUnit.HOURS)
|
||||
.build();
|
||||
}
|
||||
|
||||
public String generateText(String prompt, String modelId) {
|
||||
String cacheKey = modelId + ":" + prompt.hashCode();
|
||||
|
||||
return responseCache.get(cacheKey, key -> {
|
||||
String payload = createPayload(modelId, prompt);
|
||||
InvokeModelRequest request = InvokeModelRequest.builder()
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(request);
|
||||
return response.body().asUtf8String();
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Response Parsing
|
||||
|
||||
Create specialized parsers for different model responses:
|
||||
|
||||
```java
|
||||
public interface ResponseParser {
|
||||
String parse(String responseJson);
|
||||
}
|
||||
|
||||
public class AnthropicResponseParser implements ResponseParser {
|
||||
@Override
|
||||
public String parse(String responseJson) {
|
||||
try {
|
||||
JSONObject jsonResponse = new JSONObject(responseJson);
|
||||
return jsonResponse.getJSONArray("content")
|
||||
.getJSONObject(0)
|
||||
.getString("text");
|
||||
} catch (Exception e) {
|
||||
throw new ResponseParsingException("Failed to parse Anthropic response", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class AmazonTitanResponseParser implements ResponseParser {
|
||||
@Override
|
||||
public String parse(String responseJson) {
|
||||
try {
|
||||
JSONObject jsonResponse = new JSONObject(responseJson);
|
||||
return jsonResponse.getJSONArray("results")
|
||||
.getJSONObject(0)
|
||||
.getString("outputText");
|
||||
} catch (Exception e) {
|
||||
throw new ResponseParsingException("Failed to parse Amazon Titan response", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class LlamaResponseParser implements ResponseParser {
|
||||
@Override
|
||||
public String parse(String responseJson) {
|
||||
try {
|
||||
JSONObject jsonResponse = new JSONObject(responseJson);
|
||||
return jsonResponse.getString("generation");
|
||||
} catch (Exception e) {
|
||||
throw new ResponseParsingException("Failed to parse Llama response", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Metrics and Monitoring
|
||||
|
||||
Implement comprehensive monitoring:
|
||||
|
||||
```java
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
import io.micrometer.core.instrument.Timer;
|
||||
|
||||
@Service
|
||||
public class MonitoredAIService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
private final Timer generationTimer;
|
||||
private final Counter errorCounter;
|
||||
|
||||
public MonitoredAIService(BedrockRuntimeClient bedrockRuntimeClient,
|
||||
MeterRegistry meterRegistry) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
this.generationTimer = Timer.builder("bedrock.generation.time")
|
||||
.description("Time spent generating text with Bedrock")
|
||||
.register(meterRegistry);
|
||||
this.errorCounter = Counter.builder("bedrock.generation.errors")
|
||||
.description("Number of generation errors")
|
||||
.register(meterRegistry);
|
||||
}
|
||||
|
||||
public String generateText(String prompt, String modelId) {
|
||||
return generationTimer.record(() -> {
|
||||
try {
|
||||
String payload = createPayload(modelId, prompt);
|
||||
InvokeModelRequest request = InvokeModelRequest.builder()
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(request);
|
||||
return response.body().asUtf8String();
|
||||
|
||||
} catch (Exception e) {
|
||||
errorCounter.increment();
|
||||
throw new GenerationException("Failed to generate text", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Configuration Management
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@ConfigurationProperties(prefix = "bedrock")
|
||||
public class AdvancedBedrockConfiguration {
|
||||
|
||||
private String defaultRegion = "us-east-1";
|
||||
private int maxRetries = 3;
|
||||
private Duration timeout = Duration.ofSeconds(30);
|
||||
private boolean enableMetrics = true;
|
||||
private int maxCacheSize = 1000;
|
||||
private Duration cacheExpireAfter = Duration.ofHours(1);
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public BedrockRuntimeClient bedrockRuntimeClient() {
|
||||
BedrockRuntimeClient.Builder builder = BedrockRuntimeClient.builder()
|
||||
.region(Region.of(defaultRegion));
|
||||
|
||||
if (enableMetrics) {
|
||||
builder.overrideConfiguration(c -> c.putAdvancedProperty(
|
||||
"metrics.enabled", "true"));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
}
|
||||
```
|
||||
|
||||
## Streaming Response Handling
|
||||
|
||||
Advanced streaming with proper backpressure handling:
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class StreamingAIService {
|
||||
|
||||
private final BedrockRuntimeClient bedrockRuntimeClient;
|
||||
|
||||
public StreamingAIService(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
}
|
||||
|
||||
public Flux<String> streamResponse(String modelId, String prompt) {
|
||||
InvokeModelWithResponseStreamRequest request =
|
||||
InvokeModelWithResponseStreamRequest.builder()
|
||||
.modelId(modelId)
|
||||
.body(SdkBytes.fromUtf8String(createPayload(modelId, prompt)))
|
||||
.build();
|
||||
|
||||
return Mono.fromCallable(() ->
|
||||
bedrockRuntimeClient.invokeModelWithResponseStream(request))
|
||||
.flatMapMany(responseStream -> Flux.defer(() ->
|
||||
Flux.create(sink -> {
|
||||
responseStream.stream().forEach(event -> {
|
||||
if (event instanceof PayloadPart) {
|
||||
PayloadPart payloadPart = (PayloadPart) event;
|
||||
String chunk = payloadPart.bytes().asUtf8String();
|
||||
processChunk(chunk, sink);
|
||||
}
|
||||
});
|
||||
sink.complete();
|
||||
}))
|
||||
)
|
||||
.onErrorResume(e -> Flux.error(new StreamingException("Stream failed", e)));
|
||||
}
|
||||
|
||||
private void processChunk(String chunk, FluxSink<String> sink) {
|
||||
try {
|
||||
JSONObject chunkJson = new JSONObject(chunk);
|
||||
if (chunkJson.getString("type").equals("content_block_delta")) {
|
||||
String text = chunkJson.getJSONObject("delta").getString("text");
|
||||
sink.next(text);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
sink.error(new ChunkProcessingException("Failed to process chunk", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,18 @@
|
||||
<!DOCTYPE html>
|
||||
<!DOCTYPE HTML><html xmlns="http://www.w3.org/1999/xhtml"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"><title>Amazon Bedrock</title><meta xmlns="" name="subtitle" content="API Reference"><meta xmlns="" name="abstract" content="Details about operations and parameters in the Amazon Bedrock API Reference"><meta http-equiv="refresh" content="10;URL=welcome.html"><script type="text/javascript"><!--
|
||||
var myDefaultPage = "welcome.html";
|
||||
var myPage = document.location.search.substr(1);
|
||||
var myHash = document.location.hash;
|
||||
|
||||
if (myPage == null || myPage.length == 0) {
|
||||
myPage = myDefaultPage;
|
||||
} else {
|
||||
var docfile = myPage.match(/[^=\;\/?:\s]+\.html/);
|
||||
if (docfile == null) {
|
||||
myPage = myDefaultPage;
|
||||
} else {
|
||||
myPage = docfile + myHash;
|
||||
}
|
||||
}
|
||||
self.location.replace(myPage);
|
||||
--></script></head><body></body></html>
|
||||
@@ -0,0 +1,18 @@
|
||||
<!DOCTYPE html>
|
||||
<!DOCTYPE HTML><html xmlns="http://www.w3.org/1999/xhtml"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"><title>Amazon Bedrock</title><meta xmlns="" name="subtitle" content="User Guide"><meta xmlns="" name="abstract" content="User Guide for the Amazon Bedrock service."><meta http-equiv="refresh" content="10;URL=what-is-bedrock.html"><script type="text/javascript"><!--
|
||||
var myDefaultPage = "what-is-bedrock.html";
|
||||
var myPage = document.location.search.substr(1);
|
||||
var myHash = document.location.hash;
|
||||
|
||||
if (myPage == null || myPage.length == 0) {
|
||||
myPage = myDefaultPage;
|
||||
} else {
|
||||
var docfile = myPage.match(/[^=\;\/?:\s]+\.html/);
|
||||
if (docfile == null) {
|
||||
myPage = myDefaultPage;
|
||||
} else {
|
||||
myPage = docfile + myHash;
|
||||
}
|
||||
}
|
||||
self.location.replace(myPage);
|
||||
--></script></head><body></body></html>
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,148 @@
|
||||
<!DOCTYPE HTML>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<!-- Generated by javadoc (23) on Tue Oct 28 00:04:26 UTC 2025 -->
|
||||
<title>software.amazon.awssdk.services.bedrock (AWS SDK for Java - 2.36.3)</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
|
||||
<meta name="dc.created" content="2025-10-28">
|
||||
<meta name="description" content="declaration: package: software.amazon.awssdk.services.bedrock">
|
||||
<meta name="generator" content="javadoc/PackageWriter">
|
||||
<link rel="stylesheet" type="text/css" href="../../../../../resource-files/jquery-ui.min.css" title="Style">
|
||||
<link rel="stylesheet" type="text/css" href="../../../../../resource-files/stylesheet.css" title="Style">
|
||||
<link rel="stylesheet" type="text/css" href="../../../../../resource-files/aws-sdk-java-v2-javadoc.css" title="Style">
|
||||
<script type="text/javascript" src="../../../../../script-files/script.js"></script>
|
||||
<script type="text/javascript" src="../../../../../script-files/jquery-3.7.1.min.js"></script>
|
||||
<script type="text/javascript" src="../../../../../script-files/jquery-ui.min.js"></script>
|
||||
</head>
|
||||
<body class="package-declaration-page">
|
||||
<script type="text/javascript">const pathtoroot = "../../../../../";
|
||||
loadScripts(document, 'script');</script>
|
||||
<noscript>
|
||||
<div>JavaScript is disabled on your browser.</div>
|
||||
</noscript>
|
||||
<header role="banner">
|
||||
<nav role="navigation">
|
||||
<!-- ========= START OF TOP NAVBAR ======= -->
|
||||
<div class="top-nav" id="navbar-top">
|
||||
<div class="nav-content">
|
||||
<div class="nav-menu-button"><button id="navbar-toggle-button" aria-controls="navbar-top" aria-expanded="false" aria-label="Toggle navigation links"><span class="nav-bar-toggle-icon"> </span><span class="nav-bar-toggle-icon"> </span><span class="nav-bar-toggle-icon"> </span></button></div>
|
||||
<div class="skip-nav"><a href="#skip-navbar-top" title="Skip navigation links">Skip navigation links</a></div>
|
||||
<ul id="navbar-top-firstrow" class="nav-list" title="Navigation">
|
||||
<li><a href="../../../../../index.html">Overview</a></li>
|
||||
<li class="nav-bar-cell1-rev">Package</li>
|
||||
<li><a href="../../../../../index-all.html">Index</a></li>
|
||||
<li><a href="../../../../../search.html">Search</a></li>
|
||||
<li><a href="../../../../../help-doc.html#package">Help</a></li>
|
||||
</ul>
|
||||
<div class="about-language"><h2>AWS SDK for Java API Reference - 2.36.3</h2></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="sub-nav">
|
||||
<div class="nav-content">
|
||||
<ol class="sub-nav-list">
|
||||
<li><a href="package-summary.html" class="current-selection">software.amazon.awssdk.services.bedrock</a></li>
|
||||
</ol>
|
||||
<div class="nav-list-search">
|
||||
<input type="text" id="search-input" disabled placeholder="Search" aria-label="Search in documentation" autocomplete="off">
|
||||
<input type="reset" id="reset-search" disabled value="Reset">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- ========= END OF TOP NAVBAR ========= -->
|
||||
<span class="skip-nav" id="skip-navbar-top"></span></nav>
|
||||
</header>
|
||||
<div class="main-grid">
|
||||
<nav role="navigation" class="toc" aria-label="Table of contents">
|
||||
<div class="toc-header">Contents</div>
|
||||
<button class="hide-sidebar"><span>Hide sidebar </span>❮</button><button class="show-sidebar">❯<span> Show sidebar</span></button>
|
||||
<ol class="toc-list">
|
||||
<li><a href="#" tabindex="0">Description</a></li>
|
||||
<li><a href="#related-package-summary" tabindex="0">Related Packages</a></li>
|
||||
<li><a href="#class-summary" tabindex="0">Classes and Interfaces</a></li>
|
||||
</ol>
|
||||
</nav>
|
||||
<main role="main">
|
||||
<div class="header">
|
||||
<h1 title="Package software.amazon.awssdk.services.bedrock" class="title">Package software.amazon.awssdk.services.bedrock</h1>
|
||||
</div>
|
||||
<hr>
|
||||
<div class="horizontal-scroll">
|
||||
<div class="package-signature">package <span class="element-name">software.amazon.awssdk.services.bedrock</span></div>
|
||||
<section class="package-description" id="package-description">
|
||||
<div class="block"><p>
|
||||
Describes the API operations for creating, managing, fine-turning, and evaluating Amazon Bedrock models.
|
||||
</p></div>
|
||||
</section>
|
||||
</div>
|
||||
<section class="summary">
|
||||
<ul class="summary-list">
|
||||
<li>
|
||||
<div id="related-package-summary">
|
||||
<div class="caption"><span>Related Packages</span></div>
|
||||
<div class="summary-table two-column-summary">
|
||||
<div class="table-header col-first">Package</div>
|
||||
<div class="table-header col-last">Description</div>
|
||||
<div class="col-first even-row-color"><a href="endpoints/package-summary.html">software.amazon.awssdk.services.bedrock.endpoints</a></div>
|
||||
<div class="col-last even-row-color"> </div>
|
||||
<div class="col-first odd-row-color"><a href="internal/package-summary.html">software.amazon.awssdk.services.bedrock.internal</a></div>
|
||||
<div class="col-last odd-row-color"> </div>
|
||||
<div class="col-first even-row-color"><a href="model/package-summary.html">software.amazon.awssdk.services.bedrock.model</a></div>
|
||||
<div class="col-last even-row-color"> </div>
|
||||
<div class="col-first odd-row-color"><a href="paginators/package-summary.html">software.amazon.awssdk.services.bedrock.paginators</a></div>
|
||||
<div class="col-last odd-row-color"> </div>
|
||||
<div class="col-first even-row-color"><a href="transform/package-summary.html">software.amazon.awssdk.services.bedrock.transform</a></div>
|
||||
<div class="col-last even-row-color"> </div>
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
<li>
|
||||
<div id="class-summary">
|
||||
<div class="table-tabs" role="tablist" aria-orientation="horizontal"><button id="class-summary-tab0" role="tab" aria-selected="true" aria-controls="class-summary.tabpanel" tabindex="0" onkeydown="switchTab(event)" onclick="show('class-summary', 'class-summary', 2)" class="active-table-tab">All Classes and Interfaces</button><button id="class-summary-tab1" role="tab" aria-selected="false" aria-controls="class-summary.tabpanel" tabindex="-1" onkeydown="switchTab(event)" onclick="show('class-summary', 'class-summary-tab1', 2)" class="table-tab">Interfaces</button><button id="class-summary-tab2" role="tab" aria-selected="false" aria-controls="class-summary.tabpanel" tabindex="-1" onkeydown="switchTab(event)" onclick="show('class-summary', 'class-summary-tab2', 2)" class="table-tab">Classes</button></div>
|
||||
<div id="class-summary.tabpanel" role="tabpanel" aria-labelledby="class-summary-tab0">
|
||||
<div class="summary-table two-column-summary">
|
||||
<div class="table-header col-first">Class</div>
|
||||
<div class="table-header col-last">Description</div>
|
||||
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockAsyncClient.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockAsyncClient</a></div>
|
||||
<div class="col-last even-row-color class-summary class-summary-tab1">
|
||||
<div class="block">Service client for accessing Amazon Bedrock asynchronously.</div>
|
||||
</div>
|
||||
<div class="col-first odd-row-color class-summary class-summary-tab1"><a href="BedrockAsyncClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockAsyncClientBuilder</a></div>
|
||||
<div class="col-last odd-row-color class-summary class-summary-tab1">
|
||||
<div class="block">A builder for creating an instance of <a href="BedrockAsyncClient.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockAsyncClient</code></a>.</div>
|
||||
</div>
|
||||
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockBaseClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockBaseClientBuilder</a><B extends <a href="BedrockBaseClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockBaseClientBuilder</a><B,<wbr>C>,<wbr>C></div>
|
||||
<div class="col-last even-row-color class-summary class-summary-tab1">
|
||||
<div class="block">This includes configuration specific to Amazon Bedrock that is supported by both <a href="BedrockClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockClientBuilder</code></a> and
|
||||
<a href="BedrockAsyncClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockAsyncClientBuilder</code></a>.</div>
|
||||
</div>
|
||||
<div class="col-first odd-row-color class-summary class-summary-tab1"><a href="BedrockClient.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockClient</a></div>
|
||||
<div class="col-last odd-row-color class-summary class-summary-tab1">
|
||||
<div class="block">Service client for accessing Amazon Bedrock.</div>
|
||||
</div>
|
||||
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockClientBuilder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockClientBuilder</a></div>
|
||||
<div class="col-last even-row-color class-summary class-summary-tab1">
|
||||
<div class="block">A builder for creating an instance of <a href="BedrockClient.html" title="interface in software.amazon.awssdk.services.bedrock"><code>BedrockClient</code></a>.</div>
|
||||
</div>
|
||||
<div class="col-first odd-row-color class-summary class-summary-tab2"><a href="BedrockServiceClientConfiguration.html" title="class in software.amazon.awssdk.services.bedrock">BedrockServiceClientConfiguration</a></div>
|
||||
<div class="col-last odd-row-color class-summary class-summary-tab2">
|
||||
<div class="block">Class to expose the service client settings to the user.</div>
|
||||
</div>
|
||||
<div class="col-first even-row-color class-summary class-summary-tab1"><a href="BedrockServiceClientConfiguration.Builder.html" title="interface in software.amazon.awssdk.services.bedrock">BedrockServiceClientConfiguration.Builder</a></div>
|
||||
<div class="col-last even-row-color class-summary class-summary-tab1">
|
||||
<div class="block">A builder for creating a <a href="BedrockServiceClientConfiguration.html" title="class in software.amazon.awssdk.services.bedrock"><code>BedrockServiceClientConfiguration</code></a></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
<footer role="contentinfo">
|
||||
<hr>
|
||||
<p class="legal-copy"><small><div style="margin:1.2em;"><h3><a id="fdbk" target="_blank">Provide feedback</a><h3></div> <span id="awsdocs-legal-zone-copyright"></span> <script type="text/javascript">document.addEventListener("DOMContentLoaded",()=>{ var a=document.createElement("meta"),b=document.createElement("meta"),c=document.createElement("script"), h=document.getElementsByTagName("head")[0],l=location.href,f=document.getElementById("fdbk"); a.name="guide-name",a.content="API Reference";b.name="service-name",b.content="AWS SDK for Java"; c.setAttribute("type","text/javascript"),c.setAttribute("src", "https://docs.aws.amazon.com/assets/js/awsdocs-boot.js");h.appendChild(a);h.appendChild(b); h.appendChild(c);f.setAttribute("href", "https://docs-feedback.aws.amazon.com/feedback.jsp?hidden_service_name="+ encodeURI("AWS SDK for Java")+"&topic_url="+encodeURI(l))}); </script></small></p>
|
||||
</footer>
|
||||
</main>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,340 @@
|
||||
# Model Reference
|
||||
|
||||
## Supported Foundation Models
|
||||
|
||||
### Amazon Models
|
||||
|
||||
#### Amazon Titan Text
|
||||
|
||||
**Model ID:** `amazon.titan-text-express-v1`
|
||||
- **Description:** High-quality text generation model
|
||||
- **Context Window:** Up to 8K tokens
|
||||
- **Languages:** English, Spanish, French, German, Italian, Portuguese
|
||||
|
||||
**Payload Format:**
|
||||
```json
|
||||
{
|
||||
"inputText": "Your prompt here",
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": 512,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"results": [{
|
||||
"outputText": "Generated text"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
#### Amazon Titan Text Lite
|
||||
|
||||
**Model ID:** `amazon.titan-text-lite-v1`
|
||||
- **Description:** Cost-effective text generation model
|
||||
- **Context Window:** Up to 4K tokens
|
||||
- **Use Case:** Simple text generation tasks
|
||||
|
||||
#### Amazon Titan Embeddings
|
||||
|
||||
**Model ID:** `amazon.titan-embed-text-v1`
|
||||
- **Description:** High-quality text embeddings
|
||||
- **Context Window:** 8K tokens
|
||||
- **Output:** 1024-dimensional vector
|
||||
|
||||
**Payload Format:**
|
||||
```json
|
||||
{
|
||||
"inputText": "Your text here"
|
||||
}
|
||||
```
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"embedding": [0.1, -0.2, 0.3, ...]
|
||||
}
|
||||
```
|
||||
|
||||
#### Amazon Titan Image Generator
|
||||
|
||||
**Model ID:** `amazon.titan-image-generator-v1`
|
||||
- **Description:** High-quality image generation
|
||||
- **Image Size:** 512x512, 1024x1024
|
||||
- **Use Case:** Text-to-image generation
|
||||
|
||||
**Payload Format:**
|
||||
```json
|
||||
{
|
||||
"taskType": "TEXT_IMAGE",
|
||||
"textToImageParams": {
|
||||
"text": "Your description"
|
||||
},
|
||||
"imageGenerationConfig": {
|
||||
"numberOfImages": 1,
|
||||
"quality": "standard",
|
||||
"cfgScale": 8.0,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"seed": 12345
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Anthropic Models
|
||||
|
||||
#### Claude 3.5 Sonnet
|
||||
|
||||
**Model ID:** `anthropic.claude-3-5-sonnet-20241022-v2:0`
|
||||
- **Description:** High-performance model for complex reasoning, analysis, and creative tasks
|
||||
- **Context Window:** 200K tokens
|
||||
- **Languages:** Multiple languages supported
|
||||
- **Use Case:** Code generation, complex analysis, creative writing, research
|
||||
- **Features:** Tool use, function calling, JSON mode
|
||||
|
||||
**Payload Format:**
|
||||
```json
|
||||
{
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Your message"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"content": [{
|
||||
"text": "Response content"
|
||||
}],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Claude 3.5 Haiku
|
||||
|
||||
**Model ID:** `anthropic.claude-3-5-haiku-20241022-v2:0`
|
||||
- **Description:** Fast and affordable model for real-time applications
|
||||
- **Context Window:** 200K tokens
|
||||
- **Use Case:** Real-time applications, chatbots, quick responses
|
||||
- **Features:** Tool use, function calling, JSON mode
|
||||
|
||||
#### Claude 3 Opus
|
||||
|
||||
**Model ID:** `anthropic.claude-3-opus-20240229-v1:0`
|
||||
- **Description:** Most capable model
|
||||
- **Context Window:** 200K tokens
|
||||
- **Use Case:** Complex reasoning, analysis
|
||||
|
||||
#### Claude 3 Sonnet (Legacy)
|
||||
|
||||
**Model ID:** `anthropic.claude-3-sonnet-20240229-v1:0`
|
||||
- **Description:** Previous generation model
|
||||
- **Context Window:** 200K tokens
|
||||
- **Use Case:** General purpose applications
|
||||
|
||||
### Meta Models
|
||||
|
||||
#### Llama 3.1 70B
|
||||
|
||||
**Model ID:** `meta.llama3-1-70b-instruct-v1:0`
|
||||
- **Description:** Latest generation large open-source model
|
||||
- **Context Window:** 128K tokens
|
||||
- **Use Case:** General purpose instruction following, complex reasoning
|
||||
- **Features:** Improved instruction following, larger context window
|
||||
|
||||
#### Llama 3.1 8B
|
||||
|
||||
**Model ID:** `meta.llama3-1-8b-instruct-v1:0`
|
||||
- **Description:** Latest generation small fast model
|
||||
- **Context Window:** 8K tokens
|
||||
- **Use Case:** Fast inference, lightweight applications
|
||||
|
||||
#### Llama 3 70B
|
||||
|
||||
**Model ID:** `meta.llama3-70b-instruct-v1:0`
|
||||
- **Description:** Previous generation large open-source model
|
||||
- **Context Window:** 8K tokens
|
||||
- **Use Case:** General purpose instruction following
|
||||
|
||||
**Payload Format:**
|
||||
```json
|
||||
{
|
||||
"prompt": "[INST] Your prompt here [/INST]",
|
||||
"max_gen_len": 512,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9
|
||||
}
|
||||
```
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"generation": "Generated text"
|
||||
}
|
||||
```
|
||||
|
||||
#### Llama 3 8B
|
||||
|
||||
**Model ID:** `meta.llama3-8b-instruct-v1:0`
|
||||
- **Description:** Smaller, faster version
|
||||
- **Context Window:** 8K tokens
|
||||
- **Use Case:** Fast inference, lightweight applications
|
||||
|
||||
### Stability AI Models
|
||||
|
||||
#### Stable Diffusion XL
|
||||
|
||||
**Model ID:** `stability.stable-diffusion-xl-v1`
|
||||
- **Description:** High-quality image generation
|
||||
- **Image Size:** Up to 1024x1024
|
||||
- **Use Case:** Text-to-image generation, art creation
|
||||
|
||||
**Payload Format:**
|
||||
```json
|
||||
{
|
||||
"text_prompts": [{
|
||||
"text": "Your description"
|
||||
}],
|
||||
"style_preset": "photographic",
|
||||
"seed": 12345,
|
||||
"cfg_scale": 10,
|
||||
"steps": 50
|
||||
}
|
||||
```
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"artifacts": [{
|
||||
"base64": "base64-encoded-image-data",
|
||||
"finishReason": "SUCCESS"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Other Models
|
||||
|
||||
#### Cohere Command
|
||||
|
||||
**Model ID:** `cohere.command-text-v14`
|
||||
- **Description:** Text generation model
|
||||
- **Context Window:** 128K tokens
|
||||
- **Use Case:** Content generation, summarization
|
||||
|
||||
#### Mistral Models
|
||||
|
||||
**Model ID:** `mistral.mistral-7b-instruct-v0:2`
|
||||
- **Description:** High-performing open-source model
|
||||
- **Context Window:** 32K tokens
|
||||
- **Use Case:** Instruction following, code generation
|
||||
|
||||
**Model ID:** `mistral.mixtral-8x7b-instruct-v0:1`
|
||||
- **Description:** Mixture of experts model
|
||||
- **Context Window:** 32K tokens
|
||||
- **Use Case:** Complex reasoning tasks
|
||||
|
||||
## Model Selection Guide
|
||||
|
||||
### Use Case Recommendations
|
||||
|
||||
| Use Case | Recommended Models | Notes |
|
||||
|----------|-------------------|-------|
|
||||
| **General Chat/Chatbots** | Claude 3.5 Haiku, Llama 3 8B | Fast response times |
|
||||
| **Content Creation** | Claude 3.5 Sonnet, Cohere | Creative, coherent outputs |
|
||||
| **Code Generation** | Claude 3.5 Sonnet, Llama 3.1 70B | Excellent understanding |
|
||||
| **Analysis & Reasoning** | Claude 3 Opus, Claude 3.5 Sonnet | Complex reasoning |
|
||||
| **Real-time Applications** | Claude 3.5 Haiku, Titan Lite | Fast inference |
|
||||
| **Cost-sensitive Apps** | Titan Lite, Claude 3.5 Haiku | Lower cost per token |
|
||||
| **High Quality** | Claude 3 Opus, Claude 3.5 Sonnet | Premium quality |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
| Model | Speed | Cost | Quality | Context Window |
|
||||
|-------|-------|------|---------|----------------|
|
||||
| Claude 3 Opus | Slow | High | Excellent | 200K |
|
||||
| Claude 3.5 Sonnet | Medium | Medium | Excellent | 200K |
|
||||
| Claude 3.5 Haiku | Fast | Low | Good | 200K |
|
||||
| Claude 3 Sonnet (Legacy) | Medium | Medium | Good | 200K |
|
||||
| Llama 3.1 70B | Medium | Medium | Good | 128K |
|
||||
| Llama 3.1 8B | Fast | Low | Fair | 8K |
|
||||
| Llama 3 70B | Medium | Medium | Good | 8K |
|
||||
| Llama 3 8B | Fast | Low | Fair | 8K |
|
||||
| Titan Express | Fast | Medium | Good | 8K |
|
||||
| Titan Lite | Fast | Low | Fair | 4K |
|
||||
|
||||
## Model Comparison Matrix
|
||||
|
||||
| Feature | Claude 3 | Llama 3 | Titan | Stability |
|
||||
|---------|----------|---------|-------|-----------|
|
||||
| **Streaming** | ✅ | ✅ | ✅ | ❌ |
|
||||
| **Tool Use** | ✅ | ❌ | ❌ | ❌ |
|
||||
| **Image Generation** | ❌ | ❌ | ✅ | ✅ |
|
||||
| **Embeddings** | ❌ | ❌ | ✅ | ❌ |
|
||||
| **Multiple Languages** | ✅ | ✅ | ✅ | ✅ |
|
||||
| **Context Window** | 200K | 8K | 8K | N/A |
|
||||
| **Open Source** | ❌ | ✅ | ❌ | ✅ |
|
||||
|
||||
## Model Configuration Templates
|
||||
|
||||
### Text Generation Template
|
||||
```java
|
||||
private static JSONObject createTextGenerationPayload(String modelId, String prompt) {
|
||||
JSONObject payload = new JSONObject();
|
||||
|
||||
if (modelId.startsWith("anthropic.claude")) {
|
||||
payload.put("anthropic_version", "bedrock-2023-05-31");
|
||||
payload.put("max_tokens", 1000);
|
||||
payload.put("messages", new JSONObject[]{new JSONObject()
|
||||
.put("role", "user")
|
||||
.put("content", prompt)
|
||||
});
|
||||
} else if (modelId.startsWith("meta.llama")) {
|
||||
payload.put("prompt", "[INST] " + prompt + " [/INST]");
|
||||
payload.put("max_gen_len", 512);
|
||||
} else if (modelId.startsWith("amazon.titan")) {
|
||||
payload.put("inputText", prompt);
|
||||
payload.put("textGenerationConfig", new JSONObject()
|
||||
.put("maxTokenCount", 512)
|
||||
.put("temperature", 0.7)
|
||||
);
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
```
|
||||
|
||||
### Image Generation Template
|
||||
```java
|
||||
private static JSONObject createImageGenerationPayload(String modelId, String prompt) {
|
||||
JSONObject payload = new JSONObject();
|
||||
|
||||
if (modelId.equals("amazon.titan-image-generator-v1")) {
|
||||
payload.put("taskType", "TEXT_IMAGE");
|
||||
payload.put("textToImageParams", new JSONObject().put("text", prompt));
|
||||
payload.put("imageGenerationConfig", new JSONObject()
|
||||
.put("numberOfImages", 1)
|
||||
.put("quality", "standard")
|
||||
.put("height", 512)
|
||||
.put("width", 512)
|
||||
);
|
||||
} else if (modelId.equals("stability.stable-diffusion-xl-v1")) {
|
||||
payload.put("text_prompts", new JSONObject[]{new JSONObject().put("text", prompt)});
|
||||
payload.put("style_preset", "photographic");
|
||||
payload.put("steps", 50);
|
||||
payload.put("cfg_scale", 10);
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,121 @@
|
||||
# Model ID Lookup Guide
|
||||
|
||||
This document provides quick lookup for the most commonly used model IDs in Amazon Bedrock.
|
||||
|
||||
## Text Generation Models
|
||||
|
||||
### Claude (Anthropic)
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Claude 4.5 Sonnet | `anthropic.claude-sonnet-4-5-20250929-v1:0` | Latest high-performance model | Complex reasoning, coding, creative tasks |
|
||||
| Claude 4.5 Haiku | `anthropic.claude-haiku-4-5-20251001-v1:0` | Latest fast model | Real-time applications, chatbots |
|
||||
| Claude 3.7 Sonnet | `anthropic.claude-3-7-sonnet-20250219-v1:0` | Most advanced reasoning | High-stakes decisions, complex analysis |
|
||||
| Claude Opus 4.1 | `anthropic.claude-opus-4-1-20250805-v1:0` | Most powerful creative | Advanced creative tasks |
|
||||
| Claude 3.5 Sonnet v2 | `anthropic.claude-3-5-sonnet-20241022-v2:0` | High-performance model | General use, coding |
|
||||
| Claude 3.5 Haiku | `anthropic.claude-3-5-haiku-20241022-v1:0` | Fast and affordable | Real-time applications |
|
||||
|
||||
### Llama (Meta)
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Llama 3.3 70B | `meta.llama3-3-70b-instruct-v1:0` | Latest generation | Complex reasoning, general use |
|
||||
| Llama 3.2 90B | `meta.llama3-2-90b-instruct-v1:0` | Large context | Long context tasks |
|
||||
| Llama 3.2 11B | `meta.llama3-2-11b-instruct-v1:0` | Medium model | Balanced performance |
|
||||
| Llama 3.2 3B | `meta.llama3-2-3b-instruct-v1:0` | Small model | Fast inference |
|
||||
| Llama 3.2 1B | `meta.llama3-2-1b-instruct-v1:0` | Ultra-fast | Quick responses |
|
||||
| Llama 3.1 70B | `meta.llama3-1-70b-instruct-v1:0` | Previous gen | General use |
|
||||
| Llama 3.1 8B | `meta.llama3-1-8b-instruct-v1:0` | Fast small model | Lightweight applications |
|
||||
|
||||
### Mistral AI
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Mistral Large 2407 | `mistral.mistral-large-2407-v1:0` | Latest large model | Complex reasoning |
|
||||
| Mistral Large 2402 | `mistral.mistral-large-2402-v1:0` | Previous large model | General use |
|
||||
| Mistral Pixtral 2502 | `mistral.pixtral-large-2502-v1:0` | Multimodal | Text + image understanding |
|
||||
| Mistral 7B | `mistral.mistral-7b-instruct-v0:2` | Small fast model | Quick responses |
|
||||
|
||||
### Amazon
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Titan Text Express | `amazon.titan-text-express-v1` | Fast text generation | Quick responses |
|
||||
| Titan Text Lite | `amazon.titan-text-lite-v1` | Cost-effective | Budget-sensitive apps |
|
||||
| Titan Embeddings | `amazon.titan-embed-text-v1` | Text embeddings | Semantic search |
|
||||
|
||||
### Cohere
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Command R+ | `cohere.command-r-plus-v1:0` | High performance | Complex tasks |
|
||||
| Command R | `cohere.command-r-v1:0` | General purpose | Standard use cases |
|
||||
|
||||
## Image Generation Models
|
||||
|
||||
### Stability AI
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Stable Diffusion 3.5 Large | `stability.sd3-5-large-v1:0` | Latest image gen | High-quality images |
|
||||
| Stable Diffusion XL | `stability.stable-diffusion-xl-v1` | Previous generation | General image generation |
|
||||
|
||||
### Amazon Nova
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Nova Canvas | `amazon.nova-canvas-v1:0` | Image generation | Creative images |
|
||||
| Nova Reel | `amazon.nova-reel-v1:1` | Video generation | Video content |
|
||||
|
||||
## Embedding Models
|
||||
|
||||
### Amazon
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Titan Embeddings | `amazon.titan-embed-text-v1` | Text embeddings | Semantic search |
|
||||
| Titan Embeddings V2 | `amazon.titan-embed-text-v2:0` | Improved embeddings | Better accuracy |
|
||||
|
||||
### Cohere
|
||||
| Model | Model ID | Description | Use Case |
|
||||
|-------|----------|-------------|----------|
|
||||
| Embed English | `cohere.embed-english-v3` | English embeddings | English content |
|
||||
| Embed Multilingual | `cohere.embed-multilingual-v3` | Multi-language | International use |
|
||||
|
||||
## Selection Guide
|
||||
|
||||
### By Speed
|
||||
1. **Fastest**: Llama 3.2 1B, Claude 4.5 Haiku, Titan Lite
|
||||
2. **Fast**: Mistral 7B, Llama 3.2 3B
|
||||
3. **Medium**: Claude 3.5 Sonnet, Llama 3.2 11B
|
||||
4. **Slow**: Claude 4.5 Sonnet, Llama 3.3 70B
|
||||
|
||||
### By Quality
|
||||
1. **Highest**: Claude 4.5 Sonnet, Claude 3.7 Sonnet, Claude Opus 4.1
|
||||
2. **High**: Claude 3.5 Sonnet, Llama 3.3 70B
|
||||
3. **Medium**: Mistral Large, Llama 3.2 11B
|
||||
4. **Basic**: Mistral 7B, Llama 3.2 3B
|
||||
|
||||
### By Cost
|
||||
1. **Most Affordable**: Claude 4.5 Haiku, Llama 3.2 1B
|
||||
2. **Affordable**: Mistral 7B, Titan Lite
|
||||
3. **Medium**: Claude 3.5 Haiku, Llama 3.2 3B
|
||||
4. **Expensive**: Claude 4.5 Sonnet, Llama 3.3 70B
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Default Model Selection
|
||||
```java
|
||||
// For most applications
|
||||
String DEFAULT_MODEL = "anthropic.claude-sonnet-4-5-20250929-v1:0";
|
||||
|
||||
// For real-time applications
|
||||
String FAST_MODEL = "anthropic.claude-haiku-4-5-20251001-v1:0";
|
||||
|
||||
// For budget-sensitive applications
|
||||
String CHEAP_MODEL = "amazon.titan-text-lite-v1";
|
||||
|
||||
// For complex reasoning
|
||||
String POWERFUL_MODEL = "anthropic.claude-3-7-sonnet-20250219-v1:0";
|
||||
```
|
||||
|
||||
### Model Fallback Chain
|
||||
```java
|
||||
private static final String[] MODEL_CHAIN = {
|
||||
"anthropic.claude-sonnet-4-5-20250929-v1:0", // Primary
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0", // Fast fallback
|
||||
"amazon.titan-text-lite-v1" // Cheap fallback
|
||||
};
|
||||
```
|
||||
@@ -0,0 +1,365 @@
|
||||
# Testing Strategies
|
||||
|
||||
## Unit Testing
|
||||
|
||||
### Mocking Bedrock Clients
|
||||
|
||||
```java
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class BedrockServiceTest {
|
||||
|
||||
@Mock
|
||||
private BedrockRuntimeClient bedrockRuntimeClient;
|
||||
|
||||
@InjectMocks
|
||||
private BedrockAIService aiService;
|
||||
|
||||
@Test
|
||||
void shouldGenerateTextWithClaude() {
|
||||
// Arrange
|
||||
String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
String prompt = "Hello, world!";
|
||||
String expectedResponse = "Hello! How can I help you today?";
|
||||
|
||||
InvokeModelResponse mockResponse = InvokeModelResponse.builder()
|
||||
.body(SdkBytes.fromUtf8String(
|
||||
"{\"content\":[{\"text\":\"" + expectedResponse + "\"}]}"))
|
||||
.build();
|
||||
|
||||
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
|
||||
.thenReturn(mockResponse);
|
||||
|
||||
// Act
|
||||
String result = aiService.generateText(prompt, modelId);
|
||||
|
||||
// Assert
|
||||
assertThat(result).isEqualTo(expectedResponse);
|
||||
verify(bedrockRuntimeClient).invokeModel(argThat(request ->
|
||||
request.modelId().equals(modelId)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleThrottling() {
|
||||
// Arrange
|
||||
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
|
||||
.thenThrow(ThrottlingException.builder()
|
||||
.message("Rate limit exceeded")
|
||||
.build());
|
||||
|
||||
// Act & Assert
|
||||
assertThatThrownBy(() -> aiService.generateText("test"))
|
||||
.isInstanceOf(RuntimeException.class)
|
||||
.hasMessageContaining("Rate limit exceeded");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Error Conditions
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldHandleInvalidModelId() {
|
||||
String invalidModelId = "invalid.model.id";
|
||||
String prompt = "test";
|
||||
|
||||
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
|
||||
.thenThrow(ValidationException.builder()
|
||||
.message("Invalid model identifier")
|
||||
.build());
|
||||
|
||||
assertThatThrownBy(() -> aiService.generateText(prompt, invalidModelId))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("Invalid model identifier");
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Multiple Models
|
||||
|
||||
```java
|
||||
@ParameterizedTest
|
||||
@EnumSource(ModelProvider.class)
|
||||
void shouldSupportAllModels(ModelProvider modelProvider) {
|
||||
String prompt = "Hello";
|
||||
String modelId = modelProvider.getModelId();
|
||||
String expectedResponse = "Response";
|
||||
|
||||
InvokeModelResponse mockResponse = InvokeModelResponse.builder()
|
||||
.body(SdkBytes.fromUtf8String(createMockResponse(modelProvider, expectedResponse)))
|
||||
.build();
|
||||
|
||||
when(bedrockRuntimeClient.invokeModel(any(InvokeModelRequest.class)))
|
||||
.thenReturn(mockResponse);
|
||||
|
||||
String result = aiService.generateText(prompt, modelId);
|
||||
|
||||
assertThat(result).isEqualTo(expectedResponse);
|
||||
}
|
||||
|
||||
private enum ModelProvider {
|
||||
CLAUDE("anthropic.claude-3-sonnet-20240229-v1:0"),
|
||||
LLAMA("meta.llama3-70b-instruct-v1:0"),
|
||||
TITAN("amazon.titan-text-express-v1");
|
||||
|
||||
private final String modelId;
|
||||
|
||||
ModelProvider(String modelId) {
|
||||
this.modelId = modelId;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Testing
|
||||
|
||||
### Testcontainers Integration
|
||||
|
||||
```java
|
||||
@Testcontainers
|
||||
@SpringBootTest(classes = BedrockConfiguration.class)
|
||||
@ActiveProfiles("test")
|
||||
class BedrockIntegrationTest {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localStack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:latest"))
|
||||
.withServices(AWSService BEDROCK_RUNTIME)
|
||||
.withEnv("DEFAULT_REGION", "us-east-1");
|
||||
|
||||
@Autowired
|
||||
private BedrockRuntimeClient bedrockRuntimeClient;
|
||||
|
||||
@Test
|
||||
void shouldConnectToLocalStack() {
|
||||
assertThat(bedrockRuntimeClient).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldListFoundationModels() {
|
||||
ListFoundationModelsResponse response =
|
||||
bedrockRuntimeClient.listFoundationModels();
|
||||
|
||||
assertThat(response.modelSummaries()).isNotEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### LocalStack Configuration
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class LocalStackConfig {
|
||||
|
||||
@Value("${localstack.enabled:true}")
|
||||
private boolean localStackEnabled;
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(name = "localstack.enabled", havingValue = "true")
|
||||
public AwsCredentialsProvider localStackCredentialsProvider() {
|
||||
return StaticCredentialsProvider.create(
|
||||
new AwsBasicCredentialsAccessKey("test", "test"));
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(name = "localstack.enabled", havingValue = "true")
|
||||
public BedrockRuntimeClient localStackBedrockRuntimeClient(
|
||||
AwsCredentialsProvider credentialsProvider) {
|
||||
|
||||
return BedrockRuntimeClient.builder()
|
||||
.credentialsProvider(credentialsProvider)
|
||||
.endpointOverride(localStack.getEndpoint())
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Performance Testing
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldPerformWithinTimeLimit() {
|
||||
String prompt = "Performance test prompt";
|
||||
int iterationCount = 100;
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
for (int i = 0; i < iterationCount; i++) {
|
||||
InvokeModelResponse response = bedrockRuntimeClient.invokeModel(
|
||||
request -> request
|
||||
.modelId("anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
.body(SdkBytes.fromUtf8String(createPayload(prompt))));
|
||||
}
|
||||
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
double avgTimePerRequest = (double) duration / iterationCount;
|
||||
|
||||
assertThat(avgTimePerRequest).isLessThan(5000); // Less than 5 seconds per request
|
||||
System.out.println("Average response time: " + avgTimePerRequest + "ms");
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Streaming Responses
|
||||
|
||||
### Streaming Handler Testing
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldStreamResponse() throws InterruptedException {
|
||||
String prompt = "Stream this response";
|
||||
|
||||
MockStreamHandler mockHandler = new MockStreamHandler();
|
||||
|
||||
InvokeModelWithResponseStreamRequest streamRequest =
|
||||
InvokeModelWithResponseStreamRequest.builder()
|
||||
.modelId("anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
.body(SdkBytes.fromUtf8String(createPayload(prompt)))
|
||||
.build();
|
||||
|
||||
bedrockRuntimeClient.invokeModelWithResponseStream(streamRequest, mockHandler);
|
||||
|
||||
// Wait for streaming to complete
|
||||
mockHandler.awaitCompletion(10, TimeUnit.SECONDS);
|
||||
|
||||
assertThat(mockHandler.getStreamedContent()).isNotEmpty();
|
||||
assertThat(mockHandler.getStreamedContent()).contains(" streamed");
|
||||
}
|
||||
|
||||
private static class MockStreamHandler extends
|
||||
InvokeModelWithResponseStreamResponseHandler.Visitor {
|
||||
|
||||
private final StringBuilder contentBuilder = new StringBuilder();
|
||||
private final CountDownLatch latch = new CountDownLatch(1);
|
||||
|
||||
@Override
|
||||
public void visit(EventStream eventStream) {
|
||||
eventStream.forEach(event -> {
|
||||
if (event instanceof PayloadPart) {
|
||||
PayloadPart payloadPart = (PayloadPart) event;
|
||||
String chunk = payloadPart.bytes().asUtf8String();
|
||||
contentBuilder.append(chunk);
|
||||
}
|
||||
});
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
public String getStreamedContent() {
|
||||
return contentBuilder.toString();
|
||||
}
|
||||
|
||||
public void awaitCompletion(long timeout, TimeUnit unit)
|
||||
throws InterruptedException {
|
||||
latch.await(timeout, unit);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Configuration
|
||||
|
||||
### Testing Different Regions
|
||||
|
||||
```java
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = Region.class,
|
||||
names = {"US_EAST_1", "US_WEST_2", "EU_WEST_1"})
|
||||
void shouldWorkInAllRegions(Region region) {
|
||||
BedrockRuntimeClient client = BedrockRuntimeClient.builder()
|
||||
.region(region)
|
||||
.build();
|
||||
|
||||
assertThat(client).isNotNull();
|
||||
}
|
||||
|
||||
### Testing Authentication
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldUseIamRoleForAuthentication() {
|
||||
BedrockRuntimeClient client = BedrockRuntimeClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Test that client can make basic calls
|
||||
ListFoundationModelsResponse response = client.listFoundationModels();
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
}
|
||||
```
|
||||
|
||||
## Test Data Management
|
||||
|
||||
### Test Response Fixtures
|
||||
|
||||
```java
|
||||
public class BedrockTestFixtures {
|
||||
|
||||
public static String createClaudeResponse() {
|
||||
return "{\"content\":[{\"text\":\"Hello! How can I help you today?\"}]}";
|
||||
}
|
||||
|
||||
public static String createLlamaResponse() {
|
||||
return "{\"generation\":\"Hello! How can I assist you?\"}";
|
||||
}
|
||||
|
||||
public static String createTitanResponse() {
|
||||
return "{\"results\":[{\"outputText\":\"Hello! How can I help?\"}]}";
|
||||
}
|
||||
|
||||
public static String createPayload(String prompt) {
|
||||
return new JSONObject()
|
||||
.put("anthropic_version", "bedrock-2023-05-31")
|
||||
.put("max_tokens", 1000)
|
||||
.put("messages", new JSONObject[]{
|
||||
new JSONObject()
|
||||
.put("role", "user")
|
||||
.put("content", prompt)
|
||||
})
|
||||
.toString();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Test Suite
|
||||
|
||||
```java
|
||||
@Suite
|
||||
@SelectClasses({
|
||||
BedrockAIServiceTest.class,
|
||||
BedrockConfigurationTest.class,
|
||||
BedrockStreamingTest.class,
|
||||
BedrockErrorHandlingTest.class
|
||||
})
|
||||
public class BedrockTestSuite {
|
||||
// Integration test suite for all Bedrock functionality
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Unit Testing Best Practices
|
||||
|
||||
1. **Mock External Dependencies:** Always mock AWS SDK clients in unit tests
|
||||
2. **Test Error Scenarios:** Include tests for throttling, validation errors, and network issues
|
||||
3. **Parameterized Tests:** Test multiple models and configurations efficiently
|
||||
4. **Performance Assertions:** Include basic performance benchmarks
|
||||
5. **Test Data Fixtures:** Reuse test response data across tests
|
||||
|
||||
### Integration Testing Best Practices
|
||||
|
||||
1. **Use LocalStack:** Test against LocalStack for local development
|
||||
2. **Test Multiple Regions:** Verify functionality across different AWS regions
|
||||
3. **Test Edge Cases:** Include timeout, retry, and concurrent request scenarios
|
||||
4. **Monitor Performance:** Track response times and error rates
|
||||
5. **Clean Up Resources:** Ensure proper cleanup after integration tests
|
||||
|
||||
### Testing Configuration
|
||||
|
||||
```properties
|
||||
# application-test.properties
|
||||
localstack.enabled=true
|
||||
aws.region=us-east-1
|
||||
bedrock.timeout=5000
|
||||
bedrock.retry.max-attempts=3
|
||||
```
|
||||
660
skills/aws-java/aws-sdk-java-v2-core/SKILL.md
Normal file
660
skills/aws-java/aws-sdk-java-v2-core/SKILL.md
Normal file
@@ -0,0 +1,660 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-core
|
||||
description: Core patterns and best practices for AWS SDK for Java 2.x. Use when configuring AWS service clients, setting up authentication, managing credentials, configuring timeouts, HTTP clients, or following AWS SDK best practices.
|
||||
category: aws
|
||||
tags: [aws, java, sdk, core, authentication, configuration]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - Core Patterns
|
||||
|
||||
## Overview
|
||||
|
||||
Configure AWS service clients, authentication, timeouts, HTTP clients, and implement best practices for AWS SDK for Java 2.x applications. This skill provides essential patterns for building robust, performant, and secure integrations with AWS services.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Setting up AWS SDK for Java 2.x service clients with proper configuration
|
||||
- Configuring authentication and credential management strategies
|
||||
- Implementing client lifecycle management and resource cleanup
|
||||
- Optimizing performance with HTTP client configuration and connection pooling
|
||||
- Setting up proper timeout configurations for API calls
|
||||
- Implementing error handling and retry policies
|
||||
- Enabling monitoring and metrics collection
|
||||
- Integrating AWS SDK with Spring Boot applications
|
||||
- Testing AWS integrations with LocalStack and Testcontainers
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Service Client Setup
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
|
||||
// Basic client with region
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Always close clients when done
|
||||
try (S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build()) {
|
||||
// Use client
|
||||
} // Auto-closed
|
||||
```
|
||||
|
||||
### Basic Authentication
|
||||
|
||||
```java
|
||||
// Uses default credential provider chain
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build(); // Automatically detects credentials
|
||||
```
|
||||
|
||||
## Client Configuration
|
||||
|
||||
### Service Client Builder Pattern
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
|
||||
import software.amazon.awssdk.http.apache.ApacheHttpClient;
|
||||
import software.amazon.awssdk.http.apache.ProxyConfiguration;
|
||||
import software.amazon.awssdk.metrics.publishers.cloudwatch.CloudWatchMetricPublisher;
|
||||
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
|
||||
import java.time.Duration;
|
||||
import java.net.URI;
|
||||
|
||||
// Advanced client configuration
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.EU_SOUTH_2)
|
||||
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.overrideConfiguration(b -> b
|
||||
.apiCallTimeout(Duration.ofSeconds(30))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(10))
|
||||
.addMetricPublisher(CloudWatchMetricPublisher.create()))
|
||||
.httpClientBuilder(ApacheHttpClient.builder()
|
||||
.maxConnections(100)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.proxyConfiguration(ProxyConfiguration.builder()
|
||||
.endpoint(URI.create("http://proxy:8080"))
|
||||
.build()))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Separate Configuration Objects
|
||||
|
||||
```java
|
||||
ClientOverrideConfiguration clientConfig = ClientOverrideConfiguration.builder()
|
||||
.apiCallTimeout(Duration.ofSeconds(30))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(10))
|
||||
.addMetricPublisher(CloudWatchMetricPublisher.create())
|
||||
.build();
|
||||
|
||||
ApacheHttpClient httpClient = ApacheHttpClient.builder()
|
||||
.maxConnections(100)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.build();
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.EU_SOUTH_2)
|
||||
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.overrideConfiguration(clientConfig)
|
||||
.httpClient(httpClient)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Authentication and Credentials
|
||||
|
||||
### Default Credentials Provider Chain
|
||||
|
||||
```java
|
||||
// SDK automatically uses default credential provider chain:
|
||||
// 1. Java system properties (aws.accessKeyId and aws.secretAccessKey)
|
||||
// 2. Environment variables (AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY)
|
||||
// 3. Web identity token from AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
// 4. Shared credentials and config files (~/.aws/credentials and ~/.aws/config)
|
||||
// 5. Amazon ECS container credentials
|
||||
// 6. Amazon EC2 instance profile credentials
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build(); // Uses default credential provider chain
|
||||
```
|
||||
|
||||
### Explicit Credentials Providers
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.auth.credentials.*;
|
||||
|
||||
// Environment variables
|
||||
CredentialsProvider envCredentials = EnvironmentVariableCredentialsProvider.create();
|
||||
|
||||
// Profile from ~/.aws/credentials
|
||||
CredentialsProvider profileCredentials = ProfileCredentialsProvider.create("myprofile");
|
||||
|
||||
// Static credentials (NOT recommended for production)
|
||||
CredentialsProvider staticCredentials = StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create("accessKeyId", "secretAccessKey")
|
||||
);
|
||||
|
||||
// Instance profile (for EC2)
|
||||
CredentialsProvider instanceProfileCredentials = InstanceProfileCredentialsProvider.create();
|
||||
|
||||
// Use with client
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.credentialsProvider(profileCredentials)
|
||||
.build();
|
||||
```
|
||||
|
||||
### SSO Authentication Setup
|
||||
|
||||
```properties
|
||||
# ~/.aws/config
|
||||
[default]
|
||||
sso_session = my-sso
|
||||
sso_account_id = 111122223333
|
||||
sso_role_name = SampleRole
|
||||
region = us-east-1
|
||||
output = json
|
||||
|
||||
[sso-session my-sso]
|
||||
sso_region = us-east-1
|
||||
sso_start_url = https://provided-domain.awsapps.com/start
|
||||
sso_registration_scopes = sso:account:access
|
||||
```
|
||||
|
||||
```bash
|
||||
# Login before running application
|
||||
aws sso login
|
||||
|
||||
# Verify active session
|
||||
aws sts get-caller-identity
|
||||
```
|
||||
|
||||
## HTTP Client Configuration
|
||||
|
||||
### Apache HTTP Client (Recommended for Sync)
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.http.apache.ApacheHttpClient;
|
||||
|
||||
ApacheHttpClient httpClient = ApacheHttpClient.builder()
|
||||
.maxConnections(100)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.connectionTimeToLive(Duration.ofMinutes(5))
|
||||
.expectContinueEnabled(true)
|
||||
.build();
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.httpClient(httpClient)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Netty HTTP Client (For Async Operations)
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
|
||||
import software.amazon.awssdk.http.nio.netty.SslProvider;
|
||||
|
||||
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
|
||||
.maxConcurrency(100)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.readTimeout(Duration.ofSeconds(30))
|
||||
.writeTimeout(Duration.ofSeconds(30))
|
||||
.sslProvider(SslProvider.OPENSSL) // Better performance than JDK
|
||||
.build();
|
||||
|
||||
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
|
||||
.httpClient(httpClient)
|
||||
.build();
|
||||
```
|
||||
|
||||
### URL Connection HTTP Client (Lightweight)
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
|
||||
|
||||
UrlConnectionHttpClient httpClient = UrlConnectionHttpClient.builder()
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Reuse Service Clients
|
||||
|
||||
**DO:**
|
||||
```java
|
||||
@Service
|
||||
public class S3Service {
|
||||
private final S3Client s3Client;
|
||||
|
||||
public S3Service() {
|
||||
this.s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
// Reuse s3Client for all operations
|
||||
}
|
||||
```
|
||||
|
||||
**DON'T:**
|
||||
```java
|
||||
public void uploadFile(String bucket, String key) {
|
||||
// Creates new client each time - wastes resources!
|
||||
S3Client s3 = S3Client.builder().build();
|
||||
s3.putObject(...);
|
||||
s3.close();
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Configure API Timeouts
|
||||
|
||||
```java
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.overrideConfiguration(b -> b
|
||||
.apiCallTimeout(Duration.ofSeconds(30))
|
||||
.apiCallAttemptTimeout(Duration.ofMillis(5000)))
|
||||
.build();
|
||||
```
|
||||
|
||||
### 3. Close Unused Clients
|
||||
|
||||
```java
|
||||
// Try-with-resources
|
||||
try (S3Client s3 = S3Client.builder().build()) {
|
||||
s3.listBuckets();
|
||||
}
|
||||
|
||||
// Explicit close
|
||||
S3Client s3Client = S3Client.builder().build();
|
||||
try {
|
||||
s3Client.listBuckets();
|
||||
} finally {
|
||||
s3Client.close();
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Close Streaming Responses
|
||||
|
||||
```java
|
||||
try (ResponseInputStream<GetObjectResponse> s3Object =
|
||||
s3Client.getObject(GetObjectRequest.builder()
|
||||
.bucket(bucket)
|
||||
.key(key)
|
||||
.build())) {
|
||||
|
||||
// Read and process stream immediately
|
||||
byte[] data = s3Object.readAllBytes();
|
||||
|
||||
} // Stream auto-closed, connection returned to pool
|
||||
```
|
||||
|
||||
### 5. Optimize SSL for Async Clients
|
||||
|
||||
**Add dependency:**
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty-tcnative-boringssl-static</artifactId>
|
||||
<version>2.0.61.Final</version>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
**Configure SSL:**
|
||||
```java
|
||||
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
|
||||
.sslProvider(SslProvider.OPENSSL)
|
||||
.build();
|
||||
|
||||
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
|
||||
.httpClient(httpClient)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Spring Boot Integration
|
||||
|
||||
### Configuration Properties
|
||||
|
||||
```java
|
||||
@ConfigurationProperties(prefix = "aws")
|
||||
public record AwsProperties(
|
||||
String region,
|
||||
String accessKeyId,
|
||||
String secretAccessKey,
|
||||
S3Properties s3,
|
||||
DynamoDbProperties dynamoDb
|
||||
) {
|
||||
public record S3Properties(
|
||||
Integer maxConnections,
|
||||
Integer connectionTimeoutSeconds,
|
||||
Integer apiCallTimeoutSeconds
|
||||
) {}
|
||||
|
||||
public record DynamoDbProperties(
|
||||
Integer maxConnections,
|
||||
Integer readTimeoutSeconds
|
||||
) {}
|
||||
}
|
||||
```
|
||||
|
||||
### Client Configuration Beans
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(AwsProperties.class)
|
||||
public class AwsClientConfiguration {
|
||||
|
||||
private final AwsProperties awsProperties;
|
||||
|
||||
public AwsClientConfiguration(AwsProperties awsProperties) {
|
||||
this.awsProperties = awsProperties;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public S3Client s3Client() {
|
||||
return S3Client.builder()
|
||||
.region(Region.of(awsProperties.region()))
|
||||
.credentialsProvider(credentialsProvider())
|
||||
.overrideConfiguration(clientOverrideConfiguration(
|
||||
awsProperties.s3().apiCallTimeoutSeconds()))
|
||||
.httpClient(apacheHttpClient(
|
||||
awsProperties.s3().maxConnections(),
|
||||
awsProperties.s3().connectionTimeoutSeconds()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private CredentialsProvider credentialsProvider() {
|
||||
if (awsProperties.accessKeyId() != null &&
|
||||
awsProperties.secretAccessKey() != null) {
|
||||
return StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
awsProperties.accessKeyId(),
|
||||
awsProperties.secretAccessKey()));
|
||||
}
|
||||
return DefaultCredentialsProvider.create();
|
||||
}
|
||||
|
||||
private ClientOverrideConfiguration clientOverrideConfiguration(
|
||||
Integer apiCallTimeoutSeconds) {
|
||||
return ClientOverrideConfiguration.builder()
|
||||
.apiCallTimeout(Duration.ofSeconds(
|
||||
apiCallTimeoutSeconds != null ? apiCallTimeoutSeconds : 30))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(10))
|
||||
.build();
|
||||
}
|
||||
|
||||
private ApacheHttpClient apacheHttpClient(
|
||||
Integer maxConnections,
|
||||
Integer connectionTimeoutSeconds) {
|
||||
return ApacheHttpClient.builder()
|
||||
.maxConnections(maxConnections != null ? maxConnections : 50)
|
||||
.connectionTimeout(Duration.ofSeconds(
|
||||
connectionTimeoutSeconds != null ? connectionTimeoutSeconds : 5))
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Application Properties
|
||||
|
||||
```yaml
|
||||
aws:
|
||||
region: us-east-1
|
||||
s3:
|
||||
max-connections: 100
|
||||
connection-timeout-seconds: 5
|
||||
api-call-timeout-seconds: 30
|
||||
dynamo-db:
|
||||
max-connections: 50
|
||||
read-timeout-seconds: 30
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.s3.model.S3Exception;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.core.exception.SdkServiceException;
|
||||
|
||||
try {
|
||||
s3Client.getObject(request);
|
||||
|
||||
} catch (S3Exception e) {
|
||||
// Service-specific exception
|
||||
System.err.println("S3 Error: " + e.awsErrorDetails().errorMessage());
|
||||
System.err.println("Error Code: " + e.awsErrorDetails().errorCode());
|
||||
System.err.println("Status Code: " + e.statusCode());
|
||||
System.err.println("Request ID: " + e.requestId());
|
||||
|
||||
} catch (SdkServiceException e) {
|
||||
// Generic service exception
|
||||
System.err.println("AWS Service Error: " + e.getMessage());
|
||||
|
||||
} catch (SdkClientException e) {
|
||||
// Client-side error (network, timeout, etc.)
|
||||
System.err.println("Client Error: " + e.getMessage());
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Patterns
|
||||
|
||||
### LocalStack Integration
|
||||
|
||||
```java
|
||||
@TestConfiguration
|
||||
public class LocalStackAwsConfig {
|
||||
|
||||
@Bean
|
||||
public S3Client s3Client() {
|
||||
return S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.endpointOverride(URI.create("http://localhost:4566"))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create("test", "test")))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testcontainers with LocalStack
|
||||
|
||||
```java
|
||||
@Testcontainers
|
||||
@SpringBootTest
|
||||
class S3IntegrationTest {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(LocalStackContainer.Service.S3);
|
||||
|
||||
@DynamicPropertySource
|
||||
static void overrideProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("aws.s3.endpoint",
|
||||
() -> localstack.getEndpointOverride(LocalStackContainer.Service.S3));
|
||||
registry.add("aws.region", () -> localstack.getRegion());
|
||||
registry.add("aws.access-key-id", localstack::getAccessKey);
|
||||
registry.add("aws.secret-access-key", localstack::getSecretKey);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Maven Dependencies
|
||||
|
||||
```xml
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bom</artifactId>
|
||||
<version>2.25.0</version> // Use latest stable version
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
<dependencies>
|
||||
<!-- Core SDK -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>sdk-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Apache HTTP Client (recommended for sync) -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>apache-client</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Netty HTTP Client (for async) -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>netty-nio-client</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- URL Connection HTTP Client (lightweight) -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>url-connection-client</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- CloudWatch Metrics -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>cloudwatch-metric-publisher</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- OpenSSL for better performance -->
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty-tcnative-boringssl-static</artifactId>
|
||||
<version>2.0.61.Final</version>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
```
|
||||
|
||||
## Gradle Dependencies
|
||||
|
||||
```gradle
|
||||
dependencies {
|
||||
implementation platform('software.amazon.awssdk:bom:2.25.0')
|
||||
|
||||
implementation 'software.amazon.awssdk:sdk-core'
|
||||
implementation 'software.amazon.awssdk:apache-client'
|
||||
implementation 'software.amazon.awssdk:netty-nio-client'
|
||||
implementation 'software.amazon.awssdk:cloudwatch-metric-publisher'
|
||||
|
||||
runtimeOnly 'io.netty:netty-tcnative-boringssl-static:2.0.61.Final'
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic S3 Upload
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.sync.RequestBody;
|
||||
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
|
||||
|
||||
try (S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build()) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket("my-bucket")
|
||||
.key("uploads/file.txt")
|
||||
.build();
|
||||
|
||||
s3.putObject(request, RequestBody.fromString("Hello, World!"));
|
||||
}
|
||||
```
|
||||
|
||||
### S3 List Objects with Pagination
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
|
||||
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
|
||||
|
||||
try (S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build()) {
|
||||
ListObjectsV2Request request = ListObjectsV2Request.builder()
|
||||
.bucket("my-bucket")
|
||||
.build();
|
||||
|
||||
ListObjectsV2Response response = s3.listObjectsV2(request);
|
||||
response.contents().forEach(object -> {
|
||||
System.out.println("Object key: " + object.key());
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### Async S3 Upload
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.async.AsyncRequestBody;
|
||||
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
|
||||
|
||||
S3AsyncClient s3AsyncClient = S3AsyncClient.builder().build();
|
||||
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket("my-bucket")
|
||||
.key("async-upload.txt")
|
||||
.build();
|
||||
|
||||
CompletableFuture<PutObjectResponse> future = s3AsyncClient.putObject(
|
||||
request, Async.fromString("Hello, Async World!"));
|
||||
|
||||
future.thenAccept(response -> {
|
||||
System.out.println("Upload completed: " + response.eTag());
|
||||
}).exceptionally(error -> {
|
||||
System.err.println("Upload failed: " + error.getMessage());
|
||||
return null;
|
||||
});
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Connection Pooling**: Default max connections is 50. Increase for high-throughput applications.
|
||||
2. **Timeouts**: Always set both `apiCallTimeout` and `apiCallAttemptTimeout`.
|
||||
3. **Client Reuse**: Create clients once, reuse throughout application lifecycle.
|
||||
4. **Stream Handling**: Close streams immediately to prevent connection pool exhaustion.
|
||||
5. **Async for I/O**: Use async clients for I/O-bound operations.
|
||||
6. **OpenSSL**: Use OpenSSL with Netty for better SSL performance.
|
||||
7. **Metrics**: Enable CloudWatch metrics to monitor performance.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Never hardcode credentials**: Use credential providers or environment variables.
|
||||
2. **Use IAM roles**: Prefer IAM roles over access keys when possible.
|
||||
3. **Rotate credentials**: Implement credential rotation for long-lived keys.
|
||||
4. **Least privilege**: Grant minimum required permissions.
|
||||
5. **Enable SSL**: Always use HTTPS endpoints (default).
|
||||
6. **Audit logging**: Enable AWS CloudTrail for API call auditing.
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `aws-sdk-java-v2-s3` - S3-specific patterns and examples
|
||||
- `aws-sdk-java-v2-dynamodb` - DynamoDB patterns and examples
|
||||
- `aws-sdk-java-v2-lambda` - Lambda patterns and examples
|
||||
|
||||
## References
|
||||
|
||||
See [references/](references/) for detailed documentation:
|
||||
|
||||
- [Developer Guide](references/developer-guide.md) - Comprehensive guide and architecture overview
|
||||
- [API Reference](references/api-reference.md) - Complete API documentation for core classes
|
||||
- [Best Practices](references/best-practices.md) - In-depth best practices and configuration examples
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [AWS SDK for Java 2.x Developer Guide](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html)
|
||||
- [AWS SDK for Java 2.x API Reference](https://sdk.amazonaws.com/java/api/latest/)
|
||||
- [Best Practices](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/best-practices.html)
|
||||
- [GitHub Repository](https://github.com/aws/aws-sdk-java-v2)
|
||||
258
skills/aws-java/aws-sdk-java-v2-core/references/api-reference.md
Normal file
258
skills/aws-java/aws-sdk-java-v2-core/references/api-reference.md
Normal file
@@ -0,0 +1,258 @@
|
||||
# AWS SDK for Java 2.x API Reference
|
||||
|
||||
## Core Client Classes
|
||||
|
||||
### AwsClient
|
||||
Base interface for all AWS service clients.
|
||||
|
||||
```java
|
||||
public interface AwsClient extends AutoCloseable {
|
||||
// Base client interface
|
||||
}
|
||||
```
|
||||
|
||||
### SdkClient
|
||||
Enhanced client interface with SDK-specific features.
|
||||
|
||||
```java
|
||||
public interface SdkClient extends AwsClient {
|
||||
// Enhanced client methods
|
||||
}
|
||||
```
|
||||
|
||||
## Client Builders
|
||||
|
||||
### ClientBuilder
|
||||
Base builder interface for all AWS service clients.
|
||||
|
||||
**Key Methods:**
|
||||
- `region(Region region)` - Set AWS region
|
||||
- `credentialsProvider(CredentialsProvider credentialsProvider)` - Configure authentication
|
||||
- `overrideConfiguration(ClientOverrideConfiguration overrideConfiguration)` - Override default settings
|
||||
- `httpClient(HttpClient httpClient)` - Specify HTTP client implementation
|
||||
- `build()` - Create client instance
|
||||
|
||||
## Configuration Classes
|
||||
|
||||
### ClientOverrideConfiguration
|
||||
Controls client-level configuration including timeouts and metrics.
|
||||
|
||||
**Key Properties:**
|
||||
- `apiCallTimeout(Duration)` - Total timeout for all retry attempts
|
||||
- `apiCallAttemptTimeout(Duration)` - Timeout per individual attempt
|
||||
- `retryPolicy(RetryPolicy)` - Retry behavior configuration
|
||||
- `metricPublishers(MetricPublisher...)` - Enable metrics collection
|
||||
|
||||
### Builder Example
|
||||
|
||||
```java
|
||||
ClientOverrideConfiguration config = ClientOverrideConfiguration.builder()
|
||||
.apiCallTimeout(Duration.ofSeconds(30))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(10))
|
||||
.addMetricPublisher(CloudWatchMetricPublisher.create())
|
||||
.build();
|
||||
```
|
||||
|
||||
## HTTP Client Implementations
|
||||
|
||||
### ApacheHttpClient
|
||||
Synchronous HTTP client with advanced features.
|
||||
|
||||
**Builder Configuration:**
|
||||
- `maxConnections(Integer)` - Maximum concurrent connections
|
||||
- `connectionTimeout(Duration)` - Connection establishment timeout
|
||||
- `socketTimeout(Duration)` - Socket read/write timeout
|
||||
- `connectionTimeToLive(Duration)` - Connection lifetime
|
||||
- `proxyConfiguration(ProxyConfiguration)` - Proxy settings
|
||||
|
||||
### NettyNioAsyncHttpClient
|
||||
Asynchronous HTTP client for high-performance applications.
|
||||
|
||||
**Builder Configuration:**
|
||||
- `maxConcurrency(Integer)` - Maximum concurrent operations
|
||||
- `connectionTimeout(Duration)` - Connection timeout
|
||||
- `readTimeout(Duration)` - Read operation timeout
|
||||
- `writeTimeout(Duration)` - Write operation timeout
|
||||
- `sslProvider(SslProvider)` - SSL/TLS implementation
|
||||
|
||||
### UrlConnectionHttpClient
|
||||
Lightweight HTTP client using Java's URLConnection.
|
||||
|
||||
**Builder Configuration:**
|
||||
- `socketTimeout(Duration)` - Socket timeout
|
||||
- `connectTimeout(Duration)` - Connection timeout
|
||||
|
||||
## Authentication and Credentials
|
||||
|
||||
### Credential Providers
|
||||
|
||||
#### EnvironmentVariableCredentialsProvider
|
||||
Reads credentials from environment variables.
|
||||
|
||||
```java
|
||||
CredentialsProvider provider = EnvironmentVariableCredentialsProvider.create();
|
||||
```
|
||||
|
||||
#### SystemPropertyCredentialsProvider
|
||||
Reads credentials from Java system properties.
|
||||
|
||||
```java
|
||||
CredentialsProvider provider = SystemPropertyCredentialsProvider.create();
|
||||
```
|
||||
|
||||
#### ProfileCredentialsProvider
|
||||
Reads credentials from AWS configuration files.
|
||||
|
||||
```java
|
||||
CredentialsProvider provider = ProfileCredentialsProvider.create("profile-name");
|
||||
```
|
||||
|
||||
#### StaticCredentialsProvider
|
||||
Provides static credentials (not recommended for production).
|
||||
|
||||
```java
|
||||
AwsBasicCredentials credentials = AwsBasicCredentials.create("key", "secret");
|
||||
CredentialsProvider provider = StaticCredentialsProvider.create(credentials);
|
||||
```
|
||||
|
||||
#### DefaultCredentialsProvider
|
||||
Implements the default credential provider chain.
|
||||
|
||||
```java
|
||||
CredentialsProvider provider = DefaultCredentialsProvider.create();
|
||||
```
|
||||
|
||||
### SSO Authentication
|
||||
|
||||
#### AwsSsoCredentialsProvider
|
||||
Enables SSO-based authentication.
|
||||
|
||||
```java
|
||||
AwsSsoCredentialsProvider ssoProvider = AwsSsoCredentialsProvider.builder()
|
||||
.ssoProfile("my-sso-profile")
|
||||
.build();
|
||||
```
|
||||
|
||||
## Error Handling Classes
|
||||
|
||||
### SdkClientException
|
||||
Client-side exceptions (network, timeout, configuration issues).
|
||||
|
||||
```java
|
||||
try {
|
||||
awsOperation();
|
||||
} catch (SdkClientException e) {
|
||||
// Handle client-side errors
|
||||
}
|
||||
```
|
||||
|
||||
### SdkServiceException
|
||||
Service-side exceptions (AWS service errors).
|
||||
|
||||
```java
|
||||
try {
|
||||
awsOperation();
|
||||
} catch (SdkServiceException e) {
|
||||
// Handle service-side errors
|
||||
System.err.println("Error Code: " + e.awsErrorDetails().errorCode());
|
||||
System.err.println("Request ID: " + e.requestId());
|
||||
}
|
||||
```
|
||||
|
||||
### S3Exception
|
||||
S3-specific exceptions.
|
||||
|
||||
```java
|
||||
try {
|
||||
s3Operation();
|
||||
} catch (S3Exception e) {
|
||||
// Handle S3-specific errors
|
||||
System.err.println("S3 Error: " + e.awsErrorDetails().errorMessage());
|
||||
}
|
||||
```
|
||||
|
||||
## Metrics and Monitoring
|
||||
|
||||
### CloudWatchMetricPublisher
|
||||
Publishes metrics to AWS CloudWatch.
|
||||
|
||||
```java
|
||||
CloudWatchMetricPublisher publisher = CloudWatchMetricPublisher.create();
|
||||
```
|
||||
|
||||
### MetricPublisher
|
||||
Base interface for custom metrics publishers.
|
||||
|
||||
```java
|
||||
public interface MetricPublisher {
|
||||
void publish(MetricCollection metricCollection);
|
||||
}
|
||||
```
|
||||
|
||||
## Utility Classes
|
||||
|
||||
### Duration and Time
|
||||
Configure timeouts using Java Duration.
|
||||
|
||||
```java
|
||||
Duration apiTimeout = Duration.ofSeconds(30);
|
||||
Duration attemptTimeout = Duration.ofSeconds(10);
|
||||
```
|
||||
|
||||
### Region
|
||||
AWS regions for service endpoints.
|
||||
|
||||
```java
|
||||
Region region = Region.US_EAST_1;
|
||||
Region regionEU = Region.EU_WEST_1;
|
||||
```
|
||||
|
||||
### URI
|
||||
Endpoint configuration and proxy settings.
|
||||
|
||||
```java
|
||||
URI proxyUri = URI.create("http://proxy:8080");
|
||||
URI endpointOverride = URI.create("http://localhost:4566");
|
||||
```
|
||||
|
||||
## Configuration Best Practices
|
||||
|
||||
### Resource Management
|
||||
Always close clients when no longer needed.
|
||||
|
||||
```java
|
||||
try (S3Client s3 = S3Client.builder().build()) {
|
||||
// Use client
|
||||
} // Auto-closed
|
||||
```
|
||||
|
||||
### Connection Pooling
|
||||
Reuse clients to avoid connection pool overhead.
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class AwsService {
|
||||
private final S3Client s3Client;
|
||||
|
||||
public AwsService() {
|
||||
this.s3Client = S3Client.builder().build();
|
||||
}
|
||||
|
||||
// Reuse s3Client throughout application
|
||||
}
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
Implement comprehensive error handling for robust applications.
|
||||
|
||||
```java
|
||||
try {
|
||||
// AWS operation
|
||||
} catch (SdkServiceException e) {
|
||||
// Handle service errors
|
||||
} catch (SdkClientException e) {
|
||||
// Handle client errors
|
||||
} catch (Exception e) {
|
||||
// Handle other errors
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,344 @@
|
||||
# AWS SDK for Java 2.x Best Practices
|
||||
|
||||
## Client Configuration
|
||||
|
||||
### Timeout Configuration
|
||||
Always configure both API call and attempt timeouts to prevent hanging requests.
|
||||
|
||||
```java
|
||||
ClientOverrideConfiguration config = ClientOverrideConfiguration.builder()
|
||||
.apiCallTimeout(Duration.ofSeconds(30)) // Total for all retries
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(10)) // Per-attempt timeout
|
||||
.build();
|
||||
```
|
||||
|
||||
**Best Practices:**
|
||||
- Set `apiCallTimeout` higher than `apiCallAttemptTimeout`
|
||||
- Use appropriate timeouts based on your service's characteristics
|
||||
- Consider network latency and service response times
|
||||
- Monitor timeout metrics to adjust as needed
|
||||
|
||||
### HTTP Client Selection
|
||||
Choose the appropriate HTTP client for your use case.
|
||||
|
||||
#### For Synchronous Applications (Apache HttpClient)
|
||||
```java
|
||||
ApacheHttpClient httpClient = ApacheHttpClient.builder()
|
||||
.maxConnections(100)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
```
|
||||
|
||||
**Best Use Cases:**
|
||||
- Traditional synchronous applications
|
||||
- Medium-throughput operations
|
||||
- When blocking behavior is acceptable
|
||||
|
||||
#### For Asynchronous Applications (Netty NIO Client)
|
||||
```java
|
||||
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
|
||||
.maxConcurrency(100)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.readTimeout(Duration.ofSeconds(30))
|
||||
.writeTimeout(Duration.ofSeconds(30))
|
||||
.sslProvider(SslProvider.OPENSSL)
|
||||
.build();
|
||||
```
|
||||
|
||||
**Best Use Cases:**
|
||||
- High-throughput applications
|
||||
- I/O-bound operations
|
||||
- When non-blocking behavior is required
|
||||
- For improved SSL performance
|
||||
|
||||
#### For Lightweight Applications (URL Connection Client)
|
||||
```java
|
||||
UrlConnectionHttpClient httpClient = UrlConnectionHttpClient.builder()
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
```
|
||||
|
||||
**Best Use Cases:**
|
||||
- Simple applications with low requirements
|
||||
- When minimizing dependencies
|
||||
- For basic operations
|
||||
|
||||
## Authentication and Security
|
||||
|
||||
### Credential Management
|
||||
|
||||
#### Default Provider Chain
|
||||
```java
|
||||
// Use default chain (recommended)
|
||||
S3Client s3Client = S3Client.builder().build();
|
||||
```
|
||||
|
||||
**Default Chain Order:**
|
||||
1. Java system properties (`aws.accessKeyId`, `aws.secretAccessKey`)
|
||||
2. Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`)
|
||||
3. Web identity token from `AWS_WEB_IDENTITY_TOKEN_FILE`
|
||||
4. Shared credentials file (`~/.aws/credentials`)
|
||||
5. Config file (`~/.aws/config`)
|
||||
6. Amazon ECS container credentials
|
||||
7. Amazon EC2 instance profile credentials
|
||||
|
||||
#### Explicit Credential Provider
|
||||
```java
|
||||
// Use specific credential provider
|
||||
CredentialsProvider credentials = ProfileCredentialsProvider.create("my-profile");
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.credentialsProvider(credentials)
|
||||
.build();
|
||||
```
|
||||
|
||||
#### IAM Roles (Preferred for Production)
|
||||
```java
|
||||
// Use IAM role credentials
|
||||
CredentialsProvider instanceProfile = InstanceProfileCredentialsProvider.create();
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.credentialsProvider(instanceProfile)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Never hardcode credentials** - Use credential providers or environment variables
|
||||
2. **Use IAM roles** - Prefer over access keys when possible
|
||||
3. **Implement credential rotation** - For long-lived access keys
|
||||
4. **Apply least privilege** - Grant minimum required permissions
|
||||
5. **Enable SSL** - Always use HTTPS (default behavior)
|
||||
6. **Monitor access** - Enable AWS CloudTrail for auditing
|
||||
7. **Use SSO for human users** - Instead of long-term credentials
|
||||
|
||||
## Resource Management
|
||||
|
||||
### Client Lifecycle
|
||||
```java
|
||||
// Option 1: Try-with-resources (recommended)
|
||||
try (S3Client s3 = S3Client.builder().build()) {
|
||||
// Use client
|
||||
} // Auto-closed
|
||||
|
||||
// Option 2: Explicit close
|
||||
S3Client s3 = S3Client.builder().build();
|
||||
try {
|
||||
// Use client
|
||||
} finally {
|
||||
s3.close();
|
||||
}
|
||||
```
|
||||
|
||||
### Stream Handling
|
||||
Close streams immediately to prevent connection pool exhaustion.
|
||||
|
||||
```java
|
||||
try (ResponseInputStream<GetObjectResponse> response =
|
||||
s3Client.getObject(GetObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(objectKey)
|
||||
.build())) {
|
||||
|
||||
// Read and process data immediately
|
||||
byte[] data = response.readAllBytes();
|
||||
|
||||
} // Stream auto-closed, connection returned to pool
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Connection Pooling
|
||||
```java
|
||||
// Configure connection pooling
|
||||
ApacheHttpClient httpClient = ApacheHttpClient.builder()
|
||||
.maxConnections(100) // Adjust based on your needs
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.connectionTimeToLive(Duration.ofMinutes(5))
|
||||
.build();
|
||||
```
|
||||
|
||||
**Best Practices:**
|
||||
- Set appropriate `maxConnections` based on expected load
|
||||
- Consider connection time to live (TTL)
|
||||
- Monitor connection pool metrics
|
||||
- Use appropriate timeouts
|
||||
|
||||
### SSL Optimization
|
||||
Use OpenSSL with Netty for better SSL performance.
|
||||
|
||||
```xml
|
||||
<!-- Maven dependency -->
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty-tcnative-boringssl-static</artifactId>
|
||||
<version>2.0.61.Final</version>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
```java
|
||||
// Use OpenSSL for async clients
|
||||
NettyNioAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder()
|
||||
.sslProvider(SslProvider.OPENSSL)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Async for I/O-Bound Operations
|
||||
```java
|
||||
// Use async clients for I/O-bound operations
|
||||
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
|
||||
.httpClient(httpClient)
|
||||
.build();
|
||||
|
||||
// Use CompletableFuture for non-blocking operations
|
||||
CompletableFuture<PutObjectResponse> future = s3AsyncClient.putObject(request);
|
||||
future.thenAccept(response -> {
|
||||
// Handle success
|
||||
}).exceptionally(error -> {
|
||||
// Handle error
|
||||
return null;
|
||||
});
|
||||
```
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
### Enable SDK Metrics
|
||||
```java
|
||||
CloudWatchMetricPublisher publisher = CloudWatchMetricPublisher.create();
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.overrideConfiguration(b -> b
|
||||
.addMetricPublisher(publisher))
|
||||
.build();
|
||||
```
|
||||
|
||||
### CloudWatch Integration
|
||||
Configure CloudWatch metrics publisher to collect SDK metrics.
|
||||
|
||||
```java
|
||||
CloudWatchMetricPublisher cloudWatchPublisher = CloudWatchMetricPublisher.builder()
|
||||
.namespace("MyApplication")
|
||||
.credentialProvider(credentials)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Custom Metrics
|
||||
Implement custom metrics for application-specific monitoring.
|
||||
|
||||
```java
|
||||
public class CustomMetricPublisher implements MetricPublisher {
|
||||
@Override
|
||||
public void publish(MetricCollection metrics) {
|
||||
// Implement custom metrics logic
|
||||
metrics.forEach(metric -> {
|
||||
System.out.println("Metric: " + metric.name() + " = " + metric.value());
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Comprehensive Error Handling
|
||||
```java
|
||||
try {
|
||||
awsOperation();
|
||||
} catch (SdkServiceException e) {
|
||||
// Service-specific error
|
||||
System.err.println("AWS Service Error: " + e.awsErrorDetails().errorMessage());
|
||||
System.err.println("Error Code: " + e.awsErrorDetails().errorCode());
|
||||
System.err.println("Status Code: " + e.statusCode());
|
||||
System.err.println("Request ID: " + e.requestId());
|
||||
|
||||
} catch (SdkClientException e) {
|
||||
// Client-side error (network, timeout, etc.)
|
||||
System.err.println("Client Error: " + e.getMessage());
|
||||
|
||||
} catch (Exception e) {
|
||||
// Other errors
|
||||
System.err.println("Unexpected Error: " + e.getMessage());
|
||||
}
|
||||
```
|
||||
|
||||
### Retry Configuration
|
||||
```java
|
||||
RetryPolicy retryPolicy = RetryPolicy.builder()
|
||||
.numRetries(3)
|
||||
.retryCondition(RetryCondition.defaultRetryCondition())
|
||||
.backoffStrategy(BackoffStrategy.defaultStrategy())
|
||||
.build();
|
||||
```
|
||||
|
||||
## Testing Strategies
|
||||
|
||||
### Local Testing with LocalStack
|
||||
```java
|
||||
@TestConfiguration
|
||||
public class LocalStackConfig {
|
||||
@Bean
|
||||
public S3Client s3Client() {
|
||||
return S3Client.builder()
|
||||
.endpointOverride(URI.create("http://localhost:4566"))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create("test", "test")))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testcontainers Integration
|
||||
```java
|
||||
@Testcontainers
|
||||
@SpringBootTest
|
||||
public class AwsIntegrationTest {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(LocalStackContainer.Service.S3);
|
||||
|
||||
@DynamicPropertySource
|
||||
static void configProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("aws.endpoint", () -> localstack.getEndpointOverride(LocalStackContainer.Service.S3));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Templates
|
||||
|
||||
### High-Throughput Configuration
|
||||
```java
|
||||
ApacheHttpClient highThroughputClient = ApacheHttpClient.builder()
|
||||
.maxConnections(200)
|
||||
.connectionTimeout(Duration.ofSeconds(3))
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.connectionTimeToLive(Duration.ofMinutes(10))
|
||||
.build();
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.httpClient(highThroughputClient)
|
||||
.overrideConfiguration(b -> b
|
||||
.apiCallTimeout(Duration.ofSeconds(45))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(15)))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Low-Latency Configuration
|
||||
```java
|
||||
ApacheHttpClient lowLatencyClient = ApacheHttpClient.builder()
|
||||
.maxConnections(50)
|
||||
.connectionTimeout(Duration.ofSeconds(2))
|
||||
.socketTimeout(Duration.ofSeconds(10))
|
||||
.build();
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.httpClient(lowLatencyClient)
|
||||
.overrideConfiguration(b -> b
|
||||
.apiCallTimeout(Duration.ofSeconds(15))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(3)))
|
||||
.build();
|
||||
```
|
||||
@@ -0,0 +1,130 @@
|
||||
# AWS SDK for Java 2.x Developer Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The AWS SDK for Java 2.x provides a modern, type-safe API for AWS services. Built on Java 8+, it offers improved performance, better error handling, and enhanced security compared to v1.x.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Modern Architecture**: Built on Java 8+ with reactive and async support
|
||||
- **Type Safety**: Comprehensive type annotations and validation
|
||||
- **Performance Optimized**: Connection pooling, async support, and SSL optimization
|
||||
- **Enhanced Security**: Better credential management and security practices
|
||||
- **Extensive Coverage**: Support for all AWS services with regular updates
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### Service Clients
|
||||
The primary interface for interacting with AWS services. All clients implement the `SdkClient` interface.
|
||||
|
||||
```java
|
||||
// S3Client example
|
||||
S3Client s3 = S3Client.builder().region(Region.US_EAST_1).build();
|
||||
```
|
||||
|
||||
### Client Configuration
|
||||
Configure behavior through builders supporting:
|
||||
- Timeout settings
|
||||
- HTTP client selection
|
||||
- Authentication methods
|
||||
- Monitoring and metrics
|
||||
|
||||
### Credential Providers
|
||||
Multiple authentication methods:
|
||||
- Environment variables
|
||||
- System properties
|
||||
- Shared credential files
|
||||
- IAM roles
|
||||
- SSO integration
|
||||
|
||||
### HTTP Clients
|
||||
Choose from three HTTP implementations:
|
||||
- Apache HttpClient (synchronous)
|
||||
- Netty NIO Client (asynchronous)
|
||||
- URL Connection Client (lightweight)
|
||||
|
||||
## Migration from v1.x
|
||||
|
||||
The SDK 2.x is not backward compatible with v1.x. Key changes:
|
||||
- Builder pattern for client creation
|
||||
- Different package structure
|
||||
- Enhanced error handling
|
||||
- New credential system
|
||||
- Improved resource management
|
||||
|
||||
## Getting Started
|
||||
|
||||
Include the BOM (Bill of Materials) for version management:
|
||||
|
||||
```xml
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bom</artifactId>
|
||||
<version>2.25.0</version> // Use latest stable version
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
```
|
||||
|
||||
Add service-specific dependencies:
|
||||
|
||||
```xml
|
||||
<dependencies>
|
||||
<!-- S3 Service -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>s3</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Core SDK -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>sdk-core</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
AWS Service Client
|
||||
├── Configuration Layer
|
||||
│ ├── Client Override Configuration
|
||||
│ └── HTTP Client Configuration
|
||||
├── Authentication Layer
|
||||
│ ├── Credential Providers
|
||||
│ └── Security Context
|
||||
├── Transport Layer
|
||||
│ ├── HTTP Client (Apache/Netty/URLConn)
|
||||
│ └── Connection Pool
|
||||
└── Protocol Layer
|
||||
├── Service Protocol Implementation
|
||||
└── Error Handling
|
||||
```
|
||||
|
||||
## Service Discovery
|
||||
|
||||
The SDK automatically discovers and registers all available AWS services through service interfaces and paginators.
|
||||
|
||||
### Available Services
|
||||
|
||||
All AWS services are available through dedicated client interfaces:
|
||||
- S3 (Simple Storage Service)
|
||||
- DynamoDB (NoSQL Database)
|
||||
- Lambda (Serverless Functions)
|
||||
- EC2 (Compute Cloud)
|
||||
- RDS (Managed Databases)
|
||||
- And 200+ other services
|
||||
|
||||
For a complete list, see the AWS Service documentation.
|
||||
|
||||
## Support and Community
|
||||
|
||||
- **GitHub Issues**: Report bugs and request features
|
||||
- **AWS Amplify**: For mobile app developers
|
||||
- **Migration Guide**: Available for v1.x users
|
||||
- **Changelog**: Track changes on GitHub
|
||||
392
skills/aws-java/aws-sdk-java-v2-dynamodb/SKILL.md
Normal file
392
skills/aws-java/aws-sdk-java-v2-dynamodb/SKILL.md
Normal file
@@ -0,0 +1,392 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-dynamodb
|
||||
description: Amazon DynamoDB patterns using AWS SDK for Java 2.x. Use when creating, querying, scanning, or performing CRUD operations on DynamoDB tables, working with indexes, batch operations, transactions, or integrating with Spring Boot applications.
|
||||
category: aws
|
||||
tags: [aws, dynamodb, java, sdk, nosql, database]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - Amazon DynamoDB
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Creating, updating, or deleting DynamoDB tables
|
||||
- Performing CRUD operations on DynamoDB items
|
||||
- Querying or scanning tables
|
||||
- Working with Global Secondary Indexes (GSI) or Local Secondary Indexes (LSI)
|
||||
- Implementing batch operations for efficiency
|
||||
- Using DynamoDB transactions
|
||||
- Integrating DynamoDB with Spring Boot applications
|
||||
- Working with DynamoDB Enhanced Client for type-safe operations
|
||||
|
||||
## Dependencies
|
||||
|
||||
Add to `pom.xml`:
|
||||
```xml
|
||||
<!-- Low-level DynamoDB client -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>dynamodb</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Enhanced client (recommended) -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>dynamodb-enhanced</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
## Client Setup
|
||||
|
||||
### Low-Level Client
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
|
||||
|
||||
DynamoDbClient dynamoDb = DynamoDbClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Enhanced Client (Recommended)
|
||||
```java
|
||||
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedClient;
|
||||
|
||||
DynamoDbEnhancedClient enhancedClient = DynamoDbEnhancedClient.builder()
|
||||
.dynamoDbClient(dynamoDb)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Entity Mapping
|
||||
|
||||
To define DynamoDB entities, use `@DynamoDbBean` annotation:
|
||||
|
||||
```java
|
||||
@DynamoDbBean
|
||||
public class Customer {
|
||||
|
||||
@DynamoDbPartitionKey
|
||||
private String customerId;
|
||||
|
||||
@DynamoDbAttribute("customer_name")
|
||||
private String name;
|
||||
|
||||
private String email;
|
||||
|
||||
@DynamoDbSortKey
|
||||
private String orderId;
|
||||
|
||||
// Getters and setters
|
||||
}
|
||||
```
|
||||
|
||||
For complex entity mapping with GSIs and custom converters, see [Entity Mapping Reference](references/entity-mapping.md).
|
||||
|
||||
## CRUD Operations
|
||||
|
||||
### Basic Operations
|
||||
```java
|
||||
// Create or update item
|
||||
DynamoDbTable<Customer> table = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
|
||||
table.putItem(customer);
|
||||
|
||||
// Get item
|
||||
Customer result = table.getItem(Key.builder().partitionValue(customerId).build());
|
||||
|
||||
// Update item
|
||||
return table.updateItem(customer);
|
||||
|
||||
// Delete item
|
||||
table.deleteItem(Key.builder().partitionValue(customerId).build());
|
||||
```
|
||||
|
||||
### Composite Key Operations
|
||||
```java
|
||||
// Get item with composite key
|
||||
Order order = table.getItem(Key.builder()
|
||||
.partitionValue(customerId)
|
||||
.sortValue(orderId)
|
||||
.build());
|
||||
```
|
||||
|
||||
## Query Operations
|
||||
|
||||
### Basic Query
|
||||
```java
|
||||
import software.amazon.awssdk.enhanced.dynamodb.model.QueryConditional;
|
||||
|
||||
QueryConditional queryConditional = QueryConditional
|
||||
.keyEqualTo(Key.builder()
|
||||
.partitionValue(customerId)
|
||||
.build());
|
||||
|
||||
List<Order> orders = table.query(queryConditional).items().stream()
|
||||
.collect(Collectors.toList());
|
||||
```
|
||||
|
||||
### Advanced Query with Filters
|
||||
```java
|
||||
import software.amazon.awssdk.enhanced.dynamodb.Expression;
|
||||
|
||||
Expression filter = Expression.builder()
|
||||
.expression("status = :pending")
|
||||
.putExpressionValue(":pending", AttributeValue.builder().s("PENDING").build())
|
||||
.build();
|
||||
|
||||
List<Order> pendingOrders = table.query(r -> r
|
||||
.queryConditional(queryConditional)
|
||||
.filterExpression(filter))
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
```
|
||||
|
||||
For detailed query patterns, see [Advanced Operations Reference](references/advanced-operations.md).
|
||||
|
||||
## Scan Operations
|
||||
|
||||
```java
|
||||
// Scan all items
|
||||
List<Customer> allCustomers = table.scan().items().stream()
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// Scan with filter
|
||||
Expression filter = Expression.builder()
|
||||
.expression("points >= :minPoints")
|
||||
.putExpressionValue(":minPoints", AttributeValue.builder().n("1000").build())
|
||||
.build();
|
||||
|
||||
List<Customer> vipCustomers = table.scan(r -> r.filterExpression(filter))
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
```
|
||||
|
||||
## Batch Operations
|
||||
|
||||
### Batch Get
|
||||
```java
|
||||
import software.amazon.awssdk.enhanced.dynamodb.model.*;
|
||||
|
||||
List<Key> keys = customerIds.stream()
|
||||
.map(id -> Key.builder().partitionValue(id).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
ReadBatch.Builder<Customer> batchBuilder = ReadBatch.builder(Customer.class)
|
||||
.mappedTableResource(table);
|
||||
|
||||
keys.forEach(batchBuilder::addGetItem);
|
||||
|
||||
BatchGetResultPageIterable result = enhancedClient.batchGetItem(r ->
|
||||
r.addReadBatch(batchBuilder.build()));
|
||||
|
||||
List<Customer> customers = result.resultsForTable(table).stream()
|
||||
.collect(Collectors.toList());
|
||||
```
|
||||
|
||||
### Batch Write
|
||||
```java
|
||||
WriteBatch.Builder<Customer> batchBuilder = WriteBatch.builder(Customer.class)
|
||||
.mappedTableResource(table);
|
||||
|
||||
customers.forEach(batchBuilder::addPutItem);
|
||||
|
||||
enhancedClient.batchWriteItem(r -> r.addWriteBatch(batchBuilder.build()));
|
||||
```
|
||||
|
||||
## Transactions
|
||||
|
||||
### Transactional Write
|
||||
```java
|
||||
enhancedClient.transactWriteItems(r -> r
|
||||
.addPutItem(customerTable, customer)
|
||||
.addPutItem(orderTable, order));
|
||||
```
|
||||
|
||||
### Transactional Read
|
||||
```java
|
||||
TransactGetItemsEnhancedRequest request = TransactGetItemsEnhancedRequest.builder()
|
||||
.addGetItem(customerTable, customerKey)
|
||||
.addGetItem(orderTable, orderKey)
|
||||
.build();
|
||||
|
||||
List<Document> results = enhancedClient.transactGetItems(request);
|
||||
```
|
||||
|
||||
## Spring Boot Integration
|
||||
|
||||
### Configuration
|
||||
```java
|
||||
@Configuration
|
||||
public class DynamoDbConfiguration {
|
||||
|
||||
@Bean
|
||||
public DynamoDbClient dynamoDbClient() {
|
||||
return DynamoDbClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public DynamoDbEnhancedClient dynamoDbEnhancedClient(DynamoDbClient dynamoDbClient) {
|
||||
return DynamoDbEnhancedClient.builder()
|
||||
.dynamoDbClient(dynamoDbClient)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Repository Pattern
|
||||
```java
|
||||
@Repository
|
||||
public class CustomerRepository {
|
||||
|
||||
private final DynamoDbTable<Customer> customerTable;
|
||||
|
||||
public CustomerRepository(DynamoDbEnhancedClient enhancedClient) {
|
||||
this.customerTable = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
|
||||
}
|
||||
|
||||
public void save(Customer customer) {
|
||||
customerTable.putItem(customer);
|
||||
}
|
||||
|
||||
public Optional<Customer> findById(String customerId) {
|
||||
Key key = Key.builder().partitionValue(customerId).build();
|
||||
return Optional.ofNullable(customerTable.getItem(key));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For comprehensive Spring Boot integration patterns, see [Spring Boot Integration Reference](references/spring-boot-integration.md).
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Testing with Mocks
|
||||
```java
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class CustomerServiceTest {
|
||||
|
||||
@Mock
|
||||
private DynamoDbClient dynamoDbClient;
|
||||
|
||||
@Mock
|
||||
private DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
@Mock
|
||||
private DynamoDbTable<Customer> customerTable;
|
||||
|
||||
@InjectMocks
|
||||
private CustomerService customerService;
|
||||
|
||||
@Test
|
||||
void saveCustomer_ShouldReturnSavedCustomer() {
|
||||
// Arrange
|
||||
when(enhancedClient.table(anyString(), any(TableSchema.class)))
|
||||
.thenReturn(customerTable);
|
||||
|
||||
Customer customer = new Customer("123", "John Doe", "john@example.com");
|
||||
|
||||
// Act
|
||||
Customer result = customerService.saveCustomer(customer);
|
||||
|
||||
// Assert
|
||||
assertNotNull(result);
|
||||
verify(customerTable).putItem(customer);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Testing with LocalStack
|
||||
```java
|
||||
@Testcontainers
|
||||
@SpringBootTest
|
||||
class DynamoDbIntegrationTest {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(LocalStackContainer.Service.DYNAMODB);
|
||||
|
||||
@DynamicPropertySource
|
||||
static void configureProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("aws.endpoint",
|
||||
() -> localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB).toString());
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
@Test
|
||||
void testCustomerCRUDOperations() {
|
||||
// Test implementation
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For detailed testing strategies, see [Testing Strategies](references/testing-strategies.md).
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Enhanced Client**: Provides type-safe operations with less boilerplate
|
||||
2. **Design partition keys carefully**: Distribute data evenly across partitions
|
||||
3. **Use composite keys**: Leverage sort keys for efficient queries
|
||||
4. **Create GSIs strategically**: Support different access patterns
|
||||
5. **Use batch operations**: Reduce API calls for multiple items
|
||||
6. **Implement pagination**: For large result sets use pagination
|
||||
7. **Use transactions**: For operations that must be atomic
|
||||
8. **Avoid scans**: Prefer queries with proper indexes
|
||||
9. **Handle conditional writes**: Prevent race conditions
|
||||
10. **Use proper error handling**: Handle exceptions like `ProvisionedThroughputExceeded`
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Conditional Operations
|
||||
```java
|
||||
PutItemEnhancedRequest request = PutItemEnhancedRequest.builder(table)
|
||||
.item(customer)
|
||||
.conditionExpression("attribute_not_exists(customerId)")
|
||||
.build();
|
||||
|
||||
table.putItemWithRequestBuilder(request);
|
||||
```
|
||||
|
||||
### Pagination
|
||||
```java
|
||||
ScanEnhancedRequest request = ScanEnhancedRequest.builder()
|
||||
.limit(100)
|
||||
.build();
|
||||
|
||||
PaginatedScanIterable<Customer> results = table.scan(request);
|
||||
results.stream().forEach(page -> {
|
||||
// Process each page
|
||||
});
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Monitor read/write capacity units
|
||||
- Implement exponential backoff for retries
|
||||
- Use proper pagination for large datasets
|
||||
- Consider eventual consistency for reads
|
||||
- Use `ReturnConsumedCapacity` to monitor capacity usage
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `aws-sdk-java-v2-core `- Core AWS SDK patterns
|
||||
- `spring-data-jpa` - Alternative data access patterns
|
||||
- `unit-test-service-layer` - Service testing patterns
|
||||
- `unit-test-wiremock-rest-api` - Testing external APIs
|
||||
|
||||
## References
|
||||
|
||||
- [AWS DynamoDB Documentation](https://docs.aws.amazon.com/dynamodb/)
|
||||
- [AWS SDK for Java Documentation](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/)
|
||||
- [DynamoDB Examples](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/dynamodb)
|
||||
- [LocalStack for Testing](https://docs.localstack.cloud/user-guide/aws/)
|
||||
|
||||
For detailed implementations, see the references folder:
|
||||
- [Entity Mapping Reference](references/entity-mapping.md)
|
||||
- [Advanced Operations Reference](references/advanced-operations.md)
|
||||
- [Spring Boot Integration Reference](references/spring-boot-integration.md)
|
||||
- [Testing Strategies](references/testing-strategies.md)
|
||||
@@ -0,0 +1,195 @@
|
||||
# Advanced Operations Reference
|
||||
|
||||
This document covers advanced DynamoDB operations and patterns.
|
||||
|
||||
## Query Operations
|
||||
|
||||
### Key Conditions
|
||||
|
||||
#### Key.equalTo()
|
||||
```java
|
||||
QueryConditional equalTo = QueryConditional
|
||||
.keyEqualTo(Key.builder()
|
||||
.partitionValue("customer123")
|
||||
.build());
|
||||
```
|
||||
|
||||
#### Key.between()
|
||||
```java
|
||||
QueryConditional between = QueryConditional
|
||||
.sortBetween(
|
||||
Key.builder().partitionValue("customer123").sortValue("2023-01-01").build(),
|
||||
Key.builder().partitionValue("customer123").sortValue("2023-12-31").build());
|
||||
```
|
||||
|
||||
#### Key.beginsWith()
|
||||
```java
|
||||
QueryConditional beginsWith = QueryConditional
|
||||
.sortKeyBeginsWith(Key.builder()
|
||||
.partitionValue("customer123")
|
||||
.sortValue("2023-")
|
||||
.build());
|
||||
```
|
||||
|
||||
### Filter Expressions
|
||||
|
||||
```java
|
||||
Expression filter = Expression.builder()
|
||||
.expression("points >= :minPoints AND status = :status")
|
||||
.putExpressionName("#p", "points")
|
||||
.putExpressionName("#s", "status")
|
||||
.putExpressionValue(":minPoints", AttributeValue.builder().n("1000").build())
|
||||
.putExpressionValue(":status", AttributeValue.builder().s("ACTIVE").build())
|
||||
.build();
|
||||
```
|
||||
|
||||
### Projection Expressions
|
||||
|
||||
```java
|
||||
Expression projection = Expression.builder()
|
||||
.expression("customerId, name, email")
|
||||
.putExpressionName("#c", "customerId")
|
||||
.putExpressionName("#n", "name")
|
||||
.putExpressionName("#e", "email")
|
||||
.build();
|
||||
```
|
||||
|
||||
## Scan Operations
|
||||
|
||||
### Pagination
|
||||
```java
|
||||
ScanEnhancedRequest request = ScanEnhancedRequest.builder()
|
||||
.limit(100)
|
||||
.build();
|
||||
|
||||
PaginatedScanIterable<Customer> results = table.scan(request);
|
||||
results.stream().forEach(page -> {
|
||||
// Process each page of results
|
||||
});
|
||||
```
|
||||
|
||||
### Conditional Scan
|
||||
```java
|
||||
Expression filter = Expression.builder()
|
||||
.expression("active = :active")
|
||||
.putExpressionValue(":active", AttributeValue.builder().bool(true).build())
|
||||
.build();
|
||||
|
||||
return table.scan(r -> r
|
||||
.filterExpression(filter)
|
||||
.limit(50))
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
```
|
||||
|
||||
## Batch Operations
|
||||
|
||||
### Batch Get with Unprocessed Keys
|
||||
```java
|
||||
List<Key> keys = customerIds.stream()
|
||||
.map(id -> Key.builder().partitionValue(id).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
ReadBatch.Builder<Customer> batchBuilder = ReadBatch.builder(Customer.class)
|
||||
.mappedTableResource(table);
|
||||
|
||||
keys.forEach(batchBuilder::addGetItem);
|
||||
|
||||
BatchGetResultPageIterable result = enhancedClient.batchGetItem(r ->
|
||||
r.addReadBatch(batchBuilder.build()));
|
||||
|
||||
// Handle unprocessed keys
|
||||
result.stream()
|
||||
.flatMap(page -> page.unprocessedKeys().entrySet().stream())
|
||||
.forEach(entry -> {
|
||||
// Retry logic for unprocessed keys
|
||||
});
|
||||
```
|
||||
|
||||
### Batch Write with Different Operations
|
||||
```java
|
||||
WriteBatch.Builder<Customer> batchBuilder = WriteBatch.builder(Customer.class)
|
||||
.mappedTableResource(table);
|
||||
|
||||
batchBuilder.addPutItem(customer1);
|
||||
batchBuilder.addDeleteItem(customer2);
|
||||
batchBuilder.addPutItem(customer3);
|
||||
|
||||
enhancedClient.batchWriteItem(r -> r.addWriteBatch(batchBuilder.build()));
|
||||
```
|
||||
|
||||
## Transactions
|
||||
|
||||
### Conditional Writes
|
||||
```java
|
||||
PutItemEnhancedRequest putRequest = PutItemEnhancedRequest.builder(table)
|
||||
.item(customer)
|
||||
.conditionExpression("attribute_not_exists(customerId)")
|
||||
.build();
|
||||
|
||||
table.putItemWithRequestBuilder(putRequest);
|
||||
```
|
||||
|
||||
### Multiple Table Operations
|
||||
```java
|
||||
TransactWriteItemsEnhancedRequest request = TransactWriteItemsEnhancedRequest.builder()
|
||||
.addPutItem(customerTable, customer)
|
||||
.addPutItem(orderTable, order)
|
||||
.addUpdateItem(productTable, product)
|
||||
.addDeleteItem(cartTable, cartKey)
|
||||
.build();
|
||||
|
||||
enhancedClient.transactWriteItems(request);
|
||||
```
|
||||
|
||||
## Conditional Operations
|
||||
|
||||
### Condition Expressions
|
||||
```java
|
||||
// Check if attribute exists
|
||||
.setAttribute("conditionExpression", "attribute_not_exists(customerId)")
|
||||
|
||||
// Check attribute values
|
||||
.setAttribute("conditionExpression", "points > :currentPoints")
|
||||
.setAttribute("expressionAttributeValues", Map.of(
|
||||
":currentPoints", AttributeValue.builder().n("500").build()))
|
||||
|
||||
// Multiple conditions
|
||||
.setAttribute("conditionExpression", "points > :min AND active = :active")
|
||||
.setAttribute("expressionAttributeValues", Map.of(
|
||||
":min", AttributeValue.builder().n("100").build(),
|
||||
":active", AttributeValue.builder().bool(true).build()))
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Provisioned Throughput Exceeded
|
||||
```java
|
||||
try {
|
||||
table.putItem(customer);
|
||||
} catch (TransactionCanceledException e) {
|
||||
// Handle transaction cancellation
|
||||
} catch (ConditionalCheckFailedException e) {
|
||||
// Handle conditional check failure
|
||||
} catch (ResourceNotFoundException e) {
|
||||
// Handle table not found
|
||||
} catch (DynamoDbException e) {
|
||||
// Handle other DynamoDB exceptions
|
||||
}
|
||||
```
|
||||
|
||||
### Exponential Backoff for Retry
|
||||
```java
|
||||
int maxRetries = 3;
|
||||
long baseDelay = 1000; // 1 second
|
||||
|
||||
for (int attempt = 0; attempt < maxRetries; attempt++) {
|
||||
try {
|
||||
operation();
|
||||
break;
|
||||
} catch (ProvisionedThroughputExceededException e) {
|
||||
long delay = baseDelay * (1 << attempt);
|
||||
Thread.sleep(delay);
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,120 @@
|
||||
# Entity Mapping Reference
|
||||
|
||||
This document provides detailed information about entity mapping in DynamoDB Enhanced Client.
|
||||
|
||||
## @DynamoDbBean Annotation
|
||||
|
||||
The `@DynamoDbBean` annotation marks a class as a DynamoDB entity:
|
||||
|
||||
```java
|
||||
@DynamoDbBean
|
||||
public class Customer {
|
||||
// Class implementation
|
||||
}
|
||||
```
|
||||
|
||||
## Field Annotations
|
||||
|
||||
### @DynamoDbPartitionKey
|
||||
Marks a field as the partition key:
|
||||
|
||||
```java
|
||||
@DynamoDbPartitionKey
|
||||
public String getCustomerId() {
|
||||
return customerId;
|
||||
}
|
||||
```
|
||||
|
||||
### @DynamoDbSortKey
|
||||
Marks a field as the sort key (used with composite keys):
|
||||
|
||||
```java
|
||||
@DynamoDbSortKey
|
||||
@DynamoDbAttribute("order_id")
|
||||
public String getOrderId() {
|
||||
return orderId;
|
||||
}
|
||||
```
|
||||
|
||||
### @DynamoDbAttribute
|
||||
Maps a field to a DynamoDB attribute with custom name:
|
||||
|
||||
```java
|
||||
@DynamoDbAttribute("customer_name")
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
```
|
||||
|
||||
### @DynamoDbSecondaryPartitionKey
|
||||
Marks a field as a partition key for a Global Secondary Index:
|
||||
|
||||
```java
|
||||
@DynamoDbSecondaryPartitionKey(indexNames = "category-index")
|
||||
public String getCategory() {
|
||||
return category;
|
||||
}
|
||||
```
|
||||
|
||||
### @DynamoDbSecondarySortKey
|
||||
Marks a field as a sort key for a Global Secondary Index:
|
||||
|
||||
```java
|
||||
@DynamoDbSecondarySortKey(indexNames = "category-index")
|
||||
public BigDecimal getPrice() {
|
||||
return price;
|
||||
}
|
||||
```
|
||||
|
||||
### @DynamoDbConvertedBy
|
||||
Custom attribute conversion:
|
||||
|
||||
```java
|
||||
@DynamoDbConvertedBy(LocalDateTimeConverter.class)
|
||||
public LocalDateTime getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
```
|
||||
|
||||
## Supported Data Types
|
||||
|
||||
The enhanced client automatically handles the following data types:
|
||||
|
||||
- String → S (String)
|
||||
- Integer, Long → N (Number)
|
||||
- BigDecimal → N (Number)
|
||||
- Boolean → BOOL
|
||||
- LocalDateTime → S (ISO-8601 format)
|
||||
- LocalDate → S (ISO-8601 format)
|
||||
- UUID → S (String)
|
||||
- Enum → S (String representation)
|
||||
- Custom types with converters
|
||||
|
||||
## Custom Converters
|
||||
|
||||
Create custom converters for complex data types:
|
||||
|
||||
```java
|
||||
public class LocalDateTimeConverter extends AttributeConverter<LocalDateTime, String> {
|
||||
|
||||
@Override
|
||||
public String transformFrom(LocalDateTime input) {
|
||||
return input.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public LocalDateTime transformTo(String input) {
|
||||
return LocalDateTime.parse(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AttributeValue transformToAttributeValue(String input) {
|
||||
return AttributeValue.builder().s(input).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String transformFromAttributeValue(AttributeValue attributeValue) {
|
||||
return attributeValue.s();
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,377 @@
|
||||
# Spring Boot Integration Reference
|
||||
|
||||
This document provides detailed information about integrating DynamoDB with Spring Boot applications.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
```java
|
||||
@Configuration
|
||||
public class DynamoDbConfiguration {
|
||||
|
||||
@Bean
|
||||
@Profile("local")
|
||||
public DynamoDbClient dynamoDbClient() {
|
||||
return DynamoDbClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Profile("prod")
|
||||
public DynamoDbClient dynamoDbClientProd(
|
||||
@Value("${aws.region}") String region,
|
||||
@Value("${aws.accessKeyId}") String accessKeyId,
|
||||
@Value("${aws.secretAccessKey}") String secretAccessKey) {
|
||||
|
||||
return DynamoDbClient.builder()
|
||||
.region(Region.of(region))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(accessKeyId, secretAccessKey)))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public DynamoDbEnhancedClient dynamoDbEnhancedClient(DynamoDbClient dynamoDbClient) {
|
||||
return DynamoDbEnhancedClient.builder()
|
||||
.dynamoDbClient(dynamoDbClient)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Properties Configuration
|
||||
`application-local.properties`:
|
||||
```properties
|
||||
aws.region=us-east-1
|
||||
```
|
||||
|
||||
`application-prod.properties`:
|
||||
```properties
|
||||
aws.region=us-east-1
|
||||
aws.accessKeyId=${AWS_ACCESS_KEY_ID}
|
||||
aws.secretAccessKey=${AWS_SECRET_ACCESS_KEY}
|
||||
```
|
||||
|
||||
## Repository Pattern Implementation
|
||||
|
||||
### Base Repository Interface
|
||||
```java
|
||||
public interface DynamoDbRepository<T> {
|
||||
void save(T entity);
|
||||
Optional<T> findById(Object partitionKey);
|
||||
Optional<T> findById(Object partitionKey, Object sortKey);
|
||||
void delete(Object partitionKey);
|
||||
void delete(Object partitionKey, Object sortKey);
|
||||
List<T> findAll();
|
||||
List<T> findAll(int limit);
|
||||
boolean existsById(Object partitionKey);
|
||||
boolean existsById(Object partitionKey, Object sortKey);
|
||||
}
|
||||
|
||||
public interface CustomerRepository extends DynamoDbRepository<Customer> {
|
||||
List<Customer> findByEmail(String email);
|
||||
List<Customer> findByPointsGreaterThan(Integer minPoints);
|
||||
}
|
||||
```
|
||||
|
||||
### Generic Repository Implementation
|
||||
```java
|
||||
@Repository
|
||||
public class GenericDynamoDbRepository<T> implements DynamoDbRepository<T> {
|
||||
|
||||
private final DynamoDbTable<T> table;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public GenericDynamoDbRepository(DynamoDbEnhancedClient enhancedClient,
|
||||
Class<T> entityClass,
|
||||
String tableName) {
|
||||
this.table = enhancedClient.table(tableName, TableSchema.fromBean(entityClass));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(T entity) {
|
||||
table.putItem(entity);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<T> findById(Object partitionKey) {
|
||||
Key key = Key.builder().partitionValue(partitionKey).build();
|
||||
return Optional.ofNullable(table.getItem(key));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<T> findById(Object partitionKey, Object sortKey) {
|
||||
Key key = Key.builder()
|
||||
.partitionValue(partitionKey)
|
||||
.sortValue(sortKey)
|
||||
.build();
|
||||
return Optional.ofNullable(table.getItem(key));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(Object partitionKey) {
|
||||
Key key = Key.builder().partitionValue(partitionKey).build();
|
||||
table.deleteItem(key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<T> findAll() {
|
||||
return table.scan().items().stream()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<T> findAll(int limit) {
|
||||
return table.scan(ScanEnhancedRequest.builder().limit(limit).build())
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Specific Repository Implementation
|
||||
```java
|
||||
@Repository
|
||||
public class CustomerRepositoryImpl implements CustomerRepository {
|
||||
|
||||
private final DynamoDbTable<Customer> customerTable;
|
||||
|
||||
public CustomerRepositoryImpl(DynamoDbEnhancedClient enhancedClient) {
|
||||
this.customerTable = enhancedClient.table(
|
||||
"Customers",
|
||||
TableSchema.fromBean(Customer.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Customer> findByEmail(String email) {
|
||||
Expression filter = Expression.builder()
|
||||
.expression("email = :email")
|
||||
.putExpressionValue(":email", AttributeValue.builder().s(email).build())
|
||||
.build();
|
||||
|
||||
return customerTable.scan(r -> r.filterExpression(filter))
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Customer> findByPointsGreaterThan(Integer minPoints) {
|
||||
Expression filter = Expression.builder()
|
||||
.expression("points >= :minPoints")
|
||||
.putExpressionValue(":minPoints", AttributeValue.builder().n(minPoints.toString()).build())
|
||||
.build();
|
||||
|
||||
return customerTable.scan(r -> r.filterExpression(filter))
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Service Layer Implementation
|
||||
|
||||
### Service with Transaction Management
|
||||
```java
|
||||
@Service
|
||||
@Transactional
|
||||
public class CustomerService {
|
||||
|
||||
private final CustomerRepository customerRepository;
|
||||
private final OrderRepository orderRepository;
|
||||
private final DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
public CustomerService(CustomerRepository customerRepository,
|
||||
OrderRepository orderRepository,
|
||||
DynamoDbEnhancedClient enhancedClient) {
|
||||
this.customerRepository = customerRepository;
|
||||
this.orderRepository = orderRepository;
|
||||
this.enhancedClient = enhancedClient;
|
||||
}
|
||||
|
||||
public void createCustomerWithOrder(Customer customer, Order order) {
|
||||
// Use transaction for atomic operation
|
||||
enhancedClient.transactWriteItems(r -> r
|
||||
.addPutItem(getCustomerTable(), customer)
|
||||
.addPutItem(getOrderTable(), order));
|
||||
}
|
||||
|
||||
private DynamoDbTable<Customer> getCustomerTable() {
|
||||
return enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
|
||||
}
|
||||
|
||||
private DynamoDbTable<Order> getOrderTable() {
|
||||
return enhancedClient.table("Orders", TableSchema.fromBean(Order.class));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Async Operations
|
||||
```java
|
||||
@Service
|
||||
public class AsyncCustomerService {
|
||||
|
||||
private final DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
public CompletableFuture<Void> saveCustomerAsync(Customer customer) {
|
||||
return CompletableFuture.runAsync(() -> {
|
||||
DynamoDbTable<Customer> table = enhancedClient.table(
|
||||
"Customers",
|
||||
TableSchema.fromBean(Customer.class));
|
||||
table.putItem(customer);
|
||||
});
|
||||
}
|
||||
|
||||
public CompletableFuture<List<Customer>> findCustomersByPointsAsync(Integer minPoints) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
Expression filter = Expression.builder()
|
||||
.expression("points >= :minPoints")
|
||||
.putExpressionValue(":minPoints", AttributeValue.builder().n(minPoints.toString()).build())
|
||||
.build();
|
||||
|
||||
DynamoDbTable<Customer> table = enhancedClient.table(
|
||||
"Customers",
|
||||
TableSchema.fromBean(Customer.class));
|
||||
|
||||
return table.scan(r -> r.filterExpression(filter))
|
||||
.items().stream()
|
||||
.collect(Collectors.toList());
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing with LocalStack
|
||||
|
||||
### Test Configuration
|
||||
```java
|
||||
@TestConfiguration
|
||||
@ContextConfiguration(classes = {LocalStackDynamoDbConfig.class})
|
||||
public class DynamoDbTestConfig {
|
||||
|
||||
@Bean
|
||||
public DynamoDbClient dynamoDbClient() {
|
||||
return LocalStackDynamoDbConfig.dynamoDbClient();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public DynamoDbEnhancedClient dynamoDbEnhancedClient() {
|
||||
return DynamoDbEnhancedClient.builder()
|
||||
.dynamoDbClient(dynamoDbClient())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@SpringBootTest(classes = {DynamoDbTestConfig.class})
|
||||
@Import(DynamoDbTestConfig.class)
|
||||
public class CustomerRepositoryIntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
// Clean up test data
|
||||
clearTestData();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCustomerOperations() {
|
||||
// Test implementation
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### LocalStack Container Setup
|
||||
```java
|
||||
public class LocalStackDynamoDbConfig {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(LocalStackContainer.Service.DYNAMODB);
|
||||
|
||||
@Bean
|
||||
@DynamicPropertySource
|
||||
public static void configureProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("aws.region", () -> Region.US_EAST_1.toString());
|
||||
registry.add("aws.accessKeyId", () -> localstack.getAccessKey());
|
||||
registry.add("aws.secretAccessKey", () -> localstack.getSecretKey());
|
||||
registry.add("aws.endpoint",
|
||||
() -> localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB).toString());
|
||||
}
|
||||
|
||||
@Bean
|
||||
public DynamoDbClient dynamoDbClient(
|
||||
@Value("${aws.region}") String region,
|
||||
@Value("${aws.accessKeyId}") String accessKeyId,
|
||||
@Value("${aws.secretAccessKey}") String secretAccessKey,
|
||||
@Value("${aws.endpoint}") String endpoint) {
|
||||
|
||||
return DynamoDbClient.builder()
|
||||
.region(Region.of(region))
|
||||
.endpointOverride(URI.create(endpoint))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(accessKeyId, secretAccessKey)))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Health Check Integration
|
||||
|
||||
### Custom Health Indicator
|
||||
```java
|
||||
@Component
|
||||
public class DynamoDbHealthIndicator implements HealthIndicator {
|
||||
|
||||
private final DynamoDbClient dynamoDbClient;
|
||||
|
||||
public DynamoDbHealthIndicator(DynamoDbClient dynamoDbClient) {
|
||||
this.dynamoDbClient = dynamoDbClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Health health() {
|
||||
try {
|
||||
dynamoDbClient.listTables();
|
||||
return Health.up()
|
||||
.withDetail("region", dynamoDbClient.serviceClientConfiguration().region())
|
||||
.build();
|
||||
} catch (Exception e) {
|
||||
return Health.down()
|
||||
.withException(e)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Metrics Collection
|
||||
|
||||
### Micrometer Integration
|
||||
```java
|
||||
@Component
|
||||
public class DynamoDbMetricsCollector {
|
||||
|
||||
private final DynamoDbClient dynamoDbClient;
|
||||
private final MeterRegistry meterRegistry;
|
||||
|
||||
@EventListener
|
||||
public void handleDynamoDbOperation(DynamoDbOperationEvent event) {
|
||||
Timer.Sample sample = Timer.start();
|
||||
sample.stop(Timer.builder("dynamodb.operation")
|
||||
.tag("operation", event.getOperation())
|
||||
.tag("table", event.getTable())
|
||||
.register(meterRegistry));
|
||||
}
|
||||
}
|
||||
|
||||
public class DynamoDbOperationEvent {
|
||||
private String operation;
|
||||
private String table;
|
||||
private long duration;
|
||||
|
||||
// Getters and setters
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,407 @@
|
||||
# Testing Strategies for DynamoDB
|
||||
|
||||
This document provides comprehensive testing strategies for DynamoDB applications using the AWS SDK for Java 2.x.
|
||||
|
||||
## Unit Testing with Mocks
|
||||
|
||||
### Mocking DynamoDbClient
|
||||
```java
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class CustomerServiceTest {
|
||||
|
||||
@Mock
|
||||
private DynamoDbClient dynamoDbClient;
|
||||
|
||||
@Mock
|
||||
private DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
@Mock
|
||||
private DynamoDbTable<Customer> customerTable;
|
||||
|
||||
@InjectMocks
|
||||
private CustomerService customerService;
|
||||
|
||||
@Test
|
||||
void saveCustomer_ShouldReturnSavedCustomer() {
|
||||
// Arrange
|
||||
Customer customer = new Customer("123", "John Doe", "john@example.com");
|
||||
|
||||
when(enhancedClient.table(anyString(), any(TableSchema.class)))
|
||||
.thenReturn(customerTable);
|
||||
|
||||
when(customerTable.putItem(customer))
|
||||
.thenReturn(null);
|
||||
|
||||
// Act
|
||||
Customer result = customerService.saveCustomer(customer);
|
||||
|
||||
// Assert
|
||||
assertNotNull(result);
|
||||
assertEquals("123", result.getCustomerId());
|
||||
verify(customerTable).putItem(customer);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getCustomer_NotFound_ShouldReturnEmpty() {
|
||||
// Arrange
|
||||
when(enhancedClient.table(anyString(), any(TableSchema.class)))
|
||||
.thenReturn(customerTable);
|
||||
|
||||
when(customerTable.getItem(any(Key.class)))
|
||||
.thenReturn(null);
|
||||
|
||||
// Act
|
||||
Optional<Customer> result = customerService.getCustomer("123");
|
||||
|
||||
// Assert
|
||||
assertFalse(result.isPresent());
|
||||
verify(customerTable).getItem(any(Key.class));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Query Operations
|
||||
```java
|
||||
@Test
|
||||
void queryCustomersByStatus_ShouldReturnMatchingCustomers() {
|
||||
// Arrange
|
||||
List<Customer> mockCustomers = List.of(
|
||||
new Customer("1", "Alice", "alice@example.com"),
|
||||
new Customer("2", "Bob", "bob@example.com")
|
||||
);
|
||||
|
||||
DynamoDbTable<Customer> mockTable = mock(DynamoDbTable.class);
|
||||
DynamoDbIndex<Customer> mockIndex = mock(DynamoDbIndex.class);
|
||||
|
||||
QueryEnhancedRequest queryRequest = QueryEnhancedRequest.builder()
|
||||
.queryConditional(QueryConditional.keyEqualTo(Key.builder()
|
||||
.partitionValue("ACTIVE")
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
when(enhancedClient.table("Customers", TableSchema.fromBean(Customer.class)))
|
||||
.thenReturn(mockTable);
|
||||
when(mockTable.index("status-index"))
|
||||
.thenReturn(mockIndex);
|
||||
when(mockIndex.query(queryRequest))
|
||||
.thenReturn(PaginatedQueryIterable.from(mock(Customer.class), mock(QueryResponseEnhanced.class)));
|
||||
|
||||
QueryResponseEnhanced mockResponse = mock(QueryResponseEnhanced.class);
|
||||
when(mockResponse.items())
|
||||
.thenReturn(mockCustomers.stream());
|
||||
|
||||
when(mockIndex.query(any(QueryEnhancedRequest.class)))
|
||||
.thenReturn(PaginatedQueryIterable.from(mock(Customer.class), mockResponse));
|
||||
|
||||
// Act
|
||||
List<Customer> result = customerService.findByStatus("ACTIVE");
|
||||
|
||||
// Assert
|
||||
assertEquals(2, result.size());
|
||||
verify(mockIndex).query(any(QueryEnhancedRequest.class));
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Testing with Testcontainers
|
||||
|
||||
### LocalStack Setup
|
||||
```java
|
||||
@Testcontainers
|
||||
@SpringBootTest
|
||||
@AutoConfigureMockMvc
|
||||
class DynamoDbIntegrationTest {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(LocalStackContainer.Service.DYNAMODB);
|
||||
|
||||
@DynamicPropertySource
|
||||
static void configureProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("aws.region", () -> Region.US_EAST_1.toString());
|
||||
registry.add("aws.accessKeyId", () -> localstack.getAccessKey());
|
||||
registry.add("aws.secretAccessKey", () -> localstack.getSecretKey());
|
||||
registry.add("aws.endpoint",
|
||||
() -> localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB).toString());
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
createTestTable();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCustomerCRUDOperations() {
|
||||
// Test create
|
||||
Customer customer = new Customer("test-123", "Test User", "test@example.com");
|
||||
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
|
||||
.putItem(customer);
|
||||
|
||||
// Test read
|
||||
Customer retrieved = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
|
||||
.getItem(Key.builder().partitionValue("test-123").build());
|
||||
|
||||
assertNotNull(retrieved);
|
||||
assertEquals("Test User", retrieved.getName());
|
||||
|
||||
// Test update
|
||||
customer.setPoints(1000);
|
||||
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
|
||||
.putItem(customer);
|
||||
|
||||
// Test delete
|
||||
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
|
||||
.deleteItem(Key.builder().partitionValue("test-123").build());
|
||||
}
|
||||
|
||||
private void createTestTable() {
|
||||
DynamoDbClient client = DynamoDbClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.endpointOverride(localstack.getEndpointOverride(LocalStackContainer.Service.DYNAMODB))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(localstack.getAccessKey(), localstack.getSecretKey())))
|
||||
.build();
|
||||
|
||||
CreateTableRequest request = CreateTableRequest.builder()
|
||||
.tableName("Customers")
|
||||
.keySchema(KeySchemaElement.builder()
|
||||
.attributeName("customerId")
|
||||
.keyType(KeyType.HASH)
|
||||
.build())
|
||||
.attributeDefinitions(AttributeDefinition.builder()
|
||||
.attributeName("customerId")
|
||||
.attributeType(ScalarAttributeType.S)
|
||||
.build())
|
||||
.provisionedThroughput(ProvisionedThroughput.builder()
|
||||
.readCapacityUnits(5L)
|
||||
.writeCapacityUnits(5L)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
client.createTable(request);
|
||||
waiterForTableActive(client, "Customers");
|
||||
}
|
||||
|
||||
private void waiterForTableActive(DynamoDbClient client, String tableName) {
|
||||
Waiter waiter = client.waiter();
|
||||
CreateTableResponse response = client.createTable(request);
|
||||
|
||||
waiter.waitUntilTableExists(r -> r
|
||||
.tableName(tableName)
|
||||
.maxWait(Duration.ofSeconds(30)));
|
||||
|
||||
try {
|
||||
waiter.waitUntilTableExists(r -> r.tableName(tableName));
|
||||
} catch (WaiterTimeoutException e) {
|
||||
throw new RuntimeException("Table creation timed out", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testcontainers with H2 Migration
|
||||
```java
|
||||
@SpringBootTest
|
||||
@Testcontainers
|
||||
@AutoConfigureDataJpa
|
||||
class CustomerRepositoryTest {
|
||||
|
||||
@Container
|
||||
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>("postgres:15-alpine")
|
||||
.withDatabaseName("testdb")
|
||||
.withUsername("test")
|
||||
.withPassword("test");
|
||||
|
||||
@DynamicPropertySource
|
||||
static void postgresProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("spring.datasource.url", postgres::getJdbcUrl);
|
||||
registry.add("spring.datasource.username", postgres::getUsername);
|
||||
registry.add("spring.datasource.password", postgres::getPassword);
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private CustomerRepository customerRepository;
|
||||
|
||||
@Autowired
|
||||
private DynamoDbEnhancedClient dynamoDbClient;
|
||||
|
||||
@Test
|
||||
void testRepositoryWithRealDatabase() {
|
||||
// Test with real database
|
||||
Customer customer = new Customer("123", "Test User", "test@example.com");
|
||||
customerRepository.save(customer);
|
||||
|
||||
Customer retrieved = customerRepository.findById("123").orElse(null);
|
||||
assertNotNull(retrieved);
|
||||
assertEquals("Test User", retrieved.getName());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
### Load Testing with Gatling
|
||||
```java
|
||||
class CustomerSimulation extends Simulation {
|
||||
HttpProtocolBuilder httpProtocolBuilder = http
|
||||
.baseUrl("http://localhost:8080")
|
||||
.acceptHeader("application/json");
|
||||
|
||||
ScenarioBuilder scn = scenario("Customer Operations")
|
||||
.exec(http("create_customer")
|
||||
.post("/api/customers")
|
||||
.body(StringBody(
|
||||
"""{
|
||||
"customerId": "test-123",
|
||||
"name": "Test User",
|
||||
"email": "test@example.com"
|
||||
}"""))
|
||||
.asJson()
|
||||
.check(status().is(201)))
|
||||
.exec(http("get_customer")
|
||||
.get("/api/customers/test-123")
|
||||
.check(status().is(200)));
|
||||
|
||||
{
|
||||
setUp(
|
||||
scn.injectOpen(
|
||||
rampUsersPerSec(10).to(100).during(60),
|
||||
constantUsersPerSec(100).during(120)
|
||||
)
|
||||
).protocols(httpProtocolBuilder);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Microbenchmark Testing
|
||||
```java
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
|
||||
@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
|
||||
@Fork(1)
|
||||
@State(Scope.Benchmark)
|
||||
public class DynamoDbPerformanceBenchmark {
|
||||
|
||||
private DynamoDbEnhancedClient enhancedClient;
|
||||
private DynamoDbTable<Customer> customerTable;
|
||||
private Customer testCustomer;
|
||||
|
||||
@Setup
|
||||
public void setup() {
|
||||
enhancedClient = DynamoDbEnhancedClient.builder()
|
||||
.dynamoDbClient(DynamoDbClient.builder().build())
|
||||
.build();
|
||||
|
||||
customerTable = enhancedClient.table("Customers", TableSchema.fromBean(Customer.class));
|
||||
testCustomer = new Customer("benchmark-123", "Benchmark User", "benchmark@example.com");
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testPutItem() {
|
||||
customerTable.putItem(testCustomer);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testGetItem() {
|
||||
customerTable.getItem(Key.builder().partitionValue("benchmark-123").build());
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testQuery() {
|
||||
customerTable.scan().items().stream().collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Property-Based Testing
|
||||
|
||||
### Using jqwik
|
||||
```java
|
||||
@Property
|
||||
@Report(Reporting.GENERATED)
|
||||
void customerSerializationShouldBeConsistent(
|
||||
@ForAll("customers") Customer customer
|
||||
) {
|
||||
// When
|
||||
String serialized = serializeCustomer(customer);
|
||||
Customer deserialized = deserializeCustomer(serialized);
|
||||
|
||||
// Then
|
||||
assertEquals(customer.getCustomerId(), deserialized.getCustomerId());
|
||||
assertEquals(customer.getName(), deserialized.getName());
|
||||
assertEquals(customer.getEmail(), deserialized.getEmail());
|
||||
}
|
||||
|
||||
@Provide
|
||||
Arbitrary<Customer> customers() {
|
||||
return Arbitraries.one(
|
||||
Arbitraries.of("customer-", "user-", "client-").string()
|
||||
).map(id -> new Customer(
|
||||
id + Arbitraries.integers().between(1000, 9999).sample(),
|
||||
Arbitraries.strings().ofLength(10).sample(),
|
||||
Arbitraries.strings().email().sample()
|
||||
));
|
||||
}
|
||||
```
|
||||
|
||||
## Test Data Management
|
||||
|
||||
### Test Data Factory
|
||||
```java
|
||||
@Component
|
||||
public class TestDataFactory {
|
||||
|
||||
private final DynamoDbEnhancedClient enhancedClient;
|
||||
|
||||
@Autowired
|
||||
public TestDataFactory(DynamoDbEnhancedClient enhancedClient) {
|
||||
this.enhancedClient = enhancedClient;
|
||||
}
|
||||
|
||||
public Customer createTestCustomer(String id) {
|
||||
Customer customer = new Customer(
|
||||
id != null ? id : UUID.randomUUID().toString(),
|
||||
"Test User",
|
||||
"test@example.com"
|
||||
);
|
||||
customer.setPoints(1000);
|
||||
customer.setCreatedAt(LocalDateTime.now());
|
||||
|
||||
enhancedClient.table("Customers", TableSchema.fromBean(Customer.class))
|
||||
.putItem(customer);
|
||||
|
||||
return customer;
|
||||
}
|
||||
|
||||
public void cleanupTestData() {
|
||||
// Implementation to clean up test data
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Test Database Configuration
|
||||
```java
|
||||
@TestConfiguration
|
||||
public class TestDataConfig {
|
||||
|
||||
@Bean
|
||||
public TestDataCleaner testDataCleaner() {
|
||||
return new TestDataCleaner();
|
||||
}
|
||||
}
|
||||
|
||||
@Component
|
||||
public class TestDataCleaner {
|
||||
|
||||
private final DynamoDbClient dynamoDbClient;
|
||||
|
||||
@EventListener(ApplicationReadyEvent.class)
|
||||
public void cleanup() {
|
||||
// Clean up test data before each test run
|
||||
}
|
||||
}
|
||||
```
|
||||
416
skills/aws-java/aws-sdk-java-v2-kms/SKILL.md
Normal file
416
skills/aws-java/aws-sdk-java-v2-kms/SKILL.md
Normal file
@@ -0,0 +1,416 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-kms
|
||||
description: AWS Key Management Service (KMS) patterns using AWS SDK for Java 2.x. Use when creating/managing encryption keys, encrypting/decrypting data, generating data keys, digital signing, key rotation, or integrating encryption into Spring Boot applications.
|
||||
category: aws
|
||||
tags: [aws, kms, java, sdk, encryption, security]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash, WebFetch
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - AWS KMS (Key Management Service)
|
||||
|
||||
## Overview
|
||||
|
||||
This skill provides comprehensive patterns for AWS Key Management Service (KMS) using AWS SDK for Java 2.x. Focus on implementing secure encryption solutions with proper key management, envelope encryption, and Spring Boot integration patterns.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Creating and managing symmetric encryption keys for data protection
|
||||
- Implementing client-side encryption and envelope encryption patterns
|
||||
- Generating data keys for local data encryption with KMS-managed keys
|
||||
- Setting up digital signatures and verification with asymmetric keys
|
||||
- Integrating encryption capabilities into Spring Boot applications
|
||||
- Implementing secure key lifecycle management
|
||||
- Setting up key rotation policies and access controls
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Maven
|
||||
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>kms</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
|
||||
```groovy
|
||||
implementation 'software.amazon.awssdk:kms:2.x.x'
|
||||
```
|
||||
|
||||
## Client Setup
|
||||
|
||||
### Basic Synchronous Client
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.kms.KmsClient;
|
||||
|
||||
KmsClient kmsClient = KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Basic Asynchronous Client
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.kms.KmsAsyncClient;
|
||||
|
||||
KmsAsyncClient kmsAsyncClient = KmsAsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Advanced Client Configuration
|
||||
|
||||
```java
|
||||
KmsClient kmsClient = KmsClient.builder()
|
||||
.region(Region.of(System.getenv("AWS_REGION")))
|
||||
.credentialsProvider(DefaultCredentialsProvider.create())
|
||||
.overrideConfiguration(c -> c.retryPolicy(RetryPolicy.builder()
|
||||
.numRetries(3)
|
||||
.build()))
|
||||
.build();
|
||||
```
|
||||
|
||||
## Basic Key Management
|
||||
|
||||
### Create Encryption Key
|
||||
|
||||
```java
|
||||
public String createEncryptionKey(KmsClient kmsClient, String description) {
|
||||
CreateKeyRequest request = CreateKeyRequest.builder()
|
||||
.description(description)
|
||||
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
|
||||
.build();
|
||||
|
||||
CreateKeyResponse response = kmsClient.createKey(request);
|
||||
return response.keyMetadata().keyId();
|
||||
}
|
||||
```
|
||||
|
||||
### Describe Key
|
||||
|
||||
```java
|
||||
public KeyMetadata getKeyMetadata(KmsClient kmsClient, String keyId) {
|
||||
DescribeKeyRequest request = DescribeKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
|
||||
return kmsClient.describeKey(request).keyMetadata();
|
||||
}
|
||||
```
|
||||
|
||||
### Enable/Disable Key
|
||||
|
||||
```java
|
||||
public void toggleKeyState(KmsClient kmsClient, String keyId, boolean enable) {
|
||||
if (enable) {
|
||||
kmsClient.enableKey(EnableKeyRequest.builder().keyId(keyId).build());
|
||||
} else {
|
||||
kmsClient.disableKey(DisableKeyRequest.builder().keyId(keyId).build());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Encryption and Decryption
|
||||
|
||||
### Encrypt Data
|
||||
|
||||
```java
|
||||
public String encryptData(KmsClient kmsClient, String keyId, String plaintext) {
|
||||
SdkBytes plaintextBytes = SdkBytes.fromString(plaintext, StandardCharsets.UTF_8);
|
||||
|
||||
EncryptRequest request = EncryptRequest.builder()
|
||||
.keyId(keyId)
|
||||
.plaintext(plaintextBytes)
|
||||
.build();
|
||||
|
||||
EncryptResponse response = kmsClient.encrypt(request);
|
||||
return Base64.getEncoder().encodeToString(
|
||||
response.ciphertextBlob().asByteArray());
|
||||
}
|
||||
```
|
||||
|
||||
### Decrypt Data
|
||||
|
||||
```java
|
||||
public String decryptData(KmsClient kmsClient, String ciphertextBase64) {
|
||||
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
|
||||
SdkBytes ciphertextBytes = SdkBytes.fromByteArray(ciphertext);
|
||||
|
||||
DecryptRequest request = DecryptRequest.builder()
|
||||
.ciphertextBlob(ciphertextBytes)
|
||||
.build();
|
||||
|
||||
DecryptResponse response = kmsClient.decrypt(request);
|
||||
return response.plaintext().asString(StandardCharsets.UTF_8);
|
||||
}
|
||||
```
|
||||
|
||||
## Envelope Encryption Pattern
|
||||
|
||||
### Generate and Use Data Key
|
||||
|
||||
```java
|
||||
public DataKeyResult encryptWithEnvelope(KmsClient kmsClient, String masterKeyId, byte[] data) {
|
||||
// Generate data key
|
||||
GenerateDataKeyRequest keyRequest = GenerateDataKeyRequest.builder()
|
||||
.keyId(masterKeyId)
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build();
|
||||
|
||||
GenerateDataKeyResponse keyResponse = kmsClient.generateDataKey(keyRequest);
|
||||
|
||||
// Encrypt data with data key
|
||||
byte[] encryptedData = encryptWithAES(data,
|
||||
keyResponse.plaintext().asByteArray());
|
||||
|
||||
// Clear plaintext key from memory
|
||||
Arrays.fill(keyResponse.plaintext().asByteArray(), (byte) 0);
|
||||
|
||||
return new DataKeyResult(
|
||||
encryptedData,
|
||||
keyResponse.ciphertextBlob().asByteArray());
|
||||
}
|
||||
|
||||
public byte[] decryptWithEnvelope(KmsClient kmsClient,
|
||||
DataKeyResult encryptedEnvelope) {
|
||||
// Decrypt data key
|
||||
DecryptRequest keyDecryptRequest = DecryptRequest.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(
|
||||
encryptedEnvelope.encryptedKey()))
|
||||
.build();
|
||||
|
||||
DecryptResponse keyDecryptResponse = kmsClient.decrypt(keyDecryptRequest);
|
||||
|
||||
// Decrypt data with decrypted key
|
||||
byte[] decryptedData = decryptWithAES(
|
||||
encryptedEnvelope.encryptedData(),
|
||||
keyDecryptResponse.plaintext().asByteArray());
|
||||
|
||||
// Clear plaintext key from memory
|
||||
Arrays.fill(keyDecryptResponse.plaintext().asByteArray(), (byte) 0);
|
||||
|
||||
return decryptedData;
|
||||
}
|
||||
```
|
||||
|
||||
## Digital Signatures
|
||||
|
||||
### Create Signing Key and Sign Data
|
||||
|
||||
```java
|
||||
public String createAndSignData(KmsClient kmsClient, String description, String message) {
|
||||
// Create signing key
|
||||
CreateKeyRequest keyRequest = CreateKeyRequest.builder()
|
||||
.description(description)
|
||||
.keySpec(KeySpec.RSA_2048)
|
||||
.keyUsage(KeyUsageType.SIGN_VERIFY)
|
||||
.build();
|
||||
|
||||
CreateKeyResponse keyResponse = kmsClient.createKey(keyRequest);
|
||||
String keyId = keyResponse.keyMetadata().keyId();
|
||||
|
||||
// Sign data
|
||||
SignRequest signRequest = SignRequest.builder()
|
||||
.keyId(keyId)
|
||||
.message(SdkBytes.fromString(message, StandardCharsets.UTF_8))
|
||||
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
|
||||
.build();
|
||||
|
||||
SignResponse signResponse = kmsClient.sign(signRequest);
|
||||
return Base64.getEncoder().encodeToString(
|
||||
signResponse.signature().asByteArray());
|
||||
}
|
||||
```
|
||||
|
||||
### Verify Signature
|
||||
|
||||
```java
|
||||
public boolean verifySignature(KmsClient kmsClient,
|
||||
String keyId,
|
||||
String message,
|
||||
String signatureBase64) {
|
||||
byte[] signature = Base64.getDecoder().decode(signatureBase64);
|
||||
|
||||
VerifyRequest verifyRequest = VerifyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.message(SdkBytes.fromString(message, StandardCharsets.UTF_8))
|
||||
.signature(SdkBytes.fromByteArray(signature))
|
||||
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
|
||||
.build();
|
||||
|
||||
VerifyResponse verifyResponse = kmsClient.verify(verifyRequest);
|
||||
return verifyResponse.signatureValid();
|
||||
}
|
||||
```
|
||||
|
||||
## Spring Boot Integration
|
||||
|
||||
### Configuration Class
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class KmsConfiguration {
|
||||
|
||||
@Bean
|
||||
public KmsClient kmsClient() {
|
||||
return KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public KmsAsyncClient kmsAsyncClient() {
|
||||
return KmsAsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Encryption Service
|
||||
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class KmsEncryptionService {
|
||||
|
||||
private final KmsClient kmsClient;
|
||||
|
||||
@Value("${kms.encryption-key-id}")
|
||||
private String keyId;
|
||||
|
||||
public String encrypt(String plaintext) {
|
||||
try {
|
||||
EncryptRequest request = EncryptRequest.builder()
|
||||
.keyId(keyId)
|
||||
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
|
||||
.build();
|
||||
|
||||
EncryptResponse response = kmsClient.encrypt(request);
|
||||
return Base64.getEncoder().encodeToString(
|
||||
response.ciphertextBlob().asByteArray());
|
||||
|
||||
} catch (KmsException e) {
|
||||
throw new RuntimeException("Encryption failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
public String decrypt(String ciphertextBase64) {
|
||||
try {
|
||||
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
|
||||
|
||||
DecryptRequest request = DecryptRequest.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
|
||||
.build();
|
||||
|
||||
DecryptResponse response = kmsClient.decrypt(request);
|
||||
return response.plaintext().asString(StandardCharsets.UTF_8);
|
||||
|
||||
} catch (KmsException e) {
|
||||
throw new RuntimeException("Decryption failed", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Encryption Example
|
||||
|
||||
```java
|
||||
public class BasicEncryptionExample {
|
||||
public static void main(String[] args) {
|
||||
KmsClient kmsClient = KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Create key
|
||||
String keyId = createEncryptionKey(kmsClient, "Example encryption key");
|
||||
System.out.println("Created key: " + keyId);
|
||||
|
||||
// Encrypt and decrypt
|
||||
String plaintext = "Hello, World!";
|
||||
String encrypted = encryptData(kmsClient, keyId, plaintext);
|
||||
String decrypted = decryptData(kmsClient, encrypted);
|
||||
|
||||
System.out.println("Original: " + plaintext);
|
||||
System.out.println("Decrypted: " + decrypted);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Envelope Encryption Example
|
||||
|
||||
```java
|
||||
public class EnvelopeEncryptionExample {
|
||||
public static void main(String[] args) {
|
||||
KmsClient kmsClient = KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
String masterKeyId = "alias/your-master-key";
|
||||
String largeData = "This is a large amount of data that needs encryption...";
|
||||
byte[] data = largeData.getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
// Encrypt using envelope pattern
|
||||
DataKeyResult encryptedEnvelope = encryptWithEnvelope(
|
||||
kmsClient, masterKeyId, data);
|
||||
|
||||
// Decrypt
|
||||
byte[] decryptedData = decryptWithEnvelope(
|
||||
kmsClient, encryptedEnvelope);
|
||||
|
||||
String result = new String(decryptedData, StandardCharsets.UTF_8);
|
||||
System.out.println("Decrypted: " + result);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Security
|
||||
- **Always use envelope encryption for large data** - Encrypt data locally and only encrypt the data key with KMS
|
||||
- **Use encryption context** - Add contextual information to track and audit usage
|
||||
- **Never log sensitive data** - Avoid logging plaintext or encryption keys
|
||||
- **Implement proper key lifecycle** - Enable automatic rotation and set deletion policies
|
||||
- **Use separate keys for different purposes** - Don't reuse keys across multiple applications
|
||||
|
||||
### Performance
|
||||
- **Cache encrypted data keys** - Reduce KMS API calls by caching data keys
|
||||
- **Use async operations** - Leverage async clients for non-blocking I/O
|
||||
- **Reuse client instances** - Don't create new clients for each operation
|
||||
- **Implement connection pooling** - Configure proper connection pooling settings
|
||||
|
||||
### Error Handling
|
||||
- **Implement retry logic** - Handle throttling exceptions with exponential backoff
|
||||
- **Check key states** - Verify key is enabled before performing operations
|
||||
- **Use circuit breakers** - Prevent cascading failures during KMS outages
|
||||
- **Log errors comprehensively** - Include KMS error codes and context
|
||||
|
||||
## References
|
||||
|
||||
For detailed implementation patterns, advanced techniques, and comprehensive examples:
|
||||
|
||||
- @references/technical-guide.md - Complete technical implementation patterns
|
||||
- @references/spring-boot-integration.md - Spring Boot integration patterns
|
||||
- @references/testing.md - Testing strategies and examples
|
||||
- @references/best-practices.md - Security and operational best practices
|
||||
|
||||
## Related Skills
|
||||
|
||||
- @aws-sdk-java-v2-core - Core AWS SDK patterns and configuration
|
||||
- @aws-sdk-java-v2-dynamodb - DynamoDB integration patterns
|
||||
- @aws-sdk-java-v2-secrets-manager - Secrets management patterns
|
||||
- @spring-boot-dependency-injection - Spring dependency injection patterns
|
||||
|
||||
## External References
|
||||
|
||||
- [AWS KMS Developer Guide](https://docs.aws.amazon.com/kms/latest/developerguide/)
|
||||
- [AWS SDK for Java 2.x Documentation](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html)
|
||||
- [KMS Best Practices](https://docs.aws.amazon.com/kms/latest/developerguide/best-practices.html)
|
||||
550
skills/aws-java/aws-sdk-java-v2-kms/references/best-practices.md
Normal file
550
skills/aws-java/aws-sdk-java-v2-kms/references/best-practices.md
Normal file
@@ -0,0 +1,550 @@
|
||||
# AWS KMS Best Practices
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### Key Management
|
||||
|
||||
1. **Use Separate Keys for Different Purposes**
|
||||
- Create unique keys for different applications or data types
|
||||
- Avoid reusing keys across multiple purposes
|
||||
- Use aliases instead of raw key IDs for references
|
||||
|
||||
```java
|
||||
// Good: Create specific keys
|
||||
String encryptionKey = kms.createKey("Database encryption key");
|
||||
String signingKey = kms.createSigningKey("Document signing key");
|
||||
|
||||
// Bad: Use the same key for everything
|
||||
```
|
||||
|
||||
2. **Enable Automatic Key Rotation**
|
||||
- Enable automatic key rotation for enhanced security
|
||||
- Review rotation schedules based on compliance requirements
|
||||
|
||||
```java
|
||||
public void enableKeyRotation(KmsClient kmsClient, String keyId) {
|
||||
EnableKeyRotationRequest request = EnableKeyRotationRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
kmsClient.enableKeyRotation(request);
|
||||
}
|
||||
```
|
||||
|
||||
3. **Implement Key Lifecycle Policies**
|
||||
- Set key expiration dates based on data retention policies
|
||||
- Schedule key deletion when no longer needed
|
||||
- Use key policies to enforce lifecycle rules
|
||||
|
||||
4. **Use Key Aliases**
|
||||
- Always use aliases instead of raw key IDs
|
||||
- Create meaningful aliases following naming conventions
|
||||
- Regularly review and update aliases
|
||||
|
||||
```java
|
||||
public void createKeyWithAlias(KmsClient kmsClient, String alias, String description) {
|
||||
// Create key
|
||||
CreateKeyResponse response = kmsClient.createKey(
|
||||
CreateKeyRequest.builder()
|
||||
.description(description)
|
||||
.build());
|
||||
|
||||
// Create alias
|
||||
CreateAliasRequest aliasRequest = CreateAliasRequest.builder()
|
||||
.aliasName(alias)
|
||||
.targetKeyId(response.keyMetadata().keyId())
|
||||
.build();
|
||||
kmsClient.createAlias(aliasRequest);
|
||||
}
|
||||
```
|
||||
|
||||
### Encryption Security
|
||||
|
||||
1. **Never Log Plaintext or Encryption Keys**
|
||||
- Avoid logging sensitive data in any form
|
||||
- Ensure proper logging configuration to prevent accidental exposure
|
||||
|
||||
```java
|
||||
// Bad: Logging sensitive data
|
||||
logger.info("Encrypted data: {}", encryptedData);
|
||||
|
||||
// Good: Log only metadata
|
||||
logger.info("Encryption completed for user: {}", userId);
|
||||
```
|
||||
|
||||
2. **Use Encryption Context**
|
||||
- Always include encryption context for additional security
|
||||
- Use contextual information to verify data integrity
|
||||
|
||||
```java
|
||||
public Map<String, String> createEncryptionContext(String userId, String dataType) {
|
||||
return Map.of(
|
||||
"userId", userId,
|
||||
"dataType", dataType,
|
||||
"timestamp", Instant.now().toString()
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
3. **Implement Least Privilege IAM Policies**
|
||||
- Grant minimal required permissions to KMS keys
|
||||
- Use IAM policies to restrict access to specific resources
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "arn:aws:iam::123456789012:role/app-role"},
|
||||
"Action": [
|
||||
"kms:Encrypt",
|
||||
"kms:Decrypt",
|
||||
"kms:DescribeKey"
|
||||
],
|
||||
"Resource": "arn:aws:kms:us-east-1:123456789012:key/your-key-id",
|
||||
"Condition": {
|
||||
"StringEquals": {
|
||||
"kms:EncryptionContext:userId": "${aws:userid}"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
4. **Clear Sensitive Data from Memory**
|
||||
- Explicitly clear sensitive data from memory after use
|
||||
- Use secure memory management practices
|
||||
|
||||
```java
|
||||
public void secureMemoryExample() {
|
||||
byte[] sensitiveKey = new byte[32];
|
||||
// ... use the key ...
|
||||
|
||||
// Clear sensitive data
|
||||
Arrays.fill(sensitiveKey, (byte) 0);
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Best Practices
|
||||
|
||||
1. **Cache Data Keys for Envelope Encryption**
|
||||
- Cache encrypted data keys to avoid repeated KMS calls
|
||||
- Use appropriate cache eviction policies
|
||||
- Monitor cache hit rates
|
||||
|
||||
```java
|
||||
public class DataKeyCache {
|
||||
private final Cache<String, byte[]> keyCache;
|
||||
|
||||
public DataKeyCache() {
|
||||
this.keyCache = Caffeine.newBuilder()
|
||||
.expireAfterWrite(1, TimeUnit.HOURS)
|
||||
.maximumSize(1000)
|
||||
.build();
|
||||
}
|
||||
|
||||
public byte[] getCachedDataKey(String keyId, KmsClient kmsClient) {
|
||||
return keyCache.get(keyId, k -> {
|
||||
GenerateDataKeyResponse response = kmsClient.generateDataKey(
|
||||
GenerateDataKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build());
|
||||
return response.ciphertextBlob().asByteArray();
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Use Async Operations for Non-Blocking I/O**
|
||||
- Leverage async clients for parallel processing
|
||||
- Use CompletableFuture for chaining operations
|
||||
|
||||
```java
|
||||
public CompletableFuture<Void> processMultipleAsync(List<String> dataItems) {
|
||||
List<CompletableFuture<Void>> futures = dataItems.stream()
|
||||
.map(item -> CompletableFuture.runAsync(() ->
|
||||
encryptAndStoreItem(item)))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
|
||||
}
|
||||
```
|
||||
|
||||
3. **Implement Connection Pooling**
|
||||
- Configure connection pooling for better resource utilization
|
||||
- Set appropriate pool sizes based on load
|
||||
|
||||
```java
|
||||
public KmsClient createPooledClient() {
|
||||
return KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.httpClientBuilder(ApacheHttpClient.builder()
|
||||
.maxConnections(100)
|
||||
.connectionTimeToLive(Duration.ofSeconds(30))
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
4. **Reuse KMS Client Instances**
|
||||
- Create and reuse client instances rather than creating new ones
|
||||
- Use dependency injection for client management
|
||||
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class KmsService {
|
||||
private final KmsClient kmsClient; // Inject and reuse
|
||||
|
||||
public void performOperation() {
|
||||
// Use the same client instance
|
||||
kmsClient.someOperation();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cost Optimization
|
||||
|
||||
1. **Use Envelope Encryption for Large Data**
|
||||
- Generate data keys for encrypting large datasets
|
||||
- Only use KMS for encrypting the data key, not the entire dataset
|
||||
|
||||
```java
|
||||
public class EnvelopeEncryption {
|
||||
private final KmsClient kmsClient;
|
||||
|
||||
public byte[] encryptLargeData(byte[] largeData) {
|
||||
// Generate data key
|
||||
GenerateDataKeyResponse response = kmsClient.generateDataKey(
|
||||
GenerateDataKeyRequest.builder()
|
||||
.keyId("master-key-id")
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build());
|
||||
|
||||
byte[] encryptedKey = response.ciphertextBlob().asByteArray();
|
||||
byte[] plaintextKey = response.plaintext().asByteArray();
|
||||
|
||||
// Encrypt data with local key
|
||||
byte[] encryptedData = localEncrypt(largeData, plaintextKey);
|
||||
|
||||
// Return both encrypted data and encrypted key
|
||||
return combine(encryptedKey, encryptedData);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Cache Encrypted Data Keys**
|
||||
- Cache encrypted data keys to avoid repeated KMS calls
|
||||
- Use time-based cache expiration
|
||||
|
||||
3. **Monitor API Usage**
|
||||
- Track KMS API calls for billing and optimization
|
||||
- Set up CloudWatch alarms for unexpected usage
|
||||
|
||||
```java
|
||||
public class KmsUsageMonitor {
|
||||
private final MeterRegistry meterRegistry;
|
||||
|
||||
public void recordEncryption() {
|
||||
meterRegistry.counter("kms.encryption.count").increment();
|
||||
meterRegistry.timer("kms.encryption.time").record(() -> {
|
||||
// Perform encryption
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
4. **Use Data Key Caching Libraries**
|
||||
- Implement proper caching strategies
|
||||
- Consider using dedicated caching solutions for data keys
|
||||
|
||||
## Error Handling Best Practices
|
||||
|
||||
1. **Implement Retry Logic for Throttling**
|
||||
- Add retry logic for throttling exceptions
|
||||
- Use exponential backoff for retries
|
||||
|
||||
```java
|
||||
public class KmsRetryHandler {
|
||||
private static final int MAX_RETRIES = 3;
|
||||
private static final long INITIAL_DELAY = 1000; // 1 second
|
||||
|
||||
public <T> T executeWithRetry(Supplier<T> operation) {
|
||||
int attempt = 0;
|
||||
while (attempt < MAX_RETRIES) {
|
||||
try {
|
||||
return operation.get();
|
||||
} catch (KmsException e) {
|
||||
if (!isRetryable(e) || attempt == MAX_RETRIES - 1) {
|
||||
throw e;
|
||||
}
|
||||
attempt++;
|
||||
try {
|
||||
Thread.sleep(INITIAL_DELAY * (long) Math.pow(2, attempt));
|
||||
} catch (InterruptedException ie) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Retry interrupted", ie);
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new IllegalStateException("Should not reach here");
|
||||
}
|
||||
|
||||
private boolean isRetryable(KmsException e) {
|
||||
return "ThrottlingException".equals(e.awsErrorDetails().errorCode());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Handle Key State Errors Gracefully**
|
||||
- Check key state before performing operations
|
||||
- Handle key states like PendingDeletion, Disabled, etc.
|
||||
|
||||
```java
|
||||
public void performOperationWithKeyStateCheck(KmsClient kmsClient, String keyId) {
|
||||
KeyMetadata metadata = describeKey(kmsClient, keyId);
|
||||
|
||||
switch (metadata.keyState()) {
|
||||
case ENABLED:
|
||||
// Perform operation
|
||||
break;
|
||||
case DISABLED:
|
||||
throw new IllegalStateException("Key is disabled");
|
||||
case PENDING_DELETION:
|
||||
throw new IllegalStateException("Key is scheduled for deletion");
|
||||
default:
|
||||
throw new IllegalStateException("Unknown key state: " + metadata.keyState());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **Log KMS-Specific Error Codes**
|
||||
- Implement comprehensive error logging
|
||||
- Map KMS error codes to meaningful application errors
|
||||
|
||||
```java
|
||||
public class KmsErrorHandler {
|
||||
public String mapKmsErrorToAppError(KmsException e) {
|
||||
String errorCode = e.awsErrorDetails().errorCode();
|
||||
switch (errorCode) {
|
||||
case "NotFoundException":
|
||||
return "Key not found";
|
||||
case "DisabledException":
|
||||
return "Key is disabled";
|
||||
case "AccessDeniedException":
|
||||
return "Access denied";
|
||||
case "InvalidKeyUsageException":
|
||||
return "Invalid key usage";
|
||||
default:
|
||||
return "KMS error: " + errorCode;
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
4. **Implement Circuit Breakers**
|
||||
- Use circuit breakers to handle KMS unavailability
|
||||
- Prevent cascading failures during outages
|
||||
|
||||
```java
|
||||
public class KmsCircuitBreaker {
|
||||
private final CircuitBreaker circuitBreaker;
|
||||
|
||||
public KmsCircuitBreaker() {
|
||||
this.circuitBreaker = CircuitBreaker.builder()
|
||||
.name("kmsService")
|
||||
.failureRateThreshold(50)
|
||||
.waitDurationInOpenState(Duration.ofSeconds(30))
|
||||
.ringBufferSizeInHalfOpenState(2)
|
||||
.ringBufferSizeInClosedState(2)
|
||||
.build();
|
||||
}
|
||||
|
||||
public <T> T executeWithCircuitBreaker(Callable<T> operation) {
|
||||
return circuitBreaker.executeCallable(() -> {
|
||||
try {
|
||||
return operation.call();
|
||||
} catch (KmsException e) {
|
||||
if (isFailure(e)) {
|
||||
throw new CircuitBreakerOpenException("KMS service unavailable");
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private boolean isFailure(KmsException e) {
|
||||
return "KMSDisabledException".equals(e.awsErrorDetails().errorCode());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Best Practices
|
||||
|
||||
1. **Test with Mock KMS Client**
|
||||
- Use mock clients for unit tests
|
||||
- Verify all expected interactions
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldEncryptWithProperEncryptionContext() {
|
||||
// Arrange
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class))).thenReturn(...);
|
||||
|
||||
// Act
|
||||
String result = encryptionService.encrypt("test", "user123");
|
||||
|
||||
// Assert
|
||||
verify(kmsClient).encrypt(argThat(request ->
|
||||
request.encryptionContext().containsKey("userId") &&
|
||||
request.encryptionContext().get("userId").equals("user123")));
|
||||
}
|
||||
```
|
||||
|
||||
2. **Test Error Scenarios**
|
||||
- Test various error conditions
|
||||
- Verify proper error handling and recovery
|
||||
|
||||
3. **Performance Testing**
|
||||
- Test performance under load
|
||||
- Measure latency and throughput
|
||||
|
||||
4. **Integration Testing with Local KMS**
|
||||
- Test with local KMS when possible
|
||||
- Verify integration with real AWS services
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
1. **Implement Comprehensive Logging**
|
||||
- Log all KMS operations with appropriate levels
|
||||
- Include correlation IDs for tracing
|
||||
|
||||
```java
|
||||
public class KmsLoggingAspect {
|
||||
private static final Logger logger = LoggerFactory.getLogger(KmsService.class);
|
||||
|
||||
@Around("execution(* com.yourcompany.kms..*.*(..))")
|
||||
public Object logKmsOperation(ProceedingJoinPoint joinPoint) throws Throwable {
|
||||
String operation = joinPoint.getSignature().getName();
|
||||
logger.info("Starting KMS operation: {}", operation);
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
try {
|
||||
Object result = joinPoint.proceed();
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
logger.info("Completed KMS operation: {} in {}ms", operation, duration);
|
||||
return result;
|
||||
} catch (Exception e) {
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
logger.error("KMS operation {} failed in {}ms: {}", operation, duration, e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Set Up CloudWatch Alarms**
|
||||
- Monitor API call rates
|
||||
- Set up alarms for error rates
|
||||
- Track key usage patterns
|
||||
|
||||
3. **Use Distributed Tracing**
|
||||
- Implement tracing for KMS operations
|
||||
- Correlate KMS calls with application operations
|
||||
|
||||
4. **Monitor Key Usage Metrics**
|
||||
- Track key usage patterns
|
||||
- Monitor for unusual usage patterns
|
||||
|
||||
## Compliance and Auditing
|
||||
|
||||
1. **Enable KMS Key Usage Logging**
|
||||
- Configure CloudTrail to log KMS operations
|
||||
- Enable detailed logging for compliance
|
||||
|
||||
2. **Regular Security Audits**
|
||||
- Conduct regular audits of KMS key usage
|
||||
- Review access policies periodically
|
||||
|
||||
3. **Comprehensive Backup Strategy**
|
||||
- Implement key backup and recovery procedures
|
||||
- Test backup restoration processes
|
||||
|
||||
4. **Comprehensive Access Reviews**
|
||||
- Regularly review IAM policies for KMS access
|
||||
- Remove unnecessary permissions
|
||||
|
||||
## Advanced Security Considerations
|
||||
|
||||
1. **Multi-Region KMS Keys**
|
||||
- Consider multi-region keys for disaster recovery
|
||||
- Test failover scenarios
|
||||
|
||||
2. **Cross-Account Access**
|
||||
- Implement proper cross-account access controls
|
||||
- Use resource-based policies for account sharing
|
||||
|
||||
3. **Custom Key Stores**
|
||||
- Consider custom key stores for enhanced security
|
||||
- Implement proper key management in custom stores
|
||||
|
||||
4. **Key Material External**
|
||||
- Use imported key material for enhanced control
|
||||
- Implement proper key rotation for imported keys
|
||||
|
||||
## Development Best Practices
|
||||
|
||||
1. **Use Dependency Injection**
|
||||
- Inject KMS clients rather than creating them directly
|
||||
- Use proper configuration management
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@ConfigurationProperties(prefix = "aws.kms")
|
||||
public class KmsProperties {
|
||||
private String region = "us-east-1";
|
||||
private String encryptionKeyId;
|
||||
private int maxRetries = 3;
|
||||
|
||||
// Getters and setters
|
||||
}
|
||||
```
|
||||
|
||||
2. **Proper Configuration Management**
|
||||
- Use environment-specific configurations
|
||||
- Secure sensitive configuration values
|
||||
|
||||
3. **Version Control and Documentation**
|
||||
- Keep KMS-related code well documented
|
||||
- Track key usage patterns in version control
|
||||
|
||||
4. **Code Reviews**
|
||||
- Conduct thorough code reviews for KMS-related code
|
||||
- Focus on security and error handling
|
||||
|
||||
## Implementation Checklists
|
||||
|
||||
### Key Setup Checklist
|
||||
- [ ] Create appropriate KMS keys for different purposes
|
||||
- [ ] Enable automatic key rotation
|
||||
- [ ] Set up key aliases
|
||||
- [ ] Configure IAM policies with least privilege
|
||||
- [ ] Set up CloudTrail logging
|
||||
|
||||
### Implementation Checklist
|
||||
- [ ] Use envelope encryption for large data
|
||||
- [ ] Implement proper error handling
|
||||
- [ ] Add comprehensive logging
|
||||
- [ ] Set up monitoring and alarms
|
||||
- [ ] Write comprehensive tests
|
||||
|
||||
### Security Checklist
|
||||
- [ ] Never log sensitive data
|
||||
- [ ] Use encryption context
|
||||
- [ ] Implement proper access controls
|
||||
- [ ] Clear sensitive data from memory
|
||||
- [ ] Regularly audit access patterns
|
||||
|
||||
By following these best practices, you can ensure that your AWS KMS implementation is secure, performant, cost-effective, and maintainable.
|
||||
@@ -0,0 +1,504 @@
|
||||
# Spring Boot Integration with AWS KMS
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class KmsConfiguration {
|
||||
|
||||
@Bean
|
||||
public KmsClient kmsClient() {
|
||||
return KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public KmsAsyncClient kmsAsyncClient() {
|
||||
return KmsAsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration with Custom Settings
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@ConfigurationProperties(prefix = "aws.kms")
|
||||
public class KmsAdvancedConfiguration {
|
||||
|
||||
private Region region = Region.US_EAST_1;
|
||||
private String endpoint;
|
||||
private Duration timeout = Duration.ofSeconds(10);
|
||||
private String accessKeyId;
|
||||
private String secretAccessKey;
|
||||
|
||||
@Bean
|
||||
public KmsClient kmsClient() {
|
||||
KmsClientBuilder builder = KmsClient.builder()
|
||||
.region(region)
|
||||
.overrideConfiguration(c -> c.retryPolicy(RetryPolicy.builder()
|
||||
.numRetries(3)
|
||||
.build()));
|
||||
|
||||
if (endpoint != null) {
|
||||
builder.endpointOverride(URI.create(endpoint));
|
||||
}
|
||||
|
||||
// Add credentials if provided
|
||||
if (accessKeyId != null && secretAccessKey != null) {
|
||||
builder.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(accessKeyId, secretAccessKey)));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
// Getters and Setters
|
||||
public Region getRegion() { return region; }
|
||||
public void setRegion(Region region) { this.region = region; }
|
||||
public String getEndpoint() { return endpoint; }
|
||||
public void setEndpoint(String endpoint) { this.endpoint = endpoint; }
|
||||
public Duration getTimeout() { return timeout; }
|
||||
public void setTimeout(Duration timeout) { this.timeout = timeout; }
|
||||
public String getAccessKeyId() { return accessKeyId; }
|
||||
public void setAccessKeyId(String accessKeyId) { this.accessKeyId = accessKeyId; }
|
||||
public String getSecretAccessKey() { return secretAccessKey; }
|
||||
public void setSecretAccessKey(String secretAccessKey) { this.secretAccessKey = secretAccessKey; }
|
||||
}
|
||||
```
|
||||
|
||||
### Application Properties
|
||||
|
||||
```properties
|
||||
# AWS KMS Configuration
|
||||
aws.kms.region=us-east-1
|
||||
aws.kms.endpoint=
|
||||
aws.kms.timeout=10s
|
||||
aws.kms.access-key-id=
|
||||
aws.kms.secret-access-key=
|
||||
|
||||
# KMS Key Configuration
|
||||
kms.encryption-key-id=alias/your-encryption-key
|
||||
kms.signing-key-id=alias/your-signing-key
|
||||
```
|
||||
|
||||
## Encryption Service
|
||||
|
||||
### Basic Encryption Service
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class KmsEncryptionService {
|
||||
|
||||
private final KmsClient kmsClient;
|
||||
|
||||
@Value("${kms.encryption-key-id}")
|
||||
private String keyId;
|
||||
|
||||
public KmsEncryptionService(KmsClient kmsClient) {
|
||||
this.kmsClient = kmsClient;
|
||||
}
|
||||
|
||||
public String encrypt(String plaintext) {
|
||||
try {
|
||||
EncryptRequest request = EncryptRequest.builder()
|
||||
.keyId(keyId)
|
||||
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
|
||||
.build();
|
||||
|
||||
EncryptResponse response = kmsClient.encrypt(request);
|
||||
|
||||
// Return Base64-encoded ciphertext
|
||||
return Base64.getEncoder()
|
||||
.encodeToString(response.ciphertextBlob().asByteArray());
|
||||
|
||||
} catch (KmsException e) {
|
||||
throw new RuntimeException("Encryption failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
public String decrypt(String ciphertextBase64) {
|
||||
try {
|
||||
byte[] ciphertext = Base64.getDecoder().decode(ciphertextBase64);
|
||||
|
||||
DecryptRequest request = DecryptRequest.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
|
||||
.build();
|
||||
|
||||
DecryptResponse response = kmsClient.decrypt(request);
|
||||
|
||||
return response.plaintext().asString(StandardCharsets.UTF_8);
|
||||
|
||||
} catch (KmsException e) {
|
||||
throw new RuntimeException("Decryption failed", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Secure Data Repository
|
||||
|
||||
```java
|
||||
@Repository
|
||||
public class SecureDataRepository {
|
||||
|
||||
private final KmsEncryptionService encryptionService;
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
public SecureDataRepository(KmsEncryptionService encryptionService,
|
||||
JdbcTemplate jdbcTemplate) {
|
||||
this.encryptionService = encryptionService;
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
}
|
||||
|
||||
public void saveSecureData(String id, String sensitiveData) {
|
||||
String encryptedData = encryptionService.encrypt(sensitiveData);
|
||||
|
||||
jdbcTemplate.update(
|
||||
"INSERT INTO secure_data (id, encrypted_value) VALUES (?, ?)",
|
||||
id, encryptedData);
|
||||
}
|
||||
|
||||
public String getSecureData(String id) {
|
||||
String encryptedData = jdbcTemplate.queryForObject(
|
||||
"SELECT encrypted_value FROM secure_data WHERE id = ?",
|
||||
String.class, id);
|
||||
|
||||
return encryptionService.decrypt(encryptedData);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Envelope Encryption Service
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class EnvelopeEncryptionService {
|
||||
|
||||
private final KmsClient kmsClient;
|
||||
|
||||
@Value("${kms.master-key-id}")
|
||||
private String masterKeyId;
|
||||
|
||||
private final Cache<String, DataKeyPair> keyCache =
|
||||
Caffeine.newBuilder()
|
||||
.expireAfterWrite(1, TimeUnit.HOURS)
|
||||
.maximumSize(100)
|
||||
.build();
|
||||
|
||||
public EnvelopeEncryptionService(KmsClient kmsClient) {
|
||||
this.kmsClient = kmsClient;
|
||||
}
|
||||
|
||||
public EncryptedEnvelope encryptLargeData(byte[] data) {
|
||||
// Check cache for existing key
|
||||
DataKeyPair dataKeyPair = keyCache.getIfPresent(masterKeyId);
|
||||
|
||||
if (dataKeyPair == null) {
|
||||
// Generate new data key
|
||||
GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey(
|
||||
GenerateDataKeyRequest.builder()
|
||||
.keyId(masterKeyId)
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build());
|
||||
|
||||
dataKeyPair = new DataKeyPair(
|
||||
dataKeyResponse.plaintext().asByteArray(),
|
||||
dataKeyResponse.ciphertextBlob().asByteArray());
|
||||
|
||||
// Cache the encrypted key (not plaintext)
|
||||
keyCache.put(masterKeyId, dataKeyPair);
|
||||
}
|
||||
|
||||
try {
|
||||
// Encrypt data with plaintext data key
|
||||
byte[] encryptedData = encryptWithAES(data, dataKeyPair.plaintext());
|
||||
|
||||
// Clear plaintext key from memory immediately after use
|
||||
Arrays.fill(dataKeyPair.plaintext(), (byte) 0);
|
||||
|
||||
return new EncryptedEnvelope(encryptedData, dataKeyPair.encrypted());
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Envelope encryption failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] decryptLargeData(EncryptedEnvelope envelope) {
|
||||
// Get data key from cache or decrypt from KMS
|
||||
DataKeyPair dataKeyPair = keyCache.getIfPresent(masterKeyId);
|
||||
|
||||
if (dataKeyPair == null || !Arrays.equals(dataKeyPair.encrypted(), envelope.encryptedKey())) {
|
||||
// Decrypt data key from KMS
|
||||
DecryptResponse decryptResponse = kmsClient.decrypt(
|
||||
DecryptRequest.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(envelope.encryptedKey()))
|
||||
.build());
|
||||
|
||||
dataKeyPair = new DataKeyPair(
|
||||
decryptResponse.plaintext().asByteArray(),
|
||||
envelope.encryptedKey());
|
||||
|
||||
// Cache for future use
|
||||
keyCache.put(masterKeyId, dataKeyPair);
|
||||
}
|
||||
|
||||
try {
|
||||
// Decrypt data with plaintext data key
|
||||
byte[] decryptedData = decryptWithAES(envelope.encryptedData(), dataKeyPair.plaintext());
|
||||
|
||||
// Clear plaintext key from memory
|
||||
Arrays.fill(dataKeyPair.plaintext(), (byte) 0);
|
||||
|
||||
return decryptedData;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Envelope decryption failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] encryptWithAES(byte[] data, byte[] key) throws Exception {
|
||||
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
|
||||
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
GCMParameterSpec spec = new GCMParameterSpec(128, key, key.length - 16);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, keySpec, spec);
|
||||
return cipher.doFinal(data);
|
||||
}
|
||||
|
||||
private byte[] decryptWithAES(byte[] data, byte[] key) throws Exception {
|
||||
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
|
||||
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
GCMParameterSpec spec = new GCMParameterSpec(128, key, key.length - 16);
|
||||
cipher.init(Cipher.DECRYPT_MODE, keySpec, spec);
|
||||
return cipher.doFinal(data);
|
||||
}
|
||||
|
||||
public record DataKeyPair(byte[] plaintext, byte[] encrypted) {}
|
||||
public record EncryptedEnvelope(byte[] encryptedData, byte[] encryptedKey) {}
|
||||
}
|
||||
```
|
||||
|
||||
## Data Encryption Interceptor
|
||||
|
||||
### SQL Encryption Interceptor
|
||||
|
||||
```java
|
||||
public class KmsDataEncryptInterceptor implements StatementInterceptor {
|
||||
|
||||
private final KmsEncryptionService encryptionService;
|
||||
|
||||
public KmsDataEncryptInterceptor(KmsEncryptionService encryptionService) {
|
||||
this.encryptionService = encryptionService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ResultSet intercept(ResultSet rs, Statement statement, Connection connection) throws SQLException {
|
||||
return new EncryptingResultSetWrapper(rs, encryptionService);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void interceptAfterExecution(Statement statement) {
|
||||
// No-op
|
||||
}
|
||||
}
|
||||
|
||||
class EncryptingResultSetWrapper implements ResultSet {
|
||||
|
||||
private final ResultSet delegate;
|
||||
private final KmsEncryptionService encryptionService;
|
||||
|
||||
public EncryptingResultSetWrapper(ResultSet delegate, KmsEncryptionService encryptionService) {
|
||||
this.delegate = delegate;
|
||||
this.encryptionService = encryptionService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getString(String columnLabel) throws SQLException {
|
||||
String value = delegate.getString(columnLabel);
|
||||
if (value == null) return null;
|
||||
|
||||
// Check if this is an encrypted column
|
||||
if (isEncryptedColumn(columnLabel)) {
|
||||
return encryptionService.decrypt(value);
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
private boolean isEncryptedColumn(String columnLabel) {
|
||||
// Implement logic to identify encrypted columns
|
||||
return columnLabel.contains("encrypted") || columnLabel.contains("secure");
|
||||
}
|
||||
|
||||
// Delegate other methods to original ResultSet
|
||||
@Override
|
||||
public boolean next() throws SQLException {
|
||||
return delegate.next();
|
||||
}
|
||||
|
||||
// ... other ResultSet method implementations
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Profiles
|
||||
|
||||
### Development Profile
|
||||
|
||||
```properties
|
||||
# src/main/resources/application-dev.properties
|
||||
aws.kms.region=us-east-1
|
||||
kms.encryption-key-id=alias/dev-encryption-key
|
||||
logging.level.com.yourcompany=DEBUG
|
||||
```
|
||||
|
||||
### Production Profile
|
||||
|
||||
```properties
|
||||
# src/main/resources/application-prod.properties
|
||||
aws.kms.region=${AWS_REGION:us-east-1}
|
||||
kms.encryption-key-id=${KMS_ENCRYPTION_KEY_ID:alias/production-encryption-key}
|
||||
logging.level.com.yourcompany=WARN
|
||||
spring.cloud.aws.credentials.access-key=${AWS_ACCESS_KEY_ID}
|
||||
spring.cloud.aws.credentials.secret-key=${AWS_SECRET_ACCESS_KEY}
|
||||
```
|
||||
|
||||
### Test Configuration
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@Profile("test")
|
||||
public class KmsTestConfiguration {
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public KmsClient testKmsClient() {
|
||||
// Return a mock or test-specific KMS client
|
||||
return mock(KmsClient.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public KmsEncryptionService testKmsEncryptionService() {
|
||||
return new KmsEncryptionService(testKmsClient());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Health Checks and Monitoring
|
||||
|
||||
### KMS Health Indicator
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class KmsHealthIndicator implements HealthIndicator {
|
||||
|
||||
private final KmsClient kmsClient;
|
||||
private final String keyId;
|
||||
|
||||
public KmsHealthIndicator(KmsClient kmsClient,
|
||||
@Value("${kms.encryption-key-id}") String keyId) {
|
||||
this.kmsClient = kmsClient;
|
||||
this.keyId = keyId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Health health() {
|
||||
try {
|
||||
// Test KMS connectivity by describing the key
|
||||
DescribeKeyRequest request = DescribeKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
|
||||
DescribeKeyResponse response = kmsClient.describeKey(request);
|
||||
|
||||
// Check if key is in a healthy state
|
||||
KeyState keyState = response.keyMetadata().keyState();
|
||||
boolean isHealthy = keyState == KeyState.ENABLED;
|
||||
|
||||
if (isHealthy) {
|
||||
return Health.up()
|
||||
.withDetail("keyId", keyId)
|
||||
.withDetail("keyState", keyState)
|
||||
.withDetail("keyArn", response.keyMetadata().arn())
|
||||
.build();
|
||||
} else {
|
||||
return Health.down()
|
||||
.withDetail("keyId", keyId)
|
||||
.withDetail("keyState", keyState)
|
||||
.withDetail("message", "KMS key is not in ENABLED state")
|
||||
.build();
|
||||
}
|
||||
|
||||
} catch (KmsException e) {
|
||||
return Health.down()
|
||||
.withDetail("keyId", keyId)
|
||||
.withDetail("error", e.awsErrorDetails().errorMessage())
|
||||
.withDetail("errorCode", e.awsErrorDetails().errorCode())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class KmsMetricsCollector {
|
||||
|
||||
private final MeterRegistry meterRegistry;
|
||||
private final KmsClient kmsClient;
|
||||
|
||||
private final Counter encryptionCounter;
|
||||
private final Counter decryptionCounter;
|
||||
private final Timer encryptionTimer;
|
||||
private final Timer decryptionTimer;
|
||||
|
||||
public KmsMetricsCollector(MeterRegistry meterRegistry, KmsClient kmsClient) {
|
||||
this.meterRegistry = meterRegistry;
|
||||
this.kmsClient = kmsClient;
|
||||
|
||||
this.encryptionCounter = Counter.builder("kms.encryption.count")
|
||||
.description("Number of encryption operations")
|
||||
.register(meterRegistry);
|
||||
|
||||
this.decryptionCounter = Counter.builder("kms.decryption.count")
|
||||
.description("Number of decryption operations")
|
||||
.register(meterRegistry);
|
||||
|
||||
this.encryptionTimer = Timer.builder("kms.encryption.time")
|
||||
.description("Time taken for encryption operations")
|
||||
.register(meterRegistry);
|
||||
|
||||
this.decryptionTimer = Timer.builder("kms.decryption.time")
|
||||
.description("Time taken for decryption operations")
|
||||
.register(meterRegistry);
|
||||
}
|
||||
|
||||
public String encryptWithMetrics(String plaintext) {
|
||||
encryptionCounter.increment();
|
||||
|
||||
return encryptionTimer.record(() -> {
|
||||
try {
|
||||
EncryptRequest request = EncryptRequest.builder()
|
||||
.keyId("your-key-id")
|
||||
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
|
||||
.build();
|
||||
|
||||
EncryptResponse response = kmsClient.encrypt(request);
|
||||
return Base64.getEncoder().encodeToString(
|
||||
response.ciphertextBlob().asByteArray());
|
||||
|
||||
} catch (KmsException e) {
|
||||
meterRegistry.counter("kms.encryption.errors")
|
||||
.increment();
|
||||
throw e;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,639 @@
|
||||
# AWS KMS Technical Guide
|
||||
|
||||
## Key Management Operations
|
||||
|
||||
### Create KMS Key
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.kms.model.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public String createKey(KmsClient kmsClient, String description) {
|
||||
try {
|
||||
CreateKeyRequest request = CreateKeyRequest.builder()
|
||||
.description(description)
|
||||
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
|
||||
.origin(OriginType.AWS_KMS)
|
||||
.build();
|
||||
|
||||
CreateKeyResponse response = kmsClient.createKey(request);
|
||||
|
||||
String keyId = response.keyMetadata().keyId();
|
||||
System.out.println("Created key: " + keyId);
|
||||
|
||||
return keyId;
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error creating key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create Key with Custom Key Store
|
||||
|
||||
```java
|
||||
public String createKeyWithCustomStore(KmsClient kmsClient,
|
||||
String description,
|
||||
String customKeyStoreId) {
|
||||
CreateKeyRequest request = CreateKeyRequest.builder()
|
||||
.description(description)
|
||||
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
|
||||
.origin(OriginType.AWS_CLOUDHSM)
|
||||
.customKeyStoreId(customKeyStoreId)
|
||||
.build();
|
||||
|
||||
CreateKeyResponse response = kmsClient.createKey(request);
|
||||
|
||||
return response.keyMetadata().keyId();
|
||||
}
|
||||
```
|
||||
|
||||
### List Keys
|
||||
|
||||
```java
|
||||
import java.util.List;
|
||||
|
||||
public List<KeyListEntry> listKeys(KmsClient kmsClient) {
|
||||
try {
|
||||
ListKeysRequest request = ListKeysRequest.builder()
|
||||
.limit(100)
|
||||
.build();
|
||||
|
||||
ListKeysResponse response = kmsClient.listKeys(request);
|
||||
|
||||
response.keys().forEach(key -> {
|
||||
System.out.println("Key ARN: " + key.keyArn());
|
||||
System.out.println("Key ID: " + key.keyId());
|
||||
System.out.println();
|
||||
});
|
||||
|
||||
return response.keys();
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error listing keys: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### List Keys with Pagination (Async)
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.kms.paginators.ListKeysPublisher;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public CompletableFuture<Void> listAllKeysAsync(KmsAsyncClient kmsAsyncClient) {
|
||||
ListKeysRequest request = ListKeysRequest.builder()
|
||||
.limit(15)
|
||||
.build();
|
||||
|
||||
ListKeysPublisher keysPublisher = kmsAsyncClient.listKeysPaginator(request);
|
||||
|
||||
return keysPublisher
|
||||
.subscribe(r -> r.keys().forEach(key ->
|
||||
System.out.println("Key ARN: " + key.keyArn())))
|
||||
.whenComplete((result, exception) -> {
|
||||
if (exception != null) {
|
||||
System.err.println("Error: " + exception.getMessage());
|
||||
} else {
|
||||
System.out.println("Successfully listed all keys");
|
||||
}
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### Describe Key
|
||||
|
||||
```java
|
||||
public KeyMetadata describeKey(KmsClient kmsClient, String keyId) {
|
||||
try {
|
||||
DescribeKeyRequest request = DescribeKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
|
||||
DescribeKeyResponse response = kmsClient.describeKey(request);
|
||||
KeyMetadata metadata = response.keyMetadata();
|
||||
|
||||
System.out.println("Key ID: " + metadata.keyId());
|
||||
System.out.println("Key ARN: " + metadata.arn());
|
||||
System.out.println("Key State: " + metadata.keyState());
|
||||
System.out.println("Creation Date: " + metadata.creationDate());
|
||||
System.out.println("Enabled: " + metadata.enabled());
|
||||
|
||||
return metadata;
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error describing key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Enable/Disable Key
|
||||
|
||||
```java
|
||||
public void enableKey(KmsClient kmsClient, String keyId) {
|
||||
try {
|
||||
EnableKeyRequest request = EnableKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
|
||||
kmsClient.enableKey(request);
|
||||
System.out.println("Key enabled: " + keyId);
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error enabling key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public void disableKey(KmsClient kmsClient, String keyId) {
|
||||
try {
|
||||
DisableKeyRequest request = DisableKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
|
||||
kmsClient.disableKey(request);
|
||||
System.out.println("Key disabled: " + keyId);
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error disabling key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Encryption and Decryption
|
||||
|
||||
### Encrypt Data
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
public byte[] encryptData(KmsClient kmsClient, String keyId, String plaintext) {
|
||||
try {
|
||||
SdkBytes plaintextBytes = SdkBytes.fromString(plaintext, StandardCharsets.UTF_8);
|
||||
|
||||
EncryptRequest request = EncryptRequest.builder()
|
||||
.keyId(keyId)
|
||||
.plaintext(plaintextBytes)
|
||||
.build();
|
||||
|
||||
EncryptResponse response = kmsClient.encrypt(request);
|
||||
|
||||
byte[] encryptedData = response.ciphertextBlob().asByteArray();
|
||||
System.out.println("Data encrypted successfully");
|
||||
|
||||
return encryptedData;
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error encrypting data: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Decrypt Data
|
||||
|
||||
```java
|
||||
public String decryptData(KmsClient kmsClient, byte[] ciphertext) {
|
||||
try {
|
||||
SdkBytes ciphertextBytes = SdkBytes.fromByteArray(ciphertext);
|
||||
|
||||
DecryptRequest request = DecryptRequest.builder()
|
||||
.ciphertextBlob(ciphertextBytes)
|
||||
.build();
|
||||
|
||||
DecryptResponse response = kmsClient.decrypt(request);
|
||||
|
||||
String decryptedText = response.plaintext().asString(StandardCharsets.UTF_8);
|
||||
System.out.println("Data decrypted successfully");
|
||||
|
||||
return decryptedText;
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error decrypting data: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Encrypt with Encryption Context
|
||||
|
||||
```java
|
||||
import java.util.Map;
|
||||
|
||||
public byte[] encryptWithContext(KmsClient kmsClient,
|
||||
String keyId,
|
||||
String plaintext,
|
||||
Map<String, String> encryptionContext) {
|
||||
try {
|
||||
EncryptRequest request = EncryptRequest.builder()
|
||||
.keyId(keyId)
|
||||
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
|
||||
.encryptionContext(encryptionContext)
|
||||
.build();
|
||||
|
||||
EncryptResponse response = kmsClient.encrypt(request);
|
||||
|
||||
return response.ciphertextBlob().asByteArray();
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error encrypting with context: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Data Key Generation (Envelope Encryption)
|
||||
|
||||
### Generate Data Key
|
||||
|
||||
```java
|
||||
public record DataKeyPair(byte[] plaintext, byte[] encrypted) {}
|
||||
|
||||
public DataKeyPair generateDataKey(KmsClient kmsClient, String keyId) {
|
||||
try {
|
||||
GenerateDataKeyRequest request = GenerateDataKeyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build();
|
||||
|
||||
GenerateDataKeyResponse response = kmsClient.generateDataKey(request);
|
||||
|
||||
byte[] plaintextKey = response.plaintext().asByteArray();
|
||||
byte[] encryptedKey = response.ciphertextBlob().asByteArray();
|
||||
|
||||
System.out.println("Data key generated");
|
||||
|
||||
return new DataKeyPair(plaintextKey, encryptedKey);
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error generating data key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Generate Data Key Without Plaintext
|
||||
|
||||
```java
|
||||
public byte[] generateDataKeyWithoutPlaintext(KmsClient kmsClient, String keyId) {
|
||||
try {
|
||||
GenerateDataKeyWithoutPlaintextRequest request =
|
||||
GenerateDataKeyWithoutPlaintextRequest.builder()
|
||||
.keyId(keyId)
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build();
|
||||
|
||||
GenerateDataKeyWithoutPlaintextResponse response =
|
||||
kmsClient.generateDataKeyWithoutPlaintext(request);
|
||||
|
||||
return response.ciphertextBlob().asByteArray();
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error generating data key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Digital Signing
|
||||
|
||||
### Create Signing Key
|
||||
|
||||
```java
|
||||
public String createSigningKey(KmsClient kmsClient, String description) {
|
||||
try {
|
||||
CreateKeyRequest request = CreateKeyRequest.builder()
|
||||
.description(description)
|
||||
.keySpec(KeySpec.RSA_2048)
|
||||
.keyUsage(KeyUsageType.SIGN_VERIFY)
|
||||
.origin(OriginType.AWS_KMS)
|
||||
.build();
|
||||
|
||||
CreateKeyResponse response = kmsClient.createKey(request);
|
||||
|
||||
return response.keyMetadata().keyId();
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error creating signing key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Sign Data
|
||||
|
||||
```java
|
||||
public byte[] signData(KmsClient kmsClient, String keyId, String message) {
|
||||
try {
|
||||
SdkBytes messageBytes = SdkBytes.fromString(message, StandardCharsets.UTF_8);
|
||||
|
||||
SignRequest request = SignRequest.builder()
|
||||
.keyId(keyId)
|
||||
.message(messageBytes)
|
||||
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
|
||||
.build();
|
||||
|
||||
SignResponse response = kmsClient.sign(request);
|
||||
|
||||
byte[] signature = response.signature().asByteArray();
|
||||
System.out.println("Data signed successfully");
|
||||
|
||||
return signature;
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error signing data: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Verify Signature
|
||||
|
||||
```java
|
||||
public boolean verifySignature(KmsClient kmsClient,
|
||||
String keyId,
|
||||
String message,
|
||||
byte[] signature) {
|
||||
try {
|
||||
VerifyRequest request = VerifyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.message(SdkBytes.fromString(message, StandardCharsets.UTF_8))
|
||||
.signature(SdkBytes.fromByteArray(signature))
|
||||
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
|
||||
.build();
|
||||
|
||||
VerifyResponse response = kmsClient.verify(request);
|
||||
|
||||
boolean isValid = response.signatureValid();
|
||||
System.out.println("Signature valid: " + isValid);
|
||||
|
||||
return isValid;
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error verifying signature: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Sign and Verify (Async)
|
||||
|
||||
```java
|
||||
public CompletableFuture<Boolean> signAndVerifyAsync(KmsAsyncClient kmsAsyncClient,
|
||||
String message) {
|
||||
String signMessage = message;
|
||||
|
||||
// Create signing key
|
||||
CreateKeyRequest createKeyRequest = CreateKeyRequest.builder()
|
||||
.keySpec(KeySpec.RSA_2048)
|
||||
.keyUsage(KeyUsageType.SIGN_VERIFY)
|
||||
.origin(OriginType.AWS_KMS)
|
||||
.build();
|
||||
|
||||
return kmsAsyncClient.createKey(createKeyRequest)
|
||||
.thenCompose(createKeyResponse -> {
|
||||
String keyId = createKeyResponse.keyMetadata().keyId();
|
||||
|
||||
SdkBytes messageBytes = SdkBytes.fromString(signMessage, StandardCharsets.UTF_8);
|
||||
SignRequest signRequest = SignRequest.builder()
|
||||
.keyId(keyId)
|
||||
.message(messageBytes)
|
||||
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
|
||||
.build();
|
||||
|
||||
return kmsAsyncClient.sign(signRequest)
|
||||
.thenCompose(signResponse -> {
|
||||
byte[] signedBytes = signResponse.signature().asByteArray();
|
||||
|
||||
VerifyRequest verifyRequest = VerifyRequest.builder()
|
||||
.keyId(keyId)
|
||||
.message(messageBytes)
|
||||
.signature(SdkBytes.fromByteArray(signedBytes))
|
||||
.signingAlgorithm(SigningAlgorithmSpec.RSASSA_PSS_SHA_256)
|
||||
.build();
|
||||
|
||||
return kmsAsyncClient.verify(verifyRequest)
|
||||
.thenApply(VerifyResponse::signatureValid);
|
||||
});
|
||||
})
|
||||
.exceptionally(throwable -> {
|
||||
throw new RuntimeException("Failed to sign or verify", throwable);
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
## Key Tagging
|
||||
|
||||
### Tag Key
|
||||
|
||||
```java
|
||||
public void tagKey(KmsClient kmsClient, String keyId, Map<String, String> tags) {
|
||||
try {
|
||||
List<Tag> tagList = tags.entrySet().stream()
|
||||
.map(entry -> Tag.builder()
|
||||
.tagKey(entry.getKey())
|
||||
.tagValue(entry.getValue())
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
TagResourceRequest request = TagResourceRequest.builder()
|
||||
.keyId(keyId)
|
||||
.tags(tagList)
|
||||
.build();
|
||||
|
||||
kmsClient.tagResource(request);
|
||||
System.out.println("Key tagged successfully");
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error tagging key: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### List Tags
|
||||
|
||||
```java
|
||||
public Map<String, String> listTags(KmsClient kmsClient, String keyId) {
|
||||
try {
|
||||
ListResourceTagsRequest request = ListResourceTagsRequest.builder()
|
||||
.keyId(keyId)
|
||||
.build();
|
||||
|
||||
ListResourceTagsResponse response = kmsClient.listResourceTags(request);
|
||||
|
||||
return response.tags().stream()
|
||||
.collect(Collectors.toMap(Tag::tagKey, Tag::tagValue));
|
||||
|
||||
} catch (KmsException e) {
|
||||
System.err.println("Error listing tags: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Envelope Encryption Service
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class EnvelopeEncryptionService {
|
||||
|
||||
private final KmsClient kmsClient;
|
||||
|
||||
@Value("${kms.master-key-id}")
|
||||
private String masterKeyId;
|
||||
|
||||
public EnvelopeEncryptionService(KmsClient kmsClient) {
|
||||
this.kmsClient = kmsClient;
|
||||
}
|
||||
|
||||
public EncryptedEnvelope encryptLargeData(byte[] data) {
|
||||
// Generate data key
|
||||
GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey(
|
||||
GenerateDataKeyRequest.builder()
|
||||
.keyId(masterKeyId)
|
||||
.keySpec(DataKeySpec.AES_256)
|
||||
.build());
|
||||
|
||||
byte[] plaintextKey = dataKeyResponse.plaintext().asByteArray();
|
||||
byte[] encryptedKey = dataKeyResponse.ciphertextBlob().asByteArray();
|
||||
|
||||
try {
|
||||
// Encrypt data with plaintext data key
|
||||
byte[] encryptedData = encryptWithAES(data, plaintextKey);
|
||||
|
||||
// Clear plaintext key from memory
|
||||
Arrays.fill(plaintextKey, (byte) 0);
|
||||
|
||||
return new EncryptedEnvelope(encryptedData, encryptedKey);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Envelope encryption failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] decryptLargeData(EncryptedEnvelope envelope) {
|
||||
// Decrypt data key
|
||||
DecryptResponse decryptResponse = kmsClient.decrypt(
|
||||
DecryptRequest.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(envelope.encryptedKey()))
|
||||
.build());
|
||||
|
||||
byte[] plaintextKey = decryptResponse.plaintext().asByteArray();
|
||||
|
||||
try {
|
||||
// Decrypt data with plaintext data key
|
||||
byte[] decryptedData = decryptWithAES(envelope.encryptedData(), plaintextKey);
|
||||
|
||||
// Clear plaintext key from memory
|
||||
Arrays.fill(plaintextKey, (byte) 0);
|
||||
|
||||
return decryptedData;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Envelope decryption failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] encryptWithAES(byte[] data, byte[] key) throws Exception {
|
||||
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
|
||||
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
cipher.init(Cipher.ENCRYPT_MODE, keySpec);
|
||||
return cipher.doFinal(data);
|
||||
}
|
||||
|
||||
private byte[] decryptWithAES(byte[] data, byte[] key) throws Exception {
|
||||
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
|
||||
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
cipher.init(Cipher.DECRYPT_MODE, keySpec);
|
||||
return cipher.doFinal(data);
|
||||
}
|
||||
|
||||
public record EncryptedEnvelope(byte[] encryptedData, byte[] encryptedKey) {}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Handling Strategies
|
||||
|
||||
```java
|
||||
public class KmsErrorHandler {
|
||||
|
||||
private static final int MAX_RETRIES = 3;
|
||||
private static final long RETRY_DELAY_MS = 1000;
|
||||
|
||||
public <T> T executeWithRetry(Supplier<T> operation, String operationName) {
|
||||
int attempt = 0;
|
||||
KmsException lastException = null;
|
||||
|
||||
while (attempt < MAX_RETRIES) {
|
||||
try {
|
||||
return operation.get();
|
||||
} catch (KmsException e) {
|
||||
lastException = e;
|
||||
attempt++;
|
||||
|
||||
// Check if it's a throttling error and retryable
|
||||
if (e.awsErrorDetails().errorCode().equals("ThrottlingException") && attempt < MAX_RETRIES) {
|
||||
try {
|
||||
Thread.sleep(RETRY_DELAY_MS);
|
||||
} catch (InterruptedException ie) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Retry interrupted", ie);
|
||||
}
|
||||
} else {
|
||||
// Non-retryable error or max retries exceeded
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new RuntimeException(String.format("Failed to execute %s after %d attempts", operationName, MAX_RETRIES), lastException);
|
||||
}
|
||||
|
||||
public boolean isRetryableError(KmsException e) {
|
||||
String errorCode = e.awsErrorDetails().errorCode();
|
||||
return "ThrottlingException".equals(errorCode)
|
||||
|| "TooManyRequestsException".equals(errorCode)
|
||||
|| "LimitExceededException".equals(errorCode);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Connection Pooling Configuration
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.http.apache.ApacheHttpClient;
|
||||
import org.apache.http.impl.client.CloseableHttpClient;
|
||||
import org.apache.http.impl.client.HttpClients;
|
||||
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
|
||||
|
||||
public class KmsConnectionPool {
|
||||
|
||||
public static KmsClient createPooledClient() {
|
||||
// Configure connection pool
|
||||
PoolingHttpClientConnectionManager connectionManager =
|
||||
new PoolingHttpClientConnectionManager();
|
||||
connectionManager.setMaxTotal(100);
|
||||
connectionManager.setDefaultMaxPerRoute(20);
|
||||
|
||||
CloseableHttpClient httpClient = HttpClients.custom()
|
||||
.setConnectionManager(connectionManager)
|
||||
.build();
|
||||
|
||||
ApacheHttpClient.Builder httpClientBuilder = ApacheHttpClient.builder()
|
||||
.httpClient(httpClient);
|
||||
|
||||
return KmsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.httpClientBuilder(httpClientBuilder)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
589
skills/aws-java/aws-sdk-java-v2-kms/references/testing.md
Normal file
589
skills/aws-java/aws-sdk-java-v2-kms/references/testing.md
Normal file
@@ -0,0 +1,589 @@
|
||||
# Testing AWS KMS Integration
|
||||
|
||||
## Unit Testing with Mocked Client
|
||||
|
||||
### Basic Unit Test
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.kms.KmsClient;
|
||||
import software.amazon.awssdk.services.kms.model.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class KmsEncryptionServiceTest {
|
||||
|
||||
@Mock
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@InjectMocks
|
||||
private KmsEncryptionService encryptionService;
|
||||
|
||||
@Test
|
||||
void shouldEncryptData() {
|
||||
// Arrange
|
||||
String plaintext = "sensitive data";
|
||||
byte[] ciphertext = "encrypted".getBytes();
|
||||
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenReturn(EncryptResponse.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
|
||||
.build());
|
||||
|
||||
// Act
|
||||
String result = encryptionService.encrypt(plaintext);
|
||||
|
||||
// Assert
|
||||
assertThat(result).isNotEmpty();
|
||||
verify(kmsClient).encrypt(any(EncryptRequest.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDecryptData() {
|
||||
// Arrange
|
||||
String encryptedText = "ciphertext";
|
||||
String expectedPlaintext = "sensitive data";
|
||||
|
||||
when(kmsClient.decrypt(any(DecryptRequest.class)))
|
||||
.thenReturn(DecryptResponse.builder()
|
||||
.plaintext(SdkBytes.fromString(expectedPlaintext, StandardCharsets.UTF_8))
|
||||
.build());
|
||||
|
||||
// Act
|
||||
String result = encryptionService.decrypt(encryptedText);
|
||||
|
||||
// Assert
|
||||
assertThat(result).isEqualTo(expectedPlaintext);
|
||||
verify(kmsClient).decrypt(any(DecryptRequest.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldThrowExceptionOnEncryptionFailure() {
|
||||
// Arrange
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenThrow(KmsException.builder()
|
||||
.awsErrorDetails(AwsErrorDetails.builder()
|
||||
.errorCode("KMSDisabledException")
|
||||
.errorMessage("KMS is disabled")
|
||||
.build())
|
||||
.build());
|
||||
|
||||
// Act & Assert
|
||||
assertThatThrownBy(() -> encryptionService.encrypt("test"))
|
||||
.isInstanceOf(RuntimeException.class)
|
||||
.hasMessageContaining("Encryption failed");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Parameterized Tests
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
class KmsEncryptionParameterizedTest {
|
||||
|
||||
@Mock
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@InjectMocks
|
||||
private KmsEncryptionService encryptionService;
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"hello, world",
|
||||
"12345, 67890",
|
||||
"special@chars, normal",
|
||||
"very long string with multiple words, another string",
|
||||
"", // empty string
|
||||
"null test, null test"
|
||||
})
|
||||
void shouldEncryptAndDecrypt(String plaintext, String testIdentifier) {
|
||||
// Arrange
|
||||
byte[] ciphertext = "encrypted".getBytes();
|
||||
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenReturn(EncryptResponse.builder()
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(ciphertext))
|
||||
.build());
|
||||
|
||||
when(kmsClient.decrypt(any(DecryptRequest.class)))
|
||||
.thenReturn(DecryptResponse.builder()
|
||||
.plaintext(SdkBytes.fromString(plaintext, StandardCharsets.UTF_8))
|
||||
.build());
|
||||
|
||||
// Act
|
||||
String encrypted = encryptionService.encrypt(plaintext);
|
||||
String decrypted = encryptionService.decrypt(encrypted);
|
||||
|
||||
// Assert
|
||||
assertThat(decrypted).isEqualTo(plaintext);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Testing with Testcontainers
|
||||
|
||||
### Local KMS Mock Setup
|
||||
|
||||
```java
|
||||
import org.testcontainers.containers.localstack.LocalStackContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.services.kms.KmsClient;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KMS;
|
||||
|
||||
@Testcontainers
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
class KmsIntegrationTest {
|
||||
|
||||
@Container
|
||||
private static final LocalStackContainer localStack =
|
||||
new LocalStackContainer(DockerImageName.parse("localstack/localstack:latest"))
|
||||
.withServices(KMS);
|
||||
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@BeforeAll
|
||||
void setup() {
|
||||
kmsClient = KmsClient.builder()
|
||||
.region(Region.of(localStack.getRegion()))
|
||||
.endpointOverride(localStack.getEndpointOverride(KMS))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(localStack.getAccessKey(), localStack.getSecretKey())))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCreateAndManageKeysWithLocalKms() {
|
||||
// Create a key
|
||||
String keyId = createTestKey(kmsClient, "test-key");
|
||||
assertThat(keyId).isNotEmpty();
|
||||
|
||||
// Describe the key
|
||||
KeyMetadata metadata = describeKey(kmsClient, keyId);
|
||||
assertThat(metadata.keyState()).isEqualTo(KeyState.ENABLED);
|
||||
|
||||
// List keys
|
||||
List<KeyListEntry> keys = listKeys(kmsClient);
|
||||
assertThat(keys).hasSizeGreaterThan(0);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing with Spring Boot Test Slices
|
||||
|
||||
### KmsServiceSlice Test
|
||||
|
||||
```java
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.test.web.servlet.MockMvc;
|
||||
|
||||
@SpringBootTest
|
||||
@AutoConfigureMockMvc
|
||||
@ActiveProfiles("test")
|
||||
class KmsControllerIntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private MockMvc mockMvc;
|
||||
|
||||
@MockBean
|
||||
private KmsEncryptionService kmsEncryptionService;
|
||||
|
||||
@Test
|
||||
void shouldEncryptData() throws Exception {
|
||||
String plaintext = "test data";
|
||||
String encrypted = "encrypted-data";
|
||||
|
||||
when(kmsEncryptionService.encrypt(plaintext)).thenReturn(encrypted);
|
||||
|
||||
mockMvc.perform(post("/api/kms/encrypt")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"data\":\"" + plaintext + "\"}"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.data").value(encrypted));
|
||||
|
||||
verify(kmsEncryptionService).encrypt(plaintext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleEncryptionErrors() throws Exception {
|
||||
when(kmsEncryptionService.encrypt(any()))
|
||||
.thenThrow(new RuntimeException("KMS error"));
|
||||
|
||||
mockMvc.perform(post("/api/kms/encrypt")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"data\":\"test\"}"))
|
||||
.andExpect(status().isInternalServerError());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testing with SpringBootTest and Configuration
|
||||
|
||||
```java
|
||||
import org.springframework.boot.test.context.TestConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
|
||||
@TestConfiguration
|
||||
class KmsTestConfiguration {
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public KmsClient testKmsClient() {
|
||||
// Create a mock KMS client for testing
|
||||
KmsClient mockClient = mock(KmsClient.class);
|
||||
|
||||
// Mock key creation
|
||||
when(mockClient.createKey(any(CreateKeyRequest.class)))
|
||||
.thenReturn(CreateKeyResponse.builder()
|
||||
.keyMetadata(KeyMetadata.builder()
|
||||
.keyId("test-key-id")
|
||||
.keyArn("arn:aws:kms:us-east-1:123456789012:key/test-key-id")
|
||||
.keyState(KeyState.ENABLED)
|
||||
.build())
|
||||
.build());
|
||||
|
||||
// Mock encryption
|
||||
when(mockClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenReturn(EncryptResponse.builder()
|
||||
.ciphertextBlob(SdkBytes.fromString("encrypted-data", StandardCharsets.UTF_8))
|
||||
.build());
|
||||
|
||||
// Mock decryption
|
||||
when(mockClient.decrypt(any(DecryptRequest.class)))
|
||||
.thenReturn(DecryptResponse.builder()
|
||||
.plaintext(SdkBytes.fromString("decrypted-data", StandardCharsets.UTF_8))
|
||||
.build());
|
||||
|
||||
return mockClient;
|
||||
}
|
||||
}
|
||||
|
||||
@SpringBootTest(classes = {Application.class, KmsTestConfiguration.class})
|
||||
class KmsServiceWithTestConfigIntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private KmsEncryptionService encryptionService;
|
||||
|
||||
@Test
|
||||
void shouldUseTestConfiguration() {
|
||||
String result = encryptionService.encrypt("test");
|
||||
assertThat(result).isNotEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Envelope Encryption
|
||||
|
||||
### Envelope Encryption Test
|
||||
|
||||
```java
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class EnvelopeEncryptionServiceTest {
|
||||
|
||||
@Mock
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@InjectMocks
|
||||
private EnvelopeEncryptionService envelopeEncryptionService;
|
||||
|
||||
@Test
|
||||
void shouldEncryptAndDecryptLargeData() {
|
||||
// Arrange
|
||||
byte[] testData = "large test data".getBytes();
|
||||
byte[] encryptedDataKey = "encrypted-data-key".getBytes();
|
||||
|
||||
// Mock data key generation
|
||||
when(kmsClient.generateDataKey(any(GenerateDataKeyRequest.class)))
|
||||
.thenReturn(GenerateDataKeyResponse.builder()
|
||||
.plaintext(SdkBytes.fromByteArray("data-key".getBytes()))
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(encryptedDataKey))
|
||||
.build());
|
||||
|
||||
// Mock data key decryption
|
||||
when(kmsClient.decrypt(any(DecryptRequest.class)))
|
||||
.thenReturn(DecryptResponse.builder()
|
||||
.plaintext(SdkBytes.fromByteArray("data-key".getBytes()))
|
||||
.build());
|
||||
|
||||
// Act
|
||||
EncryptedEnvelope encryptedEnvelope = envelopeEncryptionService.encryptLargeData(testData);
|
||||
byte[] decryptedData = envelopeEncryptionService.decryptLargeData(encryptedEnvelope);
|
||||
|
||||
// Assert
|
||||
assertThat(encryptedEnvelope.encryptedData()).isNotEmpty();
|
||||
assertThat(encryptedEnvelope.encryptedKey()).isEqualTo(encryptedDataKey);
|
||||
assertThat(decryptedData).isEqualTo(testData);
|
||||
|
||||
// Verify interactions
|
||||
verify(kmsClient).generateDataKey(any(GenerateDataKeyRequest.class));
|
||||
verify(kmsClient).decrypt(any(DecryptRequest.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldClearSensitiveDataFromMemory() {
|
||||
// Arrange
|
||||
byte[] testData = "test data".getBytes();
|
||||
byte[] encryptedDataKey = "encrypted-key".getBytes();
|
||||
|
||||
when(kmsClient.generateDataKey(any(GenerateDataKeyRequest.class)))
|
||||
.thenReturn(GenerateDataKeyResponse.builder()
|
||||
.plaintext(SdkBytes.fromByteArray("sensitive-data-key".getBytes()))
|
||||
.ciphertextBlob(SdkBytes.fromByteArray(encryptedDataKey))
|
||||
.build());
|
||||
|
||||
when(kmsClient.decrypt(any(DecryptRequest.class)))
|
||||
.thenReturn(DecryptResponse.builder()
|
||||
.plaintext(SdkBytes.fromByteArray("sensitive-data-key".getBytes()))
|
||||
.build());
|
||||
|
||||
// Act
|
||||
envelopeEncryptionService.encryptLargeData(testData);
|
||||
envelopeEncryptionService.decryptLargeData(new EncryptedEnvelope(testData, encryptedDataKey));
|
||||
|
||||
// Note: Memory clearing is difficult to test directly
|
||||
// In real tests, you would verify no sensitive data remains in memory traces
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Digital Signatures
|
||||
|
||||
### Digital Signature Tests
|
||||
|
||||
```java
|
||||
class DigitalSignatureServiceTest {
|
||||
|
||||
@Mock
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@InjectMocks
|
||||
private DigitalSignatureService signatureService;
|
||||
|
||||
@Test
|
||||
void shouldSignAndVerifyData() {
|
||||
// Arrange
|
||||
String message = "test message";
|
||||
byte[] signature = "signature-data".getBytes();
|
||||
|
||||
when(kmsClient.sign(any(SignRequest.class)))
|
||||
.thenReturn(SignResponse.builder()
|
||||
.signature(SdkBytes.fromByteArray(signature))
|
||||
.build());
|
||||
|
||||
when(kmsClient.verify(any(VerifyRequest.class)))
|
||||
.thenReturn(VerifyResponse.builder()
|
||||
.signatureValid(true)
|
||||
.build());
|
||||
|
||||
// Act
|
||||
byte[] signedSignature = signatureService.signData(message);
|
||||
boolean isValid = signatureService.verifySignature(message, signedSignature);
|
||||
|
||||
// Assert
|
||||
assertThat(signedSignature).isEqualTo(signature);
|
||||
assertThat(isValid).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDetectInvalidSignature() {
|
||||
// Arrange
|
||||
String message = "test message";
|
||||
byte[] signature = "invalid-signature".getBytes();
|
||||
|
||||
when(kmsClient.verify(any(VerifyRequest.class)))
|
||||
.thenReturn(VerifyResponse.builder()
|
||||
.signatureValid(false)
|
||||
.build());
|
||||
|
||||
// Act & Assert
|
||||
assertThatThrownBy(() ->
|
||||
signatureService.verifySignature(message, signature))
|
||||
.isInstanceOf(SecurityException.class)
|
||||
.hasMessageContaining("Invalid signature");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
### Performance Test with JMH
|
||||
|
||||
```java
|
||||
import org.openjdk.jmh.annotations.*;
|
||||
import org.openjdk.jmh.infra.Blackhole;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||
@Warmup(iterations = 3, time = 1)
|
||||
@Measurement(iterations = 5, time = 1)
|
||||
@Fork(1)
|
||||
class KmsPerformanceTest {
|
||||
|
||||
@MockBean
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@Autowired
|
||||
private KmsEncryptionService encryptionService;
|
||||
|
||||
@Benchmark
|
||||
public void testEncryptionPerformance(Blackhole bh) {
|
||||
String testData = "performance test data with some content";
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenReturn(EncryptResponse.builder()
|
||||
.ciphertextBlob(SdkBytes.fromString("encrypted", StandardCharsets.UTF_8))
|
||||
.build());
|
||||
|
||||
String result = encryptionService.encrypt(testData);
|
||||
bh.consume(result);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testDecryptionPerformance(Blackhole bh) {
|
||||
String encryptedData = "encrypted-performance-data";
|
||||
when(kmsClient.decrypt(any(DecryptRequest.class)))
|
||||
.thenReturn(DecryptResponse.builder()
|
||||
.plaintext(SdkBytes.fromString("decrypted", StandardCharsets.UTF_8))
|
||||
.build());
|
||||
|
||||
String result = encryptionService.decrypt(encryptedData);
|
||||
bh.consume(result);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Error Scenarios
|
||||
|
||||
### Error Handling Tests
|
||||
|
||||
```java
|
||||
class KmsErrorHandlingTest {
|
||||
|
||||
@Mock
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@InjectMocks
|
||||
private KmsEncryptionService encryptionService;
|
||||
|
||||
@Test
|
||||
void shouldHandleThrottlingException() {
|
||||
// Arrange
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenThrow(KmsException.builder()
|
||||
.awsErrorDetails(AwsErrorDetails.builder()
|
||||
.errorCode("ThrottlingException")
|
||||
.errorMessage("Rate exceeded")
|
||||
.build())
|
||||
.build());
|
||||
|
||||
// Act & Assert
|
||||
assertThatThrownBy(() -> encryptionService.encrypt("test"))
|
||||
.isInstanceOf(RuntimeException.class)
|
||||
.hasMessageContaining("Rate limit exceeded");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleDisabledKey() {
|
||||
// Arrange
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenThrow(KmsException.builder()
|
||||
.awsErrorDetails(AwsErrorDetails.builder()
|
||||
.errorCode("DisabledException")
|
||||
.errorMessage("Key is disabled")
|
||||
.build())
|
||||
.build());
|
||||
|
||||
// Act & Assert
|
||||
assertThatThrownBy(() -> encryptionService.encrypt("test"))
|
||||
.isInstanceOf(RuntimeException.class)
|
||||
.hasMessageContaining("Key is disabled");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNotFoundException() {
|
||||
// Arrange
|
||||
when(kmsClient.encrypt(any(EncryptRequest.class)))
|
||||
.thenThrow(KmsException.builder()
|
||||
.awsErrorDetails(AwsErrorDetails.builder()
|
||||
.errorCode("NotFoundException")
|
||||
.errorMessage("Key not found")
|
||||
.build())
|
||||
.build());
|
||||
|
||||
// Act & Assert
|
||||
assertThatThrownBy(() -> encryptionService.encrypt("test"))
|
||||
.isInstanceOf(RuntimeException.class)
|
||||
.hasMessageContaining("Key not found");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Testing with AWS Local
|
||||
|
||||
### Testcontainers KMS Setup
|
||||
|
||||
```java
|
||||
import org.testcontainers.containers.localstack.LocalStackContainer;
|
||||
import software.amazon.awssdk.services.kms.KmsClient;
|
||||
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KMS;
|
||||
|
||||
@SpringBootTest
|
||||
class KmsAwsLocalIntegrationTest {
|
||||
|
||||
@Container
|
||||
private static final LocalStackContainer localStack =
|
||||
new LocalStackContainer(DockerImageName.parse("localstack/localstack:latest"))
|
||||
.withServices(KMS)
|
||||
.withEnv("DEFAULT_REGION", "us-east-1");
|
||||
|
||||
private KmsClient kmsClient;
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
kmsClient = KmsClient.builder()
|
||||
.region(Region.AWS_GLOBAL)
|
||||
.endpointOverride(localStack.getEndpointOverride(KMS))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(localStack.getAccessKey(), localStack.getSecretKey())))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCreateKeyInLocalKms() {
|
||||
// This test creates a real key in the local KMS instance
|
||||
CreateKeyRequest request = CreateKeyRequest.builder()
|
||||
.description("Test key")
|
||||
.keyUsage(KeyUsageType.ENCRYPT_DECRYPT)
|
||||
.build();
|
||||
|
||||
CreateKeyResponse response = kmsClient.createKey(request);
|
||||
assertThat(response.keyMetadata().keyId()).isNotEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
508
skills/aws-java/aws-sdk-java-v2-lambda/SKILL.md
Normal file
508
skills/aws-java/aws-sdk-java-v2-lambda/SKILL.md
Normal file
@@ -0,0 +1,508 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-lambda
|
||||
description: AWS Lambda patterns using AWS SDK for Java 2.x. Use when invoking Lambda functions, creating/updating functions, managing function configurations, working with Lambda layers, or integrating Lambda with Spring Boot applications.
|
||||
category: aws
|
||||
tags: [aws, lambda, java, sdk, serverless, functions]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - AWS Lambda
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Invoking Lambda functions programmatically
|
||||
- Creating or updating Lambda functions
|
||||
- Managing Lambda function configurations
|
||||
- Working with Lambda environment variables
|
||||
- Managing Lambda layers and aliases
|
||||
- Implementing asynchronous Lambda invocations
|
||||
- Integrating Lambda with Spring Boot
|
||||
|
||||
## Overview
|
||||
|
||||
AWS Lambda is a compute service that runs code without the need to manage servers. Your code runs automatically, scaling up and down with pay-per-use pricing. Use this skill to implement AWS Lambda operations using AWS SDK for Java 2.x in applications and services.
|
||||
|
||||
## Dependencies
|
||||
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>lambda</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
## Client Setup
|
||||
|
||||
To use AWS Lambda, create a LambdaClient with the required region configuration:
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.lambda.LambdaClient;
|
||||
|
||||
LambdaClient lambdaClient = LambdaClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
For asynchronous operations, use LambdaAsyncClient:
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
|
||||
|
||||
LambdaAsyncClient asyncLambdaClient = LambdaAsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Invoke Lambda Function
|
||||
|
||||
### Synchronous Invocation
|
||||
|
||||
Invoke Lambda functions synchronously to get immediate results:
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.lambda.model.*;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
|
||||
public String invokeLambda(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String payload) {
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
return response.payload().asUtf8String();
|
||||
}
|
||||
```
|
||||
|
||||
### Asynchronous Invocation
|
||||
|
||||
Use asynchronous invocation for fire-and-forget scenarios:
|
||||
|
||||
```java
|
||||
public void invokeLambdaAsync(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String payload) {
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.invocationType(InvocationType.EVENT) // Asynchronous
|
||||
.payload(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
System.out.println("Status: " + response.statusCode());
|
||||
}
|
||||
```
|
||||
|
||||
### Invoke with JSON Objects
|
||||
|
||||
Work with JSON payloads for complex data structures:
|
||||
|
||||
```java
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
public <T> String invokeLambdaWithObject(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
T requestObject) throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String jsonPayload = mapper.writeValueAsString(requestObject);
|
||||
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
return response.payload().asUtf8String();
|
||||
}
|
||||
```
|
||||
|
||||
### Parse Typed Responses
|
||||
|
||||
Parse JSON responses into typed objects:
|
||||
|
||||
```java
|
||||
public <T> T invokeLambdaAndParse(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
Object request,
|
||||
Class<T> responseType) throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String jsonPayload = mapper.writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(invokeRequest);
|
||||
|
||||
String responseJson = response.payload().asUtf8String();
|
||||
|
||||
return mapper.readValue(responseJson, responseType);
|
||||
}
|
||||
```
|
||||
|
||||
## Function Management
|
||||
|
||||
### List Functions
|
||||
|
||||
List all Lambda functions for the current account:
|
||||
|
||||
```java
|
||||
public List<FunctionConfiguration> listFunctions(LambdaClient lambdaClient) {
|
||||
ListFunctionsResponse response = lambdaClient.listFunctions();
|
||||
|
||||
return response.functions();
|
||||
}
|
||||
```
|
||||
|
||||
### Get Function Configuration
|
||||
|
||||
Retrieve function configuration and metadata:
|
||||
|
||||
```java
|
||||
public FunctionConfiguration getFunctionConfig(LambdaClient lambdaClient,
|
||||
String functionName) {
|
||||
GetFunctionRequest request = GetFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.build();
|
||||
|
||||
GetFunctionResponse response = lambdaClient.getFunction(request);
|
||||
|
||||
return response.configuration();
|
||||
}
|
||||
```
|
||||
|
||||
### Update Function Code
|
||||
|
||||
Update Lambda function code with new deployment package:
|
||||
|
||||
```java
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
public void updateFunctionCode(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String zipFilePath) throws IOException {
|
||||
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
|
||||
|
||||
UpdateFunctionCodeRequest request = UpdateFunctionCodeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.zipFile(SdkBytes.fromByteArray(zipBytes))
|
||||
.publish(true)
|
||||
.build();
|
||||
|
||||
UpdateFunctionCodeResponse response = lambdaClient.updateFunctionCode(request);
|
||||
|
||||
System.out.println("Updated function version: " + response.version());
|
||||
}
|
||||
```
|
||||
|
||||
### Update Function Configuration
|
||||
|
||||
Modify function settings like timeout, memory, and environment variables:
|
||||
|
||||
```java
|
||||
public void updateFunctionConfiguration(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
Map<String, String> environment) {
|
||||
Environment env = Environment.builder()
|
||||
.variables(environment)
|
||||
.build();
|
||||
|
||||
UpdateFunctionConfigurationRequest request = UpdateFunctionConfigurationRequest.builder()
|
||||
.functionName(functionName)
|
||||
.environment(env)
|
||||
.timeout(60)
|
||||
.memorySize(512)
|
||||
.build();
|
||||
|
||||
lambdaClient.updateFunctionConfiguration(request);
|
||||
}
|
||||
```
|
||||
|
||||
### Create Function
|
||||
|
||||
Create new Lambda functions with code and configuration:
|
||||
|
||||
```java
|
||||
public void createFunction(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String roleArn,
|
||||
String handler,
|
||||
String zipFilePath) throws IOException {
|
||||
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
|
||||
|
||||
FunctionCode code = FunctionCode.builder()
|
||||
.zipFile(SdkBytes.fromByteArray(zipBytes))
|
||||
.build();
|
||||
|
||||
CreateFunctionRequest request = CreateFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.runtime(Runtime.JAVA17)
|
||||
.role(roleArn)
|
||||
.handler(handler)
|
||||
.code(code)
|
||||
.timeout(60)
|
||||
.memorySize(512)
|
||||
.build();
|
||||
|
||||
CreateFunctionResponse response = lambdaClient.createFunction(request);
|
||||
|
||||
System.out.println("Function ARN: " + response.functionArn());
|
||||
}
|
||||
```
|
||||
|
||||
### Delete Function
|
||||
|
||||
Remove Lambda functions when no longer needed:
|
||||
|
||||
```java
|
||||
public void deleteFunction(LambdaClient lambdaClient, String functionName) {
|
||||
DeleteFunctionRequest request = DeleteFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.build();
|
||||
|
||||
lambdaClient.deleteFunction(request);
|
||||
}
|
||||
```
|
||||
|
||||
## Spring Boot Integration
|
||||
|
||||
### Configuration
|
||||
|
||||
Configure Lambda clients as Spring beans:
|
||||
|
||||
```java
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class LambdaConfiguration {
|
||||
|
||||
@Bean
|
||||
public LambdaClient lambdaClient() {
|
||||
return LambdaClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Lambda Invoker Service
|
||||
|
||||
Create a service for Lambda function invocation:
|
||||
|
||||
```java
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
@Service
|
||||
public class LambdaInvokerService {
|
||||
|
||||
private final LambdaClient lambdaClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Autowired
|
||||
public LambdaInvokerService(LambdaClient lambdaClient, ObjectMapper objectMapper) {
|
||||
this.lambdaClient = lambdaClient;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
public <T, R> R invoke(String functionName, T request, Class<R> responseType) {
|
||||
try {
|
||||
String jsonPayload = objectMapper.writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(invokeRequest);
|
||||
|
||||
if (response.functionError() != null) {
|
||||
throw new LambdaInvocationException(
|
||||
"Lambda function error: " + response.functionError());
|
||||
}
|
||||
|
||||
String responseJson = response.payload().asUtf8String();
|
||||
|
||||
return objectMapper.readValue(responseJson, responseType);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to invoke Lambda function", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void invokeAsync(String functionName, Object request) {
|
||||
try {
|
||||
String jsonPayload = objectMapper.writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.invocationType(InvocationType.EVENT)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
lambdaClient.invoke(invokeRequest);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to invoke Lambda function async", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Typed Lambda Client
|
||||
|
||||
Create type-safe interfaces for Lambda services:
|
||||
|
||||
```java
|
||||
public interface OrderProcessor {
|
||||
OrderResponse processOrder(OrderRequest request);
|
||||
}
|
||||
|
||||
@Service
|
||||
public class LambdaOrderProcessor implements OrderProcessor {
|
||||
|
||||
private final LambdaInvokerService lambdaInvoker;
|
||||
|
||||
@Value("${lambda.order-processor.function-name}")
|
||||
private String functionName;
|
||||
|
||||
public LambdaOrderProcessor(LambdaInvokerService lambdaInvoker) {
|
||||
this.lambdaInvoker = lambdaInvoker;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OrderResponse processOrder(OrderRequest request) {
|
||||
return lambdaInvoker.invoke(functionName, request, OrderResponse.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
Implement comprehensive error handling for Lambda operations:
|
||||
|
||||
```java
|
||||
public String invokeLambdaSafe(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String payload) {
|
||||
try {
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
// Check for function error
|
||||
if (response.functionError() != null) {
|
||||
String errorMessage = response.payload().asUtf8String();
|
||||
throw new RuntimeException("Lambda error: " + errorMessage);
|
||||
}
|
||||
|
||||
// Check status code
|
||||
if (response.statusCode() != 200) {
|
||||
throw new RuntimeException("Lambda invocation failed with status: " +
|
||||
response.statusCode());
|
||||
}
|
||||
|
||||
return response.payload().asUtf8String();
|
||||
|
||||
} catch (LambdaException e) {
|
||||
System.err.println("Lambda error: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public class LambdaInvocationException extends RuntimeException {
|
||||
public LambdaInvocationException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public LambdaInvocationException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
For comprehensive code examples, see the references section:
|
||||
|
||||
- **Basic examples** - Simple invocation patterns and function management
|
||||
- **Spring Boot integration** - Complete Spring Boot configuration and service patterns
|
||||
- **Testing examples** - Unit and integration test patterns
|
||||
- **Advanced patterns** - Complex scenarios and best practices
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Reuse Lambda clients**: Create once and reuse across invocations
|
||||
2. **Set appropriate timeouts**: Match client timeout to Lambda function timeout
|
||||
3. **Use async invocation**: For fire-and-forget scenarios
|
||||
4. **Handle errors properly**: Check for function errors and status codes
|
||||
5. **Use environment variables**: For function configuration
|
||||
6. **Implement retry logic**: For transient failures
|
||||
7. **Monitor invocations**: Use CloudWatch metrics
|
||||
8. **Version functions**: Use aliases and versions for production
|
||||
9. **Use VPC**: For accessing resources in private subnets
|
||||
10. **Optimize payload size**: Keep payloads small for better performance
|
||||
|
||||
## Testing
|
||||
|
||||
Test Lambda services using mocks and test assertions:
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class LambdaInvokerServiceTest {
|
||||
|
||||
@Mock
|
||||
private LambdaClient lambdaClient;
|
||||
|
||||
@Mock
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
@InjectMocks
|
||||
private LambdaInvokerService service;
|
||||
|
||||
@Test
|
||||
void shouldInvokeLambdaSuccessfully() throws Exception {
|
||||
// Test implementation
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Related Skills
|
||||
|
||||
- @aws-sdk-java-v2-core - Core AWS SDK patterns and client configuration
|
||||
- @spring-boot-dependency-injection - Spring dependency injection best practices
|
||||
- @unit-test-service-layer - Service testing patterns with Mockito
|
||||
- @spring-boot-actuator - Production monitoring and health checks
|
||||
|
||||
## References
|
||||
|
||||
For detailed information and examples, see the following reference files:
|
||||
|
||||
- **[Official Documentation](references/official-documentation.md)** - AWS Lambda concepts, API reference, and official guidance
|
||||
- **[Examples](references/examples.md)** - Complete code examples and integration patterns
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Lambda Examples on GitHub](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/lambda)
|
||||
- [Lambda API Reference](https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/lambda/package-summary.html)
|
||||
- [AWS Lambda Developer Guide](https://docs.aws.amazon.com/lambda/latest/dg/welcome.html)
|
||||
544
skills/aws-java/aws-sdk-java-v2-lambda/references/examples.md
Normal file
544
skills/aws-java/aws-sdk-java-v2-lambda/references/examples.md
Normal file
@@ -0,0 +1,544 @@
|
||||
# AWS Lambda Java SDK Examples
|
||||
|
||||
## Client Setup
|
||||
|
||||
### Basic Client Configuration
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.lambda.LambdaClient;
|
||||
|
||||
// Create synchronous client
|
||||
LambdaClient lambdaClient = LambdaClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Create asynchronous client
|
||||
LambdaAsyncClient asyncLambdaClient = LambdaAsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Client with Configuration
|
||||
```java
|
||||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
|
||||
import software.amazon.awssdk.http.nio.netty.NettyNioHttpServer;
|
||||
|
||||
LambdaClient lambdaClient = LambdaClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.credentialsProvider(DefaultCredentialsProvider.create())
|
||||
.httpClientBuilder(NettyNioHttpServer.builder())
|
||||
.build();
|
||||
```
|
||||
|
||||
## Function Invocation Examples
|
||||
|
||||
### Synchronous Invocation with String Payload
|
||||
```java
|
||||
import software.amazon.awssdk.services.lambda.model.*;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
|
||||
public String invokeLambdaSync(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String payload) {
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
// Check for function errors
|
||||
if (response.functionError() != null) {
|
||||
throw new RuntimeException("Lambda function error: " +
|
||||
response.payload().asUtf8String());
|
||||
}
|
||||
|
||||
return response.payload().asUtf8String();
|
||||
}
|
||||
```
|
||||
|
||||
### Asynchronous Invocation
|
||||
```java
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public CompletableFuture<String> invokeLambdaAsync(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String payload) {
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.invocationType(InvocationType.EVENT) // Asynchronous
|
||||
.payload(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
return lambdaClient.invoke(request)
|
||||
.thenApply(response -> response.payload().asUtf8String());
|
||||
}
|
||||
```
|
||||
|
||||
### Invocation with JSON Object
|
||||
```java
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
public <T> String invokeLambdaWithObject(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
T requestObject) throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String jsonPayload = mapper.writeValueAsString(requestObject);
|
||||
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
return response.payload().asUtf8String();
|
||||
}
|
||||
```
|
||||
|
||||
### Parse Typed Response
|
||||
```java
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
public <T> T invokeLambdaAndParse(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
Object request,
|
||||
Class<T> responseType) throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String jsonPayload = mapper.writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(invokeRequest);
|
||||
String responseJson = response.payload().asUtf8String();
|
||||
|
||||
return mapper.readValue(responseJson, responseType);
|
||||
}
|
||||
```
|
||||
|
||||
## Function Management Examples
|
||||
|
||||
### List Functions
|
||||
```java
|
||||
public List<FunctionConfiguration> listLambdaFunctions(LambdaClient lambdaClient) {
|
||||
ListFunctionsResponse response = lambdaClient.listFunctions();
|
||||
return response.functions();
|
||||
}
|
||||
|
||||
// List functions with pagination
|
||||
public List<FunctionConfiguration> listAllFunctions(LambdaClient lambdaClient) {
|
||||
ListFunctionsRequest request = ListFunctionsRequest.builder().build();
|
||||
ListFunctionsResponse response = lambdaClient.listFunctions(request);
|
||||
|
||||
return response.functions();
|
||||
}
|
||||
```
|
||||
|
||||
### Get Function Configuration
|
||||
```java
|
||||
public FunctionConfiguration getFunctionConfig(LambdaClient lambdaClient,
|
||||
String functionName) {
|
||||
GetFunctionRequest request = GetFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.build();
|
||||
|
||||
GetFunctionResponse response = lambdaClient.getFunction(request);
|
||||
return response.configuration();
|
||||
}
|
||||
```
|
||||
|
||||
### Get Function Code
|
||||
```java
|
||||
public byte[] getFunctionCode(LambdaClient lambdaClient,
|
||||
String functionName) {
|
||||
GetFunctionRequest request = GetFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.build();
|
||||
|
||||
GetFunctionResponse response = lambdaClient.getFunction(request);
|
||||
return response.code().zipFile().asByteArray();
|
||||
}
|
||||
```
|
||||
|
||||
### Update Function Code
|
||||
```java
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
public void updateLambdaFunction(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String zipFilePath) throws IOException {
|
||||
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
|
||||
|
||||
UpdateFunctionCodeRequest request = UpdateFunctionCodeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.zipFile(SdkBytes.fromByteArray(zipBytes))
|
||||
.publish(true) // Create new version
|
||||
.build();
|
||||
|
||||
UpdateFunctionCodeResponse response = lambdaClient.updateFunctionCode(request);
|
||||
System.out.println("Updated function version: " + response.version());
|
||||
}
|
||||
```
|
||||
|
||||
### Update Function Configuration
|
||||
```java
|
||||
public void updateFunctionConfig(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
Map<String, String> environment) {
|
||||
Environment env = Environment.builder()
|
||||
.variables(environment)
|
||||
.build();
|
||||
|
||||
UpdateFunctionConfigurationRequest request = UpdateFunctionConfigurationRequest.builder()
|
||||
.functionName(functionName)
|
||||
.environment(env)
|
||||
.timeout(60)
|
||||
.memorySize(512)
|
||||
.build();
|
||||
|
||||
lambdaClient.updateFunctionConfiguration(request);
|
||||
}
|
||||
```
|
||||
|
||||
### Create Function
|
||||
```java
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
public void createLambdaFunction(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String roleArn,
|
||||
String handler,
|
||||
String zipFilePath) throws IOException {
|
||||
byte[] zipBytes = Files.readAllBytes(Paths.get(zipFilePath));
|
||||
|
||||
FunctionCode code = FunctionCode.builder()
|
||||
.zipFile(SdkBytes.fromByteArray(zipBytes))
|
||||
.build();
|
||||
|
||||
CreateFunctionRequest request = CreateFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.runtime(Runtime.JAVA17)
|
||||
.role(roleArn)
|
||||
.handler(handler)
|
||||
.code(code)
|
||||
.timeout(60)
|
||||
.memorySize(512)
|
||||
.environment(Environment.builder()
|
||||
.variables(Map.of("ENV", "production"))
|
||||
.build())
|
||||
.build();
|
||||
|
||||
CreateFunctionResponse response = lambdaClient.createFunction(request);
|
||||
System.out.println("Function ARN: " + response.functionArn());
|
||||
}
|
||||
```
|
||||
|
||||
### Delete Function
|
||||
```java
|
||||
public void deleteLambdaFunction(LambdaClient lambdaClient, String functionName) {
|
||||
DeleteFunctionRequest request = DeleteFunctionRequest.builder()
|
||||
.functionName(functionName)
|
||||
.build();
|
||||
|
||||
lambdaClient.deleteFunction(request);
|
||||
}
|
||||
```
|
||||
|
||||
## Spring Boot Integration Examples
|
||||
|
||||
### Configuration Class
|
||||
```java
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class LambdaConfiguration {
|
||||
|
||||
@Bean
|
||||
public LambdaClient lambdaClient() {
|
||||
return LambdaClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public LambdaAsyncClient asyncLambdaClient() {
|
||||
return LambdaAsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Lambda Invoker Service
|
||||
```java
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
@Service
|
||||
public class LambdaInvokerService {
|
||||
|
||||
private final LambdaClient lambdaClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Autowired
|
||||
public LambdaInvokerService(LambdaClient lambdaClient,
|
||||
ObjectMapper objectMapper) {
|
||||
this.lambdaClient = lambdaClient;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
public <T, R> R invokeFunction(String functionName,
|
||||
T request,
|
||||
Class<R> responseType) {
|
||||
try {
|
||||
String jsonPayload = objectMapper.writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(invokeRequest);
|
||||
|
||||
if (response.functionError() != null) {
|
||||
throw new LambdaInvocationException(
|
||||
"Lambda function error: " + response.functionError());
|
||||
}
|
||||
|
||||
String responseJson = response.payload().asUtf8String();
|
||||
return objectMapper.readValue(responseJson, responseType);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to invoke Lambda function", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void invokeFunctionAsync(String functionName, Object request) {
|
||||
try {
|
||||
String jsonPayload = objectMapper.writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.invocationType(InvocationType.EVENT)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
lambdaClient.invoke(invokeRequest);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to invoke Lambda function async", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Typed Lambda Client Interface
|
||||
```java
|
||||
public interface OrderProcessor {
|
||||
OrderResponse processOrder(OrderRequest request);
|
||||
CompletableFuture<OrderResponse> processOrderAsync(OrderRequest request);
|
||||
}
|
||||
|
||||
@Service
|
||||
public class LambdaOrderProcessor implements OrderProcessor {
|
||||
|
||||
private final LambdaInvokerService lambdaInvoker;
|
||||
private final LambdaAsyncClient asyncLambdaClient;
|
||||
|
||||
@Value("${lambda.order-processor.function-name}")
|
||||
private String functionName;
|
||||
|
||||
public LambdaOrderProcessor(LambdaInvokerService lambdaInvoker,
|
||||
LambdaAsyncClient asyncLambdaClient) {
|
||||
this.lambdaInvoker = lambdaInvoker;
|
||||
this.asyncLambdaClient = asyncLambdaClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OrderResponse processOrder(OrderRequest request) {
|
||||
return lambdaInvoker.invoke(functionName, request, OrderResponse.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<OrderResponse> processOrderAsync(OrderRequest request) {
|
||||
// Implement async invocation using async client
|
||||
try {
|
||||
String jsonPayload = new ObjectMapper().writeValueAsString(request);
|
||||
|
||||
InvokeRequest invokeRequest = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
return asyncLambdaClient.invoke(invokeRequest)
|
||||
.thenApply(response -> {
|
||||
try {
|
||||
return new ObjectMapper().readValue(
|
||||
response.payload().asUtf8String(),
|
||||
OrderResponse.class);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to parse response", e);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to invoke Lambda function", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling Examples
|
||||
|
||||
### Comprehensive Error Handling
|
||||
```java
|
||||
public String invokeLambdaWithFullErrorHandling(LambdaClient lambdaClient,
|
||||
String functionName,
|
||||
String payload) {
|
||||
try {
|
||||
InvokeRequest request = InvokeRequest.builder()
|
||||
.functionName(functionName)
|
||||
.payload(SdkBytes.fromUtf8String(payload))
|
||||
.build();
|
||||
|
||||
InvokeResponse response = lambdaClient.invoke(request);
|
||||
|
||||
// Check for function error
|
||||
if (response.functionError() != null) {
|
||||
String errorMessage = response.payload().asUtf8String();
|
||||
throw new LambdaInvocationException(
|
||||
"Lambda function error: " + errorMessage);
|
||||
}
|
||||
|
||||
// Check status code
|
||||
if (response.statusCode() != 200) {
|
||||
throw new LambdaInvocationException(
|
||||
"Lambda invocation failed with status: " + response.statusCode());
|
||||
}
|
||||
|
||||
return response.payload().asUtf8String();
|
||||
|
||||
} catch (LambdaException e) {
|
||||
System.err.println("AWS Lambda error: " + e.awsErrorDetails().errorMessage());
|
||||
throw new LambdaInvocationException(
|
||||
"AWS Lambda service error: " + e.awsErrorDetails().errorMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
public class LambdaInvocationException extends RuntimeException {
|
||||
public LambdaInvocationException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public LambdaInvocationException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Examples
|
||||
|
||||
### Unit Test for Lambda Service
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class LambdaInvokerServiceTest {
|
||||
|
||||
@Mock
|
||||
private LambdaClient lambdaClient;
|
||||
|
||||
@Mock
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
@InjectMocks
|
||||
private LambdaInvokerService service;
|
||||
|
||||
@Test
|
||||
void shouldInvokeLambdaSuccessfully() throws Exception {
|
||||
// Given
|
||||
OrderRequest request = new OrderRequest("ORDER-123");
|
||||
OrderResponse expectedResponse = new OrderResponse("SUCCESS");
|
||||
String jsonPayload = "{\"orderId\":\"ORDER-123\"};
|
||||
String jsonResponse = "{\"status\":\"SUCCESS\"};
|
||||
|
||||
when(objectMapper.writeValueAsString(request))
|
||||
.thenReturn(jsonPayload);
|
||||
|
||||
when(lambdaClient.invoke(any(InvokeRequest.class)))
|
||||
.thenReturn(InvokeResponse.builder()
|
||||
.statusCode(200)
|
||||
.payload(SdkBytes.fromUtf8String(jsonResponse))
|
||||
.build());
|
||||
|
||||
when(objectMapper.readValue(jsonResponse, OrderResponse.class))
|
||||
.thenReturn(expectedResponse);
|
||||
|
||||
// When
|
||||
OrderResponse result = service.invoke(
|
||||
"order-processor", request, OrderResponse.class);
|
||||
|
||||
// Then
|
||||
assertThat(result).isEqualTo(expectedResponse);
|
||||
verify(lambdaClient).invoke(any(InvokeRequest.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleFunctionError() throws Exception {
|
||||
// Given
|
||||
OrderRequest request = new OrderRequest("ORDER-123");
|
||||
String jsonPayload = "{\"orderId\":\"ORDER-123\"};
|
||||
String errorResponse = "{\"errorMessage\":\"Invalid input\"};
|
||||
|
||||
when(objectMapper.writeValueAsString(request))
|
||||
.thenReturn(jsonPayload);
|
||||
|
||||
when(lambdaClient.invoke(any(InvokeRequest.class)))
|
||||
.thenReturn(InvokeResponse.builder()
|
||||
.statusCode(200)
|
||||
.functionError("Unhandled")
|
||||
.payload(SdkBytes.fromUtf8String(errorResponse))
|
||||
.build());
|
||||
|
||||
// When & Then
|
||||
assertThatThrownBy(() ->
|
||||
service.invoke("order-processor", request, OrderResponse.class))
|
||||
.isInstanceOf(LambdaInvocationException.class)
|
||||
.hasMessageContaining("Lambda function error");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Maven Dependencies
|
||||
```xml
|
||||
<!-- AWS SDK for Java v2 Lambda -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>lambda</artifactId>
|
||||
<version>2.36.3</version> // Use the latest version available
|
||||
</dependency>
|
||||
|
||||
<!-- Jackson for JSON processing -->
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Spring Boot support -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
@@ -0,0 +1,112 @@
|
||||
# AWS Lambda Official Documentation Reference
|
||||
|
||||
## Overview
|
||||
AWS Lambda is a compute service that runs code without the need to manage servers. Your code runs automatically, scaling up and down with pay-per-use pricing.
|
||||
|
||||
## Common Use Cases
|
||||
- Stream processing: Process real-time data streams for analytics
|
||||
- Web applications: Build scalable web apps that automatically adjust
|
||||
- Mobile backends: Create secure API backends
|
||||
- IoT backends: Handle web, mobile, IoT, and third-party API requests
|
||||
- File processing: Process files automatically when uploaded
|
||||
- Database operations: Respond to database changes and automate data workflows
|
||||
- Scheduled tasks: Run automated operations on a regular schedule
|
||||
|
||||
## How Lambda Works
|
||||
1. You write and organize your code in Lambda functions
|
||||
2. You control security through Lambda permissions using execution roles
|
||||
3. Event sources and AWS services trigger your Lambda functions
|
||||
4. Lambda runs your code with language-specific runtimes
|
||||
|
||||
## Key Features
|
||||
|
||||
### Configuration & Security
|
||||
- Environment variables modify behavior without deployments
|
||||
- Versions safely test new features while maintaining stable production
|
||||
- Lambda layers optimize code reuse across multiple functions
|
||||
- Code signing ensures only approved code reaches production
|
||||
|
||||
### Performance
|
||||
- Concurrency controls manage responsiveness and resource utilization
|
||||
- Lambda SnapStart reduces cold start times to sub-second performance
|
||||
- Response streaming delivers large payloads incrementally
|
||||
- Container images package functions with complex dependencies
|
||||
|
||||
### Integration
|
||||
- VPC networks secure sensitive resources and internal services
|
||||
- File system integration shares persistent data across function invocations
|
||||
- Function URLs create public APIs without additional services
|
||||
- Lambda extensions augment functions with monitoring and operational tools
|
||||
|
||||
## AWS Lambda Java SDK API
|
||||
|
||||
### Key Classes
|
||||
- `LambdaClient` - Synchronous service client
|
||||
- `LambdaAsyncClient` - Asynchronous service client
|
||||
- `LambdaClientBuilder` - Builder for synchronous client
|
||||
- `LambdaAsyncClientBuilder` - Builder for asynchronous client
|
||||
- `LambdaServiceClientConfiguration` - Client settings configuration
|
||||
|
||||
### Related Packages
|
||||
- `software.amazon.awssdk.services.lambda.model` - API models
|
||||
- `software.amazon.awssdk.services.lambda.transform` - Request/response transformations
|
||||
- `software.amazon.awssdk.services.lambda.paginators` - Pagination utilities
|
||||
- `software.amazon.awssdk.services.lambda.waiters` - Waiter utilities
|
||||
|
||||
### Authentication
|
||||
Lambda supports signature version 4 for API authentication.
|
||||
|
||||
### CA Requirements
|
||||
Clients need to support these CAs:
|
||||
- Amazon Root CA 1
|
||||
- Starfield Services Root Certificate Authority - G2
|
||||
- Starfield Class 2 Certification Authority
|
||||
|
||||
## Core API Operations
|
||||
|
||||
### Function Management Operations
|
||||
- `CreateFunction` - Create new Lambda function
|
||||
- `DeleteFunction` - Delete existing function
|
||||
- `GetFunction` - Retrieve function configuration
|
||||
- `UpdateFunctionCode` - Update function code
|
||||
- `UpdateFunctionConfiguration` - Update function settings
|
||||
- `ListFunctions` - List functions for account
|
||||
|
||||
### Invocation Operations
|
||||
- `Invoke` - Invoke Lambda function synchronously
|
||||
- `Invoke` with `InvocationType.EVENT` - Asynchronous invocation
|
||||
|
||||
### Environment & Configuration
|
||||
- Environment variable management
|
||||
- Function configuration updates
|
||||
- Version and alias management
|
||||
- Layer management
|
||||
|
||||
## Examples Overview
|
||||
The AWS documentation includes examples for:
|
||||
- Basic Lambda function creation and invocation
|
||||
- Function configuration and updates
|
||||
- Environment variable management
|
||||
- Function listing and cleanup
|
||||
- Integration patterns
|
||||
|
||||
## Best Practices from Official Docs
|
||||
- Reuse Lambda clients across invocations
|
||||
- Set appropriate timeouts matching function requirements
|
||||
- Use async invocation for fire-and-forget scenarios
|
||||
- Implement proper error handling for function errors and status codes
|
||||
- Use environment variables for configuration management
|
||||
- Version functions for production stability
|
||||
- Monitor invocations using CloudWatch metrics
|
||||
- Implement retry logic for transient failures
|
||||
- Use VPC integration for private resources
|
||||
- Optimize payload sizes for performance
|
||||
|
||||
## Security Considerations
|
||||
- Use IAM roles with least privilege
|
||||
- Implement proper Lambda permissions
|
||||
- Use environment variables for sensitive data
|
||||
- Enable CloudTrail logging
|
||||
- Monitor security events with CloudWatch
|
||||
- Use code signing for production deployments
|
||||
- Implement proper authentication and authorization
|
||||
310
skills/aws-java/aws-sdk-java-v2-messaging/SKILL.md
Normal file
310
skills/aws-java/aws-sdk-java-v2-messaging/SKILL.md
Normal file
@@ -0,0 +1,310 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-messaging
|
||||
description: Implement AWS messaging patterns using AWS SDK for Java 2.x for SQS queues and SNS topics. Send/receive messages, manage FIFO queues, implement DLQ, publish messages, manage subscriptions, and build pub/sub patterns.
|
||||
category: aws
|
||||
tags: [aws, sqs, sns, java, sdk, messaging, pub-sub, queues, events]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash, Grep
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - Messaging (SQS & SNS)
|
||||
|
||||
## Overview
|
||||
|
||||
Provide comprehensive AWS messaging patterns using AWS SDK for Java 2.x for both SQS and SNS services. Include client setup, queue management, message operations, subscription management, and Spring Boot integration patterns.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when working with:
|
||||
- Amazon SQS queues for message queuing
|
||||
- SNS topics for event publishing and notification
|
||||
- FIFO queues and standard queues
|
||||
- Dead Letter Queues (DLQ) for message handling
|
||||
- SNS subscriptions with email, SMS, SQS, Lambda endpoints
|
||||
- Pub/sub messaging patterns and event-driven architectures
|
||||
- Spring Boot integration with AWS messaging services
|
||||
- Testing strategies using LocalStack or Testcontainers
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Dependencies
|
||||
|
||||
```xml
|
||||
<!-- SQS -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>sqs</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- SNS -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>sns</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Basic Client Setup
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.sqs.SqsClient;
|
||||
import software.amazon.awssdk.services.sns.SnsClient;
|
||||
|
||||
SqsClient sqsClient = SqsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
SnsClient snsClient = SnsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic SQS Operations
|
||||
|
||||
#### Create and Send Message
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.sqs.SqsClient;
|
||||
import software.amazon.awssdk.services.sqs.model.*;
|
||||
|
||||
// Setup SQS client
|
||||
SqsClient sqsClient = SqsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Create queue
|
||||
String queueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
|
||||
.queueName("my-queue")
|
||||
.build()).queueUrl();
|
||||
|
||||
// Send message
|
||||
String messageId = sqsClient.sendMessage(SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody("Hello, SQS!")
|
||||
.build()).messageId();
|
||||
```
|
||||
|
||||
#### Receive and Delete Message
|
||||
```java
|
||||
// Receive messages with long polling
|
||||
ReceiveMessageResponse response = sqsClient.receiveMessage(ReceiveMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.maxNumberOfMessages(10)
|
||||
.waitTimeSeconds(20)
|
||||
.build());
|
||||
|
||||
// Process and delete messages
|
||||
response.messages().forEach(message -> {
|
||||
System.out.println("Received: " + message.body());
|
||||
sqsClient.deleteMessage(DeleteMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.receiptHandle(message.receiptHandle())
|
||||
.build());
|
||||
});
|
||||
```
|
||||
|
||||
### Basic SNS Operations
|
||||
|
||||
#### Create Topic and Publish
|
||||
```java
|
||||
import software.amazon.awssdk.services.sns.SnsClient;
|
||||
import software.amazon.awssdk.services.sns.model.*;
|
||||
|
||||
// Setup SNS client
|
||||
SnsClient snsClient = SnsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Create topic
|
||||
String topicArn = snsClient.createTopic(CreateTopicRequest.builder()
|
||||
.name("my-topic")
|
||||
.build()).topicArn();
|
||||
|
||||
// Publish message
|
||||
String messageId = snsClient.publish(PublishRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.subject("Test Notification")
|
||||
.message("Hello, SNS!")
|
||||
.build()).messageId();
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
|
||||
#### FIFO Queue Pattern
|
||||
```java
|
||||
// Create FIFO queue
|
||||
Map<QueueAttributeName, String> attributes = Map.of(
|
||||
QueueAttributeName.FIFO_QUEUE, "true",
|
||||
QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true"
|
||||
);
|
||||
|
||||
String fifoQueueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
|
||||
.queueName("my-queue.fifo")
|
||||
.attributes(attributes)
|
||||
.build()).queueUrl();
|
||||
|
||||
// Send FIFO message with group ID
|
||||
String fifoMessageId = sqsClient.sendMessage(SendMessageRequest.builder()
|
||||
.queueUrl(fifoQueueUrl)
|
||||
.messageBody("Order #12345")
|
||||
.messageGroupId("orders")
|
||||
.messageDeduplicationId(UUID.randomUUID().toString())
|
||||
.build()).messageId();
|
||||
```
|
||||
|
||||
#### SNS to SQS Subscription
|
||||
```java
|
||||
// Create SQS queue for subscription
|
||||
String subscriptionQueueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
|
||||
.queueName("notification-subscriber")
|
||||
.build()).queueUrl();
|
||||
|
||||
// Get queue ARN
|
||||
String queueArn = sqsClient.getQueueAttributes(GetQueueAttributesRequest.builder()
|
||||
.queueUrl(subscriptionQueueUrl)
|
||||
.attributeNames(QueueAttributeName.QUEUE_ARN)
|
||||
.build()).attributes().get(QueueAttributeName.QUEUE_ARN);
|
||||
|
||||
// Subscribe SQS to SNS
|
||||
String subscriptionArn = snsClient.subscribe(SubscribeRequest.builder()
|
||||
.protocol("sqs")
|
||||
.endpoint(queueArn)
|
||||
.topicArn(topicArn)
|
||||
.build()).subscriptionArn();
|
||||
```
|
||||
|
||||
### Spring Boot Integration Example
|
||||
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class OrderNotificationService {
|
||||
|
||||
private final SnsClient snsClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Value("${aws.sns.order-topic-arn}")
|
||||
private String orderTopicArn;
|
||||
|
||||
public void sendOrderNotification(Order order) {
|
||||
try {
|
||||
String jsonMessage = objectMapper.writeValueAsString(order);
|
||||
|
||||
snsClient.publish(PublishRequest.builder()
|
||||
.topicArn(orderTopicArn)
|
||||
.subject("New Order Received")
|
||||
.message(jsonMessage)
|
||||
.messageAttributes(Map.of(
|
||||
"orderType", MessageAttributeValue.builder()
|
||||
.dataType("String")
|
||||
.stringValue(order.getType())
|
||||
.build()))
|
||||
.build());
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to send order notification", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### SQS Best Practices
|
||||
- **Use long polling**: Set `waitTimeSeconds` (20-40 seconds) to reduce empty responses
|
||||
- **Batch operations**: Use `sendMessageBatch` for multiple messages to reduce API calls
|
||||
- **Visibility timeout**: Set appropriately based on message processing time (default 30 seconds)
|
||||
- **Delete messages**: Always delete messages after successful processing
|
||||
- **Handle duplicates**: Implement idempotent processing for retries
|
||||
- **Implement DLQ**: Route failed messages to dead letter queues for analysis
|
||||
- **Monitor queue depth**: Use CloudWatch alarms for high queue backlog
|
||||
- **Use FIFO queues**: When message order and deduplication are critical
|
||||
|
||||
### SNS Best Practices
|
||||
- **Use filter policies**: Reduce noise by filtering messages at the source
|
||||
- **Message attributes**: Add metadata for subscription routing decisions
|
||||
- **Retry logic**: Handle transient failures with exponential backoff
|
||||
- **Monitor failed deliveries**: Set up CloudWatch alarms for failed notifications
|
||||
- **Security**: Use IAM policies for access control and data encryption
|
||||
- **FIFO topics**: Use when order and deduplication are critical
|
||||
- **Avoid large payloads**: Keep messages under 256KB for optimal performance
|
||||
|
||||
### General Guidelines
|
||||
- **Region consistency**: Use the same region for all AWS resources
|
||||
- **Resource naming**: Use consistent naming conventions for queues and topics
|
||||
- **Error handling**: Implement proper exception handling and logging
|
||||
- **Testing**: Use LocalStack for local development and testing
|
||||
- **Documentation**: Document subscription endpoints and message formats
|
||||
|
||||
## Instructions
|
||||
|
||||
### Setup AWS Credentials
|
||||
Configure AWS credentials using environment variables, AWS CLI, or IAM roles:
|
||||
```bash
|
||||
export AWS_ACCESS_KEY_ID=your-access-key
|
||||
export AWS_SECRET_ACCESS_KEY=your-secret-key
|
||||
export AWS_REGION=us-east-1
|
||||
```
|
||||
|
||||
### Configure Clients
|
||||
```java
|
||||
// Basic client configuration
|
||||
SqsClient sqsClient = SqsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Advanced client with custom configuration
|
||||
SnsClient snsClient = SnsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.credentialsProvider(DefaultCredentialsProvider.create())
|
||||
.httpClient(UrlConnectionHttpClient.create())
|
||||
.build();
|
||||
```
|
||||
|
||||
### Implement Message Processing
|
||||
1. **Connect** to SQS/SNS using the AWS SDK clients
|
||||
2. **Create** queues and topics as needed
|
||||
3. **Send/receive** messages with appropriate timeout settings
|
||||
4. **Process** messages in batches for efficiency
|
||||
5. **Delete** messages after successful processing
|
||||
6. **Handle** failures with proper error handling and retries
|
||||
|
||||
### Integrate with Spring Boot
|
||||
1. **Configure** beans for `SqsClient` and `SnsClient` in `@Configuration` classes
|
||||
2. **Use** `@Value` to inject queue URLs and topic ARNs from properties
|
||||
3. **Create** service classes with business logic for messaging operations
|
||||
4. **Implement** error handling with `@Retryable` or custom retry logic
|
||||
5. **Test** integration using Testcontainers or LocalStack
|
||||
|
||||
### Monitor and Debug
|
||||
- Use AWS CloudWatch for monitoring queue depth and message metrics
|
||||
- Enable AWS SDK logging for debugging client operations
|
||||
- Implement proper logging for message processing activities
|
||||
- Use AWS X-Ray for distributed tracing in production environments
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
- **Queue does not exist**: Verify queue URL and permissions
|
||||
- **Message not received**: Check visibility timeout and consumer logic
|
||||
- **Permission denied**: Verify IAM policies and credentials
|
||||
- **Connection timeout**: Check network connectivity and region configuration
|
||||
- **Rate limiting**: Implement retry logic with exponential backoff
|
||||
|
||||
### Performance Optimization
|
||||
- Use long polling to reduce empty responses
|
||||
- Batch message operations to minimize API calls
|
||||
- Adjust visibility timeout based on processing time
|
||||
- Implement connection pooling and reuse clients
|
||||
- Use appropriate message sizes to avoid fragmentation
|
||||
|
||||
## Detailed References
|
||||
|
||||
For comprehensive API documentation and advanced patterns, see:
|
||||
|
||||
- [@references/detailed-sqs-operations] - Complete SQS operations reference
|
||||
- [@references/detailed-sns-operations] - Complete SNS operations reference
|
||||
- [@references/spring-boot-integration] - Spring Boot integration patterns
|
||||
- [@references/aws-official-documentation] - Official AWS documentation and best practices
|
||||
@@ -0,0 +1,158 @@
|
||||
# AWS SQS & SNS Official Documentation Reference
|
||||
|
||||
This file contains reference information extracted from official AWS resources for the AWS SDK for Java 2.x messaging patterns.
|
||||
|
||||
## Source Documents
|
||||
- [AWS Java SDK v2 Examples - SQS](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/sqs)
|
||||
- [AWS Java SDK v2 Examples - SNS](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/example_code/sns)
|
||||
- [AWS SQS Developer Guide](https://docs.aws.amazon.com/sqs/latest/dg/)
|
||||
- [AWS SNS Developer Guide](https://docs.aws.amazon.com/sns/latest/dg/)
|
||||
|
||||
## Amazon SQS Reference
|
||||
|
||||
### Core Operations
|
||||
- **CreateQueue** - Create new SQS queue
|
||||
- **DeleteMessage** - Delete individual message from queue
|
||||
- **ListQueues** - List available queues
|
||||
- **ReceiveMessage** - Receive messages from queue
|
||||
- **SendMessage** - Send message to queue
|
||||
- **SendMessageBatch** - Send multiple messages to queue
|
||||
|
||||
### Advanced Features
|
||||
- **Large Message Handling** - Use S3 for messages larger than 256KB
|
||||
- **Batch Operations** - Process multiple messages efficiently
|
||||
- **Long Polling** - Reduce empty responses with `waitTimeSeconds`
|
||||
- **Visibility Timeout** - Control message visibility during processing
|
||||
- **Dead Letter Queues (DLQ)** - Handle failed messages
|
||||
- **FIFO Queues** - Ensure message ordering and deduplication
|
||||
|
||||
### Java SDK v2 Key Classes
|
||||
```java
|
||||
// Core clients and models
|
||||
software.amazon.awssdk.services.sqs.SqsClient
|
||||
software.amazon.awssdk.services.sqs.model.*
|
||||
software.amazon.awssdk.services.sqs.model.QueueAttributeName
|
||||
```
|
||||
|
||||
## Amazon SNS Reference
|
||||
|
||||
### Core Operations
|
||||
- **CreateTopic** - Create new SNS topic
|
||||
- **Publish** - Send message to topic
|
||||
- **Subscribe** - Subscribe endpoint to topic
|
||||
- **ListSubscriptions** - List topic subscriptions
|
||||
- **Unsubscribe** - Remove subscription
|
||||
|
||||
### Advanced Features
|
||||
- **Platform Endpoints** - Mobile push notifications
|
||||
- **SMS Publishing** - Send SMS messages
|
||||
- **FIFO Topics** - Ordered message delivery with deduplication
|
||||
- **Filter Policies** - Filter messages based on attributes
|
||||
- **Message Attributes** - Enrich messages with metadata
|
||||
- **DLQ for Subscriptions** - Handle failed deliveries
|
||||
|
||||
### Java SDK v2 Key Classes
|
||||
```java
|
||||
// Core clients and models
|
||||
software.amazon.awssdk.services.sns.SnsClient
|
||||
software.amazon.awssdk.services.sns.model.*
|
||||
software.amazon.awssdk.services.sns.model.MessageAttributeValue
|
||||
```
|
||||
|
||||
## Best Practices from AWS
|
||||
|
||||
### SQS Best Practices
|
||||
1. **Use Long Polling**: Set `waitTimeSeconds` (10-40 seconds) to reduce empty responses
|
||||
2. **Batch Operations**: Use `SendMessageBatch` for efficiency
|
||||
3. **Visibility Timeout**: Set appropriately based on processing time
|
||||
4. **Handle Duplicates**: Implement idempotent processing for retries
|
||||
5. **Monitor Queue Depth**: Use CloudWatch for monitoring
|
||||
6. **Implement DLQ**: Route failed messages for analysis
|
||||
|
||||
### SNS Best Practices
|
||||
1. **Use Filter Policies**: Reduce noise by filtering messages
|
||||
2. **Message Attributes**: Add metadata for routing decisions
|
||||
3. **Retry Logic**: Handle transient failures gracefully
|
||||
4. **Monitor Failed Deliveries**: Set up CloudWatch alarms
|
||||
5. **Security**: Use IAM policies for access control
|
||||
6. **FIFO Topics**: Use when order and deduplication are critical
|
||||
|
||||
## Error Handling Patterns
|
||||
|
||||
### Common SQS Errors
|
||||
- **QueueDoesNotExistException**: Verify queue URL
|
||||
- **MessageNotInflightException**: Check message visibility
|
||||
- **OverLimitException**: Implement backoff/retry logic
|
||||
- **InvalidAttributeValueException**: Validate queue attributes
|
||||
|
||||
### Common SNS Errors
|
||||
- **NotFoundException**: Verify topic ARN
|
||||
- **InvalidParameterException**: Validate subscription parameters
|
||||
- **InternalFailureException**: Implement retry logic
|
||||
- **AuthorizationErrorException**: Check IAM permissions
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### Spring Boot Integration
|
||||
- Use `@Service` classes for business logic
|
||||
- Inject `SqsClient` and `SnsClient` via constructor injection
|
||||
- Configure clients with `@Configuration` beans
|
||||
- Use `@Value` for externalizing configuration
|
||||
|
||||
### Testing Strategies
|
||||
- Use LocalStack for local development
|
||||
- Mock AWS services with Mockito for unit tests
|
||||
- Integrate with Testcontainers for integration tests
|
||||
- Test idempotent operations thoroughly
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### SQS Configuration
|
||||
```java
|
||||
SqsClient sqsClient = SqsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### SNS Configuration
|
||||
```java
|
||||
SnsClient snsClient = SnsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
- Override endpoint for local development
|
||||
- Configure custom credentials provider
|
||||
- Set custom HTTP client
|
||||
- Configure retry policies
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
### SQS Metrics
|
||||
- ApproximateNumberOfMessagesVisible
|
||||
- ApproximateNumberOfMessagesNotVisible
|
||||
- ApproximateNumberOfMessagesDelayed
|
||||
- SentMessages
|
||||
- ReceiveCalls
|
||||
|
||||
### SNS Metrics
|
||||
- NumberOfNotifications
|
||||
- PublishSuccess
|
||||
- PublishFailed
|
||||
- SubscriptionConfirmation
|
||||
- SubscriptionConfirmationFailed
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### IAM Permissions
|
||||
- Grant least privilege access
|
||||
- Use IAM roles for EC2/ECS
|
||||
- Implement resource-based policies
|
||||
- Use condition keys for fine-grained control
|
||||
|
||||
### Data Protection
|
||||
- Encrypt sensitive data in messages
|
||||
- Use KMS for message encryption
|
||||
- Implement message signing
|
||||
- Secure endpoints with HTTPS
|
||||
@@ -0,0 +1,179 @@
|
||||
# Detailed SNS Operations Reference
|
||||
|
||||
## Topic Management
|
||||
|
||||
### Create Standard Topic
|
||||
```java
|
||||
public String createTopic(SnsClient snsClient, String topicName) {
|
||||
CreateTopicRequest request = CreateTopicRequest.builder()
|
||||
.name(topicName)
|
||||
.build();
|
||||
|
||||
CreateTopicResponse response = snsClient.createTopic(request);
|
||||
return response.topicArn();
|
||||
}
|
||||
```
|
||||
|
||||
### Create FIFO Topic
|
||||
```java
|
||||
public String createFifoTopic(SnsClient snsClient, String topicName) {
|
||||
Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("FifoTopic", "true");
|
||||
attributes.put("ContentBasedDeduplication", "true");
|
||||
|
||||
CreateTopicRequest request = CreateTopicRequest.builder()
|
||||
.name(topicName + ".fifo")
|
||||
.attributes(attributes)
|
||||
.build();
|
||||
|
||||
CreateTopicResponse response = snsClient.createTopic(request);
|
||||
return response.topicArn();
|
||||
}
|
||||
```
|
||||
|
||||
### Topic Operations
|
||||
```java
|
||||
public List<Topic> listTopics(SnsClient snsClient) {
|
||||
return snsClient.listTopics().topics();
|
||||
}
|
||||
|
||||
public String getTopicArn(SnsClient snsClient, String topicName) {
|
||||
return snsClient.createTopic(CreateTopicRequest.builder()
|
||||
.name(topicName)
|
||||
.build()).topicArn();
|
||||
}
|
||||
```
|
||||
|
||||
## Message Publishing
|
||||
|
||||
### Publish Basic Message
|
||||
```java
|
||||
public String publishMessage(SnsClient snsClient, String topicArn, String message) {
|
||||
PublishRequest request = PublishRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.message(message)
|
||||
.build();
|
||||
|
||||
PublishResponse response = snsClient.publish(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
### Publish with Subject
|
||||
```java
|
||||
public String publishMessageWithSubject(SnsClient snsClient,
|
||||
String topicArn,
|
||||
String subject,
|
||||
String message) {
|
||||
PublishRequest request = PublishRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.subject(subject)
|
||||
.message(message)
|
||||
.build();
|
||||
|
||||
PublishResponse response = snsClient.publish(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
### Publish with Attributes
|
||||
```java
|
||||
public String publishMessageWithAttributes(SnsClient snsClient,
|
||||
String topicArn,
|
||||
String message,
|
||||
Map<String, String> attributes) {
|
||||
Map<String, MessageAttributeValue> messageAttributes = attributes.entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
e -> MessageAttributeValue.builder()
|
||||
.dataType("String")
|
||||
.stringValue(e.getValue())
|
||||
.build()));
|
||||
|
||||
PublishRequest request = PublishRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.message(message)
|
||||
.messageAttributes(messageAttributes)
|
||||
.build();
|
||||
|
||||
PublishResponse response = snsClient.publish(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
### Publish FIFO Message
|
||||
```java
|
||||
public String publishFifoMessage(SnsClient snsClient,
|
||||
String topicArn,
|
||||
String message,
|
||||
String messageGroupId) {
|
||||
PublishRequest request = PublishRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.message(message)
|
||||
.messageGroupId(messageGroupId)
|
||||
.messageDeduplicationId(UUID.randomUUID().toString())
|
||||
.build();
|
||||
|
||||
PublishResponse response = snsClient.publish(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
## Subscription Management
|
||||
|
||||
### Subscribe Email to Topic
|
||||
```java
|
||||
public String subscribeEmail(SnsClient snsClient, String topicArn, String email) {
|
||||
SubscribeRequest request = SubscribeRequest.builder()
|
||||
.protocol("email")
|
||||
.endpoint(email)
|
||||
.topicArn(topicArn)
|
||||
.build();
|
||||
|
||||
SubscribeResponse response = snsClient.subscribe(request);
|
||||
return response.subscriptionArn();
|
||||
}
|
||||
```
|
||||
|
||||
### Subscribe SQS to Topic
|
||||
```java
|
||||
public String subscribeSqs(SnsClient snsClient, String topicArn, String queueArn) {
|
||||
SubscribeRequest request = SubscribeRequest.builder()
|
||||
.protocol("sqs")
|
||||
.endpoint(queueArn)
|
||||
.topicArn(topicArn)
|
||||
.build();
|
||||
|
||||
SubscribeResponse response = snsClient.subscribe(request);
|
||||
return response.subscriptionArn();
|
||||
}
|
||||
```
|
||||
|
||||
### Subscribe Lambda to Topic
|
||||
```java
|
||||
public String subscribeLambda(SnsClient snsClient, String topicArn, String lambdaArn) {
|
||||
SubscribeRequest request = SubscribeRequest.builder()
|
||||
.protocol("lambda")
|
||||
.endpoint(lambdaArn)
|
||||
.topicArn(topicArn)
|
||||
.build();
|
||||
|
||||
SubscribeResponse response = snsClient.subscribe(request);
|
||||
return response.subscriptionArn();
|
||||
}
|
||||
```
|
||||
|
||||
### Subscription Operations
|
||||
```java
|
||||
public List<Subscription> listSubscriptions(SnsClient snsClient, String topicArn) {
|
||||
return snsClient.listSubscriptionsByTopic(ListSubscriptionsByTopicRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.build()).subscriptions();
|
||||
}
|
||||
|
||||
public void unsubscribe(SnsClient snsClient, String subscriptionArn) {
|
||||
snsClient.unsubscribe(UnsubscribeRequest.builder()
|
||||
.subscriptionArn(subscriptionArn)
|
||||
.build());
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,199 @@
|
||||
# Detailed SQS Operations Reference
|
||||
|
||||
## Queue Management
|
||||
|
||||
### Create Standard Queue
|
||||
```java
|
||||
public String createQueue(SqsClient sqsClient, String queueName) {
|
||||
CreateQueueRequest request = CreateQueueRequest.builder()
|
||||
.queueName(queueName)
|
||||
.build();
|
||||
|
||||
CreateQueueResponse response = sqsClient.createQueue(request);
|
||||
return response.queueUrl();
|
||||
}
|
||||
```
|
||||
|
||||
### Create FIFO Queue
|
||||
```java
|
||||
public String createFifoQueue(SqsClient sqsClient, String queueName) {
|
||||
Map<QueueAttributeName, String> attributes = new HashMap<>();
|
||||
attributes.put(QueueAttributeName.FIFO_QUEUE, "true");
|
||||
attributes.put(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true");
|
||||
|
||||
CreateQueueRequest request = CreateQueueRequest.builder()
|
||||
.queueName(queueName + ".fifo")
|
||||
.attributes(attributes)
|
||||
.build();
|
||||
|
||||
CreateQueueResponse response = sqsClient.createQueue(request);
|
||||
return response.queueUrl();
|
||||
}
|
||||
```
|
||||
|
||||
### Queue Operations
|
||||
```java
|
||||
public String getQueueUrl(SqsClient sqsClient, String queueName) {
|
||||
return sqsClient.getQueueUrl(GetQueueUrlRequest.builder()
|
||||
.queueName(queueName)
|
||||
.build()).queueUrl();
|
||||
}
|
||||
|
||||
public List<String> listQueues(SqsClient sqsClient) {
|
||||
return sqsClient.listQueues().queueUrls();
|
||||
}
|
||||
|
||||
public void purgeQueue(SqsClient sqsClient, String queueUrl) {
|
||||
sqsClient.purgeQueue(PurgeQueueRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.build());
|
||||
}
|
||||
```
|
||||
|
||||
## Message Operations
|
||||
|
||||
### Send Basic Message
|
||||
```java
|
||||
public String sendMessage(SqsClient sqsClient, String queueUrl, String messageBody) {
|
||||
SendMessageRequest request = SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody(messageBody)
|
||||
.build();
|
||||
|
||||
SendMessageResponse response = sqsClient.sendMessage(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
### Send Message with Attributes
|
||||
```java
|
||||
public String sendMessageWithAttributes(SqsClient sqsClient,
|
||||
String queueUrl,
|
||||
String messageBody,
|
||||
Map<String, String> attributes) {
|
||||
Map<String, MessageAttributeValue> messageAttributes = attributes.entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
e -> MessageAttributeValue.builder()
|
||||
.dataType("String")
|
||||
.stringValue(e.getValue())
|
||||
.build()));
|
||||
|
||||
SendMessageRequest request = SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody(messageBody)
|
||||
.messageAttributes(messageAttributes)
|
||||
.build();
|
||||
|
||||
SendMessageResponse response = sqsClient.sendMessage(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
### Send FIFO Message
|
||||
```java
|
||||
public String sendFifoMessage(SqsClient sqsClient,
|
||||
String queueUrl,
|
||||
String messageBody,
|
||||
String messageGroupId) {
|
||||
SendMessageRequest request = SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody(messageBody)
|
||||
.messageGroupId(messageGroupId)
|
||||
.messageDeduplicationId(UUID.randomUUID().toString())
|
||||
.build();
|
||||
|
||||
SendMessageResponse response = sqsClient.sendMessage(request);
|
||||
return response.messageId();
|
||||
}
|
||||
```
|
||||
|
||||
### Send Batch Messages
|
||||
```java
|
||||
public void sendBatchMessages(SqsClient sqsClient,
|
||||
String queueUrl,
|
||||
List<String> messages) {
|
||||
List<SendMessageBatchRequestEntry> entries = IntStream.range(0, messages.size())
|
||||
.mapToObj(i -> SendMessageBatchRequestEntry.builder()
|
||||
.id(String.valueOf(i))
|
||||
.messageBody(messages.get(i))
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
SendMessageBatchRequest request = SendMessageBatchRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.entries(entries)
|
||||
.build();
|
||||
|
||||
SendMessageBatchResponse response = sqsClient.sendMessageBatch(request);
|
||||
|
||||
System.out.println("Successful: " + response.successful().size());
|
||||
System.out.println("Failed: " + response.failed().size());
|
||||
}
|
||||
```
|
||||
|
||||
## Message Processing
|
||||
|
||||
### Receive Messages with Long Polling
|
||||
```java
|
||||
public List<Message> receiveMessages(SqsClient sqsClient, String queueUrl) {
|
||||
ReceiveMessageRequest request = ReceiveMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.maxNumberOfMessages(10)
|
||||
.waitTimeSeconds(20) // Long polling
|
||||
.messageAttributeNames("All")
|
||||
.build();
|
||||
|
||||
ReceiveMessageResponse response = sqsClient.receiveMessage(request);
|
||||
return response.messages();
|
||||
}
|
||||
```
|
||||
|
||||
### Delete Message
|
||||
```java
|
||||
public void deleteMessage(SqsClient sqsClient, String queueUrl, String receiptHandle) {
|
||||
DeleteMessageRequest request = DeleteMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.receiptHandle(receiptHandle)
|
||||
.build();
|
||||
|
||||
sqsClient.deleteMessage(request);
|
||||
}
|
||||
```
|
||||
|
||||
### Delete Batch Messages
|
||||
```java
|
||||
public void deleteBatchMessages(SqsClient sqsClient,
|
||||
String queueUrl,
|
||||
List<Message> messages) {
|
||||
List<DeleteMessageBatchRequestEntry> entries = messages.stream()
|
||||
.map(msg -> DeleteMessageBatchRequestEntry.builder()
|
||||
.id(msg.messageId())
|
||||
.receiptHandle(msg.receiptHandle())
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
DeleteMessageBatchRequest request = DeleteMessageBatchRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.entries(entries)
|
||||
.build();
|
||||
|
||||
sqsClient.deleteMessageBatch(request);
|
||||
}
|
||||
```
|
||||
|
||||
### Change Message Visibility
|
||||
```java
|
||||
public void changeMessageVisibility(SqsClient sqsClient,
|
||||
String queueUrl,
|
||||
String receiptHandle,
|
||||
int visibilityTimeout) {
|
||||
ChangeMessageVisibilityRequest request = ChangeMessageVisibilityRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.receiptHandle(receiptHandle)
|
||||
.visibilityTimeout(visibilityTimeout)
|
||||
.build();
|
||||
|
||||
sqsClient.changeMessageVisibility(request);
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,292 @@
|
||||
# Spring Boot Integration Reference
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Bean Configuration
|
||||
```java
|
||||
@Configuration
|
||||
public class MessagingConfiguration {
|
||||
|
||||
@Bean
|
||||
public SqsClient sqsClient() {
|
||||
return SqsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public SnsClient snsClient() {
|
||||
return SnsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Properties
|
||||
```yaml
|
||||
# application.yml
|
||||
aws:
|
||||
sqs:
|
||||
queue-url: https://sqs.us-east-1.amazonaws.com/123456789012/my-queue
|
||||
sns:
|
||||
topic-arn: arn:aws:sns:us-east-1:123456789012:my-topic
|
||||
```
|
||||
|
||||
## Service Layer Integration
|
||||
|
||||
### SQS Message Service
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class SqsMessageService {
|
||||
|
||||
private final SqsClient sqsClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Value("${aws.sqs.queue-url}")
|
||||
private String queueUrl;
|
||||
|
||||
public <T> void sendMessage(T message) {
|
||||
try {
|
||||
String jsonMessage = objectMapper.writeValueAsString(message);
|
||||
|
||||
SendMessageRequest request = SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody(jsonMessage)
|
||||
.build();
|
||||
|
||||
sqsClient.sendMessage(request);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to send SQS message", e);
|
||||
}
|
||||
}
|
||||
|
||||
public <T> List<T> receiveMessages(Class<T> messageType) {
|
||||
ReceiveMessageRequest request = ReceiveMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.maxNumberOfMessages(10)
|
||||
.waitTimeSeconds(20)
|
||||
.build();
|
||||
|
||||
ReceiveMessageResponse response = sqsClient.receiveMessage(request);
|
||||
|
||||
return response.messages().stream()
|
||||
.map(msg -> {
|
||||
try {
|
||||
return objectMapper.readValue(msg.body(), messageType);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to parse message", e);
|
||||
}
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public void deleteMessage(String receiptHandle) {
|
||||
DeleteMessageRequest request = DeleteMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.receiptHandle(receiptHandle)
|
||||
.build();
|
||||
|
||||
sqsClient.deleteMessage(request);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### SNS Notification Service
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class SnsNotificationService {
|
||||
|
||||
private final SnsClient snsClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Value("${aws.sns.topic-arn}")
|
||||
private String topicArn;
|
||||
|
||||
public void publishNotification(String subject, Object message) {
|
||||
try {
|
||||
String jsonMessage = objectMapper.writeValueAsString(message);
|
||||
|
||||
PublishRequest request = PublishRequest.builder()
|
||||
.topicArn(topicArn)
|
||||
.subject(subject)
|
||||
.message(jsonMessage)
|
||||
.build();
|
||||
|
||||
snsClient.publish(request);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to publish SNS notification", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Message Listener Pattern
|
||||
|
||||
### Scheduled Polling
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class SqsMessageListener {
|
||||
|
||||
private final SqsClient sqsClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Value("${aws.sqs.queue-url}")
|
||||
private String queueUrl;
|
||||
|
||||
@Scheduled(fixedDelay = 5000)
|
||||
public void pollMessages() {
|
||||
ReceiveMessageRequest request = ReceiveMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.maxNumberOfMessages(10)
|
||||
.waitTimeSeconds(20)
|
||||
.build();
|
||||
|
||||
ReceiveMessageResponse response = sqsClient.receiveMessage(request);
|
||||
|
||||
response.messages().forEach(this::processMessage);
|
||||
}
|
||||
|
||||
private void processMessage(Message message) {
|
||||
try {
|
||||
// Process message
|
||||
System.out.println("Processing: " + message.body());
|
||||
|
||||
// Delete message after successful processing
|
||||
deleteMessage(message.receiptHandle());
|
||||
|
||||
} catch (Exception e) {
|
||||
// Handle error - message will become visible again
|
||||
System.err.println("Failed to process message: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void deleteMessage(String receiptHandle) {
|
||||
DeleteMessageRequest request = DeleteMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.receiptHandle(receiptHandle)
|
||||
.build();
|
||||
|
||||
sqsClient.deleteMessage(request);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Pub/Sub Pattern Integration
|
||||
|
||||
### Configuration for Pub/Sub
|
||||
```java
|
||||
@Configuration
|
||||
@RequiredArgsConstructor
|
||||
public class PubSubConfiguration {
|
||||
|
||||
private final SnsClient snsClient;
|
||||
private final SqsClient sqsClient;
|
||||
|
||||
@Bean
|
||||
@DependsOn("sqsClient")
|
||||
public String setupPubSub() {
|
||||
// Create SNS topic
|
||||
String topicArn = snsClient.createTopic(CreateTopicRequest.builder()
|
||||
.name("order-events")
|
||||
.build()).topicArn();
|
||||
|
||||
// Create SQS queue
|
||||
String queueUrl = sqsClient.createQueue(CreateQueueRequest.builder()
|
||||
.queueName("order-processor")
|
||||
.build()).queueUrl();
|
||||
|
||||
// Get queue ARN
|
||||
String queueArn = sqsClient.getQueueAttributes(GetQueueAttributesRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.attributeNames(QueueAttributeName.QUEUE_ARN)
|
||||
.build()).attributes().get(QueueAttributeName.QUEUE_ARN);
|
||||
|
||||
// Subscribe SQS to SNS
|
||||
snsClient.subscribe(SubscribeRequest.builder()
|
||||
.protocol("sqs")
|
||||
.endpoint(queueArn)
|
||||
.topicArn(topicArn)
|
||||
.build());
|
||||
|
||||
return topicArn;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling Patterns
|
||||
|
||||
### Retry Mechanism
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class RetryableSqsService {
|
||||
|
||||
private final SqsClient sqsClient;
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
public void sendMessageWithRetry(String queueUrl, String messageBody) {
|
||||
retryTemplate.execute(context -> {
|
||||
try {
|
||||
SendMessageRequest request = SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody(messageBody)
|
||||
.build();
|
||||
|
||||
sqsClient.sendMessage(request);
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
throw new RetryableException("Failed to send message", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Integration
|
||||
|
||||
### LocalStack Configuration
|
||||
```java
|
||||
@TestConfiguration
|
||||
public class LocalStackMessagingConfig {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(
|
||||
LocalStackContainer.Service.SQS,
|
||||
LocalStackContainer.Service.SNS
|
||||
);
|
||||
|
||||
@Bean
|
||||
public SqsClient sqsClient() {
|
||||
return SqsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.endpointOverride(
|
||||
localstack.getEndpointOverride(LocalStackContainer.Service.SQS))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
localstack.getAccessKey(),
|
||||
localstack.getSecretKey())))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public SnsClient snsClient() {
|
||||
return SnsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.endpointOverride(
|
||||
localstack.getEndpointOverride(LocalStackContainer.Service.SNS))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
localstack.getAccessKey(),
|
||||
localstack.getSecretKey())))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
400
skills/aws-java/aws-sdk-java-v2-rds/SKILL.md
Normal file
400
skills/aws-java/aws-sdk-java-v2-rds/SKILL.md
Normal file
@@ -0,0 +1,400 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-rds
|
||||
description: AWS RDS (Relational Database Service) management using AWS SDK for Java 2.x. Use when creating, modifying, monitoring, or managing Amazon RDS database instances, snapshots, parameter groups, and configurations.
|
||||
category: aws
|
||||
tags: [aws, rds, database, java, sdk, postgresql, mysql, aurora, spring-boot]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java v2 - RDS Management
|
||||
|
||||
This skill provides comprehensive guidance for working with Amazon RDS (Relational Database Service) using the AWS SDK for Java 2.x, covering database instance management, snapshots, parameter groups, and RDS operations.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Creating and managing RDS database instances (PostgreSQL, MySQL, Aurora, etc.)
|
||||
- Taking and restoring database snapshots
|
||||
- Managing DB parameter groups and configurations
|
||||
- Querying RDS instance metadata and status
|
||||
- Setting up Multi-AZ deployments
|
||||
- Configuring automated backups
|
||||
- Managing security groups for RDS
|
||||
- Connecting Lambda functions to RDS databases
|
||||
- Implementing RDS IAM authentication
|
||||
- Monitoring RDS instances and metrics
|
||||
|
||||
## Getting Started
|
||||
|
||||
### RDS Client Setup
|
||||
|
||||
The `RdsClient` is the main entry point for interacting with Amazon RDS.
|
||||
|
||||
**Basic Client Creation:**
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.rds.RdsClient;
|
||||
|
||||
RdsClient rdsClient = RdsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
// Use client
|
||||
describeInstances(rdsClient);
|
||||
|
||||
// Always close the client
|
||||
rdsClient.close();
|
||||
```
|
||||
|
||||
**Client with Custom Configuration:**
|
||||
```java
|
||||
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
||||
import software.amazon.awssdk.http.apache.ApacheHttpClient;
|
||||
|
||||
RdsClient rdsClient = RdsClient.builder()
|
||||
.region(Region.US_WEST_2)
|
||||
.credentialsProvider(ProfileCredentialsProvider.create("myprofile"))
|
||||
.httpClient(ApacheHttpClient.builder()
|
||||
.connectionTimeout(Duration.ofSeconds(30))
|
||||
.socketTimeout(Duration.ofSeconds(60))
|
||||
.build())
|
||||
.build();
|
||||
```
|
||||
|
||||
### Describing DB Instances
|
||||
|
||||
Retrieve information about existing RDS instances.
|
||||
|
||||
**List All DB Instances:**
|
||||
```java
|
||||
public static void describeInstances(RdsClient rdsClient) {
|
||||
try {
|
||||
DescribeDbInstancesResponse response = rdsClient.describeDBInstances();
|
||||
List<DBInstance> instanceList = response.dbInstances();
|
||||
|
||||
for (DBInstance instance : instanceList) {
|
||||
System.out.println("Instance ARN: " + instance.dbInstanceArn());
|
||||
System.out.println("Engine: " + instance.engine());
|
||||
System.out.println("Status: " + instance.dbInstanceStatus());
|
||||
System.out.println("Endpoint: " + instance.endpoint().address());
|
||||
System.out.println("Port: " + instance.endpoint().port());
|
||||
System.out.println("---");
|
||||
}
|
||||
} catch (RdsException e) {
|
||||
System.err.println(e.getMessage());
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Key Operations
|
||||
|
||||
### Creating DB Instances
|
||||
|
||||
Create new RDS database instances with various configurations.
|
||||
|
||||
**Create Basic DB Instance:**
|
||||
```java
|
||||
public static String createDBInstance(RdsClient rdsClient,
|
||||
String dbInstanceIdentifier,
|
||||
String dbName,
|
||||
String masterUsername,
|
||||
String masterPassword) {
|
||||
try {
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.dbInstanceIdentifier(dbInstanceIdentifier)
|
||||
.dbName(dbName)
|
||||
.engine("postgres")
|
||||
.engineVersion("14.7")
|
||||
.dbInstanceClass("db.t3.micro")
|
||||
.allocatedStorage(20)
|
||||
.masterUsername(masterUsername)
|
||||
.masterUserPassword(masterPassword)
|
||||
.publiclyAccessible(false)
|
||||
.build();
|
||||
|
||||
CreateDbInstanceResponse response = rdsClient.createDBInstance(request);
|
||||
System.out.println("Creating DB instance: " + response.dbInstance().dbInstanceArn());
|
||||
|
||||
return response.dbInstance().dbInstanceArn();
|
||||
} catch (RdsException e) {
|
||||
System.err.println("Error creating instance: " + e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Managing DB Parameter Groups
|
||||
|
||||
Create and manage custom parameter groups for database configuration.
|
||||
|
||||
**Create DB Parameter Group:**
|
||||
```java
|
||||
public static void createDBParameterGroup(RdsClient rdsClient,
|
||||
String groupName,
|
||||
String description) {
|
||||
try {
|
||||
CreateDbParameterGroupRequest request = CreateDbParameterGroupRequest.builder()
|
||||
.dbParameterGroupName(groupName)
|
||||
.dbParameterGroupFamily("postgres15")
|
||||
.description(description)
|
||||
.build();
|
||||
|
||||
CreateDbParameterGroupResponse response = rdsClient.createDBParameterGroup(request);
|
||||
System.out.println("Created parameter group: " + response.dbParameterGroup().dbParameterGroupName());
|
||||
} catch (RdsException e) {
|
||||
System.err.println("Error creating parameter group: " + e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Managing DB Snapshots
|
||||
|
||||
Create, restore, and manage database snapshots.
|
||||
|
||||
**Create DB Snapshot:**
|
||||
```java
|
||||
public static String createDBSnapshot(RdsClient rdsClient,
|
||||
String dbInstanceIdentifier,
|
||||
String snapshotIdentifier) {
|
||||
try {
|
||||
CreateDbSnapshotRequest request = CreateDbSnapshotRequest.builder()
|
||||
.dbInstanceIdentifier(dbInstanceIdentifier)
|
||||
.dbSnapshotIdentifier(snapshotIdentifier)
|
||||
.build();
|
||||
|
||||
CreateDbSnapshotResponse response = rdsClient.createDBSnapshot(request);
|
||||
System.out.println("Creating snapshot: " + response.dbSnapshot().dbSnapshotIdentifier());
|
||||
|
||||
return response.dbSnapshot().dbSnapshotArn();
|
||||
} catch (RdsException e) {
|
||||
System.err.println("Error creating snapshot: " + e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### Spring Boot Integration
|
||||
|
||||
Refer to [references/spring-boot-integration.md](references/spring-boot-integration.md) for complete Spring Boot integration examples including:
|
||||
|
||||
- Spring Boot configuration with application properties
|
||||
- RDS client bean configuration
|
||||
- Service layer implementation
|
||||
- REST controller design
|
||||
- Exception handling
|
||||
- Testing strategies
|
||||
|
||||
### Lambda Integration
|
||||
|
||||
Refer to [references/lambda-integration.md](references/lambda-integration.md) for Lambda integration examples including:
|
||||
|
||||
- Traditional Lambda + RDS connections
|
||||
- Lambda with connection pooling
|
||||
- Using AWS Secrets Manager for credentials
|
||||
- Lambda with AWS SDK for RDS management
|
||||
- Security configuration and best practices
|
||||
|
||||
## Advanced Operations
|
||||
|
||||
### Modifying DB Instances
|
||||
|
||||
Update existing RDS instances.
|
||||
|
||||
```java
|
||||
public static void modifyDBInstance(RdsClient rdsClient,
|
||||
String dbInstanceIdentifier,
|
||||
String newInstanceClass) {
|
||||
try {
|
||||
ModifyDbInstanceRequest request = ModifyDbInstanceRequest.builder()
|
||||
.dbInstanceIdentifier(dbInstanceIdentifier)
|
||||
.dbInstanceClass(newInstanceClass)
|
||||
.applyImmediately(false) // Apply during maintenance window
|
||||
.build();
|
||||
|
||||
ModifyDbInstanceResponse response = rdsClient.modifyDBInstance(request);
|
||||
System.out.println("Modified instance: " + response.dbInstance().dbInstanceIdentifier());
|
||||
System.out.println("New class: " + response.dbInstance().dbInstanceClass());
|
||||
} catch (RdsException e) {
|
||||
System.err.println("Error modifying instance: " + e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Deleting DB Instances
|
||||
|
||||
Delete RDS instances with optional final snapshot.
|
||||
|
||||
```java
|
||||
public static void deleteDBInstanceWithSnapshot(RdsClient rdsClient,
|
||||
String dbInstanceIdentifier,
|
||||
String finalSnapshotIdentifier) {
|
||||
try {
|
||||
DeleteDbInstanceRequest request = DeleteDbInstanceRequest.builder()
|
||||
.dbInstanceIdentifier(dbInstanceIdentifier)
|
||||
.skipFinalSnapshot(false)
|
||||
.finalDBSnapshotIdentifier(finalSnapshotIdentifier)
|
||||
.build();
|
||||
|
||||
DeleteDbInstanceResponse response = rdsClient.deleteDBInstance(request);
|
||||
System.out.println("Deleting instance: " + response.dbInstance().dbInstanceIdentifier());
|
||||
} catch (RdsException e) {
|
||||
System.err.println("Error deleting instance: " + e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Security
|
||||
|
||||
**Always use encryption:**
|
||||
```java
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.storageEncrypted(true)
|
||||
.kmsKeyId("arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012")
|
||||
.build();
|
||||
```
|
||||
|
||||
**Use VPC security groups:**
|
||||
```java
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.vpcSecurityGroupIds("sg-12345678")
|
||||
.publiclyAccessible(false)
|
||||
.build();
|
||||
```
|
||||
|
||||
### High Availability
|
||||
|
||||
**Enable Multi-AZ for production:**
|
||||
```java
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.multiAZ(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Backups
|
||||
|
||||
**Configure automated backups:**
|
||||
```java
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.backupRetentionPeriod(7)
|
||||
.preferredBackupWindow("03:00-04:00")
|
||||
.build();
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
**Enable CloudWatch logs:**
|
||||
```java
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.enableCloudwatchLogsExports("postgresql", "upgrade")
|
||||
.build();
|
||||
```
|
||||
|
||||
### Cost Optimization
|
||||
|
||||
**Use appropriate instance class:**
|
||||
```java
|
||||
// Development
|
||||
.dbInstanceClass("db.t3.micro")
|
||||
|
||||
// Production
|
||||
.dbInstanceClass("db.r5.large")
|
||||
```
|
||||
|
||||
### Deletion Protection
|
||||
|
||||
**Enable for production databases:**
|
||||
```java
|
||||
CreateDbInstanceRequest request = CreateDbInstanceRequest.builder()
|
||||
.deletionProtection(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Resource Management
|
||||
|
||||
**Always close clients:**
|
||||
```java
|
||||
try (RdsClient rdsClient = RdsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build()) {
|
||||
// Use client
|
||||
} // Automatically closed
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Maven Dependencies
|
||||
|
||||
```xml
|
||||
<dependencies>
|
||||
<!-- AWS SDK for RDS -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>rds</artifactId>
|
||||
<version>2.20.0</version> // Use the latest version available
|
||||
</dependency>
|
||||
|
||||
<!-- PostgreSQL Driver -->
|
||||
<dependency>
|
||||
<groupId>org.postgresql</groupId>
|
||||
<artifactId>postgresql</artifactId>
|
||||
<version>42.6.0</version> // Use the correct version available
|
||||
</dependency>
|
||||
|
||||
<!-- MySQL Driver -->
|
||||
<dependency>
|
||||
<groupId>mysql</groupId>
|
||||
<artifactId>mysql-connector-java</artifactId>
|
||||
<version>8.0.33</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
```
|
||||
|
||||
### Gradle Dependencies
|
||||
|
||||
```gradle
|
||||
dependencies {
|
||||
// AWS SDK for RDS
|
||||
implementation 'software.amazon.awssdk:rds:2.20.0'
|
||||
|
||||
// PostgreSQL Driver
|
||||
implementation 'org.postgresql:postgresql:42.6.0'
|
||||
|
||||
// MySQL Driver
|
||||
implementation 'mysql:mysql-connector-java:8.0.33'
|
||||
}
|
||||
```
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API reference, see:
|
||||
- [API Reference](references/api-reference.md) - Complete API documentation and data models
|
||||
- [Spring Boot Integration](references/spring-boot-integration.md) - Spring Boot patterns and examples
|
||||
- [Lambda Integration](references/lambda-integration.md) - Lambda function patterns and best practices
|
||||
|
||||
## Error Handling
|
||||
|
||||
See [API Reference](references/api-reference.md#error-handling) for comprehensive error handling patterns including common exceptions, error response structure, and pagination support.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Use connection pooling for multiple database operations
|
||||
- Implement retry logic for transient failures
|
||||
- Monitor CloudWatch metrics for performance optimization
|
||||
- Use appropriate instance types for workload requirements
|
||||
- Enable Performance Insights for database optimization
|
||||
|
||||
## Support
|
||||
|
||||
For support with AWS RDS operations using AWS SDK for Java 2.x:
|
||||
- AWS Documentation: [Amazon RDS User Guide](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Welcome.html)
|
||||
- AWS SDK Documentation: [AWS SDK for Java 2.x](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html)
|
||||
- AWS Support: [AWS Support Center](https://aws.amazon.com/premiumsupport/)
|
||||
122
skills/aws-java/aws-sdk-java-v2-rds/references/api-reference.md
Normal file
122
skills/aws-java/aws-sdk-java-v2-rds/references/api-reference.md
Normal file
@@ -0,0 +1,122 @@
|
||||
# AWS RDS API Reference
|
||||
|
||||
## Core API Operations
|
||||
|
||||
### Describe Operations
|
||||
- `describeDBInstances` - List database instances
|
||||
- `describeDBParameterGroups` - List parameter groups
|
||||
- `describeDBSnapshots` - List database snapshots
|
||||
- `describeDBSubnetGroups` - List subnet groups
|
||||
|
||||
### Instance Management
|
||||
- `createDBInstance` - Create new database instance
|
||||
- `modifyDBInstance` - Modify existing instance
|
||||
- `deleteDBInstance` - Delete database instance
|
||||
|
||||
### Parameter Groups
|
||||
- `createDBParameterGroup` - Create parameter group
|
||||
- `modifyDBParameterGroup` - Modify parameters
|
||||
- `deleteDBParameterGroup` - Delete parameter group
|
||||
|
||||
### Snapshots
|
||||
- `createDBSnapshot` - Create database snapshot
|
||||
- `restoreDBInstanceFromDBSnapshot` - Restore from snapshot
|
||||
- `deleteDBSnapshot` - Delete snapshot
|
||||
|
||||
## Key Data Models
|
||||
|
||||
### DBInstance
|
||||
```java
|
||||
String dbInstanceIdentifier() // Instance name
|
||||
String dbInstanceArn() // ARN identifier
|
||||
String engine() // Database engine
|
||||
String engineVersion() // Engine version
|
||||
String dbInstanceClass() // Instance type
|
||||
int allocatedStorage() // Storage size in GB
|
||||
Endpoint endpoint() // Connection endpoint
|
||||
String dbInstanceStatus() // Instance status
|
||||
boolean multiAZ() // Multi-AZ enabled
|
||||
boolean storageEncrypted() // Storage encrypted
|
||||
```
|
||||
|
||||
### DBParameter
|
||||
```java
|
||||
String parameterName() // Parameter name
|
||||
String parameterValue() // Parameter value
|
||||
String description() // Description
|
||||
int applyMethod() // Apply method (immediate/reboot)
|
||||
```
|
||||
|
||||
### CreateDbInstanceRequest Builder
|
||||
```java
|
||||
CreateDbInstanceRequest.builder()
|
||||
.dbInstanceIdentifier(identifier)
|
||||
.engine("postgres") // Database engine
|
||||
.engineVersion("15.2") // Engine version
|
||||
.dbInstanceClass("db.t3.micro") // Instance type
|
||||
.allocatedStorage(20) // Storage size
|
||||
.masterUsername(username) // Admin username
|
||||
.masterUserPassword(password) // Admin password
|
||||
.publiclyAccessible(false) // Public access
|
||||
.storageEncrypted(true) // Storage encryption
|
||||
.multiAZ(true) // High availability
|
||||
.backupRetentionPeriod(7) // Backup retention
|
||||
.deletionProtection(true) // Protection from deletion
|
||||
.build()
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Common Exceptions
|
||||
- `DBInstanceNotFoundFault` - Instance doesn't exist
|
||||
- `DBSnapshotAlreadyExistsFault` - Snapshot name conflicts
|
||||
- `InsufficientDBInstanceCapacity` - Instance type unavailable
|
||||
- `InvalidParameterValueException` - Invalid configuration value
|
||||
- `StorageQuotaExceeded` - Storage limit reached
|
||||
|
||||
### Error Response Structure
|
||||
```java
|
||||
try {
|
||||
rdsClient.createDBInstance(request);
|
||||
} catch (RdsException e) {
|
||||
// AWS specific error handling
|
||||
String errorCode = e.awsErrorDetails().errorCode();
|
||||
String errorMessage = e.awsErrorDetails().errorMessage();
|
||||
|
||||
switch (errorCode) {
|
||||
case "DBInstanceNotFoundFault":
|
||||
// Handle missing instance
|
||||
break;
|
||||
case "InvalidParameterValueException":
|
||||
// Handle invalid parameters
|
||||
break;
|
||||
default:
|
||||
// Generic error handling
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Pagination Support
|
||||
|
||||
### List Instances with Pagination
|
||||
```java
|
||||
DescribeDbInstancesRequest request = DescribeDbInstancesRequest.builder()
|
||||
.maxResults(100) // Limit results per page
|
||||
.build();
|
||||
|
||||
String marker = null;
|
||||
do {
|
||||
if (marker != null) {
|
||||
request = request.toBuilder()
|
||||
.marker(marker)
|
||||
.build();
|
||||
}
|
||||
|
||||
DescribeDbInstancesResponse response = rdsClient.describeDBInstances(request);
|
||||
List<DBInstance> instances = response.dbInstances();
|
||||
|
||||
// Process instances
|
||||
|
||||
marker = response.marker();
|
||||
} while (marker != null);
|
||||
```
|
||||
@@ -0,0 +1,382 @@
|
||||
# AWS Lambda Integration with RDS
|
||||
|
||||
## Lambda RDS Connection Patterns
|
||||
|
||||
### 1. Traditional Lambda + RDS Connection
|
||||
|
||||
```java
|
||||
import com.amazonaws.services.lambda.runtime.Context;
|
||||
import com.amazonaws.services.lambda.runtime.RequestHandler;
|
||||
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
|
||||
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
|
||||
public class RdsLambdaHandler implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
|
||||
|
||||
@Override
|
||||
public APIGatewayProxyResponseEvent handleRequest(APIGatewayProxyRequestEvent event, Context context) {
|
||||
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent();
|
||||
|
||||
try {
|
||||
// Get environment variables
|
||||
String host = System.getenv("ProxyHostName");
|
||||
String port = System.getenv("Port");
|
||||
String dbName = System.getenv("DBName");
|
||||
String username = System.getenv("DBUserName");
|
||||
String password = System.getenv("DBPassword");
|
||||
|
||||
// Create connection string
|
||||
String connectionString = String.format(
|
||||
"jdbc:mysql://%s:%s/%s?useSSL=true&requireSSL=true",
|
||||
host, port, dbName
|
||||
);
|
||||
|
||||
// Execute query
|
||||
String sql = "SELECT COUNT(*) FROM users";
|
||||
|
||||
try (Connection connection = DriverManager.getConnection(connectionString, username, password);
|
||||
PreparedStatement statement = connection.prepareStatement(sql);
|
||||
ResultSet resultSet = statement.executeQuery()) {
|
||||
|
||||
if (resultSet.next()) {
|
||||
int count = resultSet.getInt(1);
|
||||
response.setStatusCode(200);
|
||||
response.setBody("{\"count\": " + count + "}");
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
response.setStatusCode(500);
|
||||
response.setBody("{\"error\": \"" + e.getMessage() + "\"}");
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Lambda with Connection Pooling
|
||||
|
||||
```java
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import javax.sql.DataSource;
|
||||
|
||||
public class RdsLambdaConfig {
|
||||
|
||||
private static DataSource dataSource;
|
||||
|
||||
public static synchronized DataSource getDataSource() {
|
||||
if (dataSource == null) {
|
||||
HikariConfig config = new HikariConfig();
|
||||
|
||||
String host = System.getenv("ProxyHostName");
|
||||
String port = System.getenv("Port");
|
||||
String dbName = System.getenv("DBName");
|
||||
String username = System.getenv("DBUserName");
|
||||
String password = System.getenv("DBPassword");
|
||||
|
||||
config.setJdbcUrl(String.format("jdbc:mysql://%s:%s/%s", host, port, dbName));
|
||||
config.setUsername(username);
|
||||
config.setPassword(password);
|
||||
|
||||
// Connection pool settings
|
||||
config.setMaximumPoolSize(5);
|
||||
config.setMinimumIdle(2);
|
||||
config.setIdleTimeout(30000);
|
||||
config.setConnectionTimeout(20000);
|
||||
config.setMaxLifetime(1800000);
|
||||
|
||||
// MySQL-specific settings
|
||||
config.addDataSourceProperty("useSSL", true);
|
||||
config.addDataSourceProperty("requireSSL", true);
|
||||
config.addDataSourceProperty("serverSslCertificate", "rds-ca-2019");
|
||||
config.addDataSourceProperty("connectTimeout", "30");
|
||||
|
||||
dataSource = new HikariDataSource(config);
|
||||
}
|
||||
return dataSource;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Using AWS Secrets Manager for Credentials
|
||||
|
||||
```java
|
||||
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
|
||||
import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder;
|
||||
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest;
|
||||
import com.amazonaws.services.secretsmanager.model.GetSecretValueResult;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class RdsSecretsHelper {
|
||||
|
||||
private static final String SECRET_NAME = "prod/rds/db_credentials";
|
||||
private static final String REGION = "us-east-1";
|
||||
|
||||
public static Map<String, String> getRdsCredentials() {
|
||||
AWSSecretsManager client = AWSSecretsManagerClientBuilder.standard()
|
||||
.withRegion(REGION)
|
||||
.build();
|
||||
|
||||
GetSecretValueRequest request = GetSecretValueRequest.builder()
|
||||
.secretId(SECRET_NAME)
|
||||
.build();
|
||||
|
||||
GetSecretValueResult result = client.getSecretValue(request);
|
||||
|
||||
// Parse secret JSON
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
Map<String, Object> secretMap = objectMapper.readValue(result.getSecretString(), HashMap.class);
|
||||
|
||||
Map<String, String> credentials = new HashMap<>();
|
||||
secretMap.forEach((key, value) -> {
|
||||
credentials.put(key, value.toString());
|
||||
});
|
||||
|
||||
return credentials;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Lambda with AWS SDK for RDS
|
||||
|
||||
```java
|
||||
import com.amazonaws.services.lambda.runtime.Context;
|
||||
import com.amazonaws.services.lambda.runtime.RequestHandler;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.rds.RdsClient;
|
||||
import software.amazon.awssdk.services.rds.model.*;
|
||||
|
||||
public class RdsManagementLambda implements RequestHandler<ApiRequest, ApiResponse> {
|
||||
|
||||
@Override
|
||||
public ApiResponse handleRequest(ApiRequest request, Context context) {
|
||||
RdsClient rdsClient = RdsClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
try {
|
||||
switch (request.getAction()) {
|
||||
case "list-instances":
|
||||
return listInstances(rdsClient);
|
||||
case "create-snapshot":
|
||||
return createSnapshot(rdsClient, request.getInstanceId(), request.getSnapshotId());
|
||||
case "describe-instance":
|
||||
return describeInstance(rdsClient, request.getInstanceId());
|
||||
default:
|
||||
return new ApiResponse(400, "Unknown action: " + request.getAction());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
context.getLogger().log("Error: " + e.getMessage());
|
||||
return new ApiResponse(500, "Error: " + e.getMessage());
|
||||
} finally {
|
||||
rdsClient.close();
|
||||
}
|
||||
}
|
||||
|
||||
private ApiResponse listInstances(RdsClient rdsClient) {
|
||||
DescribeDbInstancesResponse response = rdsClient.describeDBInstances();
|
||||
return new ApiResponse(200, response.toString());
|
||||
}
|
||||
|
||||
private ApiResponse createSnapshot(RdsClient rdsClient, String instanceId, String snapshotId) {
|
||||
CreateDbSnapshotRequest request = CreateDbSnapshotRequest.builder()
|
||||
.dbInstanceIdentifier(instanceId)
|
||||
.dbSnapshotIdentifier(snapshotId)
|
||||
.build();
|
||||
|
||||
CreateDbSnapshotResponse response = rdsClient.createDBSnapshot(request);
|
||||
return new ApiResponse(200, "Snapshot created: " + response.dbSnapshot().dbSnapshotIdentifier());
|
||||
}
|
||||
|
||||
private ApiResponse describeInstance(RdsClient rdsClient, String instanceId) {
|
||||
DescribeDbInstancesRequest request = DescribeDbInstancesRequest.builder()
|
||||
.dbInstanceIdentifier(instanceId)
|
||||
.build();
|
||||
|
||||
DescribeDbInstancesResponse response = rdsClient.describeDBInstances(request);
|
||||
return new ApiResponse(200, response.toString());
|
||||
}
|
||||
}
|
||||
|
||||
class ApiRequest {
|
||||
private String action;
|
||||
private String instanceId;
|
||||
private String snapshotId;
|
||||
// getters and setters
|
||||
}
|
||||
|
||||
class ApiResponse {
|
||||
private int statusCode;
|
||||
private String body;
|
||||
// constructor, getters
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices for Lambda + RDS
|
||||
|
||||
### 1. Security Configuration
|
||||
|
||||
**IAM Role:**
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"rds:*"
|
||||
],
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Security Group:**
|
||||
- Use security groups to restrict access
|
||||
- Only allow Lambda function IP ranges
|
||||
- Use VPC endpoints for private connections
|
||||
|
||||
### 2. Environment Variables
|
||||
|
||||
```bash
|
||||
# Environment variables for Lambda
|
||||
DB_HOST=mydb.abc123.us-east-1.rds.amazonaws.com
|
||||
DB_PORT=5432
|
||||
DB_NAME=mydatabase
|
||||
DB_USERNAME=admin
|
||||
DB_PASSWORD=${DB_PASSWORD}
|
||||
DB_CONNECTION_STRING=jdbc:postgresql://${DB_HOST}:${DB_PORT}/${DB_NAME}
|
||||
```
|
||||
|
||||
### 3. Error Handling
|
||||
|
||||
```java
|
||||
import com.amazonaws.services.lambda.runtime.LambdaLogger;
|
||||
|
||||
public class LambdaErrorHandler {
|
||||
|
||||
public static void handleRdsError(Exception e, LambdaLogger logger) {
|
||||
if (e instanceof RdsException) {
|
||||
RdsException rdsException = (RdsException) e;
|
||||
logger.log("RDS Error: " + rdsException.awsErrorDetails().errorCode());
|
||||
|
||||
switch (rdsException.awsErrorDetails().errorCode()) {
|
||||
case "DBInstanceNotFoundFault":
|
||||
logger.log("Database instance not found");
|
||||
break;
|
||||
case "InvalidParameterValueException":
|
||||
logger.log("Invalid parameter provided");
|
||||
break;
|
||||
case "InstanceAlreadyExistsFault":
|
||||
logger.log("Instance already exists");
|
||||
break;
|
||||
default:
|
||||
logger.log("Unknown RDS error: " + rdsException.getMessage());
|
||||
}
|
||||
} else {
|
||||
logger.log("Non-RDS error: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Performance Optimization
|
||||
|
||||
**Cold Start Mitigation:**
|
||||
```java
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.Connection;
|
||||
|
||||
public class RdsConnectionHelper {
|
||||
private static DataSource dataSource;
|
||||
private static long lastConnectionTime = 0;
|
||||
private static final long CONNECTION_TIMEOUT = 300000; // 5 minutes
|
||||
|
||||
public static Connection getConnection() throws SQLException {
|
||||
long currentTime = System.currentTimeMillis();
|
||||
|
||||
if (dataSource == null || (currentTime - lastConnectionTime) > CONNECTION_TIMEOUT) {
|
||||
dataSource = createDataSource();
|
||||
lastConnectionTime = currentTime;
|
||||
}
|
||||
|
||||
return dataSource.getConnection();
|
||||
}
|
||||
|
||||
private static DataSource createDataSource() {
|
||||
// Connection pool creation
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Batch Processing:**
|
||||
```java
|
||||
public class RdsBatchProcessor {
|
||||
|
||||
public void processBatch(List<String> userIds) {
|
||||
String sql = "SELECT * FROM users WHERE user_id IN (?)";
|
||||
|
||||
try (Connection connection = getConnection();
|
||||
PreparedStatement statement = connection.prepareStatement(sql)) {
|
||||
|
||||
// Convert list to SQL IN clause
|
||||
String placeholders = userIds.stream()
|
||||
.map(id -> "?")
|
||||
.collect(Collectors.joining(","));
|
||||
|
||||
String finalSql = sql.replace("?", placeholders);
|
||||
|
||||
// Set parameters
|
||||
for (int i = 0; i < userIds.size(); i++) {
|
||||
statement.setString(i + 1, userIds.get(i));
|
||||
}
|
||||
|
||||
ResultSet resultSet = statement.executeQuery();
|
||||
// Process results
|
||||
|
||||
} catch (SQLException e) {
|
||||
LambdaErrorHandler.handleRdsError(e, logger);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Monitoring and Logging
|
||||
|
||||
```java
|
||||
import com.amazonaws.services.cloudwatch.AmazonCloudWatch;
|
||||
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder;
|
||||
import com.amazonaws.services.cloudwatch.model.MetricDatum;
|
||||
import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest;
|
||||
|
||||
public class RdsMetricsPublisher {
|
||||
|
||||
private static final String NAMESPACE = "RDS/Lambda";
|
||||
private AmazonCloudWatch cloudWatch;
|
||||
|
||||
public RdsMetricsPublisher() {
|
||||
this.cloudWatch = AmazonCloudWatchClientBuilder.defaultClient();
|
||||
}
|
||||
|
||||
public void publishMetric(String metricName, double value) {
|
||||
MetricDatum datum = new MetricDatum()
|
||||
.withMetricName(metricName)
|
||||
.withUnit("Count")
|
||||
.withValue(value)
|
||||
.withTimestamp(new Date());
|
||||
|
||||
PutMetricDataRequest request = new PutMetricDataRequest()
|
||||
.withNamespace(NAMESPACE)
|
||||
.withMetricData(Collections.singletonList(datum));
|
||||
|
||||
cloudWatch.putMetricData(request);
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,325 @@
|
||||
# Spring Boot Integration with AWS RDS
|
||||
|
||||
## Configuration
|
||||
|
||||
### application.properties
|
||||
```properties
|
||||
# AWS Configuration
|
||||
aws.region=us-east-1
|
||||
aws.rds.instance-identifier=mydb-instance
|
||||
|
||||
# RDS Connection (from RDS endpoint)
|
||||
spring.datasource.url=jdbc:postgresql://mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydatabase
|
||||
spring.datasource.username=admin
|
||||
spring.datasource.password=${DB_PASSWORD}
|
||||
spring.datasource.driver-class-name=org.postgresql.Driver
|
||||
|
||||
# JPA Configuration
|
||||
spring.jpa.hibernate.ddl-auto=validate
|
||||
spring.jpa.show-sql=false
|
||||
spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.PostgreSQLDialect
|
||||
|
||||
# Connection Pool Configuration
|
||||
spring.datasource.hikari.maximum-pool-size=10
|
||||
spring.datasource.hikari.minimum-idle=5
|
||||
spring.datasource.hikari.idle-timeout=30000
|
||||
spring.datasource.hikari.connection-timeout=20000
|
||||
```
|
||||
|
||||
### AWS Configuration
|
||||
```java
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.rds.RdsClient;
|
||||
|
||||
@Configuration
|
||||
public class AwsRdsConfig {
|
||||
|
||||
@Value("${aws.region}")
|
||||
private String awsRegion;
|
||||
|
||||
@Bean
|
||||
public RdsClient rdsClient() {
|
||||
return RdsClient.builder()
|
||||
.region(Region.of(awsRegion))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Service Layer
|
||||
```java
|
||||
import org.springframework.stereotype.Service;
|
||||
import software.amazon.awssdk.services.rds.RdsClient;
|
||||
import software.amazon.awssdk.services.rds.model.*;
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
public class RdsService {
|
||||
|
||||
private final RdsClient rdsClient;
|
||||
|
||||
public RdsService(RdsClient rdsClient) {
|
||||
this.rdsClient = rdsClient;
|
||||
}
|
||||
|
||||
public List<DBInstance> listInstances() {
|
||||
DescribeDbInstancesResponse response = rdsClient.describeDBInstances();
|
||||
return response.dbInstances();
|
||||
}
|
||||
|
||||
public DBInstance getInstanceDetails(String instanceId) {
|
||||
DescribeDbInstancesRequest request = DescribeDbInstancesRequest.builder()
|
||||
.dbInstanceIdentifier(instanceId)
|
||||
.build();
|
||||
|
||||
DescribeDbInstancesResponse response = rdsClient.describeDBInstances(request);
|
||||
return response.dbInstances().get(0);
|
||||
}
|
||||
|
||||
public String createSnapshot(String instanceId, String snapshotId) {
|
||||
CreateDbSnapshotRequest request = CreateDbSnapshotRequest.builder()
|
||||
.dbInstanceIdentifier(instanceId)
|
||||
.dbSnapshotIdentifier(snapshotId)
|
||||
.build();
|
||||
|
||||
CreateDbSnapshotResponse response = rdsClient.createDBSnapshot(request);
|
||||
return response.dbSnapshot().dbSnapshotArn();
|
||||
}
|
||||
|
||||
public void modifyInstance(String instanceId, String newInstanceClass) {
|
||||
ModifyDbInstanceRequest request = ModifyDbInstanceRequest.builder()
|
||||
.dbInstanceIdentifier(instanceId)
|
||||
.dbInstanceClass(newInstanceClass)
|
||||
.applyImmediately(true)
|
||||
.build();
|
||||
|
||||
rdsClient.modifyDBInstance(request);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### REST Controller
|
||||
```java
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import software.amazon.awssdk.services.rds.model.DBInstance;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/rds")
|
||||
public class RdsController {
|
||||
|
||||
private final RdsService rdsService;
|
||||
|
||||
public RdsController(RdsService rdsService) {
|
||||
this.rdsService = rdsService;
|
||||
}
|
||||
|
||||
@GetMapping("/instances")
|
||||
public ResponseEntity<List<DBInstance>> listInstances() {
|
||||
return ResponseEntity.ok(rdsService.listInstances());
|
||||
}
|
||||
|
||||
@GetMapping("/instances/{id}")
|
||||
public ResponseEntity<DBInstance> getInstanceDetails(@PathVariable String id) {
|
||||
return ResponseEntity.ok(rdsService.getInstanceDetails(id));
|
||||
}
|
||||
|
||||
@PostMapping("/snapshots")
|
||||
public ResponseEntity<String> createSnapshot(
|
||||
@RequestParam String instanceId,
|
||||
@RequestParam String snapshotId) {
|
||||
String arn = rdsService.createSnapshot(instanceId, snapshotId);
|
||||
return ResponseEntity.ok(arn);
|
||||
}
|
||||
|
||||
@PutMapping("/instances/{id}")
|
||||
public ResponseEntity<String> modifyInstance(
|
||||
@PathVariable String id,
|
||||
@RequestParam String instanceClass) {
|
||||
rdsService.modifyInstance(id, instanceClass);
|
||||
return ResponseEntity.ok("Instance modified successfully");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Exception Handling
|
||||
```java
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.web.bind.annotation.ExceptionHandler;
|
||||
import org.springframework.web.bind.annotation.ResponseStatus;
|
||||
import org.springframework.web.bind.annotation.RestControllerAdvice;
|
||||
|
||||
@RestControllerAdvice
|
||||
public class RdsExceptionHandler {
|
||||
|
||||
@ExceptionHandler(RdsException.class)
|
||||
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
|
||||
public ErrorResponse handleRdsException(RdsException e) {
|
||||
return new ErrorResponse(
|
||||
"RDS_ERROR",
|
||||
e.getMessage(),
|
||||
e.awsErrorDetails().errorCode()
|
||||
);
|
||||
}
|
||||
|
||||
@ExceptionHandler(Exception.class)
|
||||
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
|
||||
public ErrorResponse handleGenericException(Exception e) {
|
||||
return new ErrorResponse(
|
||||
"INTERNAL_ERROR",
|
||||
e.getMessage()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class ErrorResponse {
|
||||
private String code;
|
||||
private String message;
|
||||
private String details;
|
||||
|
||||
// Constructor, getters, setters
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class RdsServiceTest {
|
||||
|
||||
@Mock
|
||||
private RdsClient rdsClient;
|
||||
|
||||
@Test
|
||||
void listInstances_shouldReturnInstances() {
|
||||
// Arrange
|
||||
DescribeDbInstancesResponse response = DescribeDbInstancesResponse.builder()
|
||||
.dbInstances(List.of(createTestInstance()))
|
||||
.build();
|
||||
|
||||
when(rdsClient.describeDBInstances()).thenReturn(response);
|
||||
|
||||
RdsService service = new RdsService(rdsClient);
|
||||
|
||||
// Act
|
||||
List<DBInstance> result = service.listInstances();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, result.size());
|
||||
verify(rdsClient).describeDBInstances();
|
||||
}
|
||||
|
||||
private DBInstance createTestInstance() {
|
||||
return DBInstance.builder()
|
||||
.dbInstanceIdentifier("test-instance")
|
||||
.engine("postgres")
|
||||
.dbInstanceStatus("available")
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@SpringBootTest
|
||||
@ActiveProfiles = "test"
|
||||
class RdsServiceIntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private RdsService rdsService;
|
||||
|
||||
@Test
|
||||
void listInstances_integrationTest() {
|
||||
// This test requires actual AWS credentials and RDS instances
|
||||
// Should only run with proper test configuration
|
||||
assumeTrue(false, "Integration test disabled");
|
||||
|
||||
List<DBInstance> instances = rdsService.listInstances();
|
||||
assertNotNull(instances);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Configuration Management
|
||||
- Use Spring profiles for different environments
|
||||
- Externalize sensitive configuration (passwords, keys)
|
||||
- Use Spring Cloud Config for multi-environment management
|
||||
|
||||
### 2. Connection Pooling
|
||||
```properties
|
||||
# HikariCP Configuration
|
||||
spring.datasource.hikari.maximum-pool-size=20
|
||||
spring.datasource.hikari.minimum-idle=10
|
||||
spring.datasource.hikari.idle-timeout=600000
|
||||
spring.datasource.hikari.connection-timeout=30000
|
||||
spring.datasource.hikari.connection-test-query=SELECT 1
|
||||
```
|
||||
|
||||
### 3. Retry Logic
|
||||
```java
|
||||
import org.springframework.retry.annotation.Retryable;
|
||||
import org.springframework.retry.annotation.Backoff;
|
||||
|
||||
@Service
|
||||
public class RdsServiceWithRetry {
|
||||
|
||||
private final RdsClient rdsClient;
|
||||
|
||||
@Retryable(value = { RdsException.class },
|
||||
maxAttempts = 3,
|
||||
backoff = @Backoff(delay = 1000))
|
||||
public List<DBInstance> listInstancesWithRetry() {
|
||||
return rdsClient.describeDBInstances().dbInstances();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Monitoring
|
||||
```java
|
||||
import org.springframework.boot.actuator.health.Health;
|
||||
import org.springframework.boot.actuator.health.HealthIndicator;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class RdsHealthIndicator implements HealthIndicator {
|
||||
|
||||
private final RdsClient rdsClient;
|
||||
|
||||
public RdsHealthIndicator(RdsClient rdsClient) {
|
||||
this.rdsClient = rdsClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Health health() {
|
||||
try {
|
||||
rdsClient.describeDBInstances();
|
||||
return Health.up()
|
||||
.withDetail("service", "RDS")
|
||||
.build();
|
||||
} catch (Exception e) {
|
||||
return Health.down()
|
||||
.withDetail("error", e.getMessage())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
691
skills/aws-java/aws-sdk-java-v2-s3/SKILL.md
Normal file
691
skills/aws-java/aws-sdk-java-v2-s3/SKILL.md
Normal file
@@ -0,0 +1,691 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-s3
|
||||
description: Amazon S3 patterns and examples using AWS SDK for Java 2.x. Use when working with S3 buckets, uploading/downloading objects, multipart uploads, presigned URLs, S3 Transfer Manager, object operations, or S3-specific configurations.
|
||||
category: aws
|
||||
tags: [aws, s3, java, sdk, storage, objects, transfer-manager, presigned-urls]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - Amazon S3
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Creating, listing, or deleting S3 buckets with proper configuration
|
||||
- Uploading or downloading objects from S3 with metadata and encryption
|
||||
- Working with multipart uploads for large files (>100MB) with error handling
|
||||
- Generating presigned URLs for temporary access to S3 objects
|
||||
- Copying or moving objects between S3 buckets with metadata preservation
|
||||
- Setting object metadata, storage classes, and access controls
|
||||
- Implementing S3 Transfer Manager for optimized file transfers
|
||||
- Integrating S3 with Spring Boot applications for cloud storage
|
||||
- Setting up S3 event notifications for object lifecycle management
|
||||
- Managing bucket policies, CORS configuration, and access controls
|
||||
- Implementing retry mechanisms and error handling for S3 operations
|
||||
- Testing S3 integrations with LocalStack for development environments
|
||||
|
||||
## Dependencies
|
||||
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>s3</artifactId>
|
||||
<version>2.20.0</version> // Use the latest stable version
|
||||
</dependency>
|
||||
|
||||
<!-- For S3 Transfer Manager -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>s3-transfer-manager</artifactId>
|
||||
<version>2.20.0</version> // Use the latest stable version
|
||||
</dependency>
|
||||
|
||||
<!-- For async operations -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>netty-nio-client</artifactId>
|
||||
<version>2.20.0</version> // Use the latest stable version
|
||||
</dependency>
|
||||
```
|
||||
|
||||
## Client Setup
|
||||
|
||||
### Basic Synchronous Client
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Basic Asynchronous Client
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
||||
|
||||
S3AsyncClient s3AsyncClient = S3AsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Configured Client with Retry Logic
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.http.apache.ApacheHttpClient;
|
||||
import software.amazon.awssdk.core.retry.RetryPolicy;
|
||||
import software.amazon.awssdk.core.retry.backoff.ExponentialRetryBackoff;
|
||||
import java.time.Duration;
|
||||
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.httpClientBuilder(ApacheHttpClient.builder()
|
||||
.maxConnections(200)
|
||||
.connectionTimeout(Duration.ofSeconds(5)))
|
||||
.overrideConfiguration(b -> b
|
||||
.apiCallTimeout(Duration.ofSeconds(60))
|
||||
.apiCallAttemptTimeout(Duration.ofSeconds(30))
|
||||
.retryPolicy(RetryPolicy.builder()
|
||||
.numRetries(3)
|
||||
.retryBackoffStrategy(ExponentialRetryBackoff.builder()
|
||||
.baseDelay(Duration.ofSeconds(1))
|
||||
.maxBackoffTime(Duration.ofSeconds(30))
|
||||
.build())
|
||||
.build()))
|
||||
.build();
|
||||
```
|
||||
|
||||
## Basic Bucket Operations
|
||||
|
||||
### Create Bucket
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public void createBucket(S3Client s3Client, String bucketName) {
|
||||
try {
|
||||
CreateBucketRequest request = CreateBucketRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.build();
|
||||
|
||||
s3Client.createBucket(request);
|
||||
|
||||
// Wait until bucket is ready
|
||||
HeadBucketRequest waitRequest = HeadBucketRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.build();
|
||||
|
||||
s3Client.waiter().waitUntilBucketExists(waitRequest);
|
||||
System.out.println("Bucket created successfully: " + bucketName);
|
||||
|
||||
} catch (S3Exception e) {
|
||||
System.err.println("Error creating bucket: " + e.awsErrorDetails().errorMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### List All Buckets
|
||||
|
||||
```java
|
||||
public List<String> listAllBuckets(S3Client s3Client) {
|
||||
ListBucketsResponse response = s3Client.listBuckets();
|
||||
|
||||
return response.buckets().stream()
|
||||
.map(Bucket::name)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
```
|
||||
|
||||
### Check if Bucket Exists
|
||||
|
||||
```java
|
||||
public boolean bucketExists(S3Client s3Client, String bucketName) {
|
||||
try {
|
||||
HeadBucketRequest request = HeadBucketRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.build();
|
||||
|
||||
s3Client.headBucket(request);
|
||||
return true;
|
||||
|
||||
} catch (NoSuchBucketException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Object Operations
|
||||
|
||||
### Upload File to S3
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.sync.RequestBody;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
public void uploadFile(S3Client s3Client, String bucketName, String key, String filePath) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
|
||||
System.out.println("File uploaded: " + key);
|
||||
}
|
||||
```
|
||||
|
||||
### Download File from S3
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.ResponseInputStream;
|
||||
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
public void downloadFile(S3Client s3Client, String bucketName, String key, String destPath) {
|
||||
GetObjectRequest request = GetObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
s3Client.getObject(request, Paths.get(destPath));
|
||||
System.out.println("File downloaded: " + destPath);
|
||||
}
|
||||
```
|
||||
|
||||
### Get Object Metadata
|
||||
|
||||
```java
|
||||
public Map<String, String> getObjectMetadata(S3Client s3Client, String bucketName, String key) {
|
||||
HeadObjectRequest request = HeadObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
HeadObjectResponse response = s3Client.headObject(request);
|
||||
return response.metadata();
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Object Operations
|
||||
|
||||
### Upload with Metadata and Encryption
|
||||
|
||||
```java
|
||||
public void uploadWithMetadata(S3Client s3Client, String bucketName, String key,
|
||||
String filePath, Map<String, String> metadata) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.metadata(metadata)
|
||||
.contentType("application/pdf")
|
||||
.serverSideEncryption(ServerSideEncryption.AES256)
|
||||
.storageClass(StorageClass.STANDARD_IA)
|
||||
.build();
|
||||
|
||||
PutObjectResponse response = s3Client.putObject(request,
|
||||
RequestBody.fromFile(Paths.get(filePath)));
|
||||
|
||||
System.out.println("Upload completed. ETag: " + response.eTag());
|
||||
}
|
||||
```
|
||||
|
||||
### Copy Object Between Buckets
|
||||
|
||||
```java
|
||||
public void copyObject(S3Client s3Client, String sourceBucket, String sourceKey,
|
||||
String destBucket, String destKey) {
|
||||
CopyObjectRequest request = CopyObjectRequest.builder()
|
||||
.sourceBucket(sourceBucket)
|
||||
.sourceKey(sourceKey)
|
||||
.destinationBucket(destBucket)
|
||||
.destinationKey(destKey)
|
||||
.build();
|
||||
|
||||
s3Client.copyObject(request);
|
||||
System.out.println("Object copied: " + sourceKey + " -> " + destKey);
|
||||
}
|
||||
```
|
||||
|
||||
### Delete Multiple Objects
|
||||
|
||||
```java
|
||||
public void deleteMultipleObjects(S3Client s3Client, String bucketName, List<String> keys) {
|
||||
List<ObjectIdentifier> objectIds = keys.stream()
|
||||
.map(key -> ObjectIdentifier.builder().key(key).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Delete delete = Delete.builder()
|
||||
.objects(objectIds)
|
||||
.build();
|
||||
|
||||
DeleteObjectsRequest request = DeleteObjectsRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.delete(delete)
|
||||
.build();
|
||||
|
||||
DeleteObjectsResponse response = s3Client.deleteObjects(request);
|
||||
|
||||
response.deleted().forEach(deleted ->
|
||||
System.out.println("Deleted: " + deleted.key()));
|
||||
|
||||
response.errors().forEach(error ->
|
||||
System.err.println("Failed to delete " + error.key() + ": " + error.message()));
|
||||
}
|
||||
```
|
||||
|
||||
## Presigned URLs
|
||||
|
||||
### Generate Download URL
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
|
||||
import software.amazon.awssdk.services.s3.presigner.model.*;
|
||||
import java.time.Duration;
|
||||
|
||||
public String generateDownloadUrl(String bucketName, String key) {
|
||||
try (S3Presigner presigner = S3Presigner.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build()) {
|
||||
|
||||
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
|
||||
.signatureDuration(Duration.ofMinutes(10))
|
||||
.getObjectRequest(getObjectRequest)
|
||||
.build();
|
||||
|
||||
PresignedGetObjectRequest presignedRequest = presigner.presignGetObject(presignRequest);
|
||||
|
||||
return presignedRequest.url().toString();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Generate Upload URL
|
||||
|
||||
```java
|
||||
public String generateUploadUrl(String bucketName, String key) {
|
||||
try (S3Presigner presigner = S3Presigner.create()) {
|
||||
|
||||
PutObjectRequest putObjectRequest = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
PutObjectPresignRequest presignRequest = PutObjectPresignRequest.builder()
|
||||
.signatureDuration(Duration.ofMinutes(5))
|
||||
.putObjectRequest(putObjectRequest)
|
||||
.build();
|
||||
|
||||
PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest);
|
||||
|
||||
return presignedRequest.url().toString();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## S3 Transfer Manager
|
||||
|
||||
### Upload with Transfer Manager
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.transfer.s3.*;
|
||||
import software.amazon.awssdk.transfer.s3.model.*;
|
||||
|
||||
public void uploadWithTransferManager(String bucketName, String key, String filePath) {
|
||||
try (S3TransferManager transferManager = S3TransferManager.create()) {
|
||||
|
||||
UploadFileRequest uploadRequest = UploadFileRequest.builder()
|
||||
.putObjectRequest(req -> req
|
||||
.bucket(bucketName)
|
||||
.key(key))
|
||||
.source(Paths.get(filePath))
|
||||
.build();
|
||||
|
||||
FileUpload upload = transferManager.uploadFile(uploadRequest);
|
||||
|
||||
// Monitor progress
|
||||
upload.progressFuture().thenAccept(progress -> {
|
||||
System.out.println("Upload progress: " + progress.progressPercent() + "%");
|
||||
});
|
||||
|
||||
CompletedFileUpload result = upload.completionFuture().join();
|
||||
|
||||
System.out.println("Upload complete. ETag: " + result.response().eTag());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Download with Transfer Manager
|
||||
|
||||
```java
|
||||
public void downloadWithTransferManager(String bucketName, String key, String destPath) {
|
||||
try (S3TransferManager transferManager = S3TransferManager.create()) {
|
||||
|
||||
DownloadFileRequest downloadRequest = DownloadFileRequest.builder()
|
||||
.getObjectRequest(req -> req
|
||||
.bucket(bucketName)
|
||||
.key(key))
|
||||
.destination(Paths.get(destPath))
|
||||
.build();
|
||||
|
||||
FileDownload download = transferManager.downloadFile(downloadRequest);
|
||||
|
||||
CompletedFileDownload result = download.completionFuture().join();
|
||||
|
||||
System.out.println("Download complete. Size: " + result.response().contentLength());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Spring Boot Integration
|
||||
|
||||
### Configuration Properties
|
||||
|
||||
```java
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
@ConfigurationProperties(prefix = "aws.s3")
|
||||
public class S3Properties {
|
||||
private String accessKey;
|
||||
private String secretKey;
|
||||
private String region = "us-east-1";
|
||||
private String endpoint;
|
||||
private String defaultBucket;
|
||||
private boolean asyncEnabled = false;
|
||||
private boolean transferManagerEnabled = true;
|
||||
|
||||
// Getters and setters
|
||||
public String getAccessKey() { return accessKey; }
|
||||
public void setAccessKey(String accessKey) { this.accessKey = accessKey; }
|
||||
// ... other getters and setters
|
||||
}
|
||||
```
|
||||
|
||||
### S3 Configuration Class
|
||||
|
||||
```java
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import java.net.URI;
|
||||
|
||||
@Configuration
|
||||
public class S3Configuration {
|
||||
|
||||
private final S3Properties properties;
|
||||
|
||||
public S3Configuration(S3Properties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public S3Client s3Client() {
|
||||
S3Client.Builder builder = S3Client.builder()
|
||||
.region(Region.of(properties.getRegion()));
|
||||
|
||||
if (properties.getAccessKey() != null && properties.getSecretKey() != null) {
|
||||
builder.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
properties.getAccessKey(),
|
||||
properties.getSecretKey())));
|
||||
}
|
||||
|
||||
if (properties.getEndpoint() != null) {
|
||||
builder.endpointOverride(URI.create(properties.getEndpoint()));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public S3AsyncClient s3AsyncClient() {
|
||||
S3AsyncClient.Builder builder = S3AsyncClient.builder()
|
||||
.region(Region.of(properties.getRegion()));
|
||||
|
||||
if (properties.getAccessKey() != null && properties.getSecretKey() != null) {
|
||||
builder.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
properties.getAccessKey(),
|
||||
properties.getSecretKey())));
|
||||
}
|
||||
|
||||
if (properties.getEndpoint() != null) {
|
||||
builder.endpointOverride(URI.create(properties.getEndpoint()));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public S3TransferManager s3TransferManager() {
|
||||
return S3TransferManager.builder()
|
||||
.s3Client(s3Client())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### S3 Service
|
||||
|
||||
```java
|
||||
import org.springframework.stereotype.Service;
|
||||
import software.amazon.awssdk.transfer.s3.S3TransferManager;
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
import java.nio.file.*;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class S3Service {
|
||||
|
||||
private final S3Client s3Client;
|
||||
private final S3AsyncClient s3AsyncClient;
|
||||
private final S3TransferManager transferManager;
|
||||
private final S3Properties properties;
|
||||
|
||||
public CompletableFuture<Void> uploadFileAsync(String key, Path file) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(properties.getDefaultBucket())
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
return CompletableFuture.runAsync(() -> {
|
||||
s3Client.putObject(request, RequestBody.fromFile(file));
|
||||
});
|
||||
}
|
||||
|
||||
public CompletableFuture<byte[]> downloadFileAsync(String key) {
|
||||
GetObjectRequest request = GetObjectRequest.builder()
|
||||
.bucket(properties.getDefaultBucket())
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
try (ResponseInputStream<GetObjectResponse> response = s3Client.getObject(request)) {
|
||||
return response.readAllBytes();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read S3 object", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public CompletableFuture<String> generatePresignedUrl(String key, Duration duration) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
try (S3Presigner presigner = S3Presigner.builder()
|
||||
.region(Region.of(properties.getRegion()))
|
||||
.build()) {
|
||||
|
||||
GetObjectRequest getRequest = GetObjectRequest.builder()
|
||||
.bucket(properties.getDefaultBucket())
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
|
||||
.signatureDuration(duration)
|
||||
.getObjectRequest(getRequest)
|
||||
.build();
|
||||
|
||||
return presigner.presignGetObject(presignRequest).url().toString();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public Flux<S3Object> listObjects(String prefix) {
|
||||
ListObjectsV2Request request = ListObjectsV2Request.builder()
|
||||
.bucket(properties.getDefaultBucket())
|
||||
.prefix(prefix)
|
||||
.build();
|
||||
|
||||
return Flux.create(sink -> {
|
||||
s3Client.listObjectsV2Paginator(request)
|
||||
.contents()
|
||||
.forEach(sink::next);
|
||||
sink.complete();
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic File Upload Example
|
||||
|
||||
```java
|
||||
public class S3UploadExample {
|
||||
public static void main(String[] args) {
|
||||
// Initialize client
|
||||
S3Client s3Client = S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
String bucketName = "my-example-bucket";
|
||||
String filePath = "document.pdf";
|
||||
String key = "uploads/document.pdf";
|
||||
|
||||
// Create bucket if it doesn't exist
|
||||
if (!bucketExists(s3Client, bucketName)) {
|
||||
createBucket(s3Client, bucketName);
|
||||
}
|
||||
|
||||
// Upload file
|
||||
Map<String, String> metadata = Map.of(
|
||||
"author", "John Doe",
|
||||
"content-type", "application/pdf",
|
||||
"upload-date", java.time.LocalDate.now().toString()
|
||||
);
|
||||
|
||||
uploadWithMetadata(s3Client, bucketName, key, filePath, metadata);
|
||||
|
||||
// Generate presigned URL
|
||||
String downloadUrl = generateDownloadUrl(bucketName, key);
|
||||
System.out.println("Download URL: " + downloadUrl);
|
||||
|
||||
// Close client
|
||||
s3Client.close();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Batch File Processing Example
|
||||
|
||||
```java
|
||||
import java.nio.file.*;
|
||||
import java.util.stream.*;
|
||||
|
||||
public class S3BatchProcessing {
|
||||
public void processDirectoryUpload(S3Client s3Client, String bucketName, String directoryPath) {
|
||||
try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
|
||||
List<CompletableFuture<Void>> futures = paths
|
||||
.filter(Files::isRegularFile)
|
||||
.map(path -> {
|
||||
String key = bucketName + "/" + path.getFileName().toString();
|
||||
return CompletableFuture.runAsync(() -> {
|
||||
uploadFile(s3Client, bucketName, key, path.toString());
|
||||
});
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// Wait for all uploads to complete
|
||||
CompletableFuture.allOf(
|
||||
futures.toArray(new CompletableFuture[0])
|
||||
).join();
|
||||
|
||||
System.out.println("All files uploaded successfully");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to process directory", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
1. **Use S3 Transfer Manager**: Automatically handles multipart uploads, parallel transfers, and progress tracking for files >100MB
|
||||
2. **Reuse S3 Client**: Clients are thread-safe and should be reused throughout the application lifecycle
|
||||
3. **Enable async operations**: Use S3AsyncClient for I/O-bound operations to improve throughput
|
||||
4. **Configure proper timeouts**: Set appropriate timeouts for large file operations
|
||||
5. **Use connection pooling**: Configure HTTP client for optimal connection management
|
||||
|
||||
### Security Considerations
|
||||
|
||||
1. **Use temporary credentials**: Always use IAM roles or AWS STS for short-lived access tokens
|
||||
2. **Enable server-side encryption**: Use AES-256 or AWS KMS for sensitive data
|
||||
3. **Implement access controls**: Use bucket policies and IAM roles instead of access keys in production
|
||||
4. **Validate object metadata**: Sanitize user-provided metadata to prevent header injection
|
||||
5. **Use presigned URLs**: Avoid exposing credentials by using temporary access URLs
|
||||
|
||||
### Error Handling
|
||||
|
||||
1. **Implement retry logic**: Network operations should have exponential backoff retry strategies
|
||||
2. **Handle throttling**: Implement proper handling of 429 Too Many Requests responses
|
||||
3. **Validate object existence**: Check if objects exist before operations that require them
|
||||
4. **Clean up failed operations**: Abort multipart uploads that fail
|
||||
5. **Log appropriately**: Log successful operations and errors for monitoring
|
||||
|
||||
### Cost Optimization
|
||||
|
||||
1. **Use appropriate storage classes**: Choose STANDARD, STANDARD_IA, INTELLIGENT_TIERING based on access patterns
|
||||
2. **Implement lifecycle policies**: Automatically transition or expire objects
|
||||
3. **Enable object versioning**: For important data that needs retention
|
||||
4. **Monitor usage**: Track data transfer and storage costs
|
||||
5. **Minimize API calls**: Use batch operations when possible
|
||||
|
||||
## Constraints and Limitations
|
||||
|
||||
- **File size limits**: Single PUT operations limited to 5GB; use multipart uploads for larger files
|
||||
- **Batch operations**: Maximum 1000 objects per DeleteObjects operation
|
||||
- **Metadata size**: User-defined metadata limited to 2KB
|
||||
- **Concurrent transfers**: Transfer Manager handles up to 100 concurrent transfers by default
|
||||
- **Region consistency**: Cross-region operations may incur additional costs and latency
|
||||
- **S3 eventual consistency**: New objects might not be immediately visible after upload
|
||||
|
||||
## References
|
||||
|
||||
For more detailed information, see:
|
||||
- [AWS S3 Object Operations Reference](./references/s3-object-operations.md)
|
||||
- [S3 Transfer Manager Patterns](./references/s3-transfer-patterns.md)
|
||||
- [Spring Boot Integration Guide](./references/s3-spring-boot-integration.md)
|
||||
- [AWS S3 Developer Guide](https://docs.aws.amazon.com/AmazonS3/latest/userguide/)
|
||||
- [AWS SDK for Java 2.x S3 API](https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/s3/package-summary.html)
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `aws-sdk-java-v2-core` - Core AWS SDK patterns and configuration
|
||||
- `spring-boot-dependency`-injection - Spring dependency injection patterns
|
||||
- `unit-test-service-layer` - Testing service layer patterns
|
||||
- `unit-test-wiremock-rest-api` - Testing external API integrations
|
||||
@@ -0,0 +1,371 @@
|
||||
# S3 Object Operations Reference
|
||||
|
||||
## Detailed Object Operations
|
||||
|
||||
### Advanced Upload Patterns
|
||||
|
||||
#### Streaming Upload with Progress Monitoring
|
||||
|
||||
```java
|
||||
public void uploadWithProgress(S3Client s3Client, String bucketName, String key,
|
||||
String filePath) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
try (RequestBody file = RequestBody.fromFile(Paths.get(filePath))) {
|
||||
s3Client.putObject(request, file);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Conditional Upload
|
||||
|
||||
```java
|
||||
public void conditionalUpload(S3Client s3Client, String bucketName, String key,
|
||||
String filePath, String expectedETag) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.ifMatch(expectedETag)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Download Patterns
|
||||
|
||||
#### Range Requests for Large Files
|
||||
|
||||
```java
|
||||
public void downloadInChunks(S3Client s3Client, String bucketName, String key,
|
||||
String destPath, int chunkSizeMB) {
|
||||
long fileSize = getFileSize(s3Client, bucketName, key);
|
||||
int chunkSize = chunkSizeMB * 1024 * 1024;
|
||||
|
||||
try (OutputStream os = new FileOutputStream(destPath)) {
|
||||
for (long start = 0; start < fileSize; start += chunkSize) {
|
||||
long end = Math.min(start + chunkSize - 1, fileSize - 1);
|
||||
|
||||
GetObjectRequest request = GetObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.range("bytes=" + start + "-" + end)
|
||||
.build();
|
||||
|
||||
try (ResponseInputStream<GetObjectResponse> response =
|
||||
s3Client.getObject(request)) {
|
||||
response.transferTo(os);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Metadata Management
|
||||
|
||||
#### Setting and Retrieving Object Metadata
|
||||
|
||||
```java
|
||||
public void setObjectMetadata(S3Client s3Client, String bucketName, String key,
|
||||
Map<String, String> metadata) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.metadata(metadata)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.empty());
|
||||
}
|
||||
|
||||
public Map<String, String> getObjectMetadata(S3Client s3Client,
|
||||
String bucketName, String key) {
|
||||
HeadObjectRequest request = HeadObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
HeadObjectResponse response = s3Client.headObject(request);
|
||||
return response.metadata();
|
||||
}
|
||||
```
|
||||
|
||||
### Storage Classes and Lifecycle
|
||||
|
||||
#### Managing Different Storage Classes
|
||||
|
||||
```java
|
||||
public void uploadWithStorageClass(S3Client s3Client, String bucketName, String key,
|
||||
String filePath, StorageClass storageClass) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.storageClass(storageClass)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
|
||||
}
|
||||
|
||||
// Storage class options:
|
||||
// STANDARD - Default storage class
|
||||
// STANDARD_IA - Infrequent Access
|
||||
// ONEZONE_IA - Single-zone infrequent access
|
||||
// INTELLIGENT_TIERING - Automatically optimizes storage
|
||||
// GLACIER - Archive storage
|
||||
// DEEP_ARCHIVE - Long-term archive storage
|
||||
```
|
||||
|
||||
### Object Tagging
|
||||
|
||||
#### Adding and Managing Tags
|
||||
|
||||
```java
|
||||
public void addTags(S3Client s3Client, String bucketName, String key,
|
||||
Map<String, String> tags) {
|
||||
Tagging tagging = Tagging.builder()
|
||||
.tagSet(tags.entrySet().stream()
|
||||
.map(entry -> Tag.builder()
|
||||
.key(entry.getKey())
|
||||
.value(entry.getValue())
|
||||
.build())
|
||||
.collect(Collectors.toList()))
|
||||
.build();
|
||||
|
||||
PutObjectTaggingRequest request = PutObjectTaggingRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.tagging(tagging)
|
||||
.build();
|
||||
|
||||
s3Client.putObjectTagging(request);
|
||||
}
|
||||
|
||||
public Map<String, String> getTags(S3Client s3Client, String bucketName, String key) {
|
||||
GetObjectTaggingRequest request = GetObjectTaggingRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
GetObjectTaggingResponse response = s3Client.getObjectTagging(request);
|
||||
|
||||
return response.tagSet().stream()
|
||||
.collect(Collectors.toMap(Tag::key, Tag::value));
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Copy Operations
|
||||
|
||||
#### Server-Side Copy with Metadata
|
||||
|
||||
```java
|
||||
public void copyWithMetadata(S3Client s3Client, String sourceBucket, String sourceKey,
|
||||
String destBucket, String destKey,
|
||||
Map<String, String> metadata) {
|
||||
CopyObjectRequest request = CopyObjectRequest.builder()
|
||||
.sourceBucket(sourceBucket)
|
||||
.sourceKey(sourceKey)
|
||||
.destinationBucket(destBucket)
|
||||
.destinationKey(destKey)
|
||||
.metadata(metadata)
|
||||
.metadataDirective(MetadataDirective.REPLACE)
|
||||
.build();
|
||||
|
||||
s3Client.copyObject(request);
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling Patterns
|
||||
|
||||
### Retry Mechanisms
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.retry.RetryPolicy;
|
||||
import software.amazon.awssdk.core.retry.backoff.FixedRetryBackoff;
|
||||
import software.amazon.awssdk.core.retry.conditions.RetryCondition;
|
||||
|
||||
public S3Client createS3ClientWithRetry() {
|
||||
return S3Client.builder()
|
||||
.overrideConfiguration(b -> b
|
||||
.retryPolicy(RetryPolicy.builder()
|
||||
.numRetries(3)
|
||||
.retryBackoffStrategy(FixedRetryBackoff.create(
|
||||
Duration.ofSeconds(1), 3))
|
||||
.retryCondition(RetryCondition.defaultRetryCondition())
|
||||
.build()))
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
### Throttling Handling
|
||||
|
||||
```java
|
||||
public void handleThrottling(S3Client s3Client, String bucketName, String key) {
|
||||
try {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromString("test"));
|
||||
|
||||
} catch (S3Exception e) {
|
||||
if (e.statusCode() == 429) {
|
||||
// Too Many Requests - implement backoff
|
||||
try {
|
||||
Thread.sleep(1000);
|
||||
// Retry logic here
|
||||
} catch (InterruptedException ie) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Batch Operations
|
||||
|
||||
#### Batch Delete Objects
|
||||
|
||||
```java
|
||||
public void batchDeleteObjects(S3Client s3Client, String bucketName,
|
||||
List<String> keys) {
|
||||
int batchSize = 1000; // S3 limit for batch operations
|
||||
int totalBatches = (int) Math.ceil((double) keys.size() / batchSize);
|
||||
|
||||
for (int i = 0; i < totalBatches; i++) {
|
||||
List<String> batchKeys = keys.subList(
|
||||
i * batchSize,
|
||||
Math.min((i + 1) * batchSize, keys.size()));
|
||||
|
||||
List<ObjectIdentifier> objectIdentifiers = batchKeys.stream()
|
||||
.map(key -> ObjectIdentifier.builder().key(key).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Delete delete = Delete.builder()
|
||||
.objects(objectIdentifiers)
|
||||
.build();
|
||||
|
||||
DeleteObjectsRequest request = DeleteObjectsRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.delete(delete)
|
||||
.build();
|
||||
|
||||
s3Client.deleteObjects(request);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Parallel Uploads
|
||||
|
||||
```java
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public void parallelUploads(S3Client s3Client, String bucketName,
|
||||
List<String> keys, ExecutorService executor) {
|
||||
List<CompletableFuture<Void>> futures = new ArrayList<>();
|
||||
|
||||
for (String key : keys) {
|
||||
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromString("data"));
|
||||
}, executor);
|
||||
|
||||
futures.add(future);
|
||||
}
|
||||
|
||||
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
|
||||
}
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Access Control
|
||||
|
||||
#### Setting Object ACLs
|
||||
|
||||
```java
|
||||
public void setObjectAcl(S3Client s3Client, String bucketName, String key,
|
||||
CannedAccessControlList acl) {
|
||||
PutObjectAclRequest request = PutObjectAclRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.acl(acl)
|
||||
.build();
|
||||
|
||||
s3Client.putObjectAcl(request);
|
||||
}
|
||||
|
||||
// ACL options:
|
||||
// private, public-read, public-read-write, authenticated-read,
|
||||
// aws-exec-read, bucket-owner-read, bucket-owner-full-control
|
||||
```
|
||||
|
||||
#### Encryption
|
||||
|
||||
```java
|
||||
public void encryptedUpload(S3Client s3Client, String bucketName, String key,
|
||||
String filePath, String kmsKeyId) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.serverSideEncryption(ServerSideEncryption.AWS_KMS)
|
||||
.ssekmsKeyId(kmsKeyId)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
|
||||
}
|
||||
```
|
||||
|
||||
## Monitoring and Logging
|
||||
|
||||
#### Upload Completion Events
|
||||
|
||||
```java
|
||||
public void uploadWithMonitoring(S3Client s3Client, String bucketName, String key,
|
||||
String filePath) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
Response<PutObjectResponse> response = s3Client.putObject(request,
|
||||
RequestBody.fromFile(Paths.get(filePath)));
|
||||
|
||||
System.out.println("Upload completed with ETag: " +
|
||||
response.response().eTag());
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### Event Notifications
|
||||
|
||||
```java
|
||||
public void setupEventNotifications(S3Client s3Client, String bucketName) {
|
||||
NotificationConfiguration configuration = NotificationConfiguration.builder()
|
||||
.topicConfigurations(TopicConfiguration.builder()
|
||||
.topicArn("arn:aws:sns:us-east-1:123456789012:my-topic")
|
||||
.events(Event.OBJECT_CREATED_PUT, Event.OBJECT_CREATED_POST)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
PutBucketNotificationConfigurationRequest request =
|
||||
PutBucketNotificationConfigurationRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.notificationConfiguration(configuration)
|
||||
.build();
|
||||
|
||||
s3Client.putBucketNotificationConfiguration(request);
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,668 @@
|
||||
# S3 Spring Boot Integration Reference
|
||||
|
||||
## Advanced Spring Boot Configuration
|
||||
|
||||
### Multi-Environment Configuration
|
||||
|
||||
```java
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(S3Properties.class)
|
||||
public class S3Configuration {
|
||||
|
||||
private final S3Properties properties;
|
||||
|
||||
public S3Configuration(S3Properties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(name = "s3.client.async.enabled", havingValue = "true")
|
||||
public S3AsyncClient s3AsyncClient() {
|
||||
return S3AsyncClient.builder()
|
||||
.region(Region.of(properties.getRegion()))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
properties.getAccessKey(),
|
||||
properties.getSecretKey())))
|
||||
.endpointOverride(URI.create(properties.getEndpoint()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(name = "s3.client.sync.enabled", havingValue = "true", matchIfMissing = true)
|
||||
public S3Client s3Client() {
|
||||
return S3Client.builder()
|
||||
.region(Region.of(properties.getRegion()))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
properties.getAccessKey(),
|
||||
properties.getSecretKey())))
|
||||
.endpointOverride(URI.create(properties.getEndpoint()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(name = "s3.transfer-manager.enabled", havingValue = "true")
|
||||
public S3TransferManager s3TransferManager() {
|
||||
return S3TransferManager.builder()
|
||||
.s3Client(s3Client())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(name = "s3.presigner.enabled", havingValue = "true")
|
||||
public S3Presigner s3Presigner() {
|
||||
return S3Presigner.builder()
|
||||
.region(Region.of(properties.getRegion()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@ConfigurationProperties(prefix = "s3")
|
||||
@Data
|
||||
public class S3Properties {
|
||||
private String accessKey;
|
||||
private String secretKey;
|
||||
private String region = "us-east-1";
|
||||
private String endpoint = null;
|
||||
private boolean syncEnabled = true;
|
||||
private boolean asyncEnabled = false;
|
||||
private boolean transferManagerEnabled = false;
|
||||
private boolean presignerEnabled = false;
|
||||
private int maxConnections = 100;
|
||||
private int connectionTimeout = 5000;
|
||||
private int socketTimeout = 30000;
|
||||
private String defaultBucket;
|
||||
}
|
||||
```
|
||||
|
||||
### Profile-Specific Configuration
|
||||
|
||||
```properties
|
||||
# application-dev.properties
|
||||
s3.access-key=${AWS_ACCESS_KEY}
|
||||
s3.secret-key=${AWS_SECRET_KEY}
|
||||
s3.region=us-east-1
|
||||
s3.endpoint=http://localhost:4566
|
||||
s3.async-enabled=true
|
||||
s3.transfer-manager-enabled=true
|
||||
|
||||
# application-prod.properties
|
||||
s3.access-key=${AWS_ACCESS_KEY}
|
||||
s3.secret-key=${AWS_SECRET_KEY}
|
||||
s3.region=us-east-1
|
||||
s3.async-enabled=true
|
||||
s3.presigner-enabled=true
|
||||
```
|
||||
|
||||
## Advanced Service Patterns
|
||||
|
||||
### Generic S3 Service Template
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import java.nio.file.*;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class S3Service {
|
||||
|
||||
private final S3Client s3Client;
|
||||
private final S3AsyncClient s3AsyncClient;
|
||||
private final S3TransferManager transferManager;
|
||||
private final S3Properties s3Properties;
|
||||
|
||||
// Basic Operations
|
||||
public Mono<Void> uploadObjectAsync(String key, byte[] data) {
|
||||
return Mono.fromFuture(() -> {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
return s3AsyncClient.putObject(request,
|
||||
RequestBody.fromBytes(data)).future();
|
||||
});
|
||||
}
|
||||
|
||||
public Mono<byte[]> downloadObjectAsync(String key) {
|
||||
return Mono.fromFuture(() -> {
|
||||
GetObjectRequest request = GetObjectRequest.builder()
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
return s3AsyncClient.getObject(request)
|
||||
.thenApply(response -> {
|
||||
try {
|
||||
return response.readAllBytes();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read S3 object", e);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Advanced Operations
|
||||
public Mono<UploadResult> uploadWithMetadata(String key,
|
||||
Path file,
|
||||
Map<String, String> metadata) {
|
||||
return Mono.fromFuture(() -> {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.key(key)
|
||||
.metadata(metadata)
|
||||
.contentType(getContentType(file))
|
||||
.build();
|
||||
|
||||
return s3AsyncClient.putObject(request, RequestBody.fromFile(file))
|
||||
.thenApply(response -> new UploadResult(key, response.eTag()));
|
||||
});
|
||||
}
|
||||
|
||||
public Flux<S3Object> listObjectsWithPrefix(String prefix) {
|
||||
ListObjectsV2Request request = ListObjectsV2Request.builder()
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.prefix(prefix)
|
||||
.build();
|
||||
|
||||
return Flux.create(sink -> {
|
||||
s3Client.listObjectsV2Paginator(request)
|
||||
.contents()
|
||||
.forEach(sink::next);
|
||||
sink.complete();
|
||||
});
|
||||
}
|
||||
|
||||
public Mono<Void> batchDelete(List<String> keys) {
|
||||
return Mono.fromFuture(() -> {
|
||||
List<ObjectIdentifier> objectIdentifiers = keys.stream()
|
||||
.map(key -> ObjectIdentifier.builder().key(key).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Delete delete = Delete.builder()
|
||||
.objects(objectIdentifiers)
|
||||
.build();
|
||||
|
||||
DeleteObjectsRequest request = DeleteObjectsRequest.builder()
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.delete(delete)
|
||||
.build();
|
||||
|
||||
return s3AsyncClient.deleteObjects(request).future();
|
||||
});
|
||||
}
|
||||
|
||||
// Transfer Manager Operations
|
||||
public Mono<UploadResult> uploadWithTransferManager(String key, Path file) {
|
||||
return Mono.fromFuture(() -> {
|
||||
UploadFileRequest request = UploadFileRequest.builder()
|
||||
.putObjectRequest(req -> req
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.key(key))
|
||||
.source(file)
|
||||
.build();
|
||||
|
||||
return transferManager.uploadFile(request)
|
||||
.completionFuture()
|
||||
.thenApply(result -> new UploadResult(key, result.response().eTag()));
|
||||
});
|
||||
}
|
||||
|
||||
public Mono<DownloadResult> downloadWithTransferManager(String key, Path destination) {
|
||||
return Mono.fromFuture(() -> {
|
||||
DownloadFileRequest request = DownloadFileRequest.builder()
|
||||
.getObjectRequest(req -> req
|
||||
.bucket(s3Properties.getDefaultBucket())
|
||||
.key(key))
|
||||
.destination(destination)
|
||||
.build();
|
||||
|
||||
return transferManager.downloadFile(request)
|
||||
.completionFuture()
|
||||
.thenApply(result -> new DownloadResult(destination, result.response().contentLength()));
|
||||
});
|
||||
}
|
||||
|
||||
// Utility Methods
|
||||
private String getContentType(Path file) {
|
||||
try {
|
||||
return Files.probeContentType(file);
|
||||
} catch (IOException e) {
|
||||
return "application/octet-stream";
|
||||
}
|
||||
}
|
||||
|
||||
// Records for Results
|
||||
public record UploadResult(String key, String eTag) {}
|
||||
public record DownloadResult(Path path, long size) {}
|
||||
}
|
||||
```
|
||||
|
||||
### Event-Driven S3 Operations
|
||||
|
||||
```java
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class S3EventService {
|
||||
|
||||
private final S3Service s3Service;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
|
||||
@Transactional
|
||||
public Mono<UploadResult> uploadAndPublishEvent(String key, Path file) {
|
||||
return s3Service.uploadWithTransferManager(key, file)
|
||||
.doOnSuccess(result -> {
|
||||
eventPublisher.publishEvent(new S3UploadEvent(key, result.eTag()));
|
||||
})
|
||||
.doOnError(error -> {
|
||||
eventPublisher.publishEvent(new S3UploadFailedEvent(key, error.getMessage()));
|
||||
});
|
||||
}
|
||||
|
||||
public Mono<String> generatePresignedUrl(String key) {
|
||||
return s3Service.downloadObjectAsync(key)
|
||||
.flatMap(data -> {
|
||||
return Mono.fromCallable(() -> {
|
||||
S3Presigner presigner = S3Presigner.create();
|
||||
try {
|
||||
GetObjectRequest request = GetObjectRequest.builder()
|
||||
.bucket(s3Service.getDefaultBucket())
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
|
||||
.signatureDuration(Duration.ofMinutes(10))
|
||||
.getObjectRequest(request)
|
||||
.build();
|
||||
|
||||
return presigner.presignGetObject(presignRequest)
|
||||
.url()
|
||||
.toString();
|
||||
} finally {
|
||||
presigner.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Event Classes
|
||||
public class S3UploadEvent extends ApplicationEvent {
|
||||
private final String key;
|
||||
private final String eTag;
|
||||
|
||||
public S3UploadEvent(String key, String eTag) {
|
||||
super(key);
|
||||
this.key = key;
|
||||
this.eTag = eTag;
|
||||
}
|
||||
|
||||
public String getKey() { return key; }
|
||||
public String getETag() { return eTag; }
|
||||
}
|
||||
|
||||
public class S3UploadFailedEvent extends ApplicationEvent {
|
||||
private final String key;
|
||||
private final String errorMessage;
|
||||
|
||||
public S3UploadFailedEvent(String key, String errorMessage) {
|
||||
super(key);
|
||||
this.key = key;
|
||||
this.errorMessage = errorMessage;
|
||||
}
|
||||
|
||||
public String getKey() { return key; }
|
||||
public String getErrorMessage() { return errorMessage; }
|
||||
}
|
||||
```
|
||||
|
||||
### Retry and Error Handling
|
||||
|
||||
```java
|
||||
import org.springframework.retry.annotation.*;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import reactor.core.publisher.Mono;
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ResilientS3Service {
|
||||
|
||||
private final S3Client s3Client;
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
@Retryable(value = {S3Exception.class, SdkClientException.class},
|
||||
maxAttempts = 3,
|
||||
backoff = @Backoff(delay = 1000, multiplier = 2))
|
||||
public Mono<PutObjectResponse> uploadWithRetry(String key, Path file) {
|
||||
return Mono.fromCallable(() -> {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket("my-bucket")
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
return s3Client.putObject(request, RequestBody.fromFile(file));
|
||||
});
|
||||
}
|
||||
|
||||
@Recover
|
||||
public Mono<PutObjectResponse> uploadRecover(S3Exception e, String key, Path file) {
|
||||
// Log the failure and potentially send notification
|
||||
System.err.println("Upload failed after retries: " + e.getMessage());
|
||||
return Mono.error(new S3UploadException("Upload failed after retries", e));
|
||||
}
|
||||
|
||||
@Retryable(value = {S3Exception.class},
|
||||
maxAttempts = 5,
|
||||
backoff = @Backoff(delay = 2000, multiplier = 2))
|
||||
public Mono<Void> copyObjectWithRetry(String sourceKey, String destinationKey) {
|
||||
return Mono.fromFuture(() -> {
|
||||
CopyObjectRequest request = CopyObjectRequest.builder()
|
||||
.sourceBucket("source-bucket")
|
||||
.sourceKey(sourceKey)
|
||||
.destinationBucket("destination-bucket")
|
||||
.destinationKey(destinationKey)
|
||||
.build();
|
||||
|
||||
return s3AsyncClient.copyObject(request).future();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public class S3UploadException extends RuntimeException {
|
||||
public S3UploadException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Integration
|
||||
|
||||
### Test Configuration with LocalStack
|
||||
|
||||
```java
|
||||
import org.testcontainers.containers.localstack.LocalStackContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.springframework.boot.test.context.TestConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
@Testcontainers
|
||||
@ActiveProfiles("test")
|
||||
@TestConfiguration
|
||||
public class S3TestConfig {
|
||||
|
||||
@Container
|
||||
static LocalStackContainer localstack = new LocalStackContainer(
|
||||
DockerImageName.parse("localstack/localstack:3.0"))
|
||||
.withServices(LocalStackContainer.Service.S3)
|
||||
.withEnv("DEFAULT_REGION", "us-east-1");
|
||||
|
||||
@Bean
|
||||
public S3Client testS3Client() {
|
||||
return S3Client.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.endpointOverride(localstack.getEndpointOverride(LocalStackContainer.Service.S3))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
localstack.getAccessKey(),
|
||||
localstack.getSecretKey())))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public S3AsyncClient testS3AsyncClient() {
|
||||
return S3AsyncClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.endpointOverride(localstack.getEndpointOverride(LocalStackContainer.Service.S3))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
localstack.getAccessKey(),
|
||||
localstack.getSecretKey())))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Unit Testing with Mocks
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
import reactor.core.publisher.Mono;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class S3ServiceTest {
|
||||
|
||||
@Mock
|
||||
private S3Client s3Client;
|
||||
|
||||
@InjectMocks
|
||||
private S3Service s3Service;
|
||||
|
||||
@Test
|
||||
void uploadObjectAsync_ShouldReturnUploadResult() {
|
||||
// Arrange
|
||||
String key = "test-key";
|
||||
byte[] data = "test-content".getBytes();
|
||||
String eTag = "12345";
|
||||
|
||||
PutObjectResponse response = PutObjectResponse.builder()
|
||||
.eTag(eTag)
|
||||
.build();
|
||||
|
||||
when(s3Client.putObject(any(PutObjectRequest.class), any()))
|
||||
.thenReturn(response);
|
||||
|
||||
// Act
|
||||
Mono<UploadResult> result = s3Service.uploadObjectAsync(key, data);
|
||||
|
||||
// Assert
|
||||
result.subscribe(uploadResult -> {
|
||||
assertEquals(key, uploadResult.key());
|
||||
assertEquals(eTag, uploadResult.eTag());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void listObjectsWithPrefix_ShouldReturnObjectList() {
|
||||
// Arrange
|
||||
String prefix = "documents/";
|
||||
S3Object object1 = S3Object.builder().key("documents/file1.txt").build();
|
||||
S3Object object2 = S3Object.builder().key("documents/file2.txt").build();
|
||||
|
||||
ListObjectsV2Response response = ListObjectsV2Response.builder()
|
||||
.contents(object1, object2)
|
||||
.build();
|
||||
|
||||
when(s3Client.listObjectsV2(any(ListObjectsV2Request.class)))
|
||||
.thenReturn(response);
|
||||
|
||||
// Act
|
||||
Flux<S3Object> result = s3Service.listObjectsWithPrefix(prefix);
|
||||
|
||||
// Assert
|
||||
result.collectList()
|
||||
.subscribe(objects -> {
|
||||
assertEquals(2, objects.size());
|
||||
assertTrue(objects.stream().allMatch(obj -> obj.key().startsWith(prefix)));
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Testing
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
import java.nio.file.*;
|
||||
import java.util.Map;
|
||||
|
||||
@SpringBootTest
|
||||
@ActiveProfiles("test")
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
class S3IntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private S3Service s3Service;
|
||||
|
||||
private static final String TEST_BUCKET = "test-bucket";
|
||||
private static final String TEST_FILE = "test-document.txt";
|
||||
|
||||
@BeforeAll
|
||||
static void setup() throws Exception {
|
||||
// Create test file
|
||||
Files.write(Paths.get(TEST_FILE), "Test content".getBytes());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void uploadFile_ShouldSucceed() {
|
||||
// Act & Assert
|
||||
s3Service.uploadWithMetadata(TEST_FILE, Paths.get(TEST_FILE),
|
||||
Map.of("author", "test", "type", "document"))
|
||||
.as(StepVerifier::create)
|
||||
.expectNextMatches(result ->
|
||||
result.key().equals(TEST_FILE) && result.eTag() != null)
|
||||
.verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void downloadFile_ShouldReturnContent() {
|
||||
// Act & Assert
|
||||
s3Service.downloadObjectAsync(TEST_FILE)
|
||||
.as(StepVerifier::create)
|
||||
.expectNext("Test content".getBytes())
|
||||
.verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void listObjects_ShouldReturnFiles() {
|
||||
// Act & Assert
|
||||
s3Service.listObjectsWithPrefix("")
|
||||
.as(StepVerifier::create)
|
||||
.expectNextCount(1)
|
||||
.verifyComplete();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void cleanup() {
|
||||
try {
|
||||
Files.deleteIfExists(Paths.get(TEST_FILE));
|
||||
} catch (IOException e) {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Configuration Patterns
|
||||
|
||||
### Environment-Specific Configuration
|
||||
|
||||
```java
|
||||
import org.springframework.boot.autoconfigure.condition.*;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import software.amazon.awssdk.auth.credentials.*;
|
||||
|
||||
@Configuration
|
||||
public class EnvironmentAwareS3Config {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public AwsCredentialsProvider awsCredentialsProvider(S3Properties properties) {
|
||||
if (properties.getAccessKey() != null && properties.getSecretKey() != null) {
|
||||
return StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
properties.getAccessKey(),
|
||||
properties.getSecretKey()));
|
||||
}
|
||||
return DefaultCredentialsProvider.create();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty(name = "s3.region")
|
||||
public Region region(S3Properties properties) {
|
||||
return Region.of(properties.getRegion());
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty(name = "s3.endpoint")
|
||||
public String endpoint(S3Properties properties) {
|
||||
return properties.getEndpoint();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Bucket Support
|
||||
|
||||
```java
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class MultiBucketS3Service {
|
||||
|
||||
private final Map<String, S3Client> bucketClients = new HashMap<>();
|
||||
private final S3Client defaultS3Client;
|
||||
|
||||
@Autowired
|
||||
public MultiBucketS3Service(S3Client defaultS3Client) {
|
||||
this.defaultS3Client = defaultS3Client;
|
||||
}
|
||||
|
||||
public S3Client getClientForBucket(String bucketName) {
|
||||
return bucketClients.computeIfAbsent(bucketName, name ->
|
||||
S3Client.builder()
|
||||
.region(defaultS3Client.config().region())
|
||||
.credentialsProvider(defaultS3Client.config().credentialsProvider())
|
||||
.build());
|
||||
}
|
||||
|
||||
public Mono<UploadResult> uploadToBucket(String bucketName, String key, Path file) {
|
||||
S3Client client = getClientForBucket(bucketName);
|
||||
// Upload implementation using the specific client
|
||||
return Mono.empty(); // Implementation
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,473 @@
|
||||
# S3 Transfer Patterns Reference
|
||||
|
||||
## S3 Transfer Manager Advanced Patterns
|
||||
|
||||
### Configuration and Optimization
|
||||
|
||||
#### Custom Transfer Manager Configuration
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.transfer.s3.S3TransferManager;
|
||||
import software.amazon.awssdk.transfer.s3.model.UploadFileRequest;
|
||||
import software.amazon.awssdk.core.sync.RequestBody;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
import software.amazon.awssdk.http.apache.ApacheHttpClient;
|
||||
import java.time.Duration;
|
||||
|
||||
public S3TransferManager createOptimizedTransferManager(S3Client s3Client) {
|
||||
return S3TransferManager.builder()
|
||||
.s3Client(s3Client)
|
||||
.storageProvider(ApacheHttpClient.builder()
|
||||
.maxConnections(200)
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.socketTimeout(Duration.ofSeconds(60))
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
#### Parallel Upload Configuration
|
||||
|
||||
```java
|
||||
public void configureParallelUploads() {
|
||||
S3TransferManager transferManager = S3TransferManager.create();
|
||||
|
||||
FileUpload upload = transferManager.uploadFile(
|
||||
UploadFileRequest.builder()
|
||||
.putObjectRequest(req -> req
|
||||
.bucket("my-bucket")
|
||||
.key("large-file.bin"))
|
||||
.source(Paths.get("large-file.bin"))
|
||||
.build());
|
||||
|
||||
// Track upload progress
|
||||
upload.progressFuture().thenAccept(progress -> {
|
||||
System.out.println("Upload progress: " + progress.progressPercent());
|
||||
});
|
||||
|
||||
// Handle completion
|
||||
upload.completionFuture().thenAccept(result -> {
|
||||
System.out.println("Upload completed with ETag: " +
|
||||
result.response().eTag());
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Upload Patterns
|
||||
|
||||
#### Multipart Upload with Progress Monitoring
|
||||
|
||||
```java
|
||||
public void multipartUploadWithProgress(S3Client s3Client, String bucketName,
|
||||
String key, String filePath) {
|
||||
int partSize = 5 * 1024 * 1024; // 5 MB parts
|
||||
File file = new File(filePath);
|
||||
|
||||
CreateMultipartUploadRequest createRequest = CreateMultipartUploadRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
CreateMultipartUploadResponse createResponse = s3Client.createMultipartUpload(createRequest);
|
||||
String uploadId = createResponse.uploadId();
|
||||
|
||||
List<CompletedPart> completedParts = new ArrayList<>();
|
||||
long uploadedBytes = 0;
|
||||
long totalBytes = file.length();
|
||||
|
||||
try (FileInputStream fis = new FileInputStream(file)) {
|
||||
byte[] buffer = new byte[partSize];
|
||||
int partNumber = 1;
|
||||
|
||||
while (true) {
|
||||
int bytesRead = fis.read(buffer);
|
||||
if (bytesRead == -1) break;
|
||||
|
||||
byte[] partData = Arrays.copyOf(buffer, bytesRead);
|
||||
|
||||
UploadPartRequest uploadRequest = UploadPartRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.uploadId(uploadId)
|
||||
.partNumber(partNumber)
|
||||
.build();
|
||||
|
||||
UploadPartResponse uploadResponse = s3Client.uploadPart(
|
||||
uploadRequest, RequestBody.fromBytes(partData));
|
||||
|
||||
completedParts.add(CompletedPart.builder()
|
||||
.partNumber(partNumber)
|
||||
.eTag(uploadResponse.eTag())
|
||||
.build());
|
||||
|
||||
uploadedBytes += bytesRead;
|
||||
partNumber++;
|
||||
|
||||
// Log progress
|
||||
double progress = (double) uploadedBytes / totalBytes * 100;
|
||||
System.out.printf("Upload progress: %.2f%%%n", progress);
|
||||
}
|
||||
|
||||
CompleteMultipartUploadRequest completeRequest =
|
||||
CompleteMultipartUploadRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.uploadId(uploadId)
|
||||
.multipartUpload(CompletedMultipartUpload.builder()
|
||||
.parts(completedParts)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
s3Client.completeMultipartUpload(completeRequest);
|
||||
|
||||
} catch (Exception e) {
|
||||
// Abort on failure
|
||||
AbortMultipartUploadRequest abortRequest =
|
||||
AbortMultipartUploadRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.uploadId(uploadId)
|
||||
.build();
|
||||
|
||||
s3Client.abortMultipartUpload(abortRequest);
|
||||
throw new RuntimeException("Multipart upload failed", e);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Resume Interrupted Uploads
|
||||
|
||||
```java
|
||||
public void resumeUpload(S3Client s3Client, String bucketName, String key,
|
||||
String filePath, String existingUploadId) {
|
||||
ListMultipartUploadsRequest listRequest = ListMultipartUploadsRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.prefix(key)
|
||||
.build();
|
||||
|
||||
ListMultipartUploadsResponse listResponse = s3Client.listMultipartUploads(listRequest);
|
||||
|
||||
// Check if upload already exists
|
||||
boolean uploadExists = listResponse.uploads().stream()
|
||||
.anyMatch(upload -> upload.key().equals(key) &&
|
||||
upload.uploadId().equals(existingUploadId));
|
||||
|
||||
if (uploadExists) {
|
||||
// Resume existing upload
|
||||
continueExistingUpload(s3Client, bucketName, key, existingUploadId, filePath);
|
||||
} else {
|
||||
// Start new upload
|
||||
multipartUploadWithProgress(s3Client, bucketName, key, filePath);
|
||||
}
|
||||
}
|
||||
|
||||
private void continueExistingUpload(S3Client s3Client, String bucketName,
|
||||
String key, String uploadId, String filePath) {
|
||||
// List already uploaded parts
|
||||
ListPartsRequest listPartsRequest = ListPartsRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.uploadId(uploadId)
|
||||
.build();
|
||||
|
||||
ListPartsResponse listPartsResponse = s3Client.listParts(listPartsRequest);
|
||||
|
||||
List<CompletedPart> completedParts = listPartsResponse.parts().stream()
|
||||
.map(part -> CompletedPart.builder()
|
||||
.partNumber(part.partNumber())
|
||||
.eTag(part.eTag())
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// Upload remaining parts
|
||||
// ... implementation of remaining parts upload
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Download Patterns
|
||||
|
||||
#### Partial File Download
|
||||
|
||||
```java
|
||||
public void downloadPartialFile(S3Client s3Client, String bucketName, String key,
|
||||
String destPath, long startByte, long endByte) {
|
||||
GetObjectRequest request = GetObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.range("bytes=" + startByte + "-" + endByte)
|
||||
.build();
|
||||
|
||||
try (ResponseInputStream<GetObjectResponse> response = s3Client.getObject(request);
|
||||
OutputStream outputStream = new FileOutputStream(destPath)) {
|
||||
|
||||
response.transferTo(outputStream);
|
||||
System.out.println("Partial download completed: " +
|
||||
(endByte - startByte + 1) + " bytes");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Parallel Downloads
|
||||
|
||||
```java
|
||||
import java.util.concurrent.*;
|
||||
import java.util.stream.*;
|
||||
|
||||
public void parallelDownloads(S3Client s3Client, String bucketName,
|
||||
String key, String destPath, int chunkCount) {
|
||||
long fileSize = getFileSize(s3Client, bucketName, key);
|
||||
long chunkSize = fileSize / chunkCount;
|
||||
|
||||
ExecutorService executor = Executors.newFixedThreadPool(chunkCount);
|
||||
List<Future<Void>> futures = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < chunkCount; i++) {
|
||||
long start = i * chunkSize;
|
||||
long end = (i == chunkCount - 1) ? fileSize - 1 : start + chunkSize - 1;
|
||||
|
||||
Future<Void> future = executor.submit(() -> {
|
||||
downloadPartialFile(s3Client, bucketName, key,
|
||||
destPath + ".part" + i, start, end);
|
||||
return null;
|
||||
});
|
||||
|
||||
futures.add(future);
|
||||
}
|
||||
|
||||
// Wait for all downloads to complete
|
||||
for (Future<Void> future : futures) {
|
||||
try {
|
||||
future.get();
|
||||
} catch (InterruptedException | ExecutionException e) {
|
||||
throw new RuntimeException("Download failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Combine chunks
|
||||
combineChunks(destPath, chunkCount);
|
||||
|
||||
executor.shutdown();
|
||||
}
|
||||
|
||||
private void combineChunks(String baseName, int chunkCount) throws IOException {
|
||||
try (OutputStream outputStream = new FileOutputStream(baseName)) {
|
||||
for (int i = 0; i < chunkCount; i++) {
|
||||
String chunkFile = baseName + ".part" + i;
|
||||
try (InputStream inputStream = new FileInputStream(chunkFile)) {
|
||||
inputStream.transferTo(outputStream);
|
||||
}
|
||||
new File(chunkFile).delete();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Handling and Retry
|
||||
|
||||
#### Upload with Exponential Backoff
|
||||
|
||||
```java
|
||||
import software.amazon.awssdk.core.retry.conditions.*;
|
||||
import software.amazon.awssdk.core.retry.*;
|
||||
import software.amazon.awssdk.core.retry.backoff.*;
|
||||
|
||||
public void resilientUpload(S3Client s3Client, String bucketName, String key,
|
||||
String filePath) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
// Configure retry policy
|
||||
S3Client retryS3Client = S3Client.builder()
|
||||
.overrideConfiguration(b -> b
|
||||
.retryPolicy(RetryPolicy.builder()
|
||||
.numRetries(5)
|
||||
.retryBackoffStrategy(
|
||||
ExponentialRetryBackoff.builder()
|
||||
.baseDelay(Duration.ofSeconds(1))
|
||||
.maxBackoffTime(Duration.ofSeconds(30))
|
||||
.build())
|
||||
.retryCondition(
|
||||
RetryCondition.or(
|
||||
RetryCondition.defaultRetryCondition(),
|
||||
RetryCondition.create(response ->
|
||||
response.httpResponse().is5xxServerError()))
|
||||
)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
retryS3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
|
||||
}
|
||||
```
|
||||
|
||||
#### Upload with Checkpoint
|
||||
|
||||
```java
|
||||
import java.nio.file.*;
|
||||
|
||||
public void uploadWithCheckpoint(S3Client s3Client, String bucketName,
|
||||
String key, String filePath) {
|
||||
String checkpointFile = filePath + ".checkpoint";
|
||||
Path checkpointPath = Paths.get(checkpointFile);
|
||||
|
||||
long startPos = 0;
|
||||
if (Files.exists(checkpointPath)) {
|
||||
// Read checkpoint
|
||||
try {
|
||||
startPos = Long.parseLong(Files.readString(checkpointPath));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read checkpoint", e);
|
||||
}
|
||||
}
|
||||
|
||||
if (startPos > 0) {
|
||||
// Resume upload
|
||||
continueUploadFromCheckpoint(s3Client, bucketName, key, filePath, startPos);
|
||||
} else {
|
||||
// Start new upload
|
||||
startNewUpload(s3Client, bucketName, key, filePath);
|
||||
}
|
||||
|
||||
// Update checkpoint
|
||||
long endPos = new File(filePath).length();
|
||||
try {
|
||||
Files.writeString(checkpointPath, String.valueOf(endPos));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to write checkpoint", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void continueUploadFromCheckpoint(S3Client s3Client, String bucketName,
|
||||
String key, String filePath, long startPos) {
|
||||
// Implement resume logic
|
||||
}
|
||||
|
||||
private void startNewUpload(S3Client s3Client, String bucketName,
|
||||
String key, String filePath) {
|
||||
// Implement initial upload logic
|
||||
}
|
||||
```
|
||||
|
||||
### Performance Tuning
|
||||
|
||||
#### Buffer Configuration
|
||||
|
||||
```java
|
||||
public S3Client configureLargeBuffer() {
|
||||
return S3Client.builder()
|
||||
.overrideConfiguration(b -> b
|
||||
.apiCallAttemptTimeout(Duration.ofMinutes(5))
|
||||
.apiCallTimeout(Duration.ofMinutes(10)))
|
||||
.build();
|
||||
}
|
||||
|
||||
public S3TransferManager configureHighThroughput() {
|
||||
return S3TransferManager.builder()
|
||||
.multipartUploadThreshold(8 * 1024 * 1024) // 8 MB
|
||||
.multipartUploadPartSize(10 * 1024 * 1024) // 10 MB
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
#### Network Optimization
|
||||
|
||||
```java
|
||||
public S3Client createOptimizedS3Client() {
|
||||
return S3Client.builder()
|
||||
.httpClientBuilder(ApacheHttpClient.builder()
|
||||
.maxConnections(200)
|
||||
.connectionPoolStrategy(ConnectionPoolStrategy.defaultStrategy())
|
||||
.socketTimeout(Duration.ofSeconds(30))
|
||||
.connectionTimeout(Duration.ofSeconds(5))
|
||||
.connectionAcquisitionTimeout(Duration.ofSeconds(30))
|
||||
.build())
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
### Monitoring and Metrics
|
||||
|
||||
#### Upload Progress Tracking
|
||||
|
||||
```java
|
||||
public void uploadWithProgressTracking(S3Client s3Client, String bucketName,
|
||||
String key, String filePath) {
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
// Create progress listener
|
||||
software.amazon.awssdk.core.ProgressListener progressListener =
|
||||
progressEvent -> {
|
||||
System.out.println("Transferred: " +
|
||||
progressEvent.transferredBytes() + " bytes");
|
||||
System.out.println("Progress: " +
|
||||
progressEvent.progressPercent() + "%");
|
||||
};
|
||||
|
||||
Response<PutObjectResponse> response = s3Client.putObject(
|
||||
request,
|
||||
RequestBody.fromFile(Paths.get(filePath)),
|
||||
software.amazon.awssdk.core.sync.RequestBody.fromFile(Paths.get(filePath))
|
||||
.contentLength(new File(filePath).length()),
|
||||
progressListener);
|
||||
|
||||
System.out.println("Upload complete. ETag: " +
|
||||
response.response().eTag());
|
||||
}
|
||||
```
|
||||
|
||||
#### Throughput Measurement
|
||||
|
||||
```java
|
||||
public void measureUploadThroughput(S3Client s3Client, String bucketName,
|
||||
String key, String filePath) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
long fileSize = new File(filePath).length();
|
||||
|
||||
PutObjectRequest request = PutObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(key)
|
||||
.build();
|
||||
|
||||
s3Client.putObject(request, RequestBody.fromFile(Paths.get(filePath)));
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
long duration = endTime - startTime;
|
||||
double throughput = (fileSize * 1000.0) / duration / (1024 * 1024); // MB/s
|
||||
|
||||
System.out.printf("Upload throughput: %.2f MB/s%n", throughput);
|
||||
}
|
||||
```
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
#### Upload Validation
|
||||
|
||||
```java
|
||||
public void validateUpload(S3Client s3Client, String bucketName, String key,
|
||||
String localFilePath) {
|
||||
// Download file from S3
|
||||
byte[] s3Content = downloadObject(s3Client, bucketName, key);
|
||||
|
||||
// Read local file
|
||||
byte[] localContent = Files.readAllBytes(Paths.get(localFilePath));
|
||||
|
||||
// Validate content matches
|
||||
if (!Arrays.equals(s3Content, localContent)) {
|
||||
throw new RuntimeException("Upload validation failed: content mismatch");
|
||||
}
|
||||
|
||||
// Verify file size
|
||||
long s3Size = s3Content.length;
|
||||
long localSize = localContent.length;
|
||||
if (s3Size != localSize) {
|
||||
throw new RuntimeException("Upload validation failed: size mismatch");
|
||||
}
|
||||
|
||||
System.out.println("Upload validation successful");
|
||||
}
|
||||
```
|
||||
342
skills/aws-java/aws-sdk-java-v2-secrets-manager/SKILL.md
Normal file
342
skills/aws-java/aws-sdk-java-v2-secrets-manager/SKILL.md
Normal file
@@ -0,0 +1,342 @@
|
||||
---
|
||||
name: aws-sdk-java-v2-secrets-manager
|
||||
description: AWS Secrets Manager patterns using AWS SDK for Java 2.x. Use when storing/retrieving secrets (passwords, API keys, tokens), rotating secrets automatically, managing database credentials, or integrating secret management into Spring Boot applications.
|
||||
category: aws
|
||||
tags: [aws, secrets-manager, java, sdk, security, credentials, spring-boot]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Glob, Bash
|
||||
---
|
||||
|
||||
# AWS SDK for Java 2.x - AWS Secrets Manager
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Storing and retrieving application secrets programmatically
|
||||
- Managing database credentials securely without hardcoding
|
||||
- Implementing automatic secret rotation with Lambda functions
|
||||
- Integrating AWS Secrets Manager with Spring Boot applications
|
||||
- Setting up secret caching for improved performance
|
||||
- Creating secure configuration management systems
|
||||
- Working with multi-region secret deployments
|
||||
- Implementing audit logging for secret access
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>secretsmanager</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- For secret caching (recommended for production) -->
|
||||
<dependency>
|
||||
<groupId>com.amazonaws.secretsmanager</groupId>
|
||||
<artifactId>aws-secretsmanager-caching-java</artifactId>
|
||||
<version>2.0.0</version> // Use the sdk v2 compatible version
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```gradle
|
||||
implementation 'software.amazon.awssdk:secretsmanager'
|
||||
implementation 'com.amazonaws.secretsmanager:aws-secretsmanager-caching-java:2.0.0
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Client Setup
|
||||
```java
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
|
||||
|
||||
SecretsManagerClient secretsClient = SecretsManagerClient.builder()
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Store a Secret
|
||||
```java
|
||||
import software.amazon.awssdk.services.secretsmanager.model.*;
|
||||
|
||||
public String createSecret(String secretName, String secretValue) {
|
||||
CreateSecretRequest request = CreateSecretRequest.builder()
|
||||
.name(secretName)
|
||||
.secretString(secretValue)
|
||||
.build();
|
||||
|
||||
CreateSecretResponse response = secretsClient.createSecret(request);
|
||||
return response.arn();
|
||||
}
|
||||
```
|
||||
|
||||
### Retrieve a Secret
|
||||
```java
|
||||
public String getSecretValue(String secretName) {
|
||||
GetSecretValueRequest request = GetSecretValueRequest.builder()
|
||||
.secretId(secretName)
|
||||
.build();
|
||||
|
||||
GetSecretValueResponse response = secretsClient.getSecretValue(request);
|
||||
return response.secretString();
|
||||
}
|
||||
```
|
||||
|
||||
## Core Operations
|
||||
|
||||
### Secret Management
|
||||
- Create secrets with `createSecret()`
|
||||
- Retrieve secrets with `getSecretValue()`
|
||||
- Update secrets with `updateSecret()`
|
||||
- Delete secrets with `deleteSecret()`
|
||||
- List secrets with `listSecrets()`
|
||||
- Restore deleted secrets with `restoreSecret()`
|
||||
|
||||
### Secret Versioning
|
||||
- Access specific versions by `versionId`
|
||||
- Access versions by stage (e.g., "AWSCURRENT", "AWSPENDING")
|
||||
- Automatically manage version history
|
||||
|
||||
### Secret Rotation
|
||||
- Configure automatic rotation schedules
|
||||
- Lambda-based rotation functions
|
||||
- Immediate rotation with `rotateSecret()`
|
||||
|
||||
## Caching for Performance
|
||||
|
||||
### Setup Cache
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
|
||||
public class CachedSecrets {
|
||||
private final SecretCache cache;
|
||||
|
||||
public CachedSecrets(SecretsManagerClient secretsClient) {
|
||||
this.cache = new SecretCache(secretsClient);
|
||||
}
|
||||
|
||||
public String getCachedSecret(String secretName) {
|
||||
return cache.getSecretString(secretName);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Cache Configuration
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
|
||||
|
||||
SecretCacheConfiguration config = SecretCacheConfiguration.builder()
|
||||
.maxCacheSize(1000)
|
||||
.cacheItemTTL(3600000) // 1 hour
|
||||
.build();
|
||||
```
|
||||
|
||||
## Spring Boot Integration
|
||||
|
||||
### Configuration
|
||||
```java
|
||||
@Configuration
|
||||
public class SecretsManagerConfiguration {
|
||||
|
||||
@Bean
|
||||
public SecretsManagerClient secretsManagerClient() {
|
||||
return SecretsManagerClient.builder()
|
||||
.region(Region.of(region))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public SecretCache secretCache(SecretsManagerClient secretsClient) {
|
||||
return new SecretCache(secretsClient);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Service Layer
|
||||
```java
|
||||
@Service
|
||||
public class SecretsService {
|
||||
|
||||
private final SecretCache cache;
|
||||
|
||||
public SecretsService(SecretCache cache) {
|
||||
this.cache = cache;
|
||||
}
|
||||
|
||||
public <T> T getSecretAsObject(String secretName, Class<T> type) {
|
||||
String secretJson = cache.getSecretString(secretName);
|
||||
return objectMapper.readValue(secretJson, type);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Database Configuration
|
||||
```java
|
||||
@Configuration
|
||||
public class DatabaseConfiguration {
|
||||
|
||||
@Bean
|
||||
public DataSource dataSource(SecretsService secretsService) {
|
||||
Map<String, String> credentials = secretsService.getSecretAsMap(
|
||||
"prod/database/credentials");
|
||||
|
||||
HikariConfig config = new HikariConfig();
|
||||
config.setJdbcUrl(credentials.get("url"));
|
||||
config.setUsername(credentials.get("username"));
|
||||
config.setPassword(credentials.get("password"));
|
||||
|
||||
return new HikariDataSource(config);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Database Credentials Structure
|
||||
```json
|
||||
{
|
||||
"engine": "postgres",
|
||||
"host": "mydb.us-east-1.rds.amazonaws.com",
|
||||
"port": 5432,
|
||||
"username": "admin",
|
||||
"password": "MySecurePassword123!",
|
||||
"dbname": "mydatabase",
|
||||
"url": "jdbc:postgresql://mydb.us-east-1.rds.amazonaws.com:5432/mydatabase"
|
||||
}
|
||||
```
|
||||
|
||||
### API Keys Structure
|
||||
```json
|
||||
{
|
||||
"api_key": "abcd1234-5678-90ef-ghij-klmnopqrstuv",
|
||||
"api_secret": "MySecretKey123!",
|
||||
"api_token": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Error Handling
|
||||
```java
|
||||
try {
|
||||
String secret = secretsClient.getSecretValue(request).secretString();
|
||||
} catch (SecretsManagerException e) {
|
||||
if (e.awsErrorDetails().errorCode().equals("ResourceNotFoundException")) {
|
||||
// Handle missing secret
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
```
|
||||
|
||||
### Batch Operations
|
||||
```java
|
||||
List<String> secretNames = List.of("secret1", "secret2", "secret3");
|
||||
Map<String, String> secrets = secretNames.stream()
|
||||
.collect(Collectors.toMap(
|
||||
Function.identity(),
|
||||
name -> cache.getSecretString(name)
|
||||
));
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Secret Management**:
|
||||
- Use descriptive secret names with hierarchical structure
|
||||
- Implement versioning and rotation
|
||||
- Add tags for organization and billing
|
||||
|
||||
2. **Caching**:
|
||||
- Always use caching in production environments
|
||||
- Configure appropriate TTL values based on secret sensitivity
|
||||
- Monitor cache hit rates
|
||||
|
||||
3. **Security**:
|
||||
- Never log secret values
|
||||
- Use KMS encryption for sensitive secrets
|
||||
- Implement least privilege IAM policies
|
||||
- Enable CloudTrail logging
|
||||
|
||||
4. **Performance**:
|
||||
- Reuse SecretsManagerClient instances
|
||||
- Use async operations when appropriate
|
||||
- Monitor API throttling limits
|
||||
|
||||
5. **Spring Boot Integration**:
|
||||
- Use `@Value` annotations for secret names
|
||||
- Implement proper exception handling
|
||||
- Use configuration properties for secret names
|
||||
|
||||
## Testing Strategies
|
||||
|
||||
### Unit Testing
|
||||
```java
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class SecretsServiceTest {
|
||||
|
||||
@Mock
|
||||
private SecretCache cache;
|
||||
|
||||
@InjectMocks
|
||||
private SecretsService secretsService;
|
||||
|
||||
@Test
|
||||
void shouldGetSecret() {
|
||||
when(cache.getSecretString("test-secret")).thenReturn("secret-value");
|
||||
|
||||
String result = secretsService.getSecret("test-secret");
|
||||
|
||||
assertEquals("secret-value", result);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Testing
|
||||
```java
|
||||
@SpringBootTest(classes = TestSecretsConfiguration.class)
|
||||
class SecretsManagerIntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private SecretsService secretsService;
|
||||
|
||||
@Test
|
||||
void shouldRetrieveSecret() {
|
||||
String secret = secretsService.getSecret("test-secret");
|
||||
assertNotNull(secret);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
- **Access Denied**: Check IAM permissions
|
||||
- **Resource Not Found**: Verify secret name and region
|
||||
- **Decryption Failure**: Ensure KMS key permissions
|
||||
- **Throttling**: Implement retry logic and backoff
|
||||
|
||||
### Debug Commands
|
||||
```bash
|
||||
# Check secret exists
|
||||
aws secretsmanager describe-secret --secret-id my-secret
|
||||
|
||||
# List all secrets
|
||||
aws secretsmanager list-secrets
|
||||
|
||||
# Get secret value (CLI)
|
||||
aws secretsmanager get-secret-value --secret-id my-secret
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
For detailed information and advanced patterns, see:
|
||||
|
||||
- [API Reference](./references/api-reference.md) - Complete API documentation
|
||||
- [Caching Guide](./references/caching-guide.md) - Performance optimization strategies
|
||||
- [Spring Boot Integration](./references/spring-boot-integration.md) - Complete Spring integration patterns
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `aws-sdk-java-v2-core` - Core AWS SDK patterns and best practices
|
||||
- `aws-sdk-java-v2-kms` - KMS encryption and key management
|
||||
- `spring-boot-dependency-injection` - Spring dependency injection patterns
|
||||
@@ -0,0 +1,38 @@
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
|
||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class {{ConfigClass}} {
|
||||
|
||||
@Value("${aws.secrets.region}")
|
||||
private String region;
|
||||
|
||||
@Bean
|
||||
public SecretsManagerClient secretsManagerClient() {
|
||||
return SecretsManagerClient.builder()
|
||||
.region(Region.of(region))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(
|
||||
"${aws.accessKeyId}",
|
||||
"${aws.secretKey}"
|
||||
)
|
||||
))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public SecretCache secretCache(SecretsManagerClient secretsClient) {
|
||||
SecretCacheConfiguration config = SecretCacheConfiguration.builder()
|
||||
.maxCacheSize(100)
|
||||
.cacheItemTTL(3600000) // 1 hour
|
||||
.build();
|
||||
|
||||
return new SecretCache(secretsClient, config);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
# AWS Secrets Manager API Reference
|
||||
|
||||
## Overview
|
||||
AWS Secrets Manager provides a service to enable you to store, manage, and retrieve secrets with API version 2017-10-17.
|
||||
|
||||
## Core Classes
|
||||
|
||||
### SecretsManagerClient
|
||||
- **Purpose**: Synchronous client for AWS Secrets Manager
|
||||
- **Location**: `software.amazon.awssdk.services.secretsmanager.SecretsManagerClient`
|
||||
- **Builder**: `SecretsManagerClient.builder()`
|
||||
|
||||
### SecretsManagerAsyncClient
|
||||
- **Purpose**: Asynchronous client for AWS Secrets Manager
|
||||
- **Location**: `software.amazon.awssdk.services.secretsmanager.SecretsManagerAsyncClient`
|
||||
- **Builder**: `SecretsManagerAsyncClient.builder()`
|
||||
|
||||
## Configuration Classes
|
||||
|
||||
### SecretsManagerClientBuilder
|
||||
- Methods:
|
||||
- `region(Region region)` - Set AWS region
|
||||
- `credentialsProvider(AwsCredentialsProvider credentialsProvider)` - Set credentials
|
||||
- `build()` - Create client instance
|
||||
|
||||
### SecretsManagerServiceClientConfiguration
|
||||
- Service client settings and configuration
|
||||
|
||||
## Request Types
|
||||
|
||||
### CreateSecretRequest
|
||||
- **Fields**:
|
||||
- `name(String name)` - Secret name (required)
|
||||
- `secretString(String secretString)` - Secret value
|
||||
- `secretBinary(SdkBytes secretBinary)` - Binary secret value
|
||||
- `description(String description)` - Secret description
|
||||
- `kmsKeyId(String kmsKeyId)` - KMS key for encryption
|
||||
- `tags(List<Tag> tags)` - Tags for organization
|
||||
|
||||
### GetSecretValueRequest
|
||||
- **Fields**:
|
||||
- `secretId(String secretId)` - Secret name or ARN
|
||||
- `versionId(String versionId)` - Specific version ID
|
||||
- `versionStage(String versionStage)` - Version stage (e.g., "AWSCURRENT")
|
||||
|
||||
### UpdateSecretRequest
|
||||
- **Fields**:
|
||||
- `secretId(String secretId)` - Secret name or ARN
|
||||
- `secretString(String secretString)` - New secret value
|
||||
- `secretBinary(SdkBytes secretBinary)` - New binary secret value
|
||||
- `kmsKeyId(String kmsKeyId)` - KMS key for encryption
|
||||
|
||||
### DeleteSecretRequest
|
||||
- **Fields**:
|
||||
- `secretId(String secretId)` - Secret name or ARN
|
||||
- `recoveryWindowInDays(Long recoveryWindowInDays)` - Recovery period
|
||||
- `forceDeleteWithoutRecovery(Boolean forceDeleteWithoutRecovery)` - Immediate deletion
|
||||
|
||||
### RotateSecretRequest
|
||||
- **Fields**:
|
||||
- `secretId(String secretId)` - Secret name or ARN
|
||||
- `rotationLambdaArn(String rotationLambdaArn)` - Lambda ARN for rotation
|
||||
- `rotationRules(RotationRulesType rotationRules)` - Rotation configuration
|
||||
- `rotationSchedule(RotationScheduleType rotationSchedule)` - Schedule configuration
|
||||
|
||||
## Response Types
|
||||
|
||||
### CreateSecretResponse
|
||||
- **Fields**:
|
||||
- `arn()` - Secret ARN
|
||||
- `name()` - Secret name
|
||||
- `versionId()` - Version ID
|
||||
|
||||
### GetSecretValueResponse
|
||||
- **Fields**:
|
||||
- `arn()` - Secret ARN
|
||||
- `name()` - Secret name
|
||||
- `versionId()` - Version ID
|
||||
- `secretString()` - Secret value as string
|
||||
- `secretBinary()` - Secret value as binary
|
||||
- `versionStages()` - Version stages
|
||||
|
||||
### UpdateSecretResponse
|
||||
- **Fields**:
|
||||
- `arn()` - Secret ARN
|
||||
- `name()` - Secret name
|
||||
- `versionId()` - New version ID
|
||||
|
||||
### DeleteSecretResponse
|
||||
- **Fields**:
|
||||
- `arn()` - Secret ARN
|
||||
- `name()` - Secret name
|
||||
- `deletionDate()` - Deletion date/time
|
||||
|
||||
### RotateSecretResponse
|
||||
- **Fields**:
|
||||
- `arn()` - Secret ARN
|
||||
- `name()` - Secret name
|
||||
- `versionId()` - New version ID
|
||||
|
||||
## Paginated Operations
|
||||
|
||||
### ListSecretsRequest
|
||||
- **Fields**:
|
||||
- `maxResults(Integer maxResults)` - Maximum results per page
|
||||
- `nextToken(String nextToken)` - Token for next page
|
||||
- `filter(String filter)` - Filter criteria
|
||||
|
||||
### ListSecretsResponse
|
||||
- **Fields**:
|
||||
- `secretList()` - List of secrets
|
||||
- `nextToken()` - Token for next page
|
||||
|
||||
## Error Handling
|
||||
|
||||
### SecretsManagerException
|
||||
- Common error codes:
|
||||
- `ResourceNotFoundException` - Secret not found
|
||||
- `InvalidParameterException` - Invalid parameters
|
||||
- `MalformedPolicyDocumentException` - Invalid policy document
|
||||
- `InternalServiceErrorException` - Internal service error
|
||||
- `InvalidRequestException` - Invalid request
|
||||
- `DecryptionFailure` - Decryption failed
|
||||
- `ResourceExistsException` - Resource already exists
|
||||
- `ResourceConflictException` - Resource conflict
|
||||
- `ValidationException` - Validation failed
|
||||
@@ -0,0 +1,304 @@
|
||||
# AWS Secrets Manager Caching Guide
|
||||
|
||||
## Overview
|
||||
The AWS Secrets Manager Java caching client enables in-process caching of secrets for Java applications, reducing API calls and improving performance.
|
||||
|
||||
## Prerequisites
|
||||
- Java 8+ development environment
|
||||
- AWS account with Secrets Manager access
|
||||
- Appropriate IAM permissions
|
||||
|
||||
## Installation
|
||||
|
||||
### Maven Dependency
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>com.amazonaws.secretsmanager</groupId>
|
||||
<artifactId>aws-secretsmanager-caching-java</artifactId>
|
||||
<version>2.0.0</version> // Use the latest version compatible with sdk v2
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle Dependency
|
||||
```gradle
|
||||
implementation 'com.amazonaws.secretsmanager:aws-secretsmanager-caching-java:2.0.0'
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Simple Cache Setup
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
|
||||
public class SimpleCacheExample {
|
||||
private final SecretCache cache = new SecretCache();
|
||||
|
||||
public String getSecret(String secretId) {
|
||||
return cache.getSecretString(secretId);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Cache with Custom SecretsManagerClient
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
|
||||
|
||||
public class ClientAwareCacheExample {
|
||||
private final SecretCache cache;
|
||||
|
||||
public ClientAwareCacheExample(SecretsManagerClient secretsClient) {
|
||||
this.cache = new SecretCache(secretsClient);
|
||||
}
|
||||
|
||||
public String getSecret(String secretId) {
|
||||
return cache.getSecretString(secretId);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cache Configuration
|
||||
|
||||
### SecretCacheConfiguration
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
|
||||
|
||||
public class ConfiguredCacheExample {
|
||||
private final SecretCache cache;
|
||||
|
||||
public ConfiguredCacheExample(SecretsManagerClient secretsClient) {
|
||||
SecretCacheConfiguration config = new SecretCacheConfiguration()
|
||||
.withMaxCacheSize(1000) // Maximum number of cached secrets
|
||||
.withCacheItemTTL(3600000); // 1 hour TTL in milliseconds
|
||||
|
||||
this.cache = new SecretCache(secretsClient, config);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
| Property | Type | Default | Description |
|
||||
|----------|------|---------|-------------|
|
||||
| `maxCacheSize` | Integer | 1000 | Maximum number of cached secrets |
|
||||
| `cacheItemTTL` | Long | 300000 (5 min) | Cache item TTL in milliseconds |
|
||||
| `cacheSizeEvictionPercentage` | Integer | 10 | Percentage of items to evict when cache is full |
|
||||
|
||||
## Advanced Caching Patterns
|
||||
|
||||
### Multi-Layer Cache
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class MultiLayerCache {
|
||||
private final SecretCache secretsManagerCache;
|
||||
private final ConcurrentHashMap<String, String> localCache;
|
||||
private final long localCacheTtl = 30000; // 30 seconds
|
||||
|
||||
public MultiLayerCache(SecretsManagerClient secretsClient) {
|
||||
this.secretsManagerCache = new SecretCache(secretsClient);
|
||||
this.localCache = new ConcurrentHashMap<>();
|
||||
}
|
||||
|
||||
public String getSecret(String secretId) {
|
||||
// Check local cache first
|
||||
String cached = localCache.get(secretId);
|
||||
if (cached != null) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
// Get from Secrets Manager cache
|
||||
String secret = secretsManagerCache.getSecretString(secretId);
|
||||
if (secret != null) {
|
||||
localCache.put(secretId, secret);
|
||||
}
|
||||
|
||||
return secret;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Cache Statistics
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
|
||||
public class CacheStatsExample {
|
||||
private final SecretCache cache;
|
||||
|
||||
public void demonstrateCacheStats() {
|
||||
// Get cache statistics
|
||||
long hitCount = cache.getHitCount();
|
||||
long missCount = cache.getMissCount();
|
||||
double hitRatio = cache.getHitRatio();
|
||||
|
||||
System.out.println("Cache Hit Ratio: " + hitRatio);
|
||||
System.out.println("Hits: " + hitCount + ", Misses: " + missCount);
|
||||
|
||||
// Clear cache statistics
|
||||
cache.clearCacheStats();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling and Cache Management
|
||||
|
||||
### Cache Refresh Strategy
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class CacheRefreshManager {
|
||||
private final SecretCache cache;
|
||||
private final ScheduledExecutorService scheduler;
|
||||
|
||||
public CacheRefreshManager(SecretsManagerClient secretsClient) {
|
||||
this.cache = new SecretCache(secretsClient);
|
||||
this.scheduler = Executors.newScheduledThreadPool(1);
|
||||
}
|
||||
|
||||
public void startRefreshSchedule() {
|
||||
// Refresh cache every hour
|
||||
scheduler.scheduleAtFixedRate(this::refreshCache, 1, 1, TimeUnit.HOURS);
|
||||
}
|
||||
|
||||
private void refreshCache() {
|
||||
System.out.println("Refreshing cache...");
|
||||
cache.refresh();
|
||||
}
|
||||
|
||||
public void shutdown() {
|
||||
scheduler.shutdown();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Fallback Mechanism
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
|
||||
public class FallbackCacheExample {
|
||||
private final SecretCache cache;
|
||||
private final SecretsManagerClient fallbackClient;
|
||||
|
||||
public FallbackCacheExample(SecretsManagerClient primaryClient, SecretsManagerClient fallbackClient) {
|
||||
this.cache = new SecretCache(primaryClient);
|
||||
this.fallbackClient = fallbackClient;
|
||||
}
|
||||
|
||||
public String getSecretWithFallback(String secretId) {
|
||||
try {
|
||||
// Try cached value first
|
||||
return cache.getSecretString(secretId);
|
||||
} catch (Exception e) {
|
||||
// Fallback to direct API call
|
||||
return getSecretDirect(secretId);
|
||||
}
|
||||
}
|
||||
|
||||
private String getSecretDirect(String secretId) {
|
||||
GetSecretValueRequest request = GetSecretValueRequest.builder()
|
||||
.secretId(secretId)
|
||||
.build();
|
||||
|
||||
return fallbackClient.getSecretValue(request).secretString();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Batch Secret Retrieval
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class BatchSecretRetrieval {
|
||||
private final SecretCache cache;
|
||||
|
||||
public List<String> getMultipleSecrets(List<String> secretIds) {
|
||||
List<String> results = new ArrayList<>();
|
||||
|
||||
for (String secretId : secretIds) {
|
||||
String secret = cache.getSecretString(secretId);
|
||||
results.add(secret != null ? secret : "NOT_FOUND");
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
public Map<String, String> getSecretsAsMap(List<String> secretIds) {
|
||||
Map<String, String> secretMap = new HashMap<>();
|
||||
|
||||
for (String secretId : secretIds) {
|
||||
String secret = cache.getSecretString(secretId);
|
||||
if (secret != null) {
|
||||
secretMap.put(secretId, secret);
|
||||
}
|
||||
}
|
||||
|
||||
return secretMap;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### Cache Monitoring
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
|
||||
public class CacheMonitor {
|
||||
private final SecretCache cache;
|
||||
|
||||
public void monitorCachePerformance() {
|
||||
// Monitor cache hit rate
|
||||
double hitRatio = cache.getHitRatio();
|
||||
System.out.println("Cache Hit Ratio: " + hitRatio);
|
||||
|
||||
// Monitor cache size
|
||||
long currentSize = cache.size();
|
||||
System.out.println("Current Cache Size: " + currentSize);
|
||||
|
||||
// Monitor cache hits and misses
|
||||
long hits = cache.getHitCount();
|
||||
long misses = cache.getMissCount();
|
||||
System.out.println("Cache Hits: " + hits + ", Misses: " + misses);
|
||||
}
|
||||
|
||||
public void printCacheContents() {
|
||||
// Note: SecretCache doesn't provide direct access to all cached items
|
||||
// This is a security feature to prevent accidental exposure of secrets
|
||||
System.out.println("Cache contents are protected and cannot be directly inspected");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Cache Size Configuration**:
|
||||
- Adjust `maxCacheSize` based on available memory
|
||||
- Monitor memory usage and adjust accordingly
|
||||
- Consider using heap analysis tools
|
||||
|
||||
2. **TTL Configuration**:
|
||||
- Balance between performance and freshness
|
||||
- Shorter TTL for frequently changing secrets
|
||||
- Longer TTL for stable secrets
|
||||
|
||||
3. **Error Handling**:
|
||||
- Implement fallback mechanisms
|
||||
- Handle cache misses gracefully
|
||||
- Log errors without exposing sensitive information
|
||||
|
||||
4. **Security Considerations**:
|
||||
- Never log secret values
|
||||
- Use appropriate IAM permissions
|
||||
- Consider encryption at rest for cached data
|
||||
|
||||
5. **Memory Management**:
|
||||
- Monitor memory usage
|
||||
- Consider cache eviction strategies
|
||||
- Implement proper cleanup in shutdown hooks
|
||||
@@ -0,0 +1,535 @@
|
||||
# AWS Secrets Manager Spring Boot Integration
|
||||
|
||||
## Overview
|
||||
Integrate AWS Secrets Manager with Spring Boot applications using the caching library for optimal performance and security.
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Required Dependencies
|
||||
```xml
|
||||
<!-- AWS Secrets Manager -->
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>secretsmanager</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- AWS Secrets Manager Caching -->
|
||||
<dependency>
|
||||
<groupId>com.amazonaws.secretsmanager</groupId>
|
||||
<artifactId>aws-secretsmanager-caching-java</artifactId>
|
||||
<version>2.0.0</version> // Use the latest version compatible with sdk v2
|
||||
</dependency>
|
||||
|
||||
<!-- Spring Boot Starter -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Jackson for JSON processing -->
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Connection Pooling -->
|
||||
<dependency>
|
||||
<groupId>com.zaxxer</groupId>
|
||||
<artifactId>HikariCP</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
## Configuration Properties
|
||||
|
||||
### application.yml
|
||||
```yaml
|
||||
spring:
|
||||
application:
|
||||
name: aws-secrets-manager-app
|
||||
datasource:
|
||||
url: jdbc:postgresql://localhost:5432/mydb
|
||||
username: ${db.username}
|
||||
password: ${db.password}
|
||||
hikari:
|
||||
maximum-pool-size: 10
|
||||
minimum-idle: 5
|
||||
|
||||
aws:
|
||||
secrets:
|
||||
region: us-east-1
|
||||
# Secret names for different environments
|
||||
database-credentials: prod/database/credentials
|
||||
api-keys: prod/external-api/keys
|
||||
redis-config: prod/redis/config
|
||||
|
||||
app:
|
||||
external-api:
|
||||
secret-name: prod/external/credentials
|
||||
base-url: https://api.example.com
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### SecretsManager Configuration
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
|
||||
|
||||
@Configuration
|
||||
public class SecretsManagerConfiguration {
|
||||
|
||||
@Value("${aws.secrets.region}")
|
||||
private String region;
|
||||
|
||||
@Bean
|
||||
public SecretsManagerClient secretsManagerClient() {
|
||||
return SecretsManagerClient.builder()
|
||||
.region(Region.of(region))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public SecretCache secretCache(SecretsManagerClient secretsClient) {
|
||||
SecretCacheConfiguration config = SecretCacheConfiguration.builder()
|
||||
.maxCacheSize(100)
|
||||
.cacheItemTTL(3600000) // 1 hour
|
||||
.build();
|
||||
|
||||
return new SecretCache(secretsClient, config);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Secrets Service
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
public class SecretsService {
|
||||
|
||||
private final SecretCache secretCache;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
public SecretsService(SecretCache secretCache, ObjectMapper objectMapper) {
|
||||
this.secretCache = secretCache;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get secret as string
|
||||
*/
|
||||
public String getSecret(String secretName) {
|
||||
try {
|
||||
return secretCache.getSecretString(secretName);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to retrieve secret: " + secretName, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get secret as object of specified type
|
||||
*/
|
||||
public <T> T getSecretAsObject(String secretName, Class<T> type) {
|
||||
try {
|
||||
String secretJson = secretCache.getSecretString(secretName);
|
||||
return objectMapper.readValue(secretJson, type);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to parse secret: " + secretName, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get secret as Map
|
||||
*/
|
||||
public Map<String, String> getSecretAsMap(String secretName) {
|
||||
try {
|
||||
String secretJson = secretCache.getSecretString(secretName);
|
||||
return objectMapper.readValue(secretJson,
|
||||
new TypeReference<Map<String, String>>() {});
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to parse secret map: " + secretName, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get secret with fallback
|
||||
*/
|
||||
public String getSecretWithFallback(String secretName, String defaultValue) {
|
||||
try {
|
||||
String secret = secretCache.getSecretString(secretName);
|
||||
return secret != null ? secret : defaultValue;
|
||||
} catch (Exception e) {
|
||||
return defaultValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Database Configuration Integration
|
||||
|
||||
### Dynamic DataSource Configuration
|
||||
```java
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.jdbc.DataSourceBuilder;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import javax.sql.DataSource;
|
||||
|
||||
@Configuration
|
||||
public class DatabaseConfiguration {
|
||||
|
||||
private final SecretsService secretsService;
|
||||
|
||||
@Value("${aws.secrets.database-credentials}")
|
||||
private String dbSecretName;
|
||||
|
||||
public DatabaseConfiguration(SecretsService secretsService) {
|
||||
this.secretsService = secretsService;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public DataSource dataSource() {
|
||||
Map<String, String> credentials = secretsService.getSecretAsMap(dbSecretName);
|
||||
|
||||
HikariConfig config = new HikariConfig();
|
||||
config.setJdbcUrl(credentials.get("url"));
|
||||
config.setUsername(credentials.get("username"));
|
||||
config.setPassword(credentials.get("password"));
|
||||
config.setMaximumPoolSize(10);
|
||||
config.setMinimumIdle(5);
|
||||
config.setConnectionTimeout(30000);
|
||||
config.setIdleTimeout(600000);
|
||||
config.setMaxLifetime(1800000);
|
||||
config.setLeakDetectionThreshold(15000);
|
||||
|
||||
return new HikariDataSource(config);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Properties with Secrets
|
||||
```java
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "app")
|
||||
public class AppProperties {
|
||||
|
||||
private final SecretsService secretsService;
|
||||
|
||||
@Value("${app.external-api.secret-name}")
|
||||
private String apiSecretName;
|
||||
|
||||
public AppProperties(SecretsService secretsService) {
|
||||
this.secretsService = secretsService;
|
||||
}
|
||||
|
||||
private String apiKey;
|
||||
|
||||
public String getApiKey() {
|
||||
if (apiKey == null) {
|
||||
apiKey = secretsService.getSecret(apiSecretName);
|
||||
}
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
// Additional application properties
|
||||
private String externalApiBaseUrl;
|
||||
|
||||
public String getExternalApiBaseUrl() {
|
||||
return externalApiBaseUrl;
|
||||
}
|
||||
|
||||
public void setExternalApiBaseUrl(String externalApiBaseUrl) {
|
||||
this.externalApiBaseUrl = externalApiBaseUrl;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Property Source Integration
|
||||
|
||||
### Custom Property Source
|
||||
```java
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.core.env.PropertySource;
|
||||
import org.springframework.stereotype.Component;
|
||||
import javax.annotation.PostConstruct;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class SecretsManagerPropertySource extends PropertySource<SecretsService> {
|
||||
|
||||
public static final String SECRETS_MANAGER_PROPERTY_SOURCE_NAME = "secretsManagerPropertySource";
|
||||
|
||||
private final SecretsService secretsService;
|
||||
private final Environment environment;
|
||||
|
||||
public SecretsManagerPropertySource(SecretsService secretsService, Environment environment) {
|
||||
super(SECRETS_MANAGER_PROPERTY_SOURCE_NAME, secretsService);
|
||||
this.secretsService = secretsService;
|
||||
this.environment = environment;
|
||||
}
|
||||
|
||||
@PostConstruct
|
||||
public void loadSecrets() {
|
||||
// Load secrets specified in application.yml
|
||||
String secretPrefix = "aws.secrets.";
|
||||
environment.getPropertyNames().forEach(propertyName -> {
|
||||
if (propertyName.startsWith(secretPrefix)) {
|
||||
String secretName = environment.getProperty(propertyName);
|
||||
String secretValue = secretsService.getSecret(secretName);
|
||||
if (secretValue != null) {
|
||||
// Add to property source (note: this is simplified)
|
||||
// In practice, you'd need to work with PropertySources
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getProperty(String name) {
|
||||
if (name.startsWith("aws.secret.")) {
|
||||
String secretName = name.substring("aws.secret.".length());
|
||||
return secretsService.getSecret(secretName);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## API Integration
|
||||
|
||||
### REST Client with Secrets
|
||||
```java
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Service
|
||||
public class ExternalApiClient {
|
||||
|
||||
private final SecretsService secretsService;
|
||||
private final RestTemplate restTemplate;
|
||||
private final AppProperties appProperties;
|
||||
|
||||
public ExternalApiClient(SecretsService secretsService,
|
||||
RestTemplate restTemplate,
|
||||
AppProperties appProperties) {
|
||||
this.secretsService = secretsService;
|
||||
this.restTemplate = restTemplate;
|
||||
this.appProperties = appProperties;
|
||||
}
|
||||
|
||||
public String callExternalApi(String endpoint) {
|
||||
Map<String, String> apiCredentials = secretsService.getSecretAsMap(
|
||||
appProperties.getExternalApiSecretName());
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.set("Authorization", "Bearer " + apiCredentials.get("api_token"));
|
||||
headers.set("X-API-Key", apiCredentials.get("api_key"));
|
||||
headers.set("Content-Type", "application/json");
|
||||
|
||||
HttpEntity<String> entity = new HttpEntity<>(headers);
|
||||
|
||||
ResponseEntity<String> response = restTemplate.exchange(
|
||||
endpoint,
|
||||
HttpMethod.GET,
|
||||
entity,
|
||||
String.class);
|
||||
|
||||
return response.getBody();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration for REST Template
|
||||
```java
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Configuration
|
||||
public class RestTemplateConfiguration {
|
||||
|
||||
@Bean
|
||||
public RestTemplate restTemplate() {
|
||||
return new RestTemplate();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Configuration
|
||||
|
||||
### Security Setup
|
||||
```java
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||
import org.springframework.security.web.SecurityFilterChain;
|
||||
|
||||
@Configuration
|
||||
@EnableWebSecurity
|
||||
public class SecurityConfiguration {
|
||||
|
||||
@Bean
|
||||
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
|
||||
http
|
||||
.authorizeHttpRequests(auth -> auth
|
||||
.requestMatchers("/api/secrets/**").hasRole("ADMIN")
|
||||
.anyRequest().permitAll()
|
||||
)
|
||||
.httpBasic()
|
||||
.and()
|
||||
.csrf().disable();
|
||||
|
||||
return http.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Configuration
|
||||
|
||||
### Test Configuration
|
||||
```java
|
||||
import com.amazonaws.secretsmanager.caching.SecretCache;
|
||||
import org.springframework.boot.test.context.TestConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.mock.env.MockEnvironment;
|
||||
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
|
||||
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
|
||||
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@TestConfiguration
|
||||
public class TestSecretsConfiguration {
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public SecretsManagerClient secretsManagerClient() {
|
||||
SecretsManagerClient mockClient = mock(SecretsManagerClient.class);
|
||||
|
||||
// Mock successful secret retrieval
|
||||
when(mockClient.getSecretValue(any()))
|
||||
.thenReturn(GetSecretValueResponse.builder()
|
||||
.secretString("{\"username\":\"test\",\"password\":\"testpass\"}")
|
||||
.build());
|
||||
|
||||
return mockClient;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public SecretCache secretCache(SecretsManagerClient mockClient) {
|
||||
SecretCache mockCache = mock(SecretCache.class);
|
||||
when(mockCache.getSecretString(anyString()))
|
||||
.thenReturn("{\"username\":\"test\",\"password\":\"testpass\"}");
|
||||
return mockCache;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MockEnvironment mockEnvironment() {
|
||||
MockEnvironment env = new MockEnvironment();
|
||||
env.setProperty("aws.secrets.region", "us-east-1");
|
||||
env.setProperty("aws.secrets.database-credentials", "test-db-credentials");
|
||||
return env;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Unit Tests
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class SecretsServiceTest {
|
||||
|
||||
@Mock
|
||||
private SecretCache secretCache;
|
||||
|
||||
@InjectMocks
|
||||
private SecretsService secretsService;
|
||||
|
||||
@Test
|
||||
void shouldGetSecret() {
|
||||
String secretName = "test-secret";
|
||||
String expectedValue = "secret-value";
|
||||
|
||||
when(secretCache.getSecretString(secretName))
|
||||
.thenReturn(expectedValue);
|
||||
|
||||
String result = secretsService.getSecret(secretName);
|
||||
|
||||
assertEquals(expectedValue, result);
|
||||
verify(secretCache).getSecretString(secretName);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldGetSecretAsMap() throws Exception {
|
||||
String secretName = "test-secret";
|
||||
String secretJson = "{\"key\":\"value\"}";
|
||||
Map<String, String> expectedMap = Map.of("key", "value");
|
||||
|
||||
when(secretCache.getSecretString(secretName))
|
||||
.thenReturn(secretJson);
|
||||
|
||||
Map<String, String> result = secretsService.getSecretAsMap(secretName);
|
||||
|
||||
assertEquals(expectedMap, result);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Environment-Specific Configuration**:
|
||||
- Use different secret names for development, staging, and production
|
||||
- Implement proper environment variable management
|
||||
- Use Spring profiles for environment-specific configurations
|
||||
|
||||
2. **Security Considerations**:
|
||||
- Never log secret values
|
||||
- Use appropriate IAM roles and policies
|
||||
- Enable encryption in transit and at rest
|
||||
- Implement proper access controls
|
||||
|
||||
3. **Performance Optimization**:
|
||||
- Use caching for frequently accessed secrets
|
||||
- Configure appropriate TTL values
|
||||
- Monitor cache hit rates and adjust accordingly
|
||||
- Use connection pooling for database connections
|
||||
|
||||
4. **Error Handling**:
|
||||
- Implement fallback mechanisms for critical secrets
|
||||
- Handle partial secret retrieval gracefully
|
||||
- Provide meaningful error messages without exposing sensitive information
|
||||
- Implement circuit breakers for external API calls
|
||||
|
||||
5. **Monitoring and Logging**:
|
||||
- Monitor secret retrieval performance
|
||||
- Track cache hit/miss ratios
|
||||
- Log secret access patterns (without values)
|
||||
- Set up alerts for abnormal secret access patterns
|
||||
336
skills/junit-test/unit-test-application-events/SKILL.md
Normal file
336
skills/junit-test/unit-test-application-events/SKILL.md
Normal file
@@ -0,0 +1,336 @@
|
||||
---
|
||||
name: unit-test-application-events
|
||||
description: Testing Spring application events (ApplicationEvent) with @EventListener and ApplicationEventPublisher. Test event publishing, listening, and async event handling in Spring Boot applications. Use when validating event-driven workflows in your Spring Boot services.
|
||||
category: testing
|
||||
tags: [junit-5, application-events, event-driven, listeners, publishers]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Application Events
|
||||
|
||||
Test Spring ApplicationEvent publishers and event listeners using JUnit 5. Verify event publishing, listener execution, and event propagation without full context startup.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing ApplicationEventPublisher event publishing
|
||||
- Testing @EventListener method invocation
|
||||
- Verifying event listener logic and side effects
|
||||
- Testing event propagation through listeners
|
||||
- Want fast event-driven architecture tests
|
||||
- Testing both synchronous and asynchronous event handling
|
||||
|
||||
## Setup: Event Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.springframework.boot:spring-boot-starter")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.mockito:mockito-core")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Event Publishing and Listening
|
||||
|
||||
### Custom Event and Publisher
|
||||
|
||||
```java
|
||||
// Custom application event
|
||||
public class UserCreatedEvent extends ApplicationEvent {
|
||||
private final User user;
|
||||
|
||||
public UserCreatedEvent(Object source, User user) {
|
||||
super(source);
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
public User getUser() {
|
||||
return user;
|
||||
}
|
||||
}
|
||||
|
||||
// Service that publishes events
|
||||
@Service
|
||||
public class UserService {
|
||||
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final UserRepository userRepository;
|
||||
|
||||
public UserService(ApplicationEventPublisher eventPublisher, UserRepository userRepository) {
|
||||
this.eventPublisher = eventPublisher;
|
||||
this.userRepository = userRepository;
|
||||
}
|
||||
|
||||
public User createUser(String name, String email) {
|
||||
User user = new User(name, email);
|
||||
User savedUser = userRepository.save(user);
|
||||
|
||||
eventPublisher.publishEvent(new UserCreatedEvent(this, savedUser));
|
||||
|
||||
return savedUser;
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class UserServiceEventTest {
|
||||
|
||||
@Mock
|
||||
private ApplicationEventPublisher eventPublisher;
|
||||
|
||||
@Mock
|
||||
private UserRepository userRepository;
|
||||
|
||||
@InjectMocks
|
||||
private UserService userService;
|
||||
|
||||
@Test
|
||||
void shouldPublishUserCreatedEvent() {
|
||||
User newUser = new User(1L, "Alice", "alice@example.com");
|
||||
when(userRepository.save(any(User.class))).thenReturn(newUser);
|
||||
|
||||
ArgumentCaptor<UserCreatedEvent> eventCaptor = ArgumentCaptor.forClass(UserCreatedEvent.class);
|
||||
|
||||
userService.createUser("Alice", "alice@example.com");
|
||||
|
||||
verify(eventPublisher).publishEvent(eventCaptor.capture());
|
||||
|
||||
UserCreatedEvent capturedEvent = eventCaptor.getValue();
|
||||
assertThat(capturedEvent.getUser()).isEqualTo(newUser);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Event Listeners
|
||||
|
||||
### @EventListener Annotation
|
||||
|
||||
```java
|
||||
// Event listener
|
||||
@Component
|
||||
public class UserEventListener {
|
||||
|
||||
private final EmailService emailService;
|
||||
|
||||
public UserEventListener(EmailService emailService) {
|
||||
this.emailService = emailService;
|
||||
}
|
||||
|
||||
@EventListener
|
||||
public void onUserCreated(UserCreatedEvent event) {
|
||||
User user = event.getUser();
|
||||
emailService.sendWelcomeEmail(user.getEmail());
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test for listener
|
||||
class UserEventListenerTest {
|
||||
|
||||
@Test
|
||||
void shouldSendWelcomeEmailWhenUserCreated() {
|
||||
EmailService emailService = mock(EmailService.class);
|
||||
UserEventListener listener = new UserEventListener(emailService);
|
||||
|
||||
User newUser = new User(1L, "Alice", "alice@example.com");
|
||||
UserCreatedEvent event = new UserCreatedEvent(this, newUser);
|
||||
|
||||
listener.onUserCreated(event);
|
||||
|
||||
verify(emailService).sendWelcomeEmail("alice@example.com");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldNotThrowExceptionWhenEmailServiceFails() {
|
||||
EmailService emailService = mock(EmailService.class);
|
||||
doThrow(new RuntimeException("Email service down"))
|
||||
.when(emailService).sendWelcomeEmail(any());
|
||||
|
||||
UserEventListener listener = new UserEventListener(emailService);
|
||||
User newUser = new User(1L, "Alice", "alice@example.com");
|
||||
UserCreatedEvent event = new UserCreatedEvent(this, newUser);
|
||||
|
||||
// Should handle exception gracefully
|
||||
assertThatCode(() -> listener.onUserCreated(event))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Multiple Listeners
|
||||
|
||||
### Event Propagation
|
||||
|
||||
```java
|
||||
class UserCreatedEvent extends ApplicationEvent {
|
||||
private final User user;
|
||||
private final List<String> notifications = new ArrayList<>();
|
||||
|
||||
public UserCreatedEvent(Object source, User user) {
|
||||
super(source);
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
public void addNotification(String notification) {
|
||||
notifications.add(notification);
|
||||
}
|
||||
|
||||
public List<String> getNotifications() {
|
||||
return notifications;
|
||||
}
|
||||
}
|
||||
|
||||
class MultiListenerTest {
|
||||
|
||||
@Test
|
||||
void shouldNotifyMultipleListenersSequentially() {
|
||||
EmailService emailService = mock(EmailService.class);
|
||||
NotificationService notificationService = mock(NotificationService.class);
|
||||
AnalyticsService analyticsService = mock(AnalyticsService.class);
|
||||
|
||||
UserEventListener emailListener = new UserEventListener(emailService);
|
||||
UserEventListener notificationListener = new UserEventListener(notificationService);
|
||||
UserEventListener analyticsListener = new UserEventListener(analyticsService);
|
||||
|
||||
User user = new User(1L, "Alice", "alice@example.com");
|
||||
UserCreatedEvent event = new UserCreatedEvent(this, user);
|
||||
|
||||
emailListener.onUserCreated(event);
|
||||
notificationListener.onUserCreated(event);
|
||||
analyticsListener.onUserCreated(event);
|
||||
|
||||
verify(emailService).send(any());
|
||||
verify(notificationService).notify(any());
|
||||
verify(analyticsService).track(any());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Conditional Event Listeners
|
||||
|
||||
### @EventListener with Condition
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class ConditionalEventListener {
|
||||
|
||||
@EventListener(condition = "#event.user.age > 18")
|
||||
public void onAdultUserCreated(UserCreatedEvent event) {
|
||||
// Handle adult user
|
||||
}
|
||||
}
|
||||
|
||||
class ConditionalListenerTest {
|
||||
|
||||
@Test
|
||||
void shouldProcessEventWhenConditionMatches() {
|
||||
// Test logic for matching condition
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldSkipEventWhenConditionDoesNotMatch() {
|
||||
// Test logic for non-matching condition
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Async Event Listeners
|
||||
|
||||
### @Async with @EventListener
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class AsyncEventListener {
|
||||
|
||||
private final SlowService slowService;
|
||||
|
||||
@EventListener
|
||||
@Async
|
||||
public void onUserCreatedAsync(UserCreatedEvent event) {
|
||||
slowService.processUser(event.getUser());
|
||||
}
|
||||
}
|
||||
|
||||
class AsyncEventListenerTest {
|
||||
|
||||
@Test
|
||||
void shouldProcessEventAsynchronously() throws Exception {
|
||||
SlowService slowService = mock(SlowService.class);
|
||||
AsyncEventListener listener = new AsyncEventListener(slowService);
|
||||
|
||||
User user = new User(1L, "Alice", "alice@example.com");
|
||||
UserCreatedEvent event = new UserCreatedEvent(this, user);
|
||||
|
||||
listener.onUserCreatedAsync(event);
|
||||
|
||||
// Event processed asynchronously
|
||||
Thread.sleep(100); // Wait for async completion
|
||||
verify(slowService).processUser(user);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Mock ApplicationEventPublisher** in unit tests
|
||||
- **Capture published events** using ArgumentCaptor
|
||||
- **Test listener side effects** explicitly
|
||||
- **Test error handling** in listeners
|
||||
- **Keep event listeners focused** on single responsibility
|
||||
- **Verify event data integrity** when capturing
|
||||
- **Test both sync and async** event processing
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Testing actual event publishing without mocking publisher
|
||||
- Not verifying listener invocation
|
||||
- Not capturing event details
|
||||
- Testing listener registration instead of logic
|
||||
- Not handling listener exceptions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Event not being captured**: Verify ArgumentCaptor type matches event class.
|
||||
|
||||
**Listener not invoked**: Ensure event is actually published and listener is registered.
|
||||
|
||||
**Async listener timing issues**: Use Thread.sleep() or Awaitility to wait for completion.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring ApplicationEvent](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/context/ApplicationEvent.html)
|
||||
- [Spring ApplicationEventPublisher](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/context/ApplicationEventPublisher.html)
|
||||
- [@EventListener Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/context/event/EventListener.html)
|
||||
476
skills/junit-test/unit-test-bean-validation/SKILL.md
Normal file
476
skills/junit-test/unit-test-bean-validation/SKILL.md
Normal file
@@ -0,0 +1,476 @@
|
||||
---
|
||||
name: unit-test-bean-validation
|
||||
description: Unit testing Jakarta Bean Validation (@Valid, @NotNull, @Min, @Max, etc.) with custom validators and constraint violations. Test validation logic without Spring context. Use when ensuring data integrity and validation rules are correct.
|
||||
category: testing
|
||||
tags: [junit-5, validation, bean-validation, jakarta-validation, constraints]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Bean Validation and Custom Validators
|
||||
|
||||
Test validation annotations and custom validator implementations using JUnit 5. Verify constraint violations, error messages, and validation logic in isolation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing Jakarta Bean Validation (@NotNull, @Email, @Min, etc.)
|
||||
- Testing custom @Constraint validators
|
||||
- Verifying constraint violation error messages
|
||||
- Testing cross-field validation logic
|
||||
- Want fast validation tests without Spring context
|
||||
- Testing complex validation scenarios and edge cases
|
||||
|
||||
## Setup: Bean Validation
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>jakarta.validation</groupId>
|
||||
<artifactId>jakarta.validation-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.hibernate.validator</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("jakarta.validation:jakarta.validation-api")
|
||||
testImplementation("org.hibernate.validator:hibernate-validator")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Testing Validation Constraints
|
||||
|
||||
### Setup Validator
|
||||
|
||||
```java
|
||||
import jakarta.validation.Validator;
|
||||
import jakarta.validation.ValidatorFactory;
|
||||
import jakarta.validation.Validation;
|
||||
import jakarta.validation.ConstraintViolation;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class UserValidationTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
ValidatorFactory factory = Validation.buildDefaultValidatorFactory();
|
||||
validator = factory.getValidator();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldPassValidationWithValidUser() {
|
||||
User user = new User("Alice", "alice@example.com", 25);
|
||||
|
||||
Set<ConstraintViolation<User>> violations = validator.validate(user);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailValidationWhenNameIsNull() {
|
||||
User user = new User(null, "alice@example.com", 25);
|
||||
|
||||
Set<ConstraintViolation<User>> violations = validator.validate(user);
|
||||
|
||||
assertThat(violations)
|
||||
.hasSize(1)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("must not be blank");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Individual Constraint Annotations
|
||||
|
||||
### Test @NotNull, @NotBlank, @Email
|
||||
|
||||
```java
|
||||
class UserDtoTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenEmailIsInvalid() {
|
||||
UserDto dto = new UserDto("Alice", "invalid-email");
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getPropertyPath)
|
||||
.extracting(Path::toString)
|
||||
.contains("email");
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("must be a valid email address");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenNameIsBlank() {
|
||||
UserDto dto = new UserDto(" ", "alice@example.com");
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getPropertyPath)
|
||||
.extracting(Path::toString)
|
||||
.contains("name");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenAgeIsNegative() {
|
||||
UserDto dto = new UserDto("Alice", "alice@example.com", -5);
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("must be greater than or equal to 0");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldPassWhenAllConstraintsSatisfied() {
|
||||
UserDto dto = new UserDto("Alice", "alice@example.com", 25);
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing @Min, @Max, @Size Constraints
|
||||
|
||||
```java
|
||||
class ProductDtoTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenPriceIsBelowMinimum() {
|
||||
ProductDto product = new ProductDto("Laptop", -100.0);
|
||||
|
||||
Set<ConstraintViolation<ProductDto>> violations = validator.validate(product);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("must be greater than 0");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenQuantityExceedsMaximum() {
|
||||
ProductDto product = new ProductDto("Laptop", 1000.0, 999999);
|
||||
|
||||
Set<ConstraintViolation<ProductDto>> violations = validator.validate(product);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("must be less than or equal to 10000");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenDescriptionTooLong() {
|
||||
String longDescription = "x".repeat(1001);
|
||||
ProductDto product = new ProductDto("Laptop", 1000.0, longDescription);
|
||||
|
||||
Set<ConstraintViolation<ProductDto>> violations = validator.validate(product);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("size must be between 0 and 1000");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Custom Validators
|
||||
|
||||
### Create and Test Custom Constraint
|
||||
|
||||
```java
|
||||
// Custom constraint annotation
|
||||
@Target(ElementType.FIELD)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Constraint(validatedBy = PhoneNumberValidator.class)
|
||||
public @interface ValidPhoneNumber {
|
||||
String message() default "invalid phone number format";
|
||||
Class<?>[] groups() default {};
|
||||
Class<? extends Payload>[] payload() default {};
|
||||
}
|
||||
|
||||
// Custom validator implementation
|
||||
public class PhoneNumberValidator implements ConstraintValidator<ValidPhoneNumber, String> {
|
||||
private static final String PHONE_PATTERN = "^\\d{3}-\\d{3}-\\d{4}$";
|
||||
|
||||
@Override
|
||||
public boolean isValid(String value, ConstraintValidatorContext context) {
|
||||
if (value == null) return true; // null values handled by @NotNull
|
||||
return value.matches(PHONE_PATTERN);
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test for custom validator
|
||||
class PhoneNumberValidatorTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldAcceptValidPhoneNumber() {
|
||||
Contact contact = new Contact("Alice", "555-123-4567");
|
||||
|
||||
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectInvalidPhoneNumberFormat() {
|
||||
Contact contact = new Contact("Alice", "5551234567"); // No dashes
|
||||
|
||||
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("invalid phone number format");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectPhoneNumberWithLetters() {
|
||||
Contact contact = new Contact("Alice", "ABC-DEF-GHIJ");
|
||||
|
||||
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
|
||||
|
||||
assertThat(violations).isNotEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldAllowNullPhoneNumber() {
|
||||
Contact contact = new Contact("Alice", null);
|
||||
|
||||
Set<ConstraintViolation<Contact>> violations = validator.validate(contact);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Cross-Field Validation
|
||||
|
||||
### Custom Multi-Field Constraint
|
||||
|
||||
```java
|
||||
// Custom constraint for cross-field validation
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Constraint(validatedBy = PasswordMatchValidator.class)
|
||||
public @interface PasswordsMatch {
|
||||
String message() default "passwords do not match";
|
||||
Class<?>[] groups() default {};
|
||||
Class<? extends Payload>[] payload() default {};
|
||||
}
|
||||
|
||||
// Validator implementation
|
||||
public class PasswordMatchValidator implements ConstraintValidator<PasswordsMatch, ChangePasswordRequest> {
|
||||
@Override
|
||||
public boolean isValid(ChangePasswordRequest value, ConstraintValidatorContext context) {
|
||||
if (value == null) return true;
|
||||
return value.getNewPassword().equals(value.getConfirmPassword());
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test
|
||||
class PasswordValidationTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldPassWhenPasswordsMatch() {
|
||||
ChangePasswordRequest request = new ChangePasswordRequest("oldPass", "newPass123", "newPass123");
|
||||
|
||||
Set<ConstraintViolation<ChangePasswordRequest>> violations = validator.validate(request);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWhenPasswordsDoNotMatch() {
|
||||
ChangePasswordRequest request = new ChangePasswordRequest("oldPass", "newPass123", "differentPass");
|
||||
|
||||
Set<ConstraintViolation<ChangePasswordRequest>> violations = validator.validate(request);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getMessage)
|
||||
.contains("passwords do not match");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Validation Groups
|
||||
|
||||
### Conditional Validation
|
||||
|
||||
```java
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public interface CreateValidation {}
|
||||
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public interface UpdateValidation {}
|
||||
|
||||
class UserDto {
|
||||
@NotNull(groups = {CreateValidation.class})
|
||||
private String name;
|
||||
|
||||
@Min(value = 1, groups = {CreateValidation.class, UpdateValidation.class})
|
||||
private int age;
|
||||
}
|
||||
|
||||
class ValidationGroupsTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRequireNameOnlyDuringCreation() {
|
||||
UserDto user = new UserDto(null, 25);
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(user, CreateValidation.class);
|
||||
|
||||
assertThat(violations)
|
||||
.extracting(ConstraintViolation::getPropertyPath)
|
||||
.extracting(Path::toString)
|
||||
.contains("name");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldAllowNullNameDuringUpdate() {
|
||||
UserDto user = new UserDto(null, 25);
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(user, UpdateValidation.class);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Parameterized Validation Scenarios
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
class EmailValidationTest {
|
||||
|
||||
private Validator validator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {
|
||||
"user@example.com",
|
||||
"john.doe+tag@example.co.uk",
|
||||
"admin123@subdomain.example.com"
|
||||
})
|
||||
void shouldAcceptValidEmails(String email) {
|
||||
UserDto user = new UserDto("Alice", email);
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(user);
|
||||
|
||||
assertThat(violations).isEmpty();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {
|
||||
"invalid-email",
|
||||
"user@",
|
||||
"@example.com",
|
||||
"user name@example.com"
|
||||
})
|
||||
void shouldRejectInvalidEmails(String email) {
|
||||
UserDto user = new UserDto("Alice", email);
|
||||
|
||||
Set<ConstraintViolation<UserDto>> violations = validator.validate(user);
|
||||
|
||||
assertThat(violations).isNotEmpty();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Validate at unit test level** before testing service/controller layers
|
||||
- **Test both valid and invalid cases** for every constraint
|
||||
- **Use custom validators** for business-specific validation rules
|
||||
- **Test error messages** to ensure they're user-friendly
|
||||
- **Test edge cases**: null, empty string, whitespace-only strings
|
||||
- **Use validation groups** for conditional validation rules
|
||||
- **Keep validator logic simple** - complex validation belongs in service tests
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Forgetting to test null values
|
||||
- Not extracting violation details (message, property, constraint type)
|
||||
- Testing validation at service/controller level instead of unit tests
|
||||
- Creating overly complex custom validators
|
||||
- Not documenting constraint purposes in error messages
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**ValidatorFactory not found**: Ensure `jakarta.validation-api` and `hibernate-validator` are on classpath.
|
||||
|
||||
**Custom validator not invoked**: Verify `@Constraint(validatedBy = YourValidator.class)` is correctly specified.
|
||||
|
||||
**Null handling confusion**: By default, `@NotNull` checks null, other constraints ignore null (use `@NotNull` with others for mandatory fields).
|
||||
|
||||
## References
|
||||
|
||||
- [Jakarta Bean Validation Spec](https://jakarta.ee/specifications/bean-validation/)
|
||||
- [Hibernate Validator Documentation](https://hibernate.org/validator/)
|
||||
- [Custom Constraints](https://docs.jboss.org/hibernate/stable/validator/reference/en-US/html_single/#validator-customconstraints)
|
||||
453
skills/junit-test/unit-test-boundary-conditions/SKILL.md
Normal file
453
skills/junit-test/unit-test-boundary-conditions/SKILL.md
Normal file
@@ -0,0 +1,453 @@
|
||||
---
|
||||
name: unit-test-boundary-conditions
|
||||
description: Edge case and boundary testing patterns for unit tests. Testing minimum/maximum values, null cases, empty collections, and numeric precision. Pure JUnit 5 unit tests. Use when ensuring code handles limits and special cases correctly.
|
||||
category: testing
|
||||
tags: [junit-5, boundary-testing, edge-cases, parameterized-test]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Boundary Conditions and Edge Cases
|
||||
|
||||
Test boundary conditions, edge cases, and limit values systematically. Verify code behavior at limits, with null/empty inputs, and overflow scenarios.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing minimum and maximum values
|
||||
- Testing null and empty inputs
|
||||
- Testing whitespace-only strings
|
||||
- Testing overflow/underflow scenarios
|
||||
- Testing collections with zero/one/many items
|
||||
- Verifying behavior at API boundaries
|
||||
- Want comprehensive edge case coverage
|
||||
|
||||
## Setup: Boundary Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter-params")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Numeric Boundary Testing
|
||||
|
||||
### Integer Limits
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class IntegerBoundaryTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {Integer.MIN_VALUE, Integer.MIN_VALUE + 1, 0, Integer.MAX_VALUE - 1, Integer.MAX_VALUE})
|
||||
void shouldHandleIntegerBoundaries(int value) {
|
||||
assertThat(value).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleIntegerOverflow() {
|
||||
int maxInt = Integer.MAX_VALUE;
|
||||
int result = Math.addExact(maxInt, 1); // Will throw ArithmeticException
|
||||
|
||||
assertThatThrownBy(() -> Math.addExact(Integer.MAX_VALUE, 1))
|
||||
.isInstanceOf(ArithmeticException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleIntegerUnderflow() {
|
||||
assertThatThrownBy(() -> Math.subtractExact(Integer.MIN_VALUE, 1))
|
||||
.isInstanceOf(ArithmeticException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleZero() {
|
||||
int result = MathUtils.divide(0, 5);
|
||||
assertThat(result).isZero();
|
||||
|
||||
assertThatThrownBy(() -> MathUtils.divide(5, 0))
|
||||
.isInstanceOf(ArithmeticException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## String Boundary Testing
|
||||
|
||||
### Null, Empty, and Whitespace
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
class StringBoundaryTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"", " ", " ", "\t", "\n"})
|
||||
void shouldConsiderEmptyAndWhitespaceAsInvalid(String input) {
|
||||
boolean result = StringUtils.isNotBlank(input);
|
||||
assertThat(result).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullString() {
|
||||
String result = StringUtils.trim(null);
|
||||
assertThat(result).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleSingleCharacter() {
|
||||
String result = StringUtils.capitalize("a");
|
||||
assertThat(result).isEqualTo("A");
|
||||
|
||||
String result2 = StringUtils.trim("x");
|
||||
assertThat(result2).isEqualTo("x");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleVeryLongString() {
|
||||
String longString = "x".repeat(1000000);
|
||||
|
||||
assertThat(longString.length()).isEqualTo(1000000);
|
||||
assertThat(StringUtils.isNotBlank(longString)).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleSpecialCharacters() {
|
||||
String special = "!@#$%^&*()_+-={}[]|\\:;<>?,./";
|
||||
|
||||
assertThat(StringUtils.length(special)).isEqualTo(31);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Collection Boundary Testing
|
||||
|
||||
### Empty, Single, and Large Collections
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
class CollectionBoundaryTest {
|
||||
|
||||
@Test
|
||||
void shouldHandleEmptyList() {
|
||||
List<String> empty = List.of();
|
||||
|
||||
assertThat(empty).isEmpty();
|
||||
assertThat(CollectionUtils.first(empty)).isNull();
|
||||
assertThat(CollectionUtils.count(empty)).isZero();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleSingleElementList() {
|
||||
List<String> single = List.of("only");
|
||||
|
||||
assertThat(single).hasSize(1);
|
||||
assertThat(CollectionUtils.first(single)).isEqualTo("only");
|
||||
assertThat(CollectionUtils.last(single)).isEqualTo("only");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleLargeList() {
|
||||
List<Integer> large = new ArrayList<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
large.add(i);
|
||||
}
|
||||
|
||||
assertThat(large).hasSize(100000);
|
||||
assertThat(CollectionUtils.first(large)).isZero();
|
||||
assertThat(CollectionUtils.last(large)).isEqualTo(99999);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullInCollection() {
|
||||
List<String> withNull = new ArrayList<>(List.of("a", null, "c"));
|
||||
|
||||
assertThat(withNull).contains(null);
|
||||
assertThat(CollectionUtils.filterNonNull(withNull)).hasSize(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleDuplicatesInCollection() {
|
||||
List<Integer> duplicates = List.of(1, 1, 2, 2, 3, 3);
|
||||
|
||||
assertThat(duplicates).hasSize(6);
|
||||
Set<Integer> unique = new HashSet<>(duplicates);
|
||||
assertThat(unique).hasSize(3);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Floating Point Boundary Testing
|
||||
|
||||
### Precision and Special Values
|
||||
|
||||
```java
|
||||
class FloatingPointBoundaryTest {
|
||||
|
||||
@Test
|
||||
void shouldHandleFloatingPointPrecision() {
|
||||
double result = 0.1 + 0.2;
|
||||
|
||||
// Floating point comparison needs tolerance
|
||||
assertThat(result).isCloseTo(0.3, within(0.0001));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleSpecialFloatingPointValues() {
|
||||
assertThat(Double.POSITIVE_INFINITY).isGreaterThan(Double.MAX_VALUE);
|
||||
assertThat(Double.NEGATIVE_INFINITY).isLessThan(Double.MIN_VALUE);
|
||||
assertThat(Double.NaN).isNotEqualTo(Double.NaN); // NaN != NaN
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleVerySmallAndLargeNumbers() {
|
||||
double tiny = Double.MIN_VALUE;
|
||||
double huge = Double.MAX_VALUE;
|
||||
|
||||
assertThat(tiny).isGreaterThan(0);
|
||||
assertThat(huge).isPositive();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleZeroInDivision() {
|
||||
double result = 1.0 / 0.0;
|
||||
|
||||
assertThat(result).isEqualTo(Double.POSITIVE_INFINITY);
|
||||
|
||||
double result2 = -1.0 / 0.0;
|
||||
assertThat(result2).isEqualTo(Double.NEGATIVE_INFINITY);
|
||||
|
||||
double result3 = 0.0 / 0.0;
|
||||
assertThat(result3).isNaN();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Date/Time Boundary Testing
|
||||
|
||||
### Min/Max Dates and Edge Cases
|
||||
|
||||
```java
|
||||
class DateTimeBoundaryTest {
|
||||
|
||||
@Test
|
||||
void shouldHandleMinAndMaxDates() {
|
||||
LocalDate min = LocalDate.MIN;
|
||||
LocalDate max = LocalDate.MAX;
|
||||
|
||||
assertThat(min).isBefore(max);
|
||||
assertThat(DateUtils.isValid(min)).isTrue();
|
||||
assertThat(DateUtils.isValid(max)).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleLeapYearBoundary() {
|
||||
LocalDate leapYearEnd = LocalDate.of(2024, 2, 29);
|
||||
|
||||
assertThat(leapYearEnd).isNotNull();
|
||||
assertThat(LocalDate.of(2024, 2, 29)).isEqualTo(leapYearEnd);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleInvalidDateInNonLeapYear() {
|
||||
assertThatThrownBy(() -> LocalDate.of(2023, 2, 29))
|
||||
.isInstanceOf(DateTimeException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleYearBoundaries() {
|
||||
LocalDate newYear = LocalDate.of(2024, 1, 1);
|
||||
LocalDate lastDay = LocalDate.of(2024, 12, 31);
|
||||
|
||||
assertThat(newYear).isBefore(lastDay);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleMidnightBoundary() {
|
||||
LocalTime midnight = LocalTime.MIDNIGHT;
|
||||
LocalTime almostMidnight = LocalTime.of(23, 59, 59);
|
||||
|
||||
assertThat(almostMidnight).isBefore(midnight);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Array Index Boundary Testing
|
||||
|
||||
### First, Last, and Out of Bounds
|
||||
|
||||
```java
|
||||
class ArrayBoundaryTest {
|
||||
|
||||
@Test
|
||||
void shouldHandleFirstElementAccess() {
|
||||
int[] array = {1, 2, 3, 4, 5};
|
||||
|
||||
assertThat(array[0]).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleLastElementAccess() {
|
||||
int[] array = {1, 2, 3, 4, 5};
|
||||
|
||||
assertThat(array[array.length - 1]).isEqualTo(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldThrowOnNegativeIndex() {
|
||||
int[] array = {1, 2, 3};
|
||||
|
||||
assertThatThrownBy(() -> {
|
||||
int value = array[-1];
|
||||
}).isInstanceOf(ArrayIndexOutOfBoundsException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldThrowOnOutOfBoundsIndex() {
|
||||
int[] array = {1, 2, 3};
|
||||
|
||||
assertThatThrownBy(() -> {
|
||||
int value = array[10];
|
||||
}).isInstanceOf(ArrayIndexOutOfBoundsException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleEmptyArray() {
|
||||
int[] empty = {};
|
||||
|
||||
assertThat(empty.length).isZero();
|
||||
assertThatThrownBy(() -> {
|
||||
int value = empty[0];
|
||||
}).isInstanceOf(ArrayIndexOutOfBoundsException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Concurrent and Thread Boundary Testing
|
||||
|
||||
### Null and Race Conditions
|
||||
|
||||
```java
|
||||
import java.util.concurrent.*;
|
||||
|
||||
class ConcurrentBoundaryTest {
|
||||
|
||||
@Test
|
||||
void shouldHandleNullInConcurrentMap() {
|
||||
ConcurrentHashMap<String, String> map = new ConcurrentHashMap<>();
|
||||
|
||||
map.put("key", "value");
|
||||
assertThat(map.get("nonexistent")).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleConcurrentModification() {
|
||||
List<Integer> list = new CopyOnWriteArrayList<>(List.of(1, 2, 3, 4, 5));
|
||||
|
||||
// Should not throw ConcurrentModificationException
|
||||
for (int num : list) {
|
||||
if (num == 3) {
|
||||
list.add(6);
|
||||
}
|
||||
}
|
||||
|
||||
assertThat(list).hasSize(6);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleEmptyBlockingQueue() throws InterruptedException {
|
||||
BlockingQueue<String> queue = new LinkedBlockingQueue<>();
|
||||
|
||||
assertThat(queue.poll()).isNull();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Parameterized Boundary Testing
|
||||
|
||||
### Multiple Boundary Cases
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
class ParameterizedBoundaryTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"null, false", // null
|
||||
"'', false", // empty
|
||||
"' ', false", // whitespace
|
||||
"a, true", // single char
|
||||
"abc, true" // normal
|
||||
})
|
||||
void shouldValidateStringBoundaries(String input, boolean expected) {
|
||||
boolean result = StringValidator.isValid(input);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {Integer.MIN_VALUE, 0, 1, -1, Integer.MAX_VALUE})
|
||||
void shouldHandleNumericBoundaries(int value) {
|
||||
assertThat(value).isNotNull();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Test explicitly at boundaries** - don't rely on random testing
|
||||
- **Test null and empty separately** from valid inputs
|
||||
- **Use parameterized tests** for multiple boundary cases
|
||||
- **Test both sides of boundaries** (just below, at, just above)
|
||||
- **Verify error messages** are helpful for invalid boundaries
|
||||
- **Document why** specific boundaries matter
|
||||
- **Test overflow/underflow** for numeric operations
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Testing only "happy path" without boundary cases
|
||||
- Forgetting null/empty cases
|
||||
- Not testing floating point precision
|
||||
- Not testing collection boundaries (empty, single, many)
|
||||
- Not testing string boundaries (null, empty, whitespace)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Floating point comparison fails**: Use `isCloseTo(expected, within(tolerance))`.
|
||||
|
||||
**Collection boundaries unclear**: List cases explicitly: empty (0), single (1), many (>1).
|
||||
|
||||
**Date boundary confusing**: Use `LocalDate.MIN`, `LocalDate.MAX` for clear boundaries.
|
||||
|
||||
## References
|
||||
|
||||
- [Integer.MIN_VALUE/MAX_VALUE](https://docs.oracle.com/javase/8/docs/api/java/lang/Integer.html)
|
||||
- [Double.MIN_VALUE/MAX_VALUE](https://docs.oracle.com/javase/8/docs/api/java/lang/Double.html)
|
||||
- [AssertJ Floating Point Assertions](https://assertj.github.io/assertj-core-features-highlight.html#assertions-on-numbers)
|
||||
- [Boundary Value Analysis](https://en.wikipedia.org/wiki/Boundary-value_analysis)
|
||||
401
skills/junit-test/unit-test-caching/SKILL.md
Normal file
401
skills/junit-test/unit-test-caching/SKILL.md
Normal file
@@ -0,0 +1,401 @@
|
||||
---
|
||||
name: unit-test-caching
|
||||
description: Unit tests for caching behavior using Spring Cache annotations (@Cacheable, @CachePut, @CacheEvict). Use when validating cache configuration and cache hit/miss scenarios.
|
||||
category: testing
|
||||
tags: [junit-5, caching, cacheable, cache-evict, cache-put]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Spring Caching
|
||||
|
||||
Test Spring caching annotations (@Cacheable, @CacheEvict, @CachePut) without full Spring context. Verify cache behavior, hits/misses, and invalidation strategies.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing @Cacheable method caching
|
||||
- Testing @CacheEvict cache invalidation
|
||||
- Testing @CachePut cache updates
|
||||
- Verifying cache key generation
|
||||
- Testing conditional caching
|
||||
- Want fast caching tests without Redis or cache infrastructure
|
||||
|
||||
## Setup: Caching Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-cache</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.springframework.boot:spring-boot-starter-cache")
|
||||
testImplementation("org.springframework.boot:spring-boot-starter-test")
|
||||
testImplementation("org.mockito:mockito-core")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Testing @Cacheable
|
||||
|
||||
### Cache Hit and Miss Behavior
|
||||
|
||||
```java
|
||||
// Service with caching
|
||||
@Service
|
||||
public class UserService {
|
||||
|
||||
private final UserRepository userRepository;
|
||||
|
||||
public UserService(UserRepository userRepository) {
|
||||
this.userRepository = userRepository;
|
||||
}
|
||||
|
||||
@Cacheable("users")
|
||||
public User getUserById(Long id) {
|
||||
return userRepository.findById(id).orElse(null);
|
||||
}
|
||||
}
|
||||
|
||||
// Test caching behavior
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.springframework.cache.CacheManager;
|
||||
import org.springframework.cache.annotation.EnableCaching;
|
||||
import org.springframework.cache.concurrent.ConcurrentMapCacheManager;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
@Configuration
|
||||
@EnableCaching
|
||||
class CacheTestConfig {
|
||||
@Bean
|
||||
public CacheManager cacheManager() {
|
||||
return new ConcurrentMapCacheManager("users");
|
||||
}
|
||||
}
|
||||
|
||||
class UserServiceCachingTest {
|
||||
|
||||
private UserRepository userRepository;
|
||||
private UserService userService;
|
||||
private CacheManager cacheManager;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
userRepository = mock(UserRepository.class);
|
||||
cacheManager = new ConcurrentMapCacheManager("users");
|
||||
userService = new UserService(userRepository);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCacheUserAfterFirstCall() {
|
||||
User user = new User(1L, "Alice");
|
||||
when(userRepository.findById(1L)).thenReturn(Optional.of(user));
|
||||
|
||||
User firstCall = userService.getUserById(1L);
|
||||
User secondCall = userService.getUserById(1L);
|
||||
|
||||
assertThat(firstCall).isEqualTo(secondCall);
|
||||
verify(userRepository, times(1)).findById(1L); // Called only once due to cache
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnCachedValueOnSecondCall() {
|
||||
User user = new User(1L, "Alice");
|
||||
when(userRepository.findById(1L)).thenReturn(Optional.of(user));
|
||||
|
||||
userService.getUserById(1L); // First call - hits database
|
||||
User cachedResult = userService.getUserById(1L); // Second call - hits cache
|
||||
|
||||
assertThat(cachedResult).isEqualTo(user);
|
||||
verify(userRepository, times(1)).findById(1L);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing @CacheEvict
|
||||
|
||||
### Cache Invalidation
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class ProductService {
|
||||
|
||||
private final ProductRepository productRepository;
|
||||
|
||||
public ProductService(ProductRepository productRepository) {
|
||||
this.productRepository = productRepository;
|
||||
}
|
||||
|
||||
@Cacheable("products")
|
||||
public Product getProductById(Long id) {
|
||||
return productRepository.findById(id).orElse(null);
|
||||
}
|
||||
|
||||
@CacheEvict("products")
|
||||
public void deleteProduct(Long id) {
|
||||
productRepository.deleteById(id);
|
||||
}
|
||||
|
||||
@CacheEvict(value = "products", allEntries = true)
|
||||
public void clearAllProducts() {
|
||||
// Clear entire cache
|
||||
}
|
||||
}
|
||||
|
||||
class ProductCacheEvictTest {
|
||||
|
||||
private ProductRepository productRepository;
|
||||
private ProductService productService;
|
||||
private CacheManager cacheManager;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
productRepository = mock(ProductRepository.class);
|
||||
cacheManager = new ConcurrentMapCacheManager("products");
|
||||
productService = new ProductService(productRepository);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldEvictProductFromCacheWhenDeleted() {
|
||||
Product product = new Product(1L, "Laptop", 999.99);
|
||||
when(productRepository.findById(1L)).thenReturn(Optional.of(product));
|
||||
|
||||
productService.getProductById(1L); // Cache the product
|
||||
|
||||
productService.deleteProduct(1L); // Evict from cache
|
||||
|
||||
User cachedAfterEvict = userService.getUserById(1L);
|
||||
|
||||
// After eviction, repository should be called again
|
||||
verify(productRepository, times(2)).findById(1L);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldClearAllEntriesFromCache() {
|
||||
Product product1 = new Product(1L, "Laptop", 999.99);
|
||||
Product product2 = new Product(2L, "Mouse", 29.99);
|
||||
when(productRepository.findById(1L)).thenReturn(Optional.of(product1));
|
||||
when(productRepository.findById(2L)).thenReturn(Optional.of(product2));
|
||||
|
||||
productService.getProductById(1L);
|
||||
productService.getProductById(2L);
|
||||
|
||||
productService.clearAllProducts(); // Clear all cache entries
|
||||
|
||||
productService.getProductById(1L);
|
||||
productService.getProductById(2L);
|
||||
|
||||
// Repository called twice for each product
|
||||
verify(productRepository, times(2)).findById(1L);
|
||||
verify(productRepository, times(2)).findById(2L);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing @CachePut
|
||||
|
||||
### Cache Update
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class OrderService {
|
||||
|
||||
private final OrderRepository orderRepository;
|
||||
|
||||
public OrderService(OrderRepository orderRepository) {
|
||||
this.orderRepository = orderRepository;
|
||||
}
|
||||
|
||||
@Cacheable("orders")
|
||||
public Order getOrder(Long id) {
|
||||
return orderRepository.findById(id).orElse(null);
|
||||
}
|
||||
|
||||
@CachePut(value = "orders", key = "#order.id")
|
||||
public Order updateOrder(Order order) {
|
||||
return orderRepository.save(order);
|
||||
}
|
||||
}
|
||||
|
||||
class OrderCachePutTest {
|
||||
|
||||
private OrderRepository orderRepository;
|
||||
private OrderService orderService;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
orderRepository = mock(OrderRepository.class);
|
||||
orderService = new OrderService(orderRepository);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldUpdateCacheWhenOrderIsUpdated() {
|
||||
Order originalOrder = new Order(1L, "Pending", 100.0);
|
||||
Order updatedOrder = new Order(1L, "Shipped", 100.0);
|
||||
|
||||
when(orderRepository.findById(1L)).thenReturn(Optional.of(originalOrder));
|
||||
when(orderRepository.save(updatedOrder)).thenReturn(updatedOrder);
|
||||
|
||||
orderService.getOrder(1L);
|
||||
Order result = orderService.updateOrder(updatedOrder);
|
||||
|
||||
assertThat(result.getStatus()).isEqualTo("Shipped");
|
||||
|
||||
// Next call should return updated version from cache
|
||||
Order cachedOrder = orderService.getOrder(1L);
|
||||
assertThat(cachedOrder.getStatus()).isEqualTo("Shipped");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Conditional Caching
|
||||
|
||||
### Cache with Conditions
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class DataService {
|
||||
|
||||
private final DataRepository dataRepository;
|
||||
|
||||
public DataService(DataRepository dataRepository) {
|
||||
this.dataRepository = dataRepository;
|
||||
}
|
||||
|
||||
@Cacheable(value = "data", unless = "#result == null")
|
||||
public Data getData(Long id) {
|
||||
return dataRepository.findById(id).orElse(null);
|
||||
}
|
||||
|
||||
@Cacheable(value = "users", condition = "#id > 0")
|
||||
public User getUser(Long id) {
|
||||
return userRepository.findById(id).orElse(null);
|
||||
}
|
||||
}
|
||||
|
||||
class ConditionalCachingTest {
|
||||
|
||||
@Test
|
||||
void shouldNotCacheNullResults() {
|
||||
DataRepository dataRepository = mock(DataRepository.class);
|
||||
when(dataRepository.findById(999L)).thenReturn(Optional.empty());
|
||||
|
||||
DataService service = new DataService(dataRepository);
|
||||
|
||||
service.getData(999L);
|
||||
service.getData(999L);
|
||||
|
||||
// Should call repository twice because null results are not cached
|
||||
verify(dataRepository, times(2)).findById(999L);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldNotCacheWhenConditionIsFalse() {
|
||||
UserRepository userRepository = mock(UserRepository.class);
|
||||
User user = new User(1L, "Alice");
|
||||
when(userRepository.findById(-1L)).thenReturn(Optional.of(user));
|
||||
|
||||
DataService service = new DataService(null);
|
||||
|
||||
service.getUser(-1L);
|
||||
service.getUser(-1L);
|
||||
|
||||
// Should call repository twice because id <= 0 doesn't match condition
|
||||
verify(userRepository, times(2)).findById(-1L);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Cache Keys
|
||||
|
||||
### Verify Cache Key Generation
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class InventoryService {
|
||||
|
||||
private final InventoryRepository inventoryRepository;
|
||||
|
||||
public InventoryService(InventoryRepository inventoryRepository) {
|
||||
this.inventoryRepository = inventoryRepository;
|
||||
}
|
||||
|
||||
@Cacheable(value = "inventory", key = "#productId + '-' + #warehouseId")
|
||||
public InventoryItem getInventory(Long productId, Long warehouseId) {
|
||||
return inventoryRepository.findByProductAndWarehouse(productId, warehouseId);
|
||||
}
|
||||
}
|
||||
|
||||
class CacheKeyTest {
|
||||
|
||||
@Test
|
||||
void shouldGenerateCorrectCacheKey() {
|
||||
InventoryRepository repository = mock(InventoryRepository.class);
|
||||
InventoryItem item = new InventoryItem(1L, 1L, 100);
|
||||
when(repository.findByProductAndWarehouse(1L, 1L)).thenReturn(item);
|
||||
|
||||
InventoryService service = new InventoryService(repository);
|
||||
|
||||
service.getInventory(1L, 1L); // Cache: "1-1"
|
||||
service.getInventory(1L, 1L); // Hit cache: "1-1"
|
||||
service.getInventory(2L, 1L); // Miss cache: "2-1"
|
||||
|
||||
verify(repository, times(2)).findByProductAndWarehouse(any(), any());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use in-memory CacheManager** for unit tests
|
||||
- **Verify repository calls** to confirm cache hits/misses
|
||||
- **Test both positive and negative** cache scenarios
|
||||
- **Test cache invalidation** thoroughly
|
||||
- **Test conditional caching** with various conditions
|
||||
- **Keep cache configuration simple** in tests
|
||||
- **Mock dependencies** that services use
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Testing actual cache infrastructure instead of caching logic
|
||||
- Not verifying repository call counts
|
||||
- Forgetting to test cache eviction
|
||||
- Not testing conditional caching
|
||||
- Not resetting cache between tests
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Cache not working in tests**: Ensure `@EnableCaching` is in test configuration.
|
||||
|
||||
**Wrong cache key generated**: Use `SpEL` syntax correctly in `@Cacheable(key = "...")`.
|
||||
|
||||
**Cache not evicting**: Verify `@CacheEvict` key matches stored key exactly.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring Caching Documentation](https://docs.spring.io/spring-framework/docs/current/reference/html/integration.html#cache)
|
||||
- [Spring Cache Abstractions](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/cache/annotation/Cacheable.html)
|
||||
- [SpEL in Caching](https://docs.spring.io/spring-framework/docs/current/reference/html/core.html#expressions)
|
||||
458
skills/junit-test/unit-test-config-properties/SKILL.md
Normal file
458
skills/junit-test/unit-test-config-properties/SKILL.md
Normal file
@@ -0,0 +1,458 @@
|
||||
---
|
||||
name: unit-test-config-properties
|
||||
description: Unit tests for @ConfigurationProperties classes with @ConfigurationPropertiesTest. Use when validating application configuration binding and validation.
|
||||
category: testing
|
||||
tags: [junit-5, configuration-properties, spring-profiles, property-binding]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Configuration Properties and Profiles
|
||||
|
||||
Test @ConfigurationProperties bindings, environment-specific configurations, and property validation using JUnit 5. Verify configuration loading without full Spring context startup.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing @ConfigurationProperties property binding
|
||||
- Testing property name mapping and type conversions
|
||||
- Verifying configuration validation
|
||||
- Testing environment-specific configurations
|
||||
- Testing nested property structures
|
||||
- Want fast configuration tests without Spring context
|
||||
|
||||
## Setup: Configuration Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-configuration-processor</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
annotationProcessor("org.springframework.boot:spring-boot-configuration-processor")
|
||||
testImplementation("org.springframework.boot:spring-boot-starter-test")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Testing ConfigurationProperties
|
||||
|
||||
### Simple Property Binding
|
||||
|
||||
```java
|
||||
// Configuration properties class
|
||||
@ConfigurationProperties(prefix = "app.security")
|
||||
@Data
|
||||
public class SecurityProperties {
|
||||
private String jwtSecret;
|
||||
private long jwtExpirationMs;
|
||||
private int maxLoginAttempts;
|
||||
private boolean enableTwoFactor;
|
||||
}
|
||||
|
||||
// Unit test
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class SecurityPropertiesTest {
|
||||
|
||||
@Test
|
||||
void shouldBindPropertiesFromEnvironment() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.security.jwtSecret=my-secret-key",
|
||||
"app.security.jwtExpirationMs=3600000",
|
||||
"app.security.maxLoginAttempts=5",
|
||||
"app.security.enableTwoFactor=true"
|
||||
)
|
||||
.withBean(SecurityProperties.class)
|
||||
.run(context -> {
|
||||
SecurityProperties props = context.getBean(SecurityProperties.class);
|
||||
|
||||
assertThat(props.getJwtSecret()).isEqualTo("my-secret-key");
|
||||
assertThat(props.getJwtExpirationMs()).isEqualTo(3600000L);
|
||||
assertThat(props.getMaxLoginAttempts()).isEqualTo(5);
|
||||
assertThat(props.isEnableTwoFactor()).isTrue();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldUseDefaultValuesWhenPropertiesNotProvided() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues("app.security.jwtSecret=key")
|
||||
.withBean(SecurityProperties.class)
|
||||
.run(context -> {
|
||||
SecurityProperties props = context.getBean(SecurityProperties.class);
|
||||
|
||||
assertThat(props.getJwtSecret()).isEqualTo("key");
|
||||
assertThat(props.getMaxLoginAttempts()).isZero();
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Nested Configuration Properties
|
||||
|
||||
### Complex Property Structure
|
||||
|
||||
```java
|
||||
@ConfigurationProperties(prefix = "app.database")
|
||||
@Data
|
||||
public class DatabaseProperties {
|
||||
private String url;
|
||||
private String username;
|
||||
private Pool pool = new Pool();
|
||||
private List<Replica> replicas = new ArrayList<>();
|
||||
|
||||
@Data
|
||||
public static class Pool {
|
||||
private int maxSize = 10;
|
||||
private int minIdle = 5;
|
||||
private long connectionTimeout = 30000;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Replica {
|
||||
private String name;
|
||||
private String url;
|
||||
private int priority;
|
||||
}
|
||||
}
|
||||
|
||||
class NestedPropertiesTest {
|
||||
|
||||
@Test
|
||||
void shouldBindNestedProperties() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.database.url=jdbc:mysql://localhost/db",
|
||||
"app.database.username=admin",
|
||||
"app.database.pool.maxSize=20",
|
||||
"app.database.pool.minIdle=10",
|
||||
"app.database.pool.connectionTimeout=60000"
|
||||
)
|
||||
.withBean(DatabaseProperties.class)
|
||||
.run(context -> {
|
||||
DatabaseProperties props = context.getBean(DatabaseProperties.class);
|
||||
|
||||
assertThat(props.getUrl()).isEqualTo("jdbc:mysql://localhost/db");
|
||||
assertThat(props.getPool().getMaxSize()).isEqualTo(20);
|
||||
assertThat(props.getPool().getConnectionTimeout()).isEqualTo(60000L);
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldBindListOfReplicas() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.database.replicas[0].name=replica-1",
|
||||
"app.database.replicas[0].url=jdbc:mysql://replica1/db",
|
||||
"app.database.replicas[0].priority=1",
|
||||
"app.database.replicas[1].name=replica-2",
|
||||
"app.database.replicas[1].url=jdbc:mysql://replica2/db",
|
||||
"app.database.replicas[1].priority=2"
|
||||
)
|
||||
.withBean(DatabaseProperties.class)
|
||||
.run(context -> {
|
||||
DatabaseProperties props = context.getBean(DatabaseProperties.class);
|
||||
|
||||
assertThat(props.getReplicas()).hasSize(2);
|
||||
assertThat(props.getReplicas().get(0).getName()).isEqualTo("replica-1");
|
||||
assertThat(props.getReplicas().get(1).getPriority()).isEqualTo(2);
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Property Validation
|
||||
|
||||
### Validate Configuration with Constraints
|
||||
|
||||
```java
|
||||
@ConfigurationProperties(prefix = "app.server")
|
||||
@Data
|
||||
@Validated
|
||||
public class ServerProperties {
|
||||
@NotBlank
|
||||
private String host;
|
||||
|
||||
@Min(1)
|
||||
@Max(65535)
|
||||
private int port = 8080;
|
||||
|
||||
@Positive
|
||||
private int threadPoolSize;
|
||||
|
||||
@Email
|
||||
private String adminEmail;
|
||||
}
|
||||
|
||||
class ConfigurationValidationTest {
|
||||
|
||||
@Test
|
||||
void shouldFailValidationWhenHostIsBlank() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.server.host=",
|
||||
"app.server.port=8080",
|
||||
"app.server.threadPoolSize=10"
|
||||
)
|
||||
.withBean(ServerProperties.class)
|
||||
.run(context -> {
|
||||
assertThat(context).hasFailed()
|
||||
.getFailure()
|
||||
.hasMessageContaining("host");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailValidationWhenPortOutOfRange() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.server.host=localhost",
|
||||
"app.server.port=99999",
|
||||
"app.server.threadPoolSize=10"
|
||||
)
|
||||
.withBean(ServerProperties.class)
|
||||
.run(context -> {
|
||||
assertThat(context).hasFailed();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldPassValidationWithValidConfiguration() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.server.host=localhost",
|
||||
"app.server.port=8080",
|
||||
"app.server.threadPoolSize=10",
|
||||
"app.server.adminEmail=admin@example.com"
|
||||
)
|
||||
.withBean(ServerProperties.class)
|
||||
.run(context -> {
|
||||
assertThat(context).hasNotFailed();
|
||||
ServerProperties props = context.getBean(ServerProperties.class);
|
||||
assertThat(props.getHost()).isEqualTo("localhost");
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Profile-Specific Configurations
|
||||
|
||||
### Environment-Specific Properties
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@Profile("prod")
|
||||
class ProductionConfiguration {
|
||||
@Bean
|
||||
public SecurityProperties securityProperties() {
|
||||
SecurityProperties props = new SecurityProperties();
|
||||
props.setEnableTwoFactor(true);
|
||||
props.setMaxLoginAttempts(3);
|
||||
return props;
|
||||
}
|
||||
}
|
||||
|
||||
@Configuration
|
||||
@Profile("dev")
|
||||
class DevelopmentConfiguration {
|
||||
@Bean
|
||||
public SecurityProperties securityProperties() {
|
||||
SecurityProperties props = new SecurityProperties();
|
||||
props.setEnableTwoFactor(false);
|
||||
props.setMaxLoginAttempts(999);
|
||||
return props;
|
||||
}
|
||||
}
|
||||
|
||||
class ProfileBasedConfigurationTest {
|
||||
|
||||
@Test
|
||||
void shouldLoadProductionConfiguration() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues("spring.profiles.active=prod")
|
||||
.withUserConfiguration(ProductionConfiguration.class)
|
||||
.run(context -> {
|
||||
SecurityProperties props = context.getBean(SecurityProperties.class);
|
||||
|
||||
assertThat(props.isEnableTwoFactor()).isTrue();
|
||||
assertThat(props.getMaxLoginAttempts()).isEqualTo(3);
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldLoadDevelopmentConfiguration() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues("spring.profiles.active=dev")
|
||||
.withUserConfiguration(DevelopmentConfiguration.class)
|
||||
.run(context -> {
|
||||
SecurityProperties props = context.getBean(SecurityProperties.class);
|
||||
|
||||
assertThat(props.isEnableTwoFactor()).isFalse();
|
||||
assertThat(props.getMaxLoginAttempts()).isEqualTo(999);
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Type Conversion
|
||||
|
||||
### Property Type Binding
|
||||
|
||||
```java
|
||||
@ConfigurationProperties(prefix = "app.features")
|
||||
@Data
|
||||
public class FeatureProperties {
|
||||
private Duration cacheExpiry = Duration.ofMinutes(10);
|
||||
private DataSize maxUploadSize = DataSize.ofMegabytes(100);
|
||||
private List<String> enabledFeatures;
|
||||
private Map<String, String> featureFlags;
|
||||
private Charset fileEncoding = StandardCharsets.UTF_8;
|
||||
}
|
||||
|
||||
class TypeConversionTest {
|
||||
|
||||
@Test
|
||||
void shouldConvertStringToDuration() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues("app.features.cacheExpiry=30s")
|
||||
.withBean(FeatureProperties.class)
|
||||
.run(context -> {
|
||||
FeatureProperties props = context.getBean(FeatureProperties.class);
|
||||
|
||||
assertThat(props.getCacheExpiry()).isEqualTo(Duration.ofSeconds(30));
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldConvertStringToDataSize() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues("app.features.maxUploadSize=50MB")
|
||||
.withBean(FeatureProperties.class)
|
||||
.run(context -> {
|
||||
FeatureProperties props = context.getBean(FeatureProperties.class);
|
||||
|
||||
assertThat(props.getMaxUploadSize()).isEqualTo(DataSize.ofMegabytes(50));
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldConvertCommaDelimitedListToList() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues("app.features.enabledFeatures=feature1,feature2,feature3")
|
||||
.withBean(FeatureProperties.class)
|
||||
.run(context -> {
|
||||
FeatureProperties props = context.getBean(FeatureProperties.class);
|
||||
|
||||
assertThat(props.getEnabledFeatures())
|
||||
.containsExactly("feature1", "feature2", "feature3");
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Property Binding with Default Values
|
||||
|
||||
### Verify Default Configuration
|
||||
|
||||
```java
|
||||
@ConfigurationProperties(prefix = "app.cache")
|
||||
@Data
|
||||
public class CacheProperties {
|
||||
private long ttlSeconds = 300;
|
||||
private int maxSize = 1000;
|
||||
private boolean enabled = true;
|
||||
private String cacheType = "IN_MEMORY";
|
||||
}
|
||||
|
||||
class DefaultValuesTest {
|
||||
|
||||
@Test
|
||||
void shouldUseDefaultValuesWhenNotSpecified() {
|
||||
new ApplicationContextRunner()
|
||||
.withBean(CacheProperties.class)
|
||||
.run(context -> {
|
||||
CacheProperties props = context.getBean(CacheProperties.class);
|
||||
|
||||
assertThat(props.getTtlSeconds()).isEqualTo(300L);
|
||||
assertThat(props.getMaxSize()).isEqualTo(1000);
|
||||
assertThat(props.isEnabled()).isTrue();
|
||||
assertThat(props.getCacheType()).isEqualTo("IN_MEMORY");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldOverrideDefaultValuesWithProvidedProperties() {
|
||||
new ApplicationContextRunner()
|
||||
.withPropertyValues(
|
||||
"app.cache.ttlSeconds=600",
|
||||
"app.cache.cacheType=REDIS"
|
||||
)
|
||||
.withBean(CacheProperties.class)
|
||||
.run(context -> {
|
||||
CacheProperties props = context.getBean(CacheProperties.class);
|
||||
|
||||
assertThat(props.getTtlSeconds()).isEqualTo(600L);
|
||||
assertThat(props.getCacheType()).isEqualTo("REDIS");
|
||||
assertThat(props.getMaxSize()).isEqualTo(1000); // Default unchanged
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Test all property bindings** including nested structures
|
||||
- **Test validation constraints** thoroughly
|
||||
- **Test both default and custom values**
|
||||
- **Use ApplicationContextRunner** for context-free testing
|
||||
- **Test profile-specific configurations** separately
|
||||
- **Verify type conversions** work correctly
|
||||
- **Test edge cases** (empty strings, null values, type mismatches)
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Not testing validation constraints
|
||||
- Forgetting to test default values
|
||||
- Not testing nested property structures
|
||||
- Testing with wrong property prefix
|
||||
- Not handling type conversion properly
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Properties not binding**: Verify prefix and property names match exactly (including kebab-case to camelCase conversion).
|
||||
|
||||
**Validation not triggered**: Ensure `@Validated` is present and validation dependencies are on classpath.
|
||||
|
||||
**ApplicationContextRunner not found**: Verify `spring-boot-starter-test` is in test dependencies.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring Boot ConfigurationProperties](https://docs.spring.io/spring-boot/docs/current/reference/html/configuration-metadata.html)
|
||||
- [ApplicationContextRunner Testing](https://docs.spring.io/spring-boot/docs/current/api/org/springframework/boot/test/context/runner/ApplicationContextRunner.html)
|
||||
- [Spring Profiles](https://docs.spring.io/spring-boot/docs/current/reference/html/features.html#features.profiles)
|
||||
351
skills/junit-test/unit-test-controller-layer/SKILL.md
Normal file
351
skills/junit-test/unit-test-controller-layer/SKILL.md
Normal file
@@ -0,0 +1,351 @@
|
||||
---
|
||||
name: unit-test-controller-layer
|
||||
description: Unit tests for REST controllers using MockMvc and @WebMvcTest. Test request/response mapping, validation, and exception handling. Use when testing web layer endpoints in isolation.
|
||||
category: testing
|
||||
tags: [junit-5, mockito, unit-testing, controller, rest, mockmvc]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing REST Controllers with MockMvc
|
||||
|
||||
Test @RestController and @Controller classes by mocking service dependencies and verifying HTTP responses, status codes, and serialization. Use MockMvc for lightweight controller testing without loading the full Spring context.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing REST controller request/response handling
|
||||
- Verifying HTTP status codes and response formats
|
||||
- Testing request parameter binding and validation
|
||||
- Mocking service layer for isolated controller tests
|
||||
- Testing content negotiation and response headers
|
||||
- Want fast controller tests without integration test overhead
|
||||
|
||||
## Setup: MockMvc + Mockito
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
testImplementation("org.springframework.boot:spring-boot-starter-test")
|
||||
testImplementation("org.mockito:mockito-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Testing GET Endpoint
|
||||
|
||||
### Simple GET Endpoint Test
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.test.web.servlet.MockMvc;
|
||||
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class UserControllerTest {
|
||||
|
||||
@Mock
|
||||
private UserService userService;
|
||||
|
||||
@InjectMocks
|
||||
private UserController userController;
|
||||
|
||||
private MockMvc mockMvc;
|
||||
|
||||
void setUp() {
|
||||
mockMvc = MockMvcBuilders.standaloneSetup(userController).build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnAllUsers() throws Exception {
|
||||
List<UserDto> users = List.of(
|
||||
new UserDto(1L, "Alice"),
|
||||
new UserDto(2L, "Bob")
|
||||
);
|
||||
when(userService.getAllUsers()).thenReturn(users);
|
||||
|
||||
mockMvc.perform(get("/api/users"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$").isArray())
|
||||
.andExpect(jsonPath("$[0].id").value(1))
|
||||
.andExpect(jsonPath("$[0].name").value("Alice"))
|
||||
.andExpect(jsonPath("$[1].id").value(2));
|
||||
|
||||
verify(userService, times(1)).getAllUsers();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnUserById() throws Exception {
|
||||
UserDto user = new UserDto(1L, "Alice");
|
||||
when(userService.getUserById(1L)).thenReturn(user);
|
||||
|
||||
mockMvc.perform(get("/api/users/1"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.id").value(1))
|
||||
.andExpect(jsonPath("$.name").value("Alice"));
|
||||
|
||||
verify(userService).getUserById(1L);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing POST Endpoint
|
||||
|
||||
### Create Resource with Request Body
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldCreateUserAndReturn201() throws Exception {
|
||||
UserCreateRequest request = new UserCreateRequest("Alice", "alice@example.com");
|
||||
UserDto createdUser = new UserDto(1L, "Alice", "alice@example.com");
|
||||
|
||||
when(userService.createUser(any(UserCreateRequest.class)))
|
||||
.thenReturn(createdUser);
|
||||
|
||||
mockMvc.perform(post("/api/users")
|
||||
.contentType("application/json")
|
||||
.content("{\"name\":\"Alice\",\"email\":\"alice@example.com\"}"))
|
||||
.andExpect(status().isCreated())
|
||||
.andExpect(jsonPath("$.id").value(1))
|
||||
.andExpect(jsonPath("$.name").value("Alice"))
|
||||
.andExpect(jsonPath("$.email").value("alice@example.com"));
|
||||
|
||||
verify(userService).createUser(any(UserCreateRequest.class));
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Error Scenarios
|
||||
|
||||
### Handle 404 Not Found
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldReturn404WhenUserNotFound() throws Exception {
|
||||
when(userService.getUserById(999L))
|
||||
.thenThrow(new UserNotFoundException("User not found"));
|
||||
|
||||
mockMvc.perform(get("/api/users/999"))
|
||||
.andExpect(status().isNotFound())
|
||||
.andExpect(jsonPath("$.error").value("User not found"));
|
||||
|
||||
verify(userService).getUserById(999L);
|
||||
}
|
||||
```
|
||||
|
||||
### Handle 400 Bad Request
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldReturn400WhenRequestBodyInvalid() throws Exception {
|
||||
mockMvc.perform(post("/api/users")
|
||||
.contentType("application/json")
|
||||
.content("{\"name\":\"\"}")) // Empty name
|
||||
.andExpect(status().isBadRequest())
|
||||
.andExpect(jsonPath("$.errors").isArray());
|
||||
}
|
||||
```
|
||||
|
||||
## Testing PUT/PATCH Endpoints
|
||||
|
||||
### Update Resource
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldUpdateUserAndReturn200() throws Exception {
|
||||
UserUpdateRequest request = new UserUpdateRequest("Alice Updated");
|
||||
UserDto updatedUser = new UserDto(1L, "Alice Updated");
|
||||
|
||||
when(userService.updateUser(eq(1L), any(UserUpdateRequest.class)))
|
||||
.thenReturn(updatedUser);
|
||||
|
||||
mockMvc.perform(put("/api/users/1")
|
||||
.contentType("application/json")
|
||||
.content("{\"name\":\"Alice Updated\"}"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.id").value(1))
|
||||
.andExpect(jsonPath("$.name").value("Alice Updated"));
|
||||
|
||||
verify(userService).updateUser(eq(1L), any(UserUpdateRequest.class));
|
||||
}
|
||||
```
|
||||
|
||||
## Testing DELETE Endpoint
|
||||
|
||||
### Delete Resource
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldDeleteUserAndReturn204() throws Exception {
|
||||
doNothing().when(userService).deleteUser(1L);
|
||||
|
||||
mockMvc.perform(delete("/api/users/1"))
|
||||
.andExpect(status().isNoContent());
|
||||
|
||||
verify(userService).deleteUser(1L);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturn404WhenDeletingNonExistentUser() throws Exception {
|
||||
doThrow(new UserNotFoundException("User not found"))
|
||||
.when(userService).deleteUser(999L);
|
||||
|
||||
mockMvc.perform(delete("/api/users/999"))
|
||||
.andExpect(status().isNotFound());
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Request Parameters
|
||||
|
||||
### Query Parameters
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldFilterUsersByName() throws Exception {
|
||||
List<UserDto> users = List.of(new UserDto(1L, "Alice"));
|
||||
when(userService.searchUsers("Alice")).thenReturn(users);
|
||||
|
||||
mockMvc.perform(get("/api/users/search?name=Alice"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$").isArray())
|
||||
.andExpect(jsonPath("$[0].name").value("Alice"));
|
||||
|
||||
verify(userService).searchUsers("Alice");
|
||||
}
|
||||
```
|
||||
|
||||
### Path Variables
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldGetUserByIdFromPath() throws Exception {
|
||||
UserDto user = new UserDto(123L, "Alice");
|
||||
when(userService.getUserById(123L)).thenReturn(user);
|
||||
|
||||
mockMvc.perform(get("/api/users/{id}", 123L))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.id").value(123));
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Response Headers
|
||||
|
||||
### Verify Response Headers
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldReturnCustomHeaders() throws Exception {
|
||||
when(userService.getAllUsers()).thenReturn(List.of());
|
||||
|
||||
mockMvc.perform(get("/api/users"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(header().exists("X-Total-Count"))
|
||||
.andExpect(header().string("X-Total-Count", "0"))
|
||||
.andExpect(header().string("Content-Type", containsString("application/json")));
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Request Headers
|
||||
|
||||
### Send Request Headers
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldRequireAuthorizationHeader() throws Exception {
|
||||
mockMvc.perform(get("/api/users"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
|
||||
mockMvc.perform(get("/api/users")
|
||||
.header("Authorization", "Bearer token123"))
|
||||
.andExpect(status().isOk());
|
||||
}
|
||||
```
|
||||
|
||||
## Content Negotiation
|
||||
|
||||
### Test Different Accept Headers
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldReturnJsonWhenAcceptHeaderIsJson() throws Exception {
|
||||
UserDto user = new UserDto(1L, "Alice");
|
||||
when(userService.getUserById(1L)).thenReturn(user);
|
||||
|
||||
mockMvc.perform(get("/api/users/1")
|
||||
.accept("application/json"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(content().contentType("application/json"));
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced: Testing Multiple Status Codes
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldReturnDifferentStatusCodesForDifferentScenarios() throws Exception {
|
||||
// Successful response
|
||||
when(userService.getUserById(1L)).thenReturn(new UserDto(1L, "Alice"));
|
||||
mockMvc.perform(get("/api/users/1"))
|
||||
.andExpect(status().isOk());
|
||||
|
||||
// Not found
|
||||
when(userService.getUserById(999L))
|
||||
.thenThrow(new UserNotFoundException("Not found"));
|
||||
mockMvc.perform(get("/api/users/999"))
|
||||
.andExpect(status().isNotFound());
|
||||
|
||||
// Unauthorized
|
||||
mockMvc.perform(get("/api/admin/users"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use standalone setup** when testing single controller: `MockMvcBuilders.standaloneSetup()`
|
||||
- **Mock service layer** - controllers should focus on HTTP handling
|
||||
- **Test happy path and error paths** thoroughly
|
||||
- **Verify service method calls** to ensure controller delegates correctly
|
||||
- **Use content() matchers** for response body validation
|
||||
- **Keep tests focused** on one endpoint behavior per test
|
||||
- **Use JsonPath** for fluent JSON response assertions
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- **Testing business logic in controller**: Move to service tests
|
||||
- **Not mocking service layer**: Always mock service dependencies
|
||||
- **Testing framework behavior**: Focus on your code, not Spring code
|
||||
- **Hardcoding URLs**: Use MockMvcRequestBuilders helpers
|
||||
- **Not verifying mock interactions**: Always verify service was called correctly
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Content type mismatch**: Ensure `contentType()` matches controller's `@PostMapping(consumes=...)` or use default.
|
||||
|
||||
**JsonPath not matching**: Use `mockMvc.perform(...).andDo(print())` to see actual response content.
|
||||
|
||||
**Status code assertions fail**: Check controller `@RequestMapping`, `@PostMapping` status codes and error handling.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring MockMvc Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/test/web/servlet/MockMvc.html)
|
||||
- [JsonPath for REST Assertions](https://goessner.net/articles/JsonPath/)
|
||||
- [Spring Testing Best Practices](https://docs.spring.io/spring-boot/docs/current/reference/html/features.html#features.testing)
|
||||
466
skills/junit-test/unit-test-exception-handler/SKILL.md
Normal file
466
skills/junit-test/unit-test-exception-handler/SKILL.md
Normal file
@@ -0,0 +1,466 @@
|
||||
---
|
||||
name: unit-test-exception-handler
|
||||
description: Unit tests for @ExceptionHandler and @ControllerAdvice for global exception handling. Use when validating error response formatting and HTTP status codes.
|
||||
category: testing
|
||||
tags: [junit-5, exception-handler, controller-advice, error-handling, mockmvc]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing ExceptionHandler and ControllerAdvice
|
||||
|
||||
Test exception handlers and global exception handling logic using MockMvc. Verify error response formatting, HTTP status codes, and exception-to-response mapping.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing @ExceptionHandler methods in @ControllerAdvice
|
||||
- Testing exception-to-error-response transformations
|
||||
- Verifying HTTP status codes for different exception types
|
||||
- Testing error message formatting and localization
|
||||
- Want fast exception handler tests without full integration tests
|
||||
|
||||
## Setup: Exception Handler Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.springframework.boot:spring-boot-starter-web")
|
||||
testImplementation("org.springframework.boot:spring-boot-starter-test")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Global Exception Handler
|
||||
|
||||
### Create Exception Handler
|
||||
|
||||
```java
|
||||
// Global exception handler
|
||||
@ControllerAdvice
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
@ExceptionHandler(ResourceNotFoundException.class)
|
||||
@ResponseStatus(HttpStatus.NOT_FOUND)
|
||||
public ErrorResponse handleResourceNotFound(ResourceNotFoundException ex) {
|
||||
return new ErrorResponse(
|
||||
HttpStatus.NOT_FOUND.value(),
|
||||
"Resource not found",
|
||||
ex.getMessage()
|
||||
);
|
||||
}
|
||||
|
||||
@ExceptionHandler(ValidationException.class)
|
||||
@ResponseStatus(HttpStatus.BAD_REQUEST)
|
||||
public ErrorResponse handleValidationException(ValidationException ex) {
|
||||
return new ErrorResponse(
|
||||
HttpStatus.BAD_REQUEST.value(),
|
||||
"Validation failed",
|
||||
ex.getMessage()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Error response DTO
|
||||
public record ErrorResponse(
|
||||
int status,
|
||||
String error,
|
||||
String message
|
||||
) {}
|
||||
```
|
||||
|
||||
### Unit Test Exception Handler
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.test.web.servlet.MockMvc;
|
||||
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class GlobalExceptionHandlerTest {
|
||||
|
||||
@InjectMocks
|
||||
private GlobalExceptionHandler exceptionHandler;
|
||||
|
||||
private MockMvc mockMvc;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
mockMvc = MockMvcBuilders
|
||||
.standaloneSetup(new TestController())
|
||||
.setControllerAdvice(exceptionHandler)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnNotFoundWhenResourceNotFoundException() throws Exception {
|
||||
mockMvc.perform(get("/api/users/999"))
|
||||
.andExpect(status().isNotFound())
|
||||
.andExpect(jsonPath("$.status").value(404))
|
||||
.andExpect(jsonPath("$.error").value("Resource not found"))
|
||||
.andExpect(jsonPath("$.message").value("User not found"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnBadRequestWhenValidationException() throws Exception {
|
||||
mockMvc.perform(post("/api/users")
|
||||
.contentType("application/json")
|
||||
.content("{\"name\":\"\"}"))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andExpect(jsonPath("$.status").value(400))
|
||||
.andExpect(jsonPath("$.error").value("Validation failed"));
|
||||
}
|
||||
}
|
||||
|
||||
// Test controller that throws exceptions
|
||||
@RestController
|
||||
@RequestMapping("/api")
|
||||
class TestController {
|
||||
|
||||
@GetMapping("/users/{id}")
|
||||
public User getUser(@PathVariable Long id) {
|
||||
throw new ResourceNotFoundException("User not found");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Multiple Exception Types
|
||||
|
||||
### Handle Various Exception Types
|
||||
|
||||
```java
|
||||
@ControllerAdvice
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
@ExceptionHandler(ResourceNotFoundException.class)
|
||||
@ResponseStatus(HttpStatus.NOT_FOUND)
|
||||
public ErrorResponse handleResourceNotFound(ResourceNotFoundException ex) {
|
||||
return new ErrorResponse(404, "Not found", ex.getMessage());
|
||||
}
|
||||
|
||||
@ExceptionHandler(DuplicateResourceException.class)
|
||||
@ResponseStatus(HttpStatus.CONFLICT)
|
||||
public ErrorResponse handleDuplicateResource(DuplicateResourceException ex) {
|
||||
return new ErrorResponse(409, "Conflict", ex.getMessage());
|
||||
}
|
||||
|
||||
@ExceptionHandler(UnauthorizedException.class)
|
||||
@ResponseStatus(HttpStatus.UNAUTHORIZED)
|
||||
public ErrorResponse handleUnauthorized(UnauthorizedException ex) {
|
||||
return new ErrorResponse(401, "Unauthorized", ex.getMessage());
|
||||
}
|
||||
|
||||
@ExceptionHandler(AccessDeniedException.class)
|
||||
@ResponseStatus(HttpStatus.FORBIDDEN)
|
||||
public ErrorResponse handleAccessDenied(AccessDeniedException ex) {
|
||||
return new ErrorResponse(403, "Forbidden", ex.getMessage());
|
||||
}
|
||||
|
||||
@ExceptionHandler(Exception.class)
|
||||
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
|
||||
public ErrorResponse handleGenericException(Exception ex) {
|
||||
return new ErrorResponse(500, "Internal server error", "An unexpected error occurred");
|
||||
}
|
||||
}
|
||||
|
||||
class MultiExceptionHandlerTest {
|
||||
|
||||
private MockMvc mockMvc;
|
||||
private GlobalExceptionHandler handler;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
handler = new GlobalExceptionHandler();
|
||||
mockMvc = MockMvcBuilders
|
||||
.standaloneSetup(new TestController())
|
||||
.setControllerAdvice(handler)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturn404ForNotFound() throws Exception {
|
||||
mockMvc.perform(get("/api/users/999"))
|
||||
.andExpect(status().isNotFound())
|
||||
.andExpect(jsonPath("$.status").value(404));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturn409ForDuplicate() throws Exception {
|
||||
mockMvc.perform(post("/api/users")
|
||||
.contentType("application/json")
|
||||
.content("{\"email\":\"existing@example.com\"}"))
|
||||
.andExpect(status().isConflict())
|
||||
.andExpect(jsonPath("$.status").value(409));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturn401ForUnauthorized() throws Exception {
|
||||
mockMvc.perform(get("/api/admin/dashboard"))
|
||||
.andExpect(status().isUnauthorized())
|
||||
.andExpect(jsonPath("$.status").value(401));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturn403ForAccessDenied() throws Exception {
|
||||
mockMvc.perform(get("/api/admin/users"))
|
||||
.andExpect(status().isForbidden())
|
||||
.andExpect(jsonPath("$.status").value(403));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturn500ForGenericException() throws Exception {
|
||||
mockMvc.perform(get("/api/error"))
|
||||
.andExpect(status().isInternalServerError())
|
||||
.andExpect(jsonPath("$.status").value(500));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Error Response Structure
|
||||
|
||||
### Verify Error Response Format
|
||||
|
||||
```java
|
||||
@ControllerAdvice
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
@ExceptionHandler(BadRequestException.class)
|
||||
@ResponseStatus(HttpStatus.BAD_REQUEST)
|
||||
public ResponseEntity<ErrorDetails> handleBadRequest(BadRequestException ex) {
|
||||
ErrorDetails details = new ErrorDetails(
|
||||
System.currentTimeMillis(),
|
||||
HttpStatus.BAD_REQUEST.value(),
|
||||
"Bad Request",
|
||||
ex.getMessage(),
|
||||
new Date()
|
||||
);
|
||||
return new ResponseEntity<>(details, HttpStatus.BAD_REQUEST);
|
||||
}
|
||||
}
|
||||
|
||||
class ErrorResponseStructureTest {
|
||||
|
||||
private MockMvc mockMvc;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
mockMvc = MockMvcBuilders
|
||||
.standaloneSetup(new TestController())
|
||||
.setControllerAdvice(new GlobalExceptionHandler())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldIncludeTimestampInErrorResponse() throws Exception {
|
||||
mockMvc.perform(post("/api/data")
|
||||
.contentType("application/json")
|
||||
.content("{}"))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andExpect(jsonPath("$.timestamp").exists())
|
||||
.andExpect(jsonPath("$.status").value(400))
|
||||
.andExpect(jsonPath("$.error").value("Bad Request"))
|
||||
.andExpect(jsonPath("$.message").exists())
|
||||
.andExpect(jsonPath("$.date").exists());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldIncludeAllRequiredErrorFields() throws Exception {
|
||||
MvcResult result = mockMvc.perform(get("/api/invalid"))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andReturn();
|
||||
|
||||
String response = result.getResponse().getContentAsString();
|
||||
|
||||
assertThat(response).contains("timestamp");
|
||||
assertThat(response).contains("status");
|
||||
assertThat(response).contains("error");
|
||||
assertThat(response).contains("message");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Validation Error Handling
|
||||
|
||||
### Handle MethodArgumentNotValidException
|
||||
|
||||
```java
|
||||
@ControllerAdvice
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
@ExceptionHandler(MethodArgumentNotValidException.class)
|
||||
@ResponseStatus(HttpStatus.BAD_REQUEST)
|
||||
public ValidationErrorResponse handleValidationException(
|
||||
MethodArgumentNotValidException ex) {
|
||||
|
||||
Map<String, String> errors = new HashMap<>();
|
||||
ex.getBindingResult().getFieldErrors().forEach(error ->
|
||||
errors.put(error.getField(), error.getDefaultMessage())
|
||||
);
|
||||
|
||||
return new ValidationErrorResponse(
|
||||
HttpStatus.BAD_REQUEST.value(),
|
||||
"Validation failed",
|
||||
errors
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class ValidationExceptionHandlerTest {
|
||||
|
||||
private MockMvc mockMvc;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
mockMvc = MockMvcBuilders
|
||||
.standaloneSetup(new UserController())
|
||||
.setControllerAdvice(new GlobalExceptionHandler())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnValidationErrorsForInvalidInput() throws Exception {
|
||||
mockMvc.perform(post("/api/users")
|
||||
.contentType("application/json")
|
||||
.content("{\"name\":\"\",\"age\":-5}"))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andExpect(jsonPath("$.status").value(400))
|
||||
.andExpect(jsonPath("$.errors.name").exists())
|
||||
.andExpect(jsonPath("$.errors.age").exists());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldIncludeErrorMessageForEachField() throws Exception {
|
||||
mockMvc.perform(post("/api/users")
|
||||
.contentType("application/json")
|
||||
.content("{\"name\":\"\",\"email\":\"invalid\"}"))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andExpect(jsonPath("$.errors.name").value("must not be blank"))
|
||||
.andExpect(jsonPath("$.errors.email").value("must be valid email"));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Exception Handler with Custom Logic
|
||||
|
||||
### Exception Handler with Context
|
||||
|
||||
```java
|
||||
@ControllerAdvice
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
private final MessageService messageService;
|
||||
private final LoggingService loggingService;
|
||||
|
||||
public GlobalExceptionHandler(MessageService messageService, LoggingService loggingService) {
|
||||
this.messageService = messageService;
|
||||
this.loggingService = loggingService;
|
||||
}
|
||||
|
||||
@ExceptionHandler(BusinessException.class)
|
||||
@ResponseStatus(HttpStatus.BAD_REQUEST)
|
||||
public ErrorResponse handleBusinessException(BusinessException ex, HttpServletRequest request) {
|
||||
loggingService.logException(ex, request.getRequestURI());
|
||||
|
||||
String localizedMessage = messageService.getMessage(ex.getErrorCode());
|
||||
return new ErrorResponse(
|
||||
HttpStatus.BAD_REQUEST.value(),
|
||||
"Business error",
|
||||
localizedMessage
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class ExceptionHandlerWithContextTest {
|
||||
|
||||
private MockMvc mockMvc;
|
||||
private GlobalExceptionHandler handler;
|
||||
private MessageService messageService;
|
||||
private LoggingService loggingService;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
messageService = mock(MessageService.class);
|
||||
loggingService = mock(LoggingService.class);
|
||||
handler = new GlobalExceptionHandler(messageService, loggingService);
|
||||
|
||||
mockMvc = MockMvcBuilders
|
||||
.standaloneSetup(new TestController())
|
||||
.setControllerAdvice(handler)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldLocalizeErrorMessage() throws Exception {
|
||||
when(messageService.getMessage("USER_NOT_FOUND"))
|
||||
.thenReturn("L'utilisateur n'a pas été trouvé");
|
||||
|
||||
mockMvc.perform(get("/api/users/999"))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andExpect(jsonPath("$.message").value("L'utilisateur n'a pas été trouvé"));
|
||||
|
||||
verify(messageService).getMessage("USER_NOT_FOUND");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldLogExceptionOccurrence() throws Exception {
|
||||
mockMvc.perform(get("/api/users/999"))
|
||||
.andExpect(status().isBadRequest());
|
||||
|
||||
verify(loggingService).logException(any(BusinessException.class), anyString());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Test all exception handlers** with real exception throws
|
||||
- **Verify HTTP status codes** for each exception type
|
||||
- **Test error response structure** to ensure consistency
|
||||
- **Verify logging** is triggered appropriately
|
||||
- **Use mock controllers** to throw exceptions in tests
|
||||
- **Test both happy and error paths**
|
||||
- **Keep error messages user-friendly** and consistent
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Not testing the full request path (use MockMvc with controller)
|
||||
- Forgetting to include `@ControllerAdvice` in MockMvc setup
|
||||
- Not verifying all required fields in error response
|
||||
- Testing handler logic instead of exception handling behavior
|
||||
- Not testing edge cases (null exceptions, unusual messages)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Exception handler not invoked**: Ensure controller is registered with MockMvc and actually throws the exception.
|
||||
|
||||
**JsonPath matchers not matching**: Use `.andDo(print())` to see actual response structure.
|
||||
|
||||
**Status code mismatch**: Verify `@ResponseStatus` annotation on handler method.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring ControllerAdvice Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/bind/annotation/ControllerAdvice.html)
|
||||
- [Spring ExceptionHandler](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/bind/annotation/ExceptionHandler.html)
|
||||
- [MockMvc Testing](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/test/web/servlet/MockMvc.html)
|
||||
398
skills/junit-test/unit-test-json-serialization/SKILL.md
Normal file
398
skills/junit-test/unit-test-json-serialization/SKILL.md
Normal file
@@ -0,0 +1,398 @@
|
||||
---
|
||||
name: unit-test-json-serialization
|
||||
description: Unit tests for JSON serialization/deserialization with Jackson and @JsonTest. Use when validating JSON mapping, custom serializers, and date format handling.
|
||||
category: testing
|
||||
tags: [junit-5, json-test, jackson, serialization, deserialization]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing JSON Serialization with @JsonTest
|
||||
|
||||
Test JSON serialization and deserialization of POJOs using Spring's @JsonTest. Verify Jackson configuration, custom serializers, and JSON mapping accuracy.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing JSON serialization of DTOs
|
||||
- Testing JSON deserialization to objects
|
||||
- Testing custom Jackson serializers/deserializers
|
||||
- Verifying JSON field names and formats
|
||||
- Testing null handling in JSON
|
||||
- Want fast JSON mapping tests without full Spring context
|
||||
|
||||
## Setup: JSON Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-json</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.springframework.boot:spring-boot-starter-json")
|
||||
implementation("com.fasterxml.jackson.core:jackson-databind")
|
||||
testImplementation("org.springframework.boot:spring-boot-starter-test")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: @JsonTest
|
||||
|
||||
### Test JSON Serialization
|
||||
|
||||
```java
|
||||
import org.springframework.boot.test.autoconfigure.json.JsonTest;
|
||||
import org.springframework.boot.test.json.JacksonTester;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
@JsonTest
|
||||
class UserDtoJsonTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<UserDto> json;
|
||||
|
||||
@Test
|
||||
void shouldSerializeUserToJson() throws Exception {
|
||||
UserDto user = new UserDto(1L, "Alice", "alice@example.com", 25);
|
||||
|
||||
org.assertj.core.data.Offset result = json.write(user);
|
||||
|
||||
result
|
||||
.extractingJsonPathNumberValue("$.id").isEqualTo(1)
|
||||
.extractingJsonPathStringValue("$.name").isEqualTo("Alice")
|
||||
.extractingJsonPathStringValue("$.email").isEqualTo("alice@example.com")
|
||||
.extractingJsonPathNumberValue("$.age").isEqualTo(25);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDeserializeJsonToUser() throws Exception {
|
||||
String json_content = "{\"id\":1,\"name\":\"Alice\",\"email\":\"alice@example.com\",\"age\":25}";
|
||||
|
||||
UserDto user = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(user)
|
||||
.isNotNull()
|
||||
.hasFieldOrPropertyWithValue("id", 1L)
|
||||
.hasFieldOrPropertyWithValue("name", "Alice")
|
||||
.hasFieldOrPropertyWithValue("email", "alice@example.com")
|
||||
.hasFieldOrPropertyWithValue("age", 25);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullFields() throws Exception {
|
||||
String json_content = "{\"id\":1,\"name\":null,\"email\":\"alice@example.com\",\"age\":null}";
|
||||
|
||||
UserDto user = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(user.getName()).isNull();
|
||||
assertThat(user.getAge()).isNull();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Custom JSON Properties
|
||||
|
||||
### @JsonProperty and @JsonIgnore
|
||||
|
||||
```java
|
||||
public class Order {
|
||||
@JsonProperty("order_id")
|
||||
private Long id;
|
||||
|
||||
@JsonProperty("total_amount")
|
||||
private BigDecimal amount;
|
||||
|
||||
@JsonIgnore
|
||||
private String internalNote;
|
||||
|
||||
private LocalDateTime createdAt;
|
||||
}
|
||||
|
||||
@JsonTest
|
||||
class OrderJsonTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<Order> json;
|
||||
|
||||
@Test
|
||||
void shouldMapJsonPropertyNames() throws Exception {
|
||||
String json_content = "{\"order_id\":123,\"total_amount\":99.99,\"createdAt\":\"2024-01-15T10:30:00\"}";
|
||||
|
||||
Order order = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(order.getId()).isEqualTo(123L);
|
||||
assertThat(order.getAmount()).isEqualByComparingTo(new BigDecimal("99.99"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldIgnoreJsonIgnoreAnnotatedFields() throws Exception {
|
||||
Order order = new Order(123L, new BigDecimal("99.99"));
|
||||
order.setInternalNote("Secret note");
|
||||
|
||||
JsonContent<Order> result = json.write(order);
|
||||
|
||||
assertThat(result.json).doesNotContain("internalNote");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing List Deserialization
|
||||
|
||||
### JSON Arrays
|
||||
|
||||
```java
|
||||
@JsonTest
|
||||
class UserListJsonTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<List<UserDto>> json;
|
||||
|
||||
@Test
|
||||
void shouldDeserializeUserList() throws Exception {
|
||||
String jsonArray = "[{\"id\":1,\"name\":\"Alice\"},{\"id\":2,\"name\":\"Bob\"}]";
|
||||
|
||||
List<UserDto> users = json.parseObject(jsonArray);
|
||||
|
||||
assertThat(users)
|
||||
.hasSize(2)
|
||||
.extracting(UserDto::getName)
|
||||
.containsExactly("Alice", "Bob");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldSerializeUserListToJson() throws Exception {
|
||||
List<UserDto> users = List.of(
|
||||
new UserDto(1L, "Alice"),
|
||||
new UserDto(2L, "Bob")
|
||||
);
|
||||
|
||||
JsonContent<List<UserDto>> result = json.write(users);
|
||||
|
||||
result.json.contains("Alice").contains("Bob");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Nested Objects
|
||||
|
||||
### Complex JSON Structures
|
||||
|
||||
```java
|
||||
public class Product {
|
||||
private Long id;
|
||||
private String name;
|
||||
private Category category;
|
||||
private List<Review> reviews;
|
||||
}
|
||||
|
||||
public class Category {
|
||||
private Long id;
|
||||
private String name;
|
||||
}
|
||||
|
||||
public class Review {
|
||||
private String reviewer;
|
||||
private int rating;
|
||||
private String comment;
|
||||
}
|
||||
|
||||
@JsonTest
|
||||
class ProductJsonTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<Product> json;
|
||||
|
||||
@Test
|
||||
void shouldSerializeNestedObjects() throws Exception {
|
||||
Category category = new Category(1L, "Electronics");
|
||||
Product product = new Product(1L, "Laptop", category);
|
||||
|
||||
JsonContent<Product> result = json.write(product);
|
||||
|
||||
result
|
||||
.extractingJsonPathNumberValue("$.id").isEqualTo(1)
|
||||
.extractingJsonPathStringValue("$.name").isEqualTo("Laptop")
|
||||
.extractingJsonPathNumberValue("$.category.id").isEqualTo(1)
|
||||
.extractingJsonPathStringValue("$.category.name").isEqualTo("Electronics");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDeserializeNestedObjects() throws Exception {
|
||||
String json_content = "{\"id\":1,\"name\":\"Laptop\",\"category\":{\"id\":1,\"name\":\"Electronics\"}}";
|
||||
|
||||
Product product = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(product.getCategory())
|
||||
.isNotNull()
|
||||
.hasFieldOrPropertyWithValue("name", "Electronics");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleListOfNestedObjects() throws Exception {
|
||||
String json_content = "{\"id\":1,\"name\":\"Laptop\",\"reviews\":[{\"reviewer\":\"John\",\"rating\":5},{\"reviewer\":\"Jane\",\"rating\":4}]}";
|
||||
|
||||
Product product = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(product.getReviews())
|
||||
.hasSize(2)
|
||||
.extracting(Review::getRating)
|
||||
.containsExactly(5, 4);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Date/Time Formatting
|
||||
|
||||
### LocalDateTime and Other Temporal Types
|
||||
|
||||
```java
|
||||
@JsonTest
|
||||
class DateTimeJsonTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<Event> json;
|
||||
|
||||
@Test
|
||||
void shouldFormatDateTimeCorrectly() throws Exception {
|
||||
LocalDateTime dateTime = LocalDateTime.of(2024, 1, 15, 10, 30, 0);
|
||||
Event event = new Event("Conference", dateTime);
|
||||
|
||||
JsonContent<Event> result = json.write(event);
|
||||
|
||||
result.extractingJsonPathStringValue("$.scheduledAt").isEqualTo("2024-01-15T10:30:00");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDeserializeDateTimeFromJson() throws Exception {
|
||||
String json_content = "{\"name\":\"Conference\",\"scheduledAt\":\"2024-01-15T10:30:00\"}";
|
||||
|
||||
Event event = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(event.getScheduledAt())
|
||||
.isEqualTo(LocalDateTime.of(2024, 1, 15, 10, 30, 0));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Custom Serializers
|
||||
|
||||
### Custom JsonSerializer Implementation
|
||||
|
||||
```java
|
||||
public class CustomMoneySerializer extends JsonSerializer<BigDecimal> {
|
||||
@Override
|
||||
public void serialize(BigDecimal value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
|
||||
if (value == null) {
|
||||
gen.writeNull();
|
||||
} else {
|
||||
gen.writeString(String.format("$%.2f", value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class Price {
|
||||
@JsonSerialize(using = CustomMoneySerializer.class)
|
||||
private BigDecimal amount;
|
||||
}
|
||||
|
||||
@JsonTest
|
||||
class CustomSerializerTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<Price> json;
|
||||
|
||||
@Test
|
||||
void shouldUseCustomSerializer() throws Exception {
|
||||
Price price = new Price(new BigDecimal("99.99"));
|
||||
|
||||
JsonContent<Price> result = json.write(price);
|
||||
|
||||
result.extractingJsonPathStringValue("$.amount").isEqualTo("$99.99");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Polymorphic Deserialization
|
||||
|
||||
### Type Information in JSON
|
||||
|
||||
```java
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
|
||||
@JsonSubTypes({
|
||||
@JsonSubTypes.Type(value = CreditCard.class, name = "credit_card"),
|
||||
@JsonSubTypes.Type(value = PayPal.class, name = "paypal")
|
||||
})
|
||||
public abstract class PaymentMethod {
|
||||
private String id;
|
||||
}
|
||||
|
||||
@JsonTest
|
||||
class PolymorphicJsonTest {
|
||||
|
||||
@Autowired
|
||||
private JacksonTester<PaymentMethod> json;
|
||||
|
||||
@Test
|
||||
void shouldDeserializeCreditCard() throws Exception {
|
||||
String json_content = "{\"type\":\"credit_card\",\"id\":\"card123\",\"cardNumber\":\"****1234\"}";
|
||||
|
||||
PaymentMethod method = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(method).isInstanceOf(CreditCard.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDeserializePayPal() throws Exception {
|
||||
String json_content = "{\"type\":\"paypal\",\"id\":\"pp123\",\"email\":\"user@paypal.com\"}";
|
||||
|
||||
PaymentMethod method = json.parse(json_content).getObject();
|
||||
|
||||
assertThat(method).isInstanceOf(PayPal.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use @JsonTest** for focused JSON testing
|
||||
- **Test both serialization and deserialization**
|
||||
- **Test null handling** and missing fields
|
||||
- **Test nested and complex structures**
|
||||
- **Verify field name mapping** with @JsonProperty
|
||||
- **Test date/time formatting** thoroughly
|
||||
- **Test edge cases** (empty strings, empty collections)
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Not testing null values
|
||||
- Not testing nested objects
|
||||
- Forgetting to test field name mappings
|
||||
- Not verifying JSON property presence/absence
|
||||
- Not testing deserialization of invalid JSON
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**JacksonTester not available**: Ensure class is annotated with `@JsonTest`.
|
||||
|
||||
**Field name doesn't match**: Check @JsonProperty annotation and Jackson configuration.
|
||||
|
||||
**DateTime parsing fails**: Verify date format matches Jackson's expected format.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring @JsonTest Documentation](https://docs.spring.io/spring-boot/docs/current/api/org/springframework/boot/test/autoconfigure/json/JsonTest.html)
|
||||
- [Jackson ObjectMapper](https://fasterxml.github.io/jackson-databind/javadoc/2.15/com/fasterxml/jackson/databind/ObjectMapper.html)
|
||||
- [JSON Annotations](https://fasterxml.github.io/jackson-annotations/javadoc/2.15/)
|
||||
434
skills/junit-test/unit-test-mapper-converter/SKILL.md
Normal file
434
skills/junit-test/unit-test-mapper-converter/SKILL.md
Normal file
@@ -0,0 +1,434 @@
|
||||
---
|
||||
name: unit-test-mapper-converter
|
||||
description: Unit tests for mappers and converters (MapStruct, custom mappers). Test object transformation logic in isolation. Use when ensuring correct data transformation between DTOs and domain objects.
|
||||
category: testing
|
||||
tags: [junit-5, mapstruct, mapper, dto, entity, converter]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Mappers and Converters
|
||||
|
||||
Test MapStruct mappers and custom converter classes. Verify field mapping accuracy, null handling, type conversions, and nested object transformations.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing MapStruct mapper implementations
|
||||
- Testing custom entity-to-DTO converters
|
||||
- Testing nested object mapping
|
||||
- Verifying null handling in mappers
|
||||
- Testing type conversions and transformations
|
||||
- Want comprehensive mapping test coverage before integration tests
|
||||
|
||||
## Setup: Testing Mappers
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.mapstruct</groupId>
|
||||
<artifactId>mapstruct</artifactId>
|
||||
<version>1.5.5.Final</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.mapstruct:mapstruct:1.5.5.Final")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Testing MapStruct Mapper
|
||||
|
||||
### Simple Entity to DTO Mapping
|
||||
|
||||
```java
|
||||
// Mapper interface
|
||||
@Mapper(componentModel = "spring")
|
||||
public interface UserMapper {
|
||||
UserDto toDto(User user);
|
||||
User toEntity(UserDto dto);
|
||||
List<UserDto> toDtos(List<User> users);
|
||||
}
|
||||
|
||||
// Unit test
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class UserMapperTest {
|
||||
|
||||
private final UserMapper userMapper = Mappers.getMapper(UserMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldMapUserToDto() {
|
||||
User user = new User(1L, "Alice", "alice@example.com", 25);
|
||||
|
||||
UserDto dto = userMapper.toDto(user);
|
||||
|
||||
assertThat(dto)
|
||||
.isNotNull()
|
||||
.extracting("id", "name", "email", "age")
|
||||
.containsExactly(1L, "Alice", "alice@example.com", 25);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldMapDtoToEntity() {
|
||||
UserDto dto = new UserDto(1L, "Alice", "alice@example.com", 25);
|
||||
|
||||
User user = userMapper.toEntity(dto);
|
||||
|
||||
assertThat(user)
|
||||
.isNotNull()
|
||||
.hasFieldOrPropertyWithValue("id", 1L)
|
||||
.hasFieldOrPropertyWithValue("name", "Alice");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldMapListOfUsers() {
|
||||
List<User> users = List.of(
|
||||
new User(1L, "Alice", "alice@example.com", 25),
|
||||
new User(2L, "Bob", "bob@example.com", 30)
|
||||
);
|
||||
|
||||
List<UserDto> dtos = userMapper.toDtos(users);
|
||||
|
||||
assertThat(dtos)
|
||||
.hasSize(2)
|
||||
.extracting(UserDto::getName)
|
||||
.containsExactly("Alice", "Bob");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullEntity() {
|
||||
UserDto dto = userMapper.toDto(null);
|
||||
|
||||
assertThat(dto).isNull();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Nested Object Mapping
|
||||
|
||||
### Map Complex Hierarchies
|
||||
|
||||
```java
|
||||
// Entities with nesting
|
||||
class User {
|
||||
private Long id;
|
||||
private String name;
|
||||
private Address address;
|
||||
private List<Phone> phones;
|
||||
}
|
||||
|
||||
// Mapper with nested mapping
|
||||
@Mapper(componentModel = "spring")
|
||||
public interface UserMapper {
|
||||
UserDto toDto(User user);
|
||||
User toEntity(UserDto dto);
|
||||
}
|
||||
|
||||
// Unit test for nested objects
|
||||
class NestedObjectMapperTest {
|
||||
|
||||
private final UserMapper userMapper = Mappers.getMapper(UserMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldMapNestedAddress() {
|
||||
Address address = new Address("123 Main St", "New York", "NY", "10001");
|
||||
User user = new User(1L, "Alice", address);
|
||||
|
||||
UserDto dto = userMapper.toDto(user);
|
||||
|
||||
assertThat(dto.getAddress())
|
||||
.isNotNull()
|
||||
.hasFieldOrPropertyWithValue("street", "123 Main St")
|
||||
.hasFieldOrPropertyWithValue("city", "New York");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldMapListOfNestedPhones() {
|
||||
List<Phone> phones = List.of(
|
||||
new Phone("123-456-7890", "MOBILE"),
|
||||
new Phone("987-654-3210", "HOME")
|
||||
);
|
||||
User user = new User(1L, "Alice", null, phones);
|
||||
|
||||
UserDto dto = userMapper.toDto(user);
|
||||
|
||||
assertThat(dto.getPhones())
|
||||
.hasSize(2)
|
||||
.extracting(PhoneDto::getNumber)
|
||||
.containsExactly("123-456-7890", "987-654-3210");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullNestedObjects() {
|
||||
User user = new User(1L, "Alice", null);
|
||||
|
||||
UserDto dto = userMapper.toDto(user);
|
||||
|
||||
assertThat(dto.getAddress()).isNull();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Custom Mapping Methods
|
||||
|
||||
### Mapper with @Mapping Annotations
|
||||
|
||||
```java
|
||||
@Mapper(componentModel = "spring")
|
||||
public interface ProductMapper {
|
||||
@Mapping(source = "name", target = "productName")
|
||||
@Mapping(source = "price", target = "salePrice")
|
||||
@Mapping(target = "discount", expression = "java(product.getPrice() * 0.1)")
|
||||
ProductDto toDto(Product product);
|
||||
|
||||
@Mapping(source = "productName", target = "name")
|
||||
@Mapping(source = "salePrice", target = "price")
|
||||
Product toEntity(ProductDto dto);
|
||||
}
|
||||
|
||||
class CustomMappingTest {
|
||||
|
||||
private final ProductMapper mapper = Mappers.getMapper(ProductMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldMapFieldsWithCustomNames() {
|
||||
Product product = new Product(1L, "Laptop", 999.99);
|
||||
|
||||
ProductDto dto = mapper.toDto(product);
|
||||
|
||||
assertThat(dto)
|
||||
.hasFieldOrPropertyWithValue("productName", "Laptop")
|
||||
.hasFieldOrPropertyWithValue("salePrice", 999.99);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCalculateDiscountFromExpression() {
|
||||
Product product = new Product(1L, "Laptop", 100.0);
|
||||
|
||||
ProductDto dto = mapper.toDto(product);
|
||||
|
||||
assertThat(dto.getDiscount()).isEqualTo(10.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReverseMapCustomFields() {
|
||||
ProductDto dto = new ProductDto(1L, "Laptop", 999.99);
|
||||
|
||||
Product product = mapper.toEntity(dto);
|
||||
|
||||
assertThat(product)
|
||||
.hasFieldOrPropertyWithValue("name", "Laptop")
|
||||
.hasFieldOrPropertyWithValue("price", 999.99);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Enum Mapping
|
||||
|
||||
### Map Enums Between Entity and DTO
|
||||
|
||||
```java
|
||||
// Enum with different representation
|
||||
enum UserStatus { ACTIVE, INACTIVE, SUSPENDED }
|
||||
enum UserStatusDto { ENABLED, DISABLED, LOCKED }
|
||||
|
||||
@Mapper(componentModel = "spring")
|
||||
public interface UserMapper {
|
||||
@ValueMapping(source = "ACTIVE", target = "ENABLED")
|
||||
@ValueMapping(source = "INACTIVE", target = "DISABLED")
|
||||
@ValueMapping(source = "SUSPENDED", target = "LOCKED")
|
||||
UserStatusDto toStatusDto(UserStatus status);
|
||||
}
|
||||
|
||||
class EnumMapperTest {
|
||||
|
||||
private final UserMapper mapper = Mappers.getMapper(UserMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldMapActiveToEnabled() {
|
||||
UserStatusDto dto = mapper.toStatusDto(UserStatus.ACTIVE);
|
||||
assertThat(dto).isEqualTo(UserStatusDto.ENABLED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldMapSuspendedToLocked() {
|
||||
UserStatusDto dto = mapper.toStatusDto(UserStatus.SUSPENDED);
|
||||
assertThat(dto).isEqualTo(UserStatusDto.LOCKED);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Custom Type Conversions
|
||||
|
||||
### Non-MapStruct Custom Converter
|
||||
|
||||
```java
|
||||
// Custom converter class
|
||||
public class DateFormatter {
|
||||
private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
|
||||
|
||||
public static String format(LocalDate date) {
|
||||
return date != null ? date.format(formatter) : null;
|
||||
}
|
||||
|
||||
public static LocalDate parse(String dateString) {
|
||||
return dateString != null ? LocalDate.parse(dateString, formatter) : null;
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test
|
||||
class DateFormatterTest {
|
||||
|
||||
@Test
|
||||
void shouldFormatLocalDateToString() {
|
||||
LocalDate date = LocalDate.of(2024, 1, 15);
|
||||
|
||||
String result = DateFormatter.format(date);
|
||||
|
||||
assertThat(result).isEqualTo("2024-01-15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldParseStringToLocalDate() {
|
||||
String dateString = "2024-01-15";
|
||||
|
||||
LocalDate result = DateFormatter.parse(dateString);
|
||||
|
||||
assertThat(result).isEqualTo(LocalDate.of(2024, 1, 15));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullInFormat() {
|
||||
String result = DateFormatter.format(null);
|
||||
assertThat(result).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleInvalidDateFormat() {
|
||||
assertThatThrownBy(() -> DateFormatter.parse("invalid-date"))
|
||||
.isInstanceOf(DateTimeParseException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Bidirectional Mapping
|
||||
|
||||
### Entity ↔ DTO Round Trip
|
||||
|
||||
```java
|
||||
class BidirectionalMapperTest {
|
||||
|
||||
private final UserMapper mapper = Mappers.getMapper(UserMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldMaintainDataInRoundTrip() {
|
||||
User original = new User(1L, "Alice", "alice@example.com", 25);
|
||||
|
||||
UserDto dto = mapper.toDto(original);
|
||||
User restored = mapper.toEntity(dto);
|
||||
|
||||
assertThat(restored)
|
||||
.hasFieldOrPropertyWithValue("id", original.getId())
|
||||
.hasFieldOrPropertyWithValue("name", original.getName())
|
||||
.hasFieldOrPropertyWithValue("email", original.getEmail())
|
||||
.hasFieldOrPropertyWithValue("age", original.getAge());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldPreserveAllFieldsInBothDirections() {
|
||||
Address address = new Address("123 Main", "NYC", "NY", "10001");
|
||||
User user = new User(1L, "Alice", "alice@example.com", 25, address);
|
||||
|
||||
UserDto dto = mapper.toDto(user);
|
||||
User restored = mapper.toEntity(dto);
|
||||
|
||||
assertThat(restored).usingRecursiveComparison().isEqualTo(user);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Partial Mapping
|
||||
|
||||
### Update Existing Entity from DTO
|
||||
|
||||
```java
|
||||
@Mapper(componentModel = "spring")
|
||||
public interface UserMapper {
|
||||
void updateEntity(@MappingTarget User entity, UserDto dto);
|
||||
}
|
||||
|
||||
class PartialMapperTest {
|
||||
|
||||
private final UserMapper mapper = Mappers.getMapper(UserMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldUpdateExistingEntity() {
|
||||
User existing = new User(1L, "Alice", "alice@old.com", 25);
|
||||
UserDto dto = new UserDto(1L, "Alice", "alice@new.com", 26);
|
||||
|
||||
mapper.updateEntity(existing, dto);
|
||||
|
||||
assertThat(existing)
|
||||
.hasFieldOrPropertyWithValue("email", "alice@new.com")
|
||||
.hasFieldOrPropertyWithValue("age", 26);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldNotUpdateFieldsNotInDto() {
|
||||
User existing = new User(1L, "Alice", "alice@example.com", 25);
|
||||
UserDto dto = new UserDto(1L, "Bob", null, 0);
|
||||
|
||||
mapper.updateEntity(existing, dto);
|
||||
|
||||
// Assuming null-aware mapping is configured
|
||||
assertThat(existing.getEmail()).isEqualTo("alice@example.com");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Test all mapper methods** comprehensively
|
||||
- **Verify null handling** for every nullable field
|
||||
- **Test nested objects** independently and together
|
||||
- **Use recursive comparison** for complex nested structures
|
||||
- **Test bidirectional mapping** to catch asymmetries
|
||||
- **Keep mapper tests simple and focused** on transformation correctness
|
||||
- **Use Mappers.getMapper()** for non-Spring standalone tests
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Not testing null input cases
|
||||
- Not verifying nested object mappings
|
||||
- Assuming bidirectional mapping is symmetric
|
||||
- Not testing edge cases (empty collections, etc.)
|
||||
- Tight coupling of mapper tests to MapStruct internals
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Null pointer exceptions during mapping**: Check `nullValuePropertyMappingStrategy` and `nullValueCheckStrategy` in `@Mapper`.
|
||||
|
||||
**Enum mapping not working**: Verify `@ValueMapping` annotations correctly map source to target values.
|
||||
|
||||
**Nested mapping produces null**: Ensure nested mapper interfaces are also mapped in parent mapper.
|
||||
|
||||
## References
|
||||
|
||||
- [MapStruct Official Documentation](https://mapstruct.org/)
|
||||
- [MapStruct Mapping Strategies](https://mapstruct.org/documentation/stable/reference/html/)
|
||||
- [JUnit 5 Best Practices](https://junit.org/junit5/docs/current/user-guide/)
|
||||
374
skills/junit-test/unit-test-parameterized/SKILL.md
Normal file
374
skills/junit-test/unit-test-parameterized/SKILL.md
Normal file
@@ -0,0 +1,374 @@
|
||||
---
|
||||
name: unit-test-parameterized
|
||||
description: Parameterized testing patterns with @ParameterizedTest, @ValueSource, @CsvSource. Run single test method with multiple input combinations. Use when testing multiple scenarios with similar logic.
|
||||
category: testing
|
||||
tags: [junit-5, parameterized-test, value-source, csv-source, method-source]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Parameterized Unit Tests with JUnit 5
|
||||
|
||||
Write efficient parameterized unit tests that run the same test logic with multiple input values. Reduce test duplication and improve test coverage using @ParameterizedTest.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing methods with multiple valid inputs
|
||||
- Testing boundary values systematically
|
||||
- Testing multiple invalid inputs for error cases
|
||||
- Want to reduce test duplication
|
||||
- Testing multiple scenarios with similar assertions
|
||||
- Need data-driven testing approach
|
||||
|
||||
## Setup: Parameterized Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: @ValueSource
|
||||
|
||||
### Simple Value Testing
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class StringUtilsTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"hello", "world", "test"})
|
||||
void shouldCapitalizeAllStrings(String input) {
|
||||
String result = StringUtils.capitalize(input);
|
||||
assertThat(result).startsWith(input.substring(0, 1).toUpperCase());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {1, 2, 3, 4, 5})
|
||||
void shouldBePositive(int number) {
|
||||
assertThat(number).isPositive();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void shouldHandleBothBooleanValues(boolean value) {
|
||||
assertThat(value).isNotNull();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## @MethodSource for Complex Data
|
||||
|
||||
### Factory Method Data Source
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
class CalculatorTest {
|
||||
|
||||
static Stream<org.junit.jupiter.params.provider.Arguments> additionTestCases() {
|
||||
return Stream.of(
|
||||
Arguments.of(1, 2, 3),
|
||||
Arguments.of(0, 0, 0),
|
||||
Arguments.of(-1, 1, 0),
|
||||
Arguments.of(100, 200, 300),
|
||||
Arguments.of(-5, -10, -15)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("additionTestCases")
|
||||
void shouldAddNumbersCorrectly(int a, int b, int expected) {
|
||||
int result = Calculator.add(a, b);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## @CsvSource for Tabular Data
|
||||
|
||||
### CSV-Based Test Data
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
class UserValidationTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"alice@example.com, true",
|
||||
"bob@gmail.com, true",
|
||||
"invalid-email, false",
|
||||
"user@, false",
|
||||
"@example.com, false",
|
||||
"user name@example.com, false"
|
||||
})
|
||||
void shouldValidateEmailAddresses(String email, boolean expected) {
|
||||
boolean result = UserValidator.isValidEmail(email);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"123-456-7890, true",
|
||||
"555-123-4567, true",
|
||||
"1234567890, false",
|
||||
"123-45-6789, false",
|
||||
"abc-def-ghij, false"
|
||||
})
|
||||
void shouldValidatePhoneNumbers(String phone, boolean expected) {
|
||||
boolean result = PhoneValidator.isValid(phone);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## @CsvFileSource for External Data
|
||||
|
||||
### CSV File-Based Testing
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvFileSource;
|
||||
|
||||
class PriceCalculationTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvFileSource(resources = "/test-data/prices.csv", numLinesToSkip = 1)
|
||||
void shouldCalculateTotalPrice(String product, double price, int quantity, double expected) {
|
||||
double total = PriceCalculator.calculateTotal(price, quantity);
|
||||
assertThat(total).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
|
||||
// test-data/prices.csv:
|
||||
// product,price,quantity,expected
|
||||
// Laptop,999.99,1,999.99
|
||||
// Mouse,29.99,3,89.97
|
||||
// Keyboard,79.99,2,159.98
|
||||
```
|
||||
|
||||
## @EnumSource for Enum Testing
|
||||
|
||||
### Enum-Based Test Data
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
|
||||
enum Status { ACTIVE, INACTIVE, PENDING, DELETED }
|
||||
|
||||
class StatusHandlerTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(Status.class)
|
||||
void shouldHandleAllStatuses(Status status) {
|
||||
assertThat(status).isNotNull();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = Status.class, names = {"ACTIVE", "INACTIVE"})
|
||||
void shouldHandleSpecificStatuses(Status status) {
|
||||
assertThat(status).isIn(Status.ACTIVE, Status.INACTIVE);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = Status.class, mode = EnumSource.Mode.EXCLUDE, names = {"DELETED"})
|
||||
void shouldHandleStatusesExcludingDeleted(Status status) {
|
||||
assertThat(status).isNotEqualTo(Status.DELETED);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Display Names
|
||||
|
||||
### Readable Test Output
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
class DiscountCalculationTest {
|
||||
|
||||
@ParameterizedTest(name = "Discount of {0}% should be calculated correctly")
|
||||
@ValueSource(ints = {5, 10, 15, 20})
|
||||
void shouldApplyDiscount(int discountPercent) {
|
||||
double originalPrice = 100.0;
|
||||
double discounted = DiscountCalculator.apply(originalPrice, discountPercent);
|
||||
double expected = originalPrice * (1 - discountPercent / 100.0);
|
||||
|
||||
assertThat(discounted).isEqualTo(expected);
|
||||
}
|
||||
|
||||
@ParameterizedTest(name = "User role {0} should have {1} permissions")
|
||||
@CsvSource({
|
||||
"ADMIN, 100",
|
||||
"MANAGER, 50",
|
||||
"USER, 10"
|
||||
})
|
||||
void shouldHaveCorrectPermissions(String role, int expectedPermissions) {
|
||||
User user = new User(role);
|
||||
assertThat(user.getPermissionCount()).isEqualTo(expectedPermissions);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Combining Multiple Sources
|
||||
|
||||
### ArgumentsProvider for Complex Scenarios
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.ArgumentsProvider;
|
||||
import org.junit.jupiter.params.provider.ArgumentsSource;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
class RangeValidatorArgumentProvider implements ArgumentsProvider {
|
||||
@Override
|
||||
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
|
||||
return Stream.of(
|
||||
Arguments.of(0, 0, 100, true), // Min boundary
|
||||
Arguments.of(100, 0, 100, true), // Max boundary
|
||||
Arguments.of(50, 0, 100, true), // Middle value
|
||||
Arguments.of(-1, 0, 100, false), // Below range
|
||||
Arguments.of(101, 0, 100, false) // Above range
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class RangeValidatorTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ArgumentsSource(RangeValidatorArgumentProvider.class)
|
||||
void shouldValidateRangeCorrectly(int value, int min, int max, boolean expected) {
|
||||
boolean result = RangeValidator.isInRange(value, min, max);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Edge Cases with Parameters
|
||||
|
||||
### Boundary Value Analysis
|
||||
|
||||
```java
|
||||
class BoundaryValueTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {
|
||||
Integer.MIN_VALUE, // Absolute minimum
|
||||
Integer.MIN_VALUE + 1, // Just above minimum
|
||||
-1, // Negative boundary
|
||||
0, // Zero boundary
|
||||
1, // Just above zero
|
||||
Integer.MAX_VALUE - 1, // Just below maximum
|
||||
Integer.MAX_VALUE // Absolute maximum
|
||||
})
|
||||
void shouldHandleAllBoundaryValues(int value) {
|
||||
int incremented = MathUtils.increment(value);
|
||||
assertThat(incremented).isNotLessThan(value);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
", false", // null
|
||||
"'', false", // empty
|
||||
"' ', false", // whitespace only
|
||||
"a, true", // single character
|
||||
"abc, true" // normal
|
||||
})
|
||||
void shouldValidateStrings(String input, boolean expected) {
|
||||
boolean result = StringValidator.isValid(input);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Repeat Tests
|
||||
|
||||
### Run Same Test Multiple Times
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.RepeatedTest;
|
||||
|
||||
class ConcurrencyTest {
|
||||
|
||||
@RepeatedTest(100)
|
||||
void shouldHandleConcurrentAccess() {
|
||||
// Test that might reveal race conditions if run multiple times
|
||||
AtomicInteger counter = new AtomicInteger(0);
|
||||
counter.incrementAndGet();
|
||||
assertThat(counter.get()).isEqualTo(1);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use @ParameterizedTest** to reduce test duplication
|
||||
- **Use descriptive display names** with `(name = "...")`
|
||||
- **Test boundary values** systematically
|
||||
- **Keep test logic simple** - focus on single assertion
|
||||
- **Organize test data logically** - group similar scenarios
|
||||
- **Use @MethodSource** for complex test data
|
||||
- **Use @CsvSource** for tabular test data
|
||||
- **Document expected behavior** in test names
|
||||
|
||||
## Common Patterns
|
||||
|
||||
**Testing error conditions**:
|
||||
```java
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"", " ", null})
|
||||
void shouldThrowExceptionForInvalidInput(String input) {
|
||||
assertThatThrownBy(() -> Parser.parse(input))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
```
|
||||
|
||||
**Testing multiple valid inputs**:
|
||||
```java
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {1, 2, 3, 5, 8, 13})
|
||||
void shouldBeInFibonacciSequence(int number) {
|
||||
assertThat(FibonacciChecker.isFibonacci(number)).isTrue();
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Parameter not matching**: Verify number and type of parameters match test method signature.
|
||||
|
||||
**Display name not showing**: Check parameter syntax in `name = "..."`.
|
||||
|
||||
**CSV parsing error**: Ensure CSV format is correct and quote strings containing commas.
|
||||
|
||||
## References
|
||||
|
||||
- [JUnit 5 Parameterized Tests](https://junit.org/junit5/docs/current/user-guide/#writing-tests-parameterized-tests)
|
||||
- [@ParameterizedTest Documentation](https://junit.org/junit5/docs/current/api/org.junit.jupiter.params/org/junit/jupiter/params/ParameterizedTest.html)
|
||||
- [Boundary Value Analysis](https://en.wikipedia.org/wiki/Boundary-value_analysis)
|
||||
434
skills/junit-test/unit-test-scheduled-async/SKILL.md
Normal file
434
skills/junit-test/unit-test-scheduled-async/SKILL.md
Normal file
@@ -0,0 +1,434 @@
|
||||
---
|
||||
name: unit-test-scheduled-async
|
||||
description: Unit tests for scheduled and async tasks using @Scheduled and @Async. Mock task execution and timing. Use when validating asynchronous operations and scheduling behavior.
|
||||
category: testing
|
||||
tags: [junit-5, scheduled, async, concurrency, completablefuture]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing @Scheduled and @Async Methods
|
||||
|
||||
Test scheduled tasks and async methods using JUnit 5 without running the actual scheduler. Verify execution logic, timing, and asynchronous behavior.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing @Scheduled method logic
|
||||
- Testing @Async method behavior
|
||||
- Verifying CompletableFuture results
|
||||
- Testing async error handling
|
||||
- Want fast tests without actual scheduling
|
||||
- Testing background task logic in isolation
|
||||
|
||||
## Setup: Async/Scheduled Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.springframework.boot:spring-boot-starter")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.awaitility:awaitility")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Testing @Async Methods
|
||||
|
||||
### Basic Async Testing with CompletableFuture
|
||||
|
||||
```java
|
||||
// Service with async methods
|
||||
@Service
|
||||
public class EmailService {
|
||||
|
||||
@Async
|
||||
public CompletableFuture<Boolean> sendEmailAsync(String to, String subject) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
// Simulate email sending
|
||||
System.out.println("Sending email to " + to);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@Async
|
||||
public void notifyUser(String userId) {
|
||||
System.out.println("Notifying user: " + userId);
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class EmailServiceAsyncTest {
|
||||
|
||||
@Test
|
||||
void shouldReturnCompletedFutureWhenSendingEmail() throws Exception {
|
||||
EmailService service = new EmailService();
|
||||
|
||||
CompletableFuture<Boolean> result = service.sendEmailAsync("test@example.com", "Hello");
|
||||
|
||||
Boolean success = result.get(); // Wait for completion
|
||||
assertThat(success).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCompleteWithinTimeout() {
|
||||
EmailService service = new EmailService();
|
||||
|
||||
CompletableFuture<Boolean> result = service.sendEmailAsync("test@example.com", "Hello");
|
||||
|
||||
assertThat(result)
|
||||
.isCompletedWithValue(true);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Async with Mocked Dependencies
|
||||
|
||||
### Async Service with Dependencies
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class UserNotificationService {
|
||||
|
||||
private final EmailService emailService;
|
||||
private final SmsService smsService;
|
||||
|
||||
public UserNotificationService(EmailService emailService, SmsService smsService) {
|
||||
this.emailService = emailService;
|
||||
this.smsService = smsService;
|
||||
}
|
||||
|
||||
@Async
|
||||
public CompletableFuture<String> notifyUserAsync(String userId) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
emailService.send(userId);
|
||||
smsService.send(userId);
|
||||
return "Notification sent";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class UserNotificationServiceAsyncTest {
|
||||
|
||||
@Mock
|
||||
private EmailService emailService;
|
||||
|
||||
@Mock
|
||||
private SmsService smsService;
|
||||
|
||||
@InjectMocks
|
||||
private UserNotificationService notificationService;
|
||||
|
||||
@Test
|
||||
void shouldNotifyUserAsynchronously() throws Exception {
|
||||
CompletableFuture<String> result = notificationService.notifyUserAsync("user123");
|
||||
|
||||
String message = result.get();
|
||||
assertThat(message).isEqualTo("Notification sent");
|
||||
|
||||
verify(emailService).send("user123");
|
||||
verify(smsService).send("user123");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleAsyncExceptionGracefully() {
|
||||
doThrow(new RuntimeException("Email service failed"))
|
||||
.when(emailService).send(any());
|
||||
|
||||
CompletableFuture<String> result = notificationService.notifyUserAsync("user123");
|
||||
|
||||
assertThatThrownBy(result::get)
|
||||
.isInstanceOf(ExecutionException.class)
|
||||
.hasCauseInstanceOf(RuntimeException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing @Scheduled Methods
|
||||
|
||||
### Mock Task Execution
|
||||
|
||||
```java
|
||||
// Scheduled task
|
||||
@Component
|
||||
public class DataRefreshTask {
|
||||
|
||||
private final DataRepository dataRepository;
|
||||
|
||||
public DataRefreshTask(DataRepository dataRepository) {
|
||||
this.dataRepository = dataRepository;
|
||||
}
|
||||
|
||||
@Scheduled(fixedDelay = 60000)
|
||||
public void refreshCache() {
|
||||
List<Data> data = dataRepository.findAll();
|
||||
// Update cache
|
||||
}
|
||||
|
||||
@Scheduled(cron = "0 0 * * * *") // Every hour
|
||||
public void cleanupOldData() {
|
||||
dataRepository.deleteOldData(LocalDateTime.now().minusDays(30));
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test - test logic without actual scheduling
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class DataRefreshTaskTest {
|
||||
|
||||
@Mock
|
||||
private DataRepository dataRepository;
|
||||
|
||||
@InjectMocks
|
||||
private DataRefreshTask dataRefreshTask;
|
||||
|
||||
@Test
|
||||
void shouldRefreshCacheFromRepository() {
|
||||
List<Data> expectedData = List.of(new Data(1L, "item1"));
|
||||
when(dataRepository.findAll()).thenReturn(expectedData);
|
||||
|
||||
dataRefreshTask.refreshCache(); // Call method directly
|
||||
|
||||
verify(dataRepository).findAll();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCleanupOldData() {
|
||||
LocalDateTime cutoffDate = LocalDateTime.now().minusDays(30);
|
||||
|
||||
dataRefreshTask.cleanupOldData();
|
||||
|
||||
verify(dataRepository).deleteOldData(any(LocalDateTime.class));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Async with Awaility
|
||||
|
||||
### Wait for Async Completion
|
||||
|
||||
```java
|
||||
import org.awaitility.Awaitility;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@Service
|
||||
public class BackgroundWorker {
|
||||
|
||||
private final AtomicInteger processedCount = new AtomicInteger(0);
|
||||
|
||||
@Async
|
||||
public void processItems(List<String> items) {
|
||||
items.forEach(item -> {
|
||||
// Process item
|
||||
processedCount.incrementAndGet();
|
||||
});
|
||||
}
|
||||
|
||||
public int getProcessedCount() {
|
||||
return processedCount.get();
|
||||
}
|
||||
}
|
||||
|
||||
class AwaitilityAsyncTest {
|
||||
|
||||
@Test
|
||||
void shouldProcessAllItemsAsynchronously() {
|
||||
BackgroundWorker worker = new BackgroundWorker();
|
||||
List<String> items = List.of("item1", "item2", "item3");
|
||||
|
||||
worker.processItems(items);
|
||||
|
||||
// Wait for async operation to complete (up to 5 seconds)
|
||||
Awaitility.await()
|
||||
.atMost(Duration.ofSeconds(5))
|
||||
.pollInterval(Duration.ofMillis(100))
|
||||
.untilAsserted(() -> {
|
||||
assertThat(worker.getProcessedCount()).isEqualTo(3);
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldTimeoutWhenProcessingTakesTooLong() {
|
||||
BackgroundWorker worker = new BackgroundWorker();
|
||||
List<String> items = List.of("item1", "item2", "item3");
|
||||
|
||||
worker.processItems(items);
|
||||
|
||||
assertThatThrownBy(() ->
|
||||
Awaitility.await()
|
||||
.atMost(Duration.ofMillis(100))
|
||||
.until(() -> worker.getProcessedCount() == 10)
|
||||
).isInstanceOf(ConditionTimeoutException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Async Error Handling
|
||||
|
||||
### Handle Exceptions in Async Methods
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class DataProcessingService {
|
||||
|
||||
@Async
|
||||
public CompletableFuture<Boolean> processDataAsync(String data) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
if (data == null || data.isEmpty()) {
|
||||
throw new IllegalArgumentException("Data cannot be empty");
|
||||
}
|
||||
// Process data
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@Async
|
||||
public CompletableFuture<String> safeFetchData(String id) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
try {
|
||||
return fetchData(id);
|
||||
} catch (Exception e) {
|
||||
return "Error: " + e.getMessage();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
class AsyncErrorHandlingTest {
|
||||
|
||||
@Test
|
||||
void shouldPropagateExceptionFromAsyncMethod() {
|
||||
DataProcessingService service = new DataProcessingService();
|
||||
|
||||
CompletableFuture<Boolean> result = service.processDataAsync(null);
|
||||
|
||||
assertThatThrownBy(result::get)
|
||||
.isInstanceOf(ExecutionException.class)
|
||||
.hasCauseInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("Data cannot be empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleExceptionGracefullyWithFallback() throws Exception {
|
||||
DataProcessingService service = new DataProcessingService();
|
||||
|
||||
CompletableFuture<String> result = service.safeFetchData("invalid");
|
||||
|
||||
String message = result.get();
|
||||
assertThat(message).startsWith("Error:");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Scheduled Task Timing
|
||||
|
||||
### Test Schedule Configuration
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class HealthCheckTask {
|
||||
|
||||
private final HealthCheckService healthCheckService;
|
||||
private int executionCount = 0;
|
||||
|
||||
public HealthCheckTask(HealthCheckService healthCheckService) {
|
||||
this.healthCheckService = healthCheckService;
|
||||
}
|
||||
|
||||
@Scheduled(fixedRate = 5000) // Every 5 seconds
|
||||
public void checkHealth() {
|
||||
executionCount++;
|
||||
healthCheckService.check();
|
||||
}
|
||||
|
||||
public int getExecutionCount() {
|
||||
return executionCount;
|
||||
}
|
||||
}
|
||||
|
||||
class ScheduledTaskTimingTest {
|
||||
|
||||
@Test
|
||||
void shouldExecuteTaskMultipleTimes() {
|
||||
HealthCheckService mockService = mock(HealthCheckService.class);
|
||||
HealthCheckTask task = new HealthCheckTask(mockService);
|
||||
|
||||
// Execute manually multiple times
|
||||
task.checkHealth();
|
||||
task.checkHealth();
|
||||
task.checkHealth();
|
||||
|
||||
assertThat(task.getExecutionCount()).isEqualTo(3);
|
||||
verify(mockService, times(3)).check();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Test async method logic directly** without Spring async executor
|
||||
- **Use CompletableFuture.get()** to wait for results in tests
|
||||
- **Mock dependencies** that async methods use
|
||||
- **Test error paths** for async operations
|
||||
- **Use Awaitility** when testing actual async behavior is needed
|
||||
- **Mock scheduled tasks** by calling methods directly in tests
|
||||
- **Verify task execution count** for testing scheduling logic
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Testing with actual @Async executor (use direct method calls instead)
|
||||
- Not waiting for CompletableFuture completion in tests
|
||||
- Forgetting to test exception handling in async methods
|
||||
- Not mocking dependencies that async methods call
|
||||
- Trying to test actual scheduling timing (test logic instead)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**CompletableFuture hangs in test**: Ensure methods complete or set timeout with `.get(timeout, unit)`.
|
||||
|
||||
**Async method not executing**: Call method directly instead of relying on @Async in tests.
|
||||
|
||||
**Awaitility timeout**: Increase timeout duration or reduce polling interval.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring @Async Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/scheduling/annotation/Async.html)
|
||||
- [Spring @Scheduled Documentation](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/scheduling/annotation/Scheduled.html)
|
||||
- [Awaitility Testing Library](https://github.com/awaitility/awaitility)
|
||||
- [CompletableFuture API](https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/CompletableFuture.html)
|
||||
476
skills/junit-test/unit-test-security-authorization/SKILL.md
Normal file
476
skills/junit-test/unit-test-security-authorization/SKILL.md
Normal file
@@ -0,0 +1,476 @@
|
||||
---
|
||||
name: unit-test-security-authorization
|
||||
description: Unit tests for Spring Security with @PreAuthorize, @Secured, @RolesAllowed. Test role-based access control and authorization policies. Use when validating security configurations and access control logic.
|
||||
category: testing
|
||||
tags: [junit-5, spring-security, authorization, roles, preauthorize, mockmvc]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Security and Authorization
|
||||
|
||||
Test Spring Security authorization logic using @PreAuthorize, @Secured, and custom permission evaluators. Verify access control decisions without full security infrastructure.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing @PreAuthorize and @Secured method-level security
|
||||
- Testing role-based access control (RBAC)
|
||||
- Testing custom permission evaluators
|
||||
- Verifying access denied scenarios
|
||||
- Testing authorization with authenticated principals
|
||||
- Want fast authorization tests without full Spring Security context
|
||||
|
||||
## Setup: Security Testing
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-security</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.security</groupId>
|
||||
<artifactId>spring-security-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
implementation("org.springframework.boot:spring-boot-starter-security")
|
||||
testImplementation("org.springframework.boot:spring-boot-starter-test")
|
||||
testImplementation("org.springframework.security:spring-security-test")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Testing @PreAuthorize
|
||||
|
||||
### Simple Role-Based Access Control
|
||||
|
||||
```java
|
||||
// Service with security annotations
|
||||
@Service
|
||||
public class UserService {
|
||||
|
||||
@PreAuthorize("hasRole('ADMIN')")
|
||||
public void deleteUser(Long userId) {
|
||||
// delete logic
|
||||
}
|
||||
|
||||
@PreAuthorize("hasRole('USER')")
|
||||
public User getCurrentUser() {
|
||||
// get user logic
|
||||
}
|
||||
|
||||
@PreAuthorize("hasAnyRole('ADMIN', 'MANAGER')")
|
||||
public List<User> listAllUsers() {
|
||||
// list logic
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.security.test.context.support.WithMockUser;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class UserServiceSecurityTest {
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "ADMIN")
|
||||
void shouldAllowAdminToDeleteUser() {
|
||||
UserService service = new UserService();
|
||||
|
||||
assertThatCode(() -> service.deleteUser(1L))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "USER")
|
||||
void shouldDenyUserFromDeletingUser() {
|
||||
UserService service = new UserService();
|
||||
|
||||
assertThatThrownBy(() -> service.deleteUser(1L))
|
||||
.isInstanceOf(AccessDeniedException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "ADMIN")
|
||||
void shouldAllowAdminAndManagerToListUsers() {
|
||||
UserService service = new UserService();
|
||||
|
||||
assertThatCode(() -> service.listAllUsers())
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDenyAnonymousUserAccess() {
|
||||
UserService service = new UserService();
|
||||
|
||||
assertThatThrownBy(() -> service.deleteUser(1L))
|
||||
.isInstanceOf(AccessDeniedException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing @Secured Annotation
|
||||
|
||||
### Legacy Security Configuration
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class OrderService {
|
||||
|
||||
@Secured("ROLE_ADMIN")
|
||||
public Order approveOrder(Long orderId) {
|
||||
// approval logic
|
||||
}
|
||||
|
||||
@Secured({"ROLE_ADMIN", "ROLE_MANAGER"})
|
||||
public List<Order> getOrders() {
|
||||
// get orders
|
||||
}
|
||||
}
|
||||
|
||||
class OrderSecurityTest {
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "ADMIN")
|
||||
void shouldAllowAdminToApproveOrder() {
|
||||
OrderService service = new OrderService();
|
||||
|
||||
assertThatCode(() -> service.approveOrder(1L))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "USER")
|
||||
void shouldDenyUserFromApprovingOrder() {
|
||||
OrderService service = new OrderService();
|
||||
|
||||
assertThatThrownBy(() -> service.approveOrder(1L))
|
||||
.isInstanceOf(AccessDeniedException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Controller Security with MockMvc
|
||||
|
||||
### Secure REST Endpoints
|
||||
|
||||
```java
|
||||
@RestController
|
||||
@RequestMapping("/api/admin")
|
||||
public class AdminController {
|
||||
|
||||
@GetMapping("/users")
|
||||
@PreAuthorize("hasRole('ADMIN')")
|
||||
public List<UserDto> listAllUsers() {
|
||||
// logic
|
||||
}
|
||||
|
||||
@DeleteMapping("/users/{id}")
|
||||
@PreAuthorize("hasRole('ADMIN')")
|
||||
public void deleteUser(@PathVariable Long id) {
|
||||
// delete logic
|
||||
}
|
||||
}
|
||||
|
||||
// Testing with MockMvc
|
||||
import org.springframework.security.test.context.support.WithMockUser;
|
||||
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
|
||||
|
||||
class AdminControllerSecurityTest {
|
||||
|
||||
private MockMvc mockMvc;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
mockMvc = MockMvcBuilders
|
||||
.standaloneSetup(new AdminController())
|
||||
.apply(springSecurity())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "ADMIN")
|
||||
void shouldAllowAdminToListUsers() throws Exception {
|
||||
mockMvc.perform(get("/api/admin/users"))
|
||||
.andExpect(status().isOk());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "USER")
|
||||
void shouldDenyUserFromListingUsers() throws Exception {
|
||||
mockMvc.perform(get("/api/admin/users"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDenyAnonymousAccessToAdminEndpoint() throws Exception {
|
||||
mockMvc.perform(get("/api/admin/users"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(roles = "ADMIN")
|
||||
void shouldAllowAdminToDeleteUser() throws Exception {
|
||||
mockMvc.perform(delete("/api/admin/users/1"))
|
||||
.andExpect(status().isOk());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Expression-Based Authorization
|
||||
|
||||
### Complex Permission Expressions
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class DocumentService {
|
||||
|
||||
@PreAuthorize("hasRole('ADMIN') or authentication.principal.username == #owner")
|
||||
public Document getDocument(String owner, Long docId) {
|
||||
// get document
|
||||
}
|
||||
|
||||
@PreAuthorize("hasPermission(#docId, 'Document', 'WRITE')")
|
||||
public void updateDocument(Long docId, String content) {
|
||||
// update logic
|
||||
}
|
||||
|
||||
@PreAuthorize("#userId == authentication.principal.id")
|
||||
public UserProfile getUserProfile(Long userId) {
|
||||
// get profile
|
||||
}
|
||||
}
|
||||
|
||||
class ExpressionBasedSecurityTest {
|
||||
|
||||
@Test
|
||||
@WithMockUser(username = "alice", roles = "ADMIN")
|
||||
void shouldAllowAdminToAccessAnyDocument() {
|
||||
DocumentService service = new DocumentService();
|
||||
|
||||
assertThatCode(() -> service.getDocument("bob", 1L))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(username = "alice")
|
||||
void shouldAllowOwnerToAccessOwnDocument() {
|
||||
DocumentService service = new DocumentService();
|
||||
|
||||
assertThatCode(() -> service.getDocument("alice", 1L))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(username = "alice")
|
||||
void shouldDenyUserAccessToOtherUserDocument() {
|
||||
DocumentService service = new DocumentService();
|
||||
|
||||
assertThatThrownBy(() -> service.getDocument("bob", 1L))
|
||||
.isInstanceOf(AccessDeniedException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(username = "alice", id = "1")
|
||||
void shouldAllowUserToAccessOwnProfile() {
|
||||
DocumentService service = new DocumentService();
|
||||
|
||||
assertThatCode(() -> service.getUserProfile(1L))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(username = "alice", id = "1")
|
||||
void shouldDenyUserAccessToOtherProfile() {
|
||||
DocumentService service = new DocumentService();
|
||||
|
||||
assertThatThrownBy(() -> service.getUserProfile(999L))
|
||||
.isInstanceOf(AccessDeniedException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Custom Permission Evaluator
|
||||
|
||||
### Create and Test Custom Permission Logic
|
||||
|
||||
```java
|
||||
// Custom permission evaluator
|
||||
@Component
|
||||
public class DocumentPermissionEvaluator implements PermissionEvaluator {
|
||||
|
||||
private final DocumentRepository documentRepository;
|
||||
|
||||
public DocumentPermissionEvaluator(DocumentRepository documentRepository) {
|
||||
this.documentRepository = documentRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) {
|
||||
if (authentication == null) return false;
|
||||
|
||||
Document document = (Document) targetDomainObject;
|
||||
String userUsername = authentication.getName();
|
||||
|
||||
return document.getOwner().getUsername().equals(userUsername) ||
|
||||
userHasRole(authentication, "ADMIN");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, Object permission) {
|
||||
if (authentication == null) return false;
|
||||
if (!"Document".equals(targetType)) return false;
|
||||
|
||||
Document document = documentRepository.findById((Long) targetId).orElse(null);
|
||||
if (document == null) return false;
|
||||
|
||||
return hasPermission(authentication, document, permission);
|
||||
}
|
||||
|
||||
private boolean userHasRole(Authentication authentication, String role) {
|
||||
return authentication.getAuthorities().stream()
|
||||
.anyMatch(auth -> auth.getAuthority().equals("ROLE_" + role));
|
||||
}
|
||||
}
|
||||
|
||||
// Unit test for custom evaluator
|
||||
class DocumentPermissionEvaluatorTest {
|
||||
|
||||
private DocumentPermissionEvaluator evaluator;
|
||||
private DocumentRepository documentRepository;
|
||||
private Authentication adminAuth;
|
||||
private Authentication userAuth;
|
||||
private Document document;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
documentRepository = mock(DocumentRepository.class);
|
||||
evaluator = new DocumentPermissionEvaluator(documentRepository);
|
||||
|
||||
document = new Document(1L, "Test Doc", new User("alice"));
|
||||
|
||||
adminAuth = new UsernamePasswordAuthenticationToken(
|
||||
"admin",
|
||||
null,
|
||||
List.of(new SimpleGrantedAuthority("ROLE_ADMIN"))
|
||||
);
|
||||
|
||||
userAuth = new UsernamePasswordAuthenticationToken(
|
||||
"alice",
|
||||
null,
|
||||
List.of(new SimpleGrantedAuthority("ROLE_USER"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldGrantPermissionToDocumentOwner() {
|
||||
boolean hasPermission = evaluator.hasPermission(userAuth, document, "WRITE");
|
||||
|
||||
assertThat(hasPermission).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDenyPermissionToNonOwner() {
|
||||
Authentication otherUserAuth = new UsernamePasswordAuthenticationToken(
|
||||
"bob",
|
||||
null,
|
||||
List.of(new SimpleGrantedAuthority("ROLE_USER"))
|
||||
);
|
||||
|
||||
boolean hasPermission = evaluator.hasPermission(otherUserAuth, document, "WRITE");
|
||||
|
||||
assertThat(hasPermission).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldGrantPermissionToAdmin() {
|
||||
boolean hasPermission = evaluator.hasPermission(adminAuth, document, "WRITE");
|
||||
|
||||
assertThat(hasPermission).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDenyNullAuthentication() {
|
||||
boolean hasPermission = evaluator.hasPermission(null, document, "WRITE");
|
||||
|
||||
assertThat(hasPermission).isFalse();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Multiple Roles
|
||||
|
||||
### Parameterized Role Testing
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
class RoleBasedAccessTest {
|
||||
|
||||
private AdminService service;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
service = new AdminService();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"ADMIN", "SUPER_ADMIN", "SYSTEM"})
|
||||
@WithMockUser(roles = "ADMIN")
|
||||
void shouldAllowPrivilegedRolesToDeleteUser(String role) {
|
||||
assertThatCode(() -> service.deleteUser(1L))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"USER", "GUEST", "READONLY"})
|
||||
void shouldDenyUnprivilegedRolesToDeleteUser(String role) {
|
||||
assertThatThrownBy(() -> service.deleteUser(1L))
|
||||
.isInstanceOf(AccessDeniedException.class);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use @WithMockUser** for setting authenticated user context
|
||||
- **Test both allow and deny cases** for each security rule
|
||||
- **Test with different roles** to verify role-based decisions
|
||||
- **Test expression-based security** comprehensively
|
||||
- **Mock external dependencies** (permission evaluators, etc.)
|
||||
- **Test anonymous access separately** from authenticated access
|
||||
- **Use @EnableGlobalMethodSecurity** in configuration for method-level security
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Forgetting to enable method security in test configuration
|
||||
- Not testing both allow and deny scenarios
|
||||
- Testing framework code instead of authorization logic
|
||||
- Not handling null authentication in tests
|
||||
- Mixing authentication and authorization tests unnecessarily
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**AccessDeniedException not thrown**: Ensure `@EnableGlobalMethodSecurity(prePostEnabled = true)` is configured.
|
||||
|
||||
**@WithMockUser not working**: Verify Spring Security test dependencies are on classpath.
|
||||
|
||||
**Custom PermissionEvaluator not invoked**: Check `@EnableGlobalMethodSecurity(securedEnabled = true, prePostEnabled = true)`.
|
||||
|
||||
## References
|
||||
|
||||
- [Spring Security Method Security](https://docs.spring.io/spring-security/site/docs/current/reference/html5/#jc-method)
|
||||
- [Spring Security Testing](https://docs.spring.io/spring-security/site/docs/current/reference/html5/#test)
|
||||
- [@WithMockUser Documentation](https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/test/context/support/WithMockUser.html)
|
||||
329
skills/junit-test/unit-test-service-layer/SKILL.md
Normal file
329
skills/junit-test/unit-test-service-layer/SKILL.md
Normal file
@@ -0,0 +1,329 @@
|
||||
---
|
||||
name: unit-test-service-layer
|
||||
description: Unit tests for service layer with Mockito. Test business logic in isolation by mocking dependencies. Use when validating service behaviors and business logic without database or external services.
|
||||
category: testing
|
||||
tags: [junit-5, mockito, unit-testing, service-layer, business-logic]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Service Layer with Mockito
|
||||
|
||||
Test @Service annotated classes by mocking all injected dependencies. Focus on business logic validation without starting the Spring container.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing business logic in @Service classes
|
||||
- Mocking repository and external client dependencies
|
||||
- Verifying service interactions with mocked collaborators
|
||||
- Testing complex workflows and orchestration logic
|
||||
- Want fast, isolated unit tests (no database, no API calls)
|
||||
- Testing error handling and edge cases in services
|
||||
|
||||
## Setup with Mockito and JUnit 5
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.mockito:mockito-core")
|
||||
testImplementation("org.mockito:mockito-junit-jupiter")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Service with Mocked Dependencies
|
||||
|
||||
### Single Dependency
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class UserServiceTest {
|
||||
|
||||
@Mock
|
||||
private UserRepository userRepository;
|
||||
|
||||
@InjectMocks
|
||||
private UserService userService;
|
||||
|
||||
@Test
|
||||
void shouldReturnAllUsers() {
|
||||
// Arrange
|
||||
List<User> expectedUsers = List.of(
|
||||
new User(1L, "Alice"),
|
||||
new User(2L, "Bob")
|
||||
);
|
||||
when(userRepository.findAll()).thenReturn(expectedUsers);
|
||||
|
||||
// Act
|
||||
List<User> result = userService.getAllUsers();
|
||||
|
||||
// Assert
|
||||
assertThat(result).hasSize(2);
|
||||
assertThat(result).containsExactly(
|
||||
new User(1L, "Alice"),
|
||||
new User(2L, "Bob")
|
||||
);
|
||||
verify(userRepository, times(1)).findAll();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Dependencies
|
||||
|
||||
```java
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class UserEnrichmentServiceTest {
|
||||
|
||||
@Mock
|
||||
private UserRepository userRepository;
|
||||
|
||||
@Mock
|
||||
private EmailService emailService;
|
||||
|
||||
@Mock
|
||||
private AnalyticsClient analyticsClient;
|
||||
|
||||
@InjectMocks
|
||||
private UserEnrichmentService enrichmentService;
|
||||
|
||||
@Test
|
||||
void shouldCreateUserAndSendWelcomeEmail() {
|
||||
User newUser = new User(1L, "Alice", "alice@example.com");
|
||||
when(userRepository.save(any(User.class))).thenReturn(newUser);
|
||||
doNothing().when(emailService).sendWelcomeEmail(newUser.getEmail());
|
||||
|
||||
User result = enrichmentService.registerNewUser("Alice", "alice@example.com");
|
||||
|
||||
assertThat(result.getId()).isEqualTo(1L);
|
||||
assertThat(result.getName()).isEqualTo("Alice");
|
||||
|
||||
verify(userRepository).save(any(User.class));
|
||||
verify(emailService).sendWelcomeEmail("alice@example.com");
|
||||
verify(analyticsClient, never()).trackUserRegistration(any());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Exception Handling
|
||||
|
||||
### Service Throws Expected Exception
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldThrowExceptionWhenUserNotFound() {
|
||||
when(userRepository.findById(999L))
|
||||
.thenThrow(new UserNotFoundException("User not found"));
|
||||
|
||||
assertThatThrownBy(() -> userService.getUserDetails(999L))
|
||||
.isInstanceOf(UserNotFoundException.class)
|
||||
.hasMessageContaining("User not found");
|
||||
|
||||
verify(userRepository).findById(999L);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRethrowRepositoryException() {
|
||||
when(userRepository.findAll())
|
||||
.thenThrow(new DataAccessException("Database connection failed"));
|
||||
|
||||
assertThatThrownBy(() -> userService.getAllUsers())
|
||||
.isInstanceOf(DataAccessException.class)
|
||||
.hasMessageContaining("Database connection failed");
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Complex Workflows
|
||||
|
||||
### Multiple Service Method Calls
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldTransferMoneyBetweenAccounts() {
|
||||
Account fromAccount = new Account(1L, 1000.0);
|
||||
Account toAccount = new Account(2L, 500.0);
|
||||
|
||||
when(accountRepository.findById(1L)).thenReturn(Optional.of(fromAccount));
|
||||
when(accountRepository.findById(2L)).thenReturn(Optional.of(toAccount));
|
||||
when(accountRepository.save(any(Account.class)))
|
||||
.thenAnswer(invocation -> invocation.getArgument(0));
|
||||
|
||||
moneyTransferService.transfer(1L, 2L, 200.0);
|
||||
|
||||
// Verify both accounts were updated
|
||||
verify(accountRepository, times(2)).save(any(Account.class));
|
||||
assertThat(fromAccount.getBalance()).isEqualTo(800.0);
|
||||
assertThat(toAccount.getBalance()).isEqualTo(700.0);
|
||||
}
|
||||
```
|
||||
|
||||
## Argument Capturing and Verification
|
||||
|
||||
### Capture Arguments Passed to Mock
|
||||
|
||||
```java
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
@Test
|
||||
void shouldCaptureUserDataWhenSaving() {
|
||||
ArgumentCaptor<User> userCaptor = ArgumentCaptor.forClass(User.class);
|
||||
when(userRepository.save(any(User.class)))
|
||||
.thenAnswer(invocation -> invocation.getArgument(0));
|
||||
|
||||
userService.createUser("Alice", "alice@example.com");
|
||||
|
||||
verify(userRepository).save(userCaptor.capture());
|
||||
User capturedUser = userCaptor.getValue();
|
||||
|
||||
assertThat(capturedUser.getName()).isEqualTo("Alice");
|
||||
assertThat(capturedUser.getEmail()).isEqualTo("alice@example.com");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCaptureMultipleArgumentsAcrossMultipleCalls() {
|
||||
ArgumentCaptor<User> userCaptor = ArgumentCaptor.forClass(User.class);
|
||||
|
||||
userService.createUser("Alice", "alice@example.com");
|
||||
userService.createUser("Bob", "bob@example.com");
|
||||
|
||||
verify(userRepository, times(2)).save(userCaptor.capture());
|
||||
|
||||
List<User> capturedUsers = userCaptor.getAllValues();
|
||||
assertThat(capturedUsers).hasSize(2);
|
||||
assertThat(capturedUsers.get(0).getName()).isEqualTo("Alice");
|
||||
assertThat(capturedUsers.get(1).getName()).isEqualTo("Bob");
|
||||
}
|
||||
```
|
||||
|
||||
## Verification Patterns
|
||||
|
||||
### Verify Call Order and Frequency
|
||||
|
||||
```java
|
||||
import org.mockito.InOrder;
|
||||
|
||||
@Test
|
||||
void shouldCallMethodsInCorrectOrder() {
|
||||
InOrder inOrder = inOrder(userRepository, emailService);
|
||||
|
||||
userService.registerNewUser("Alice", "alice@example.com");
|
||||
|
||||
inOrder.verify(userRepository).save(any(User.class));
|
||||
inOrder.verify(emailService).sendWelcomeEmail(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCallMethodExactlyOnce() {
|
||||
userService.getUserDetails(1L);
|
||||
|
||||
verify(userRepository, times(1)).findById(1L);
|
||||
verify(userRepository, never()).findAll();
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Async/Reactive Services
|
||||
|
||||
### Service with CompletableFuture
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldReturnCompletableFutureWhenFetchingAsyncData() {
|
||||
List<User> users = List.of(new User(1L, "Alice"));
|
||||
when(userRepository.findAllAsync())
|
||||
.thenReturn(CompletableFuture.completedFuture(users));
|
||||
|
||||
CompletableFuture<List<User>> result = userService.getAllUsersAsync();
|
||||
|
||||
assertThat(result).isCompletedWithValue(users);
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use @ExtendWith(MockitoExtension.class)** for JUnit 5 integration
|
||||
- **Construct service manually** instead of using reflection when possible
|
||||
- **Mock only direct dependencies** of the service under test
|
||||
- **Verify interactions** to ensure correct collaboration
|
||||
- **Use descriptive variable names**: `expectedUser`, `actualUser`, `captor`
|
||||
- **Test one behavior per test method** - keep tests focused
|
||||
- **Avoid testing framework code** - focus on business logic
|
||||
|
||||
## Common Patterns
|
||||
|
||||
**Partial Mock with Spy**:
|
||||
```java
|
||||
@Spy
|
||||
@InjectMocks
|
||||
private UserService userService; // Real instance, but can stub some methods
|
||||
|
||||
@Test
|
||||
void shouldUseRealMethodButMockDependency() {
|
||||
when(userRepository.findById(any())).thenReturn(Optional.of(new User()));
|
||||
// Calls real userService methods but userRepository is mocked
|
||||
}
|
||||
```
|
||||
|
||||
**Constructor Injection for Testing**:
|
||||
```java
|
||||
// In your service (production code)
|
||||
public class UserService {
|
||||
private final UserRepository userRepository;
|
||||
|
||||
public UserService(UserRepository userRepository) {
|
||||
this.repository = userRepository;
|
||||
}
|
||||
}
|
||||
|
||||
// In your test - can inject mocks directly
|
||||
@Test
|
||||
void test() {
|
||||
UserRepository mockRepo = mock(UserRepository.class);
|
||||
UserService service = new UserService(mockRepo);
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**UnfinishedStubbingException**: Ensure all `when()` calls are completed with `thenReturn()`, `thenThrow()`, or `thenAnswer()`.
|
||||
|
||||
**UnnecessaryStubbingException**: Remove unused stub definitions. Use `@ExtendWith(MockitoExtension.class)` with `MockitoExtension.LENIENT` if you intentionally have unused stubs.
|
||||
|
||||
**NullPointerException in test**: Verify `@InjectMocks` correctly injects all mocked dependencies into the service constructor.
|
||||
|
||||
## References
|
||||
|
||||
- [Mockito Documentation](https://javadoc.io/doc/org.mockito/mockito-core/latest/org/mockito/Mockito.html)
|
||||
- [JUnit 5 User Guide](https://junit.org/junit5/docs/current/user-guide/)
|
||||
- [AssertJ Assertions](https://assertj.github.io/assertj-core-features-highlight.html)
|
||||
389
skills/junit-test/unit-test-utility-methods/SKILL.md
Normal file
389
skills/junit-test/unit-test-utility-methods/SKILL.md
Normal file
@@ -0,0 +1,389 @@
|
||||
---
|
||||
name: unit-test-utility-methods
|
||||
description: Unit tests for utility/helper classes and static methods. Test pure functions and helper logic. Use when validating utility code correctness.
|
||||
category: testing
|
||||
tags: [junit-5, unit-testing, utility, static-methods, pure-functions]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing Utility Classes and Static Methods
|
||||
|
||||
Test static utility methods using JUnit 5. Focus on pure functions without side effects, edge cases, and boundary conditions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing utility classes with static helper methods
|
||||
- Testing pure functions with no state or side effects
|
||||
- Testing string manipulation and formatting utilities
|
||||
- Testing calculation and conversion utilities
|
||||
- Testing collections and array utilities
|
||||
- Want simple, fast tests without mocking complexity
|
||||
- Testing data transformation and validation helpers
|
||||
|
||||
## Basic Pattern: Static Utility Testing
|
||||
|
||||
### Simple String Utility
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
|
||||
class StringUtilsTest {
|
||||
|
||||
@Test
|
||||
void shouldCapitalizeFirstLetter() {
|
||||
String result = StringUtils.capitalize("hello");
|
||||
assertThat(result).isEqualTo("Hello");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleEmptyString() {
|
||||
String result = StringUtils.capitalize("");
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullInput() {
|
||||
String result = StringUtils.capitalize(null);
|
||||
assertThat(result).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleSingleCharacter() {
|
||||
String result = StringUtils.capitalize("a");
|
||||
assertThat(result).isEqualTo("A");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldNotChangePascalCase() {
|
||||
String result = StringUtils.capitalize("Hello");
|
||||
assertThat(result).isEqualTo("Hello");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Null Handling
|
||||
|
||||
### Null-Safe Utility Methods
|
||||
|
||||
```java
|
||||
class NullSafeUtilsTest {
|
||||
|
||||
@Test
|
||||
void shouldReturnDefaultValueWhenNull() {
|
||||
Object result = NullSafeUtils.getOrDefault(null, "default");
|
||||
assertThat(result).isEqualTo("default");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnValueWhenNotNull() {
|
||||
Object result = NullSafeUtils.getOrDefault("value", "default");
|
||||
assertThat(result).isEqualTo("value");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnFalseWhenStringIsNull() {
|
||||
boolean result = NullSafeUtils.isNotBlank(null);
|
||||
assertThat(result).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnTrueWhenStringHasContent() {
|
||||
boolean result = NullSafeUtils.isNotBlank(" text ");
|
||||
assertThat(result).isTrue();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Calculations and Conversions
|
||||
|
||||
### Math Utilities
|
||||
|
||||
```java
|
||||
class MathUtilsTest {
|
||||
|
||||
@Test
|
||||
void shouldCalculatePercentage() {
|
||||
double result = MathUtils.percentage(25, 100);
|
||||
assertThat(result).isEqualTo(25.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleZeroDivisor() {
|
||||
double result = MathUtils.percentage(50, 0);
|
||||
assertThat(result).isZero();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRoundToTwoDecimalPlaces() {
|
||||
double result = MathUtils.round(3.14159, 2);
|
||||
assertThat(result).isEqualTo(3.14);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNegativeNumbers() {
|
||||
int result = MathUtils.absoluteValue(-42);
|
||||
assertThat(result).isEqualTo(42);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Collection Utilities
|
||||
|
||||
### List/Set/Map Operations
|
||||
|
||||
```java
|
||||
class CollectionUtilsTest {
|
||||
|
||||
@Test
|
||||
void shouldFilterList() {
|
||||
List<Integer> numbers = List.of(1, 2, 3, 4, 5);
|
||||
List<Integer> evenNumbers = CollectionUtils.filter(numbers, n -> n % 2 == 0);
|
||||
assertThat(evenNumbers).containsExactly(2, 4);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnEmptyListWhenNoMatches() {
|
||||
List<Integer> numbers = List.of(1, 3, 5);
|
||||
List<Integer> evenNumbers = CollectionUtils.filter(numbers, n -> n % 2 == 0);
|
||||
assertThat(evenNumbers).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNullList() {
|
||||
List<Integer> result = CollectionUtils.filter(null, n -> true);
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldJoinStringsWithSeparator() {
|
||||
String result = CollectionUtils.join(List.of("a", "b", "c"), "-");
|
||||
assertThat(result).isEqualTo("a-b-c");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleEmptyList() {
|
||||
String result = CollectionUtils.join(List.of(), "-");
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDeduplicateList() {
|
||||
List<String> input = List.of("apple", "banana", "apple", "cherry", "banana");
|
||||
Set<String> unique = CollectionUtils.deduplicate(input);
|
||||
assertThat(unique).containsExactlyInAnyOrder("apple", "banana", "cherry");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing String Transformations
|
||||
|
||||
### Format and Parse Utilities
|
||||
|
||||
```java
|
||||
class FormatUtilsTest {
|
||||
|
||||
@Test
|
||||
void shouldFormatCurrencyWithSymbol() {
|
||||
String result = FormatUtils.formatCurrency(1234.56);
|
||||
assertThat(result).isEqualTo("$1,234.56");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleNegativeCurrency() {
|
||||
String result = FormatUtils.formatCurrency(-100.00);
|
||||
assertThat(result).isEqualTo("-$100.00");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldParsePhoneNumber() {
|
||||
String result = FormatUtils.parsePhoneNumber("5551234567");
|
||||
assertThat(result).isEqualTo("(555) 123-4567");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFormatDate() {
|
||||
LocalDate date = LocalDate.of(2024, 1, 15);
|
||||
String result = FormatUtils.formatDate(date, "yyyy-MM-dd");
|
||||
assertThat(result).isEqualTo("2024-01-15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldSluggifyString() {
|
||||
String result = FormatUtils.sluggify("Hello World! 123");
|
||||
assertThat(result).isEqualTo("hello-world-123");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Data Validation
|
||||
|
||||
### Validator Utilities
|
||||
|
||||
```java
|
||||
class ValidatorUtilsTest {
|
||||
|
||||
@Test
|
||||
void shouldValidateEmailFormat() {
|
||||
boolean valid = ValidatorUtils.isValidEmail("user@example.com");
|
||||
assertThat(valid).isTrue();
|
||||
|
||||
boolean invalid = ValidatorUtils.isValidEmail("invalid-email");
|
||||
assertThat(invalid).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldValidatePhoneNumber() {
|
||||
boolean valid = ValidatorUtils.isValidPhone("555-123-4567");
|
||||
assertThat(valid).isTrue();
|
||||
|
||||
boolean invalid = ValidatorUtils.isValidPhone("12345");
|
||||
assertThat(invalid).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldValidateUrlFormat() {
|
||||
boolean valid = ValidatorUtils.isValidUrl("https://example.com");
|
||||
assertThat(valid).isTrue();
|
||||
|
||||
boolean invalid = ValidatorUtils.isValidUrl("not a url");
|
||||
assertThat(invalid).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldValidateCreditCardNumber() {
|
||||
boolean valid = ValidatorUtils.isValidCreditCard("4532015112830366");
|
||||
assertThat(valid).isTrue();
|
||||
|
||||
boolean invalid = ValidatorUtils.isValidCreditCard("1234567890123456");
|
||||
assertThat(invalid).isFalse();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Parameterized Scenarios
|
||||
|
||||
### Multiple Test Cases with @ParameterizedTest
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
class StringUtilsParametrizedTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"", " ", "null", "undefined"})
|
||||
void shouldConsiderFalsyValuesAsEmpty(String input) {
|
||||
boolean result = StringUtils.isEmpty(input);
|
||||
assertThat(result).isTrue();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"hello,HELLO",
|
||||
"world,WORLD",
|
||||
"javaScript,JAVASCRIPT",
|
||||
"123ABC,123ABC"
|
||||
})
|
||||
void shouldConvertToUpperCase(String input, String expected) {
|
||||
String result = StringUtils.toUpperCase(input);
|
||||
assertThat(result).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing with Mockito for External Dependencies
|
||||
|
||||
### Utility with Dependency (Rare Case)
|
||||
|
||||
```java
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class DateUtilsTest {
|
||||
|
||||
@Mock
|
||||
private Clock clock;
|
||||
|
||||
@Test
|
||||
void shouldGetCurrentDateFromClock() {
|
||||
Instant fixedTime = Instant.parse("2024-01-15T10:30:00Z");
|
||||
when(clock.instant()).thenReturn(fixedTime);
|
||||
|
||||
LocalDate result = DateUtils.today(clock);
|
||||
|
||||
assertThat(result).isEqualTo(LocalDate.of(2024, 1, 15));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Edge Cases and Boundary Testing
|
||||
|
||||
```java
|
||||
class MathUtilsEdgeCaseTest {
|
||||
|
||||
@Test
|
||||
void shouldHandleMaxIntegerValue() {
|
||||
int result = MathUtils.increment(Integer.MAX_VALUE);
|
||||
assertThat(result).isEqualTo(Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleMinIntegerValue() {
|
||||
int result = MathUtils.decrement(Integer.MIN_VALUE);
|
||||
assertThat(result).isEqualTo(Integer.MIN_VALUE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleVeryLargeNumbers() {
|
||||
BigDecimal result = MathUtils.add(
|
||||
new BigDecimal("999999999999.99"),
|
||||
new BigDecimal("0.01")
|
||||
);
|
||||
assertThat(result).isEqualTo(new BigDecimal("1000000000000.00"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleFloatingPointPrecision() {
|
||||
double result = MathUtils.multiply(0.1, 0.2);
|
||||
assertThat(result).isCloseTo(0.02, within(0.0001));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Test pure functions exclusively** - no side effects or state
|
||||
- **Cover happy path and edge cases** - null, empty, extreme values
|
||||
- **Use descriptive test names** - clearly state what's being tested
|
||||
- **Keep tests simple and short** - utility tests should be quick to understand
|
||||
- **Use @ParameterizedTest** for testing multiple similar scenarios
|
||||
- **Avoid mocking when not needed** - only mock external dependencies
|
||||
- **Test boundary conditions** - min/max values, empty collections, null inputs
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- Testing framework behavior instead of utility logic
|
||||
- Over-mocking when pure functions need no mocks
|
||||
- Not testing null/empty edge cases
|
||||
- Not testing negative numbers and extreme values
|
||||
- Test methods too large - split complex scenarios
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Floating point precision issues**: Use `isCloseTo()` with delta instead of exact equality.
|
||||
|
||||
**Null handling inconsistency**: Decide whether utility returns null or throws exception, then test consistently.
|
||||
|
||||
**Complex utility logic belongs elsewhere**: Consider refactoring into testable units.
|
||||
|
||||
## References
|
||||
|
||||
- [JUnit 5 Parameterized Tests](https://junit.org/junit5/docs/current/user-guide/#writing-tests-parameterized-tests)
|
||||
- [AssertJ Assertions](https://assertj.github.io/assertj-core-features-highlight.html)
|
||||
- [Testing Edge Cases and Boundaries](https://www.baeldung.com/testing-properties-methods-using-mockito)
|
||||
170
skills/junit-test/unit-test-wiremock-rest-api/SKILL.md
Normal file
170
skills/junit-test/unit-test-wiremock-rest-api/SKILL.md
Normal file
@@ -0,0 +1,170 @@
|
||||
---
|
||||
name: unit-test-wiremock-rest-api
|
||||
description: Unit tests for external REST APIs using WireMock to mock HTTP endpoints. Use when testing service integrations with external APIs.
|
||||
category: testing
|
||||
tags: [junit-5, wiremock, unit-testing, rest-api, mocking, http-stubbing]
|
||||
version: 1.0.1
|
||||
---
|
||||
|
||||
# Unit Testing REST APIs with WireMock
|
||||
|
||||
Test interactions with third-party REST APIs without making real network calls using WireMock. This skill focuses on pure unit tests (no Spring context) that stub HTTP responses and verify requests.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Testing services that call external REST APIs
|
||||
- Need to stub HTTP responses for predictable test behavior
|
||||
- Want to test error scenarios (timeouts, 500 errors, malformed responses)
|
||||
- Need to verify request details (headers, query params, request body)
|
||||
- Integrating with third-party services (payment gateways, weather APIs, etc.)
|
||||
- Testing without network dependencies or rate limits
|
||||
- Building unit tests that run fast in CI/CD pipelines
|
||||
|
||||
## Core Dependencies
|
||||
|
||||
### Maven
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>org.wiremock</groupId>
|
||||
<artifactId>wiremock</artifactId>
|
||||
<version>3.4.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Gradle
|
||||
```kotlin
|
||||
dependencies {
|
||||
testImplementation("org.wiremock:wiremock:3.4.1")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.assertj:assertj-core")
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Pattern: Stubbing and Verifying
|
||||
|
||||
### Simple Stub with WireMock Extension
|
||||
|
||||
```java
|
||||
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import static com.github.tomakehurst.wiremock.client.WireMock.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class ExternalWeatherServiceTest {
|
||||
|
||||
@RegisterExtension
|
||||
static WireMockExtension wireMock = WireMockExtension.newInstance()
|
||||
.options(wireMockConfig().dynamicPort())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
void shouldFetchWeatherDataFromExternalApi() {
|
||||
wireMock.stubFor(get(urlEqualTo("/weather?city=London"))
|
||||
.withHeader("Accept", containing("application/json"))
|
||||
.willReturn(aResponse()
|
||||
.withStatus(200)
|
||||
.withHeader("Content-Type", "application/json")
|
||||
.withBody("{\"city\":\"London\",\"temperature\":15,\"condition\":\"Cloudy\"}")));
|
||||
|
||||
String baseUrl = wireMock.getRuntimeInfo().getHttpBaseUrl();
|
||||
WeatherApiClient client = new WeatherApiClient(baseUrl);
|
||||
WeatherData weather = client.getWeather("London");
|
||||
|
||||
assertThat(weather.getCity()).isEqualTo("London");
|
||||
assertThat(weather.getTemperature()).isEqualTo(15);
|
||||
|
||||
wireMock.verify(getRequestedFor(urlEqualTo("/weather?city=London"))
|
||||
.withHeader("Accept", containing("application/json")));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Error Scenarios
|
||||
|
||||
### Test 4xx and 5xx Responses
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldHandleNotFoundError() {
|
||||
wireMock.stubFor(get(urlEqualTo("/api/users/999"))
|
||||
.willReturn(aResponse()
|
||||
.withStatus(404)
|
||||
.withBody("{\"error\":\"User not found\"}")));
|
||||
|
||||
WeatherApiClient client = new WeatherApiClient(wireMock.getRuntimeInfo().getHttpBaseUrl());
|
||||
|
||||
assertThatThrownBy(() -> client.getUser(999))
|
||||
.isInstanceOf(UserNotFoundException.class)
|
||||
.hasMessageContaining("User not found");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRetryOnServerError() {
|
||||
wireMock.stubFor(get(urlEqualTo("/api/data"))
|
||||
.willReturn(aResponse()
|
||||
.withStatus(500)
|
||||
.withBody("{\"error\":\"Internal server error\"}")));
|
||||
|
||||
ApiClient client = new ApiClient(wireMock.getRuntimeInfo().getHttpBaseUrl());
|
||||
|
||||
assertThatThrownBy(() -> client.fetchData())
|
||||
.isInstanceOf(ServerErrorException.class);
|
||||
}
|
||||
```
|
||||
|
||||
## Request Verification
|
||||
|
||||
### Verify Request Details and Payload
|
||||
|
||||
```java
|
||||
@Test
|
||||
void shouldVerifyRequestBody() {
|
||||
wireMock.stubFor(post(urlEqualTo("/api/users"))
|
||||
.willReturn(aResponse()
|
||||
.withStatus(201)
|
||||
.withBody("{\"id\":123,\"name\":\"Alice\"}")));
|
||||
|
||||
ApiClient client = new ApiClient(wireMock.getRuntimeInfo().getHttpBaseUrl());
|
||||
UserResponse response = client.createUser("Alice");
|
||||
|
||||
assertThat(response.getId()).isEqualTo(123);
|
||||
|
||||
wireMock.verify(postRequestedFor(urlEqualTo("/api/users"))
|
||||
.withRequestBody(matchingJsonPath("$.name", equalTo("Alice")))
|
||||
.withHeader("Content-Type", containing("application/json")));
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Use dynamic port** to avoid port conflicts in parallel test execution
|
||||
- **Verify requests** to ensure correct API usage
|
||||
- **Test error scenarios** thoroughly
|
||||
- **Keep stubs focused** - one concern per test
|
||||
- **Reset WireMock** between tests automatically via `@RegisterExtension`
|
||||
- **Never call real APIs** - always stub third-party endpoints
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**WireMock not intercepting requests**: Ensure your HTTP client uses the stubbed URL from `wireMock.getRuntimeInfo().getHttpBaseUrl()`.
|
||||
|
||||
**Port conflicts**: Always use `wireMockConfig().dynamicPort()` to let WireMock choose available port.
|
||||
|
||||
## References
|
||||
|
||||
- [WireMock Official Documentation](https://wiremock.org/)
|
||||
- [WireMock Stubs and Mocking](https://wiremock.org/docs/stubbing/)
|
||||
- [JUnit 5 Extensions](https://junit.org/junit5/docs/current/user-guide/#extensions)
|
||||
145
skills/langchain4j/langchain4j-ai-services-patterns/SKILL.md
Normal file
145
skills/langchain4j/langchain4j-ai-services-patterns/SKILL.md
Normal file
@@ -0,0 +1,145 @@
|
||||
---
|
||||
name: langchain4j-ai-services-patterns
|
||||
description: Build declarative AI Services with LangChain4j using interface-based patterns, annotations, memory management, tools integration, and advanced application patterns. Use when implementing type-safe AI-powered features with minimal boilerplate code in Java applications.
|
||||
category: ai-development
|
||||
tags: [langchain4j, ai-services, annotations, declarative, tools, memory, function-calling, llm, java]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# LangChain4j AI Services Patterns
|
||||
|
||||
This skill provides guidance for building declarative AI Services with LangChain4j using interface-based patterns, annotations for system and user messages, memory management, tools integration, and advanced AI application patterns that abstract away low-level LLM interactions.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when:
|
||||
- Building declarative AI-powered interfaces with minimal boilerplate code
|
||||
- Creating type-safe AI services with Java interfaces and annotations
|
||||
- Implementing conversational AI systems with memory management
|
||||
- Designing AI services that can call external tools and functions
|
||||
- Building multi-agent systems with specialized AI components
|
||||
- Creating AI services with different personas and behaviors
|
||||
- Implementing RAG (Retrieval-Augmented Generation) patterns declaratively
|
||||
- Building production AI applications with proper error handling and validation
|
||||
- Creating AI services that return structured data types (enums, POJOs, lists)
|
||||
- Implementing streaming AI responses with reactive patterns
|
||||
|
||||
## Overview
|
||||
|
||||
LangChain4j AI Services allow you to define AI-powered functionality using plain Java interfaces with annotations, eliminating the need for manual prompt construction and response parsing. This pattern provides type-safe, declarative AI capabilities with minimal boilerplate code.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic AI Service Definition
|
||||
|
||||
```java
|
||||
interface Assistant {
|
||||
String chat(String userMessage);
|
||||
}
|
||||
|
||||
// Create instance - LangChain4j generates implementation
|
||||
Assistant assistant = AiServices.create(Assistant.class, chatModel);
|
||||
|
||||
// Use the service
|
||||
String response = assistant.chat("Hello, how are you?");
|
||||
```
|
||||
|
||||
### System Message and Templates
|
||||
|
||||
```java
|
||||
interface CustomerSupportBot {
|
||||
@SystemMessage("You are a helpful customer support agent for TechCorp")
|
||||
String handleInquiry(String customerMessage);
|
||||
|
||||
@UserMessage("Analyze sentiment: {{it}}")
|
||||
String analyzeSentiment(String feedback);
|
||||
}
|
||||
|
||||
CustomerSupportBot bot = AiServices.create(CustomerSupportBot.class, chatModel);
|
||||
```
|
||||
|
||||
### Memory Management
|
||||
|
||||
```java
|
||||
interface MultiUserAssistant {
|
||||
String chat(@MemoryId String userId, String userMessage);
|
||||
}
|
||||
|
||||
Assistant assistant = AiServices.builder(MultiUserAssistant.class)
|
||||
.chatModel(model)
|
||||
.chatMemoryProvider(userId -> MessageWindowChatMemory.withMaxMessages(10))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Tool Integration
|
||||
|
||||
```java
|
||||
class Calculator {
|
||||
@Tool("Add two numbers") double add(double a, double b) { return a + b; }
|
||||
}
|
||||
|
||||
interface MathGenius {
|
||||
String ask(String question);
|
||||
}
|
||||
|
||||
MathGenius mathGenius = AiServices.builder(MathGenius.class)
|
||||
.chatModel(model)
|
||||
.tools(new Calculator())
|
||||
.build();
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See [examples.md](references/examples.md) for comprehensive practical examples including:
|
||||
- Basic chat interfaces
|
||||
- Stateful assistants with memory
|
||||
- Multi-user scenarios
|
||||
- Structured output extraction
|
||||
- Tool calling and function execution
|
||||
- Streaming responses
|
||||
- Error handling
|
||||
- RAG integration
|
||||
- Production patterns
|
||||
|
||||
## API Reference
|
||||
|
||||
Complete API documentation, annotations, interfaces, and configuration patterns are available in [references.md](references/references.md).
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use type-safe interfaces** instead of string-based prompts
|
||||
2. **Implement proper memory management** with appropriate limits
|
||||
3. **Design clear tool descriptions** with parameter documentation
|
||||
4. **Handle errors gracefully** with custom error handlers
|
||||
5. **Use structured output** for predictable responses
|
||||
6. **Implement validation** for user inputs
|
||||
7. **Monitor performance** for production deployments
|
||||
|
||||
## Dependencies
|
||||
|
||||
```xml
|
||||
<!-- Maven -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>1.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>1.8.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
```gradle
|
||||
// Gradle
|
||||
implementation 'dev.langchain4j:langchain4j:1.8.0'
|
||||
implementation 'dev.langchain4j:langchain4j-open-ai:1.8.0'
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [LangChain4j Documentation](https://langchain4j.com/docs/)
|
||||
- [LangChain4j AI Services - API References](references/references.md)
|
||||
- [LangChain4j AI Services - Practical Examples](references/examples.md)
|
||||
@@ -0,0 +1,534 @@
|
||||
# LangChain4j AI Services - Practical Examples
|
||||
|
||||
This document provides practical, production-ready examples for LangChain4j AI Services patterns.
|
||||
|
||||
## 1. Basic Chat Interface
|
||||
|
||||
**Scenario**: Simple conversational interface without memory.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.UserMessage;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
interface SimpleChat {
|
||||
String chat(String userMessage);
|
||||
}
|
||||
|
||||
public class BasicChatExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.temperature(0.7)
|
||||
.build();
|
||||
|
||||
var chat = AiServices.builder(SimpleChat.class)
|
||||
.chatModel(chatModel)
|
||||
.build();
|
||||
|
||||
String response = chat.chat("What is Spring Boot?");
|
||||
System.out.println(response);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2. Stateful Assistant with Memory
|
||||
|
||||
**Scenario**: Multi-turn conversation with 10-message history.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
interface ConversationalAssistant {
|
||||
String chat(String userMessage);
|
||||
}
|
||||
|
||||
public class StatefulAssistantExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(ConversationalAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||
.build();
|
||||
|
||||
// Multi-turn conversation
|
||||
System.out.println(assistant.chat("My name is Alice"));
|
||||
System.out.println(assistant.chat("What is my name?")); // Remembers: "Your name is Alice"
|
||||
System.out.println(assistant.chat("What year was Spring Boot released?")); // Answers: "2014"
|
||||
System.out.println(assistant.chat("Tell me more about it")); // Context aware
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Multi-User Memory with @MemoryId
|
||||
|
||||
**Scenario**: Separate conversation history per user.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.MemoryId;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
interface MultiUserAssistant {
|
||||
String chat(@MemoryId int userId, String userMessage);
|
||||
}
|
||||
|
||||
public class MultiUserMemoryExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(MultiUserAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(20))
|
||||
.build();
|
||||
|
||||
// User 1 conversation
|
||||
System.out.println(assistant.chat(1, "I like Java"));
|
||||
System.out.println(assistant.chat(1, "What language do I prefer?")); // Java
|
||||
|
||||
// User 2 conversation - separate memory
|
||||
System.out.println(assistant.chat(2, "I prefer Python"));
|
||||
System.out.println(assistant.chat(2, "What language do I prefer?")); // Python
|
||||
|
||||
// User 1 - still remembers Java
|
||||
System.out.println(assistant.chat(1, "What about me?")); // Java
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 4. System Message & Template Variables
|
||||
|
||||
**Scenario**: Configurable system prompt with dynamic template variables.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.SystemMessage;
|
||||
import dev.langchain4j.service.UserMessage;
|
||||
import dev.langchain4j.service.V;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
interface TemplatedAssistant {
|
||||
|
||||
@SystemMessage("You are a {{role}} expert. Be concise and professional.")
|
||||
String chat(@V("role") String role, String userMessage);
|
||||
|
||||
@SystemMessage("You are a helpful assistant. Translate to {{language}}")
|
||||
@UserMessage("Translate this: {{text}}")
|
||||
String translate(@V("text") String text, @V("language") String language);
|
||||
}
|
||||
|
||||
public class TemplatedAssistantExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.temperature(0.3)
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.create(TemplatedAssistant.class, chatModel);
|
||||
|
||||
// Dynamic role
|
||||
System.out.println(assistant.chat("Java", "Explain dependency injection"));
|
||||
System.out.println(assistant.chat("DevOps", "Explain Docker containers"));
|
||||
|
||||
// Translation with template
|
||||
System.out.println(assistant.translate("Hello, how are you?", "Spanish"));
|
||||
System.out.println(assistant.translate("Good morning", "French"));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 5. Structured Output Extraction
|
||||
|
||||
**Scenario**: Extract structured data (POJO, enum, list) from LLM responses.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.UserMessage;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import java.util.List;
|
||||
|
||||
enum Sentiment {
|
||||
POSITIVE, NEGATIVE, NEUTRAL
|
||||
}
|
||||
|
||||
class ContactInfo {
|
||||
@Description("Person's full name")
|
||||
String fullName;
|
||||
|
||||
@Description("Email address")
|
||||
String email;
|
||||
|
||||
@Description("Phone number with country code")
|
||||
String phone;
|
||||
}
|
||||
|
||||
interface DataExtractor {
|
||||
|
||||
@UserMessage("Analyze sentiment: {{text}}")
|
||||
Sentiment extractSentiment(String text);
|
||||
|
||||
@UserMessage("Extract contact from: {{text}}")
|
||||
ContactInfo extractContact(String text);
|
||||
|
||||
@UserMessage("List all technologies in: {{text}}")
|
||||
List<String> extractTechnologies(String text);
|
||||
|
||||
@UserMessage("Count items in: {{text}}")
|
||||
int countItems(String text);
|
||||
}
|
||||
|
||||
public class StructuredOutputExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.responseFormat("json_object")
|
||||
.build();
|
||||
|
||||
var extractor = AiServices.create(DataExtractor.class, chatModel);
|
||||
|
||||
// Enum extraction
|
||||
Sentiment sentiment = extractor.extractSentiment("This product is amazing!");
|
||||
System.out.println("Sentiment: " + sentiment); // POSITIVE
|
||||
|
||||
// POJO extraction
|
||||
ContactInfo contact = extractor.extractContact(
|
||||
"John Smith, john@example.com, +1-555-1234");
|
||||
System.out.println("Name: " + contact.fullName);
|
||||
System.out.println("Email: " + contact.email);
|
||||
|
||||
// List extraction
|
||||
List<String> techs = extractor.extractTechnologies(
|
||||
"We use Java, Spring Boot, PostgreSQL, and Docker");
|
||||
System.out.println("Technologies: " + techs); // [Java, Spring Boot, PostgreSQL, Docker]
|
||||
|
||||
// Primitive type
|
||||
int count = extractor.countItems("I have 3 apples, 5 oranges, and 2 bananas");
|
||||
System.out.println("Total items: " + count); // 10
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 6. Tool Calling / Function Calling
|
||||
|
||||
**Scenario**: LLM calls Java methods to solve problems.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.agent.tool.Tool;
|
||||
import dev.langchain4j.agent.tool.P;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import java.time.LocalDate;
|
||||
|
||||
class Calculator {
|
||||
@Tool("Add two numbers")
|
||||
int add(@P("first number") int a, @P("second number") int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
@Tool("Multiply two numbers")
|
||||
int multiply(@P("first") int a, @P("second") int b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
|
||||
class WeatherService {
|
||||
@Tool("Get weather for a city")
|
||||
String getWeather(@P("city name") String city) {
|
||||
// Simulate API call
|
||||
return "Weather in " + city + ": 22°C, Sunny";
|
||||
}
|
||||
}
|
||||
|
||||
class DateService {
|
||||
@Tool("Get current date")
|
||||
String getCurrentDate() {
|
||||
return LocalDate.now().toString();
|
||||
}
|
||||
}
|
||||
|
||||
interface ToolUsingAssistant {
|
||||
String chat(String userMessage);
|
||||
}
|
||||
|
||||
public class ToolCallingExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.temperature(0.0)
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(ToolUsingAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||
.tools(new Calculator(), new WeatherService(), new DateService())
|
||||
.build();
|
||||
|
||||
// LLM calls tools automatically
|
||||
System.out.println(assistant.chat("What is 25 + 37?"));
|
||||
// Uses Calculator.add() → "25 + 37 equals 62"
|
||||
|
||||
System.out.println(assistant.chat("What's the weather in Paris?"));
|
||||
// Uses WeatherService.getWeather() → "Weather in Paris: 22°C, Sunny"
|
||||
|
||||
System.out.println(assistant.chat("Calculate (5 + 3) * 4"));
|
||||
// Uses add() and multiply() → "Result is 32"
|
||||
|
||||
System.out.println(assistant.chat("What's today's date?"));
|
||||
// Uses getCurrentDate() → Shows current date
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 7. Streaming Responses
|
||||
|
||||
**Scenario**: Real-time token-by-token streaming for UI responsiveness.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.TokenStream;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
|
||||
interface StreamingAssistant {
|
||||
TokenStream streamChat(String userMessage);
|
||||
}
|
||||
|
||||
public class StreamingExample {
|
||||
public static void main(String[] args) {
|
||||
var streamingModel = OpenAiStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.temperature(0.7)
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(StreamingAssistant.class)
|
||||
.streamingChatModel(streamingModel)
|
||||
.build();
|
||||
|
||||
// Stream response token by token
|
||||
assistant.streamChat("Tell me a short story about a robot")
|
||||
.onNext(token -> System.out.print(token)) // Print each token
|
||||
.onCompleteResponse(response -> {
|
||||
System.out.println("\n--- Complete ---");
|
||||
System.out.println("Tokens used: " + response.tokenUsage().totalTokenCount());
|
||||
})
|
||||
.onError(error -> System.err.println("Error: " + error.getMessage()))
|
||||
.start();
|
||||
|
||||
// Wait for completion
|
||||
try {
|
||||
Thread.sleep(5000);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 8. System Persona with Context
|
||||
|
||||
**Scenario**: Different assistants with distinct personalities and knowledge domains.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.SystemMessage;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
interface JavaExpert {
|
||||
@SystemMessage("""
|
||||
You are a Java expert with 15+ years experience.
|
||||
Focus on best practices, performance, and clean code.
|
||||
Provide code examples when relevant.
|
||||
""")
|
||||
String answer(String question);
|
||||
}
|
||||
|
||||
interface SecurityExpert {
|
||||
@SystemMessage("""
|
||||
You are a cybersecurity expert specializing in application security.
|
||||
Always consider OWASP principles and threat modeling.
|
||||
Provide practical security recommendations.
|
||||
""")
|
||||
String answer(String question);
|
||||
}
|
||||
|
||||
interface DevOpsExpert {
|
||||
@SystemMessage("""
|
||||
You are a DevOps engineer with expertise in cloud deployment,
|
||||
CI/CD pipelines, containerization, and infrastructure as code.
|
||||
""")
|
||||
String answer(String question);
|
||||
}
|
||||
|
||||
public class PersonaExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.temperature(0.5)
|
||||
.build();
|
||||
|
||||
var javaExpert = AiServices.create(JavaExpert.class, chatModel);
|
||||
var securityExpert = AiServices.create(SecurityExpert.class, chatModel);
|
||||
var devopsExpert = AiServices.create(DevOpsExpert.class, chatModel);
|
||||
|
||||
var question = "How should I handle database connections?";
|
||||
|
||||
System.out.println("=== Java Expert ===");
|
||||
System.out.println(javaExpert.answer(question));
|
||||
|
||||
System.out.println("\n=== Security Expert ===");
|
||||
System.out.println(securityExpert.answer(question));
|
||||
|
||||
System.out.println("\n=== DevOps Expert ===");
|
||||
System.out.println(devopsExpert.answer(question));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 9. Error Handling & Tool Execution Errors
|
||||
|
||||
**Scenario**: Graceful handling of tool failures and LLM errors.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.agent.tool.Tool;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
class DataAccessService {
|
||||
@Tool("Query database for user")
|
||||
String queryUser(String userId) {
|
||||
// Simulate potential error
|
||||
if (!userId.matches("\\d+")) {
|
||||
throw new IllegalArgumentException("Invalid user ID format");
|
||||
}
|
||||
return "User " + userId + ": John Doe";
|
||||
}
|
||||
|
||||
@Tool("Update user email")
|
||||
String updateEmail(String userId, String email) {
|
||||
if (!email.contains("@")) {
|
||||
throw new IllegalArgumentException("Invalid email format");
|
||||
}
|
||||
return "Updated email for user " + userId;
|
||||
}
|
||||
}
|
||||
|
||||
interface ResilientAssistant {
|
||||
String execute(String command);
|
||||
}
|
||||
|
||||
public class ErrorHandlingExample {
|
||||
public static void main(String[] args) {
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(ResilientAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.tools(new DataAccessService())
|
||||
.toolExecutionErrorHandler((request, exception) -> {
|
||||
System.err.println("Tool error: " + exception.getMessage());
|
||||
return "Error: " + exception.getMessage();
|
||||
})
|
||||
.build();
|
||||
|
||||
// Will handle tool errors gracefully
|
||||
System.out.println(assistant.execute("Get details for user abc"));
|
||||
System.out.println(assistant.execute("Update user 123 with invalid-email"));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 10. RAG Integration with AI Services
|
||||
|
||||
**Scenario**: AI Service with content retrieval for knowledge-based Q&A.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
interface KnowledgeBaseAssistant {
|
||||
String askAbout(String question);
|
||||
}
|
||||
|
||||
public class RAGIntegrationExample {
|
||||
public static void main(String[] args) {
|
||||
// Setup embedding store
|
||||
var embeddingStore = new InMemoryEmbeddingStore<TextSegment>();
|
||||
|
||||
// Setup models
|
||||
var embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("text-embedding-3-small")
|
||||
.build();
|
||||
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.build();
|
||||
|
||||
// Ingest documents
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.build();
|
||||
|
||||
ingestor.ingest(Document.from("Spring Boot is a framework for building Java applications."));
|
||||
ingestor.ingest(Document.from("Spring Data JPA simplifies database access."));
|
||||
|
||||
// Create retriever
|
||||
var contentRetriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(3)
|
||||
.minScore(0.7)
|
||||
.build();
|
||||
|
||||
// Create AI Service with RAG
|
||||
var assistant = AiServices.builder(KnowledgeBaseAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.contentRetriever(contentRetriever)
|
||||
.build();
|
||||
|
||||
String answer = assistant.askAbout("What is Spring Boot?");
|
||||
System.out.println(answer);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices Summary
|
||||
|
||||
1. **Always use @SystemMessage** for consistent behavior across different messages
|
||||
2. **Enable temperature=0** for deterministic tasks (extraction, calculations)
|
||||
3. **Use MessageWindowChatMemory** for conversation history management
|
||||
4. **Implement error handling** for tool failures
|
||||
5. **Use structured output** when you need typed responses
|
||||
6. **Stream long responses** for better UX
|
||||
7. **Use @MemoryId** for multi-user scenarios
|
||||
8. **Template variables** for dynamic system prompts
|
||||
9. **Tool descriptions** should be clear and actionable
|
||||
10. **Always validate** tool parameters before execution
|
||||
@@ -0,0 +1,433 @@
|
||||
# LangChain4j AI Services - API References
|
||||
|
||||
Complete API reference for LangChain4j AI Services patterns.
|
||||
|
||||
## Core Interfaces and Classes
|
||||
|
||||
### AiServices Builder
|
||||
|
||||
**Purpose**: Creates implementations of custom Java interfaces backed by LLM capabilities.
|
||||
|
||||
```java
|
||||
public class AiServices {
|
||||
|
||||
static <T> AiServicesBuilder<T> builder(Class<T> aiService)
|
||||
// Create builder for an AI service interface
|
||||
|
||||
static <T> T create(Class<T> aiService, ChatModel chatModel)
|
||||
// Quick creation with just chat model
|
||||
|
||||
static <T> T builder(Class<T> aiService)
|
||||
.chatModel(ChatModel chatModel) // Required for sync
|
||||
.streamingChatModel(StreamingChatModel) // Required for streaming
|
||||
.chatMemory(ChatMemory) // Single shared memory
|
||||
.chatMemoryProvider(ChatMemoryProvider) // Per-user memory
|
||||
.tools(Object... tools) // Register tool objects
|
||||
.toolProvider(ToolProvider) // Dynamic tool selection
|
||||
.contentRetriever(ContentRetriever) // For RAG
|
||||
.retrievalAugmentor(RetrievalAugmentor) // Advanced RAG
|
||||
.moderationModel(ModerationModel) // Content moderation
|
||||
.build() // Build the implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Core Annotations
|
||||
|
||||
**@SystemMessage**: Define system prompt for the AI service.
|
||||
```java
|
||||
@SystemMessage("You are a helpful Java developer")
|
||||
String chat(String userMessage);
|
||||
|
||||
// Template variables
|
||||
@SystemMessage("You are a {{expertise}} expert")
|
||||
String explain(@V("expertise") String domain, String question);
|
||||
```
|
||||
|
||||
**@UserMessage**: Define user message template.
|
||||
```java
|
||||
@UserMessage("Translate to {{language}}: {{text}}")
|
||||
String translate(@V("language") String lang, @V("text") String text);
|
||||
|
||||
// With method parameters matching template
|
||||
@UserMessage("Summarize: {{it}}")
|
||||
String summarize(String text); // {{it}} refers to parameter
|
||||
```
|
||||
|
||||
**@MemoryId**: Create separate memory context per identifier.
|
||||
```java
|
||||
interface MultiUserChat {
|
||||
String chat(@MemoryId String userId, String message);
|
||||
String chat(@MemoryId int sessionId, String message);
|
||||
}
|
||||
```
|
||||
|
||||
**@V**: Map method parameter to template variable.
|
||||
```java
|
||||
@UserMessage("Write {{type}} code for {{language}}")
|
||||
String writeCode(@V("type") String codeType, @V("language") String lang);
|
||||
```
|
||||
|
||||
### ChatMemory Implementations
|
||||
|
||||
**MessageWindowChatMemory**: Keeps last N messages.
|
||||
```java
|
||||
ChatMemory memory = MessageWindowChatMemory.withMaxMessages(10);
|
||||
// Or with explicit builder
|
||||
ChatMemory memory = MessageWindowChatMemory.builder()
|
||||
.maxMessages(10)
|
||||
.build();
|
||||
```
|
||||
|
||||
**ChatMemoryProvider**: Factory for creating per-user memory.
|
||||
```java
|
||||
ChatMemoryProvider provider = memoryId ->
|
||||
MessageWindowChatMemory.withMaxMessages(20);
|
||||
```
|
||||
|
||||
### Tool Integration
|
||||
|
||||
**@Tool**: Mark methods that LLM can call.
|
||||
```java
|
||||
@Tool("Calculate sum of two numbers")
|
||||
int add(@P("first number") int a, @P("second number") int b) {
|
||||
return a + b;
|
||||
}
|
||||
```
|
||||
|
||||
**@P**: Parameter description for LLM.
|
||||
```java
|
||||
@Tool("Search documents")
|
||||
List<Document> search(
|
||||
@P("search query") String query,
|
||||
@P("max results") int limit
|
||||
) { ... }
|
||||
```
|
||||
|
||||
**ToolProvider**: Dynamic tool selection based on context.
|
||||
```java
|
||||
interface DynamicToolAssistant {
|
||||
String execute(String command);
|
||||
}
|
||||
|
||||
ToolProvider provider = context ->
|
||||
context.contains("calculate") ? new Calculator() : new DataService();
|
||||
```
|
||||
|
||||
### Structured Output
|
||||
|
||||
**@Description**: Annotate output fields for extraction.
|
||||
```java
|
||||
class Person {
|
||||
@Description("Person's full name")
|
||||
String name;
|
||||
|
||||
@Description("Age in years")
|
||||
int age;
|
||||
}
|
||||
|
||||
interface Extractor {
|
||||
@UserMessage("Extract person from: {{it}}")
|
||||
Person extract(String text);
|
||||
}
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
**ToolExecutionErrorHandler**: Handle tool execution failures.
|
||||
```java
|
||||
.toolExecutionErrorHandler((request, exception) -> {
|
||||
logger.error("Tool failed: " + request.name(), exception);
|
||||
return "Tool execution failed: " + exception.getMessage();
|
||||
})
|
||||
```
|
||||
|
||||
**ToolArgumentsErrorHandler**: Handle malformed tool arguments.
|
||||
```java
|
||||
.toolArgumentsErrorHandler((request, exception) -> {
|
||||
logger.warn("Invalid arguments for " + request.name());
|
||||
return "Please provide valid arguments";
|
||||
})
|
||||
```
|
||||
|
||||
## Streaming APIs
|
||||
|
||||
### TokenStream
|
||||
|
||||
**Purpose**: Handle streaming LLM responses token-by-token.
|
||||
|
||||
```java
|
||||
interface StreamingAssistant {
|
||||
TokenStream streamChat(String message);
|
||||
}
|
||||
|
||||
TokenStream stream = assistant.streamChat("Tell me a story");
|
||||
|
||||
stream
|
||||
.onNext(token -> {
|
||||
// Process each token
|
||||
System.out.print(token);
|
||||
})
|
||||
.onCompleteResponse(response -> {
|
||||
// Full response available
|
||||
System.out.println("\nTokens used: " + response.tokenUsage());
|
||||
})
|
||||
.onError(error -> {
|
||||
System.err.println("Error: " + error);
|
||||
})
|
||||
.onToolExecuted(toolExecution -> {
|
||||
System.out.println("Tool: " + toolExecution.request().name());
|
||||
})
|
||||
.onRetrieved(contents -> {
|
||||
// RAG content retrieved
|
||||
contents.forEach(c -> System.out.println(c.textSegment()));
|
||||
})
|
||||
.start();
|
||||
```
|
||||
|
||||
### StreamingChatResponseHandler
|
||||
|
||||
**Purpose**: Callback-based streaming without TokenStream.
|
||||
|
||||
```java
|
||||
streamingModel.chat(request, new StreamingChatResponseHandler() {
|
||||
@Override
|
||||
public void onPartialResponse(String partialResponse) {
|
||||
System.out.print(partialResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCompleteResponse(ChatResponse response) {
|
||||
System.out.println("\nComplete!");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
error.printStackTrace();
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
## Content Retrieval
|
||||
|
||||
### ContentRetriever Interface
|
||||
|
||||
**Purpose**: Fetch relevant content for RAG.
|
||||
|
||||
```java
|
||||
interface ContentRetriever {
|
||||
Content retrieve(Query query);
|
||||
List<Content> retrieveAll(List<Query> queries);
|
||||
}
|
||||
```
|
||||
|
||||
### EmbeddingStoreContentRetriever
|
||||
|
||||
```java
|
||||
ContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(5) // Default max results
|
||||
.minScore(0.7) // Similarity threshold
|
||||
.dynamicMaxResults(query -> 10) // Query-dependent
|
||||
.dynamicMinScore(query -> 0.8) // Query-dependent
|
||||
.filter(new IsEqualTo("userId", "123")) // Metadata filter
|
||||
.dynamicFilter(query -> {...}) // Dynamic filter
|
||||
.build();
|
||||
```
|
||||
|
||||
### RetrievalAugmentor
|
||||
|
||||
**Purpose**: Advanced RAG pipeline with query transformation and re-ranking.
|
||||
|
||||
```java
|
||||
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
|
||||
.queryTransformer(new CompressingQueryTransformer(chatModel))
|
||||
.contentRetriever(contentRetriever)
|
||||
.contentAggregator(ReRankingContentAggregator.builder()
|
||||
.scoringModel(scoringModel)
|
||||
.minScore(0.8)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// Use with AI Service
|
||||
var assistant = AiServices.builder(Assistant.class)
|
||||
.chatModel(chatModel)
|
||||
.retrievalAugmentor(augmentor)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Request/Response Models
|
||||
|
||||
### ChatRequest
|
||||
|
||||
**Purpose**: Build complex chat requests with multiple messages.
|
||||
|
||||
```java
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.messages(
|
||||
SystemMessage.from("You are helpful"),
|
||||
UserMessage.from("What is AI?"),
|
||||
AiMessage.from("AI is...")
|
||||
)
|
||||
.temperature(0.7)
|
||||
.maxTokens(500)
|
||||
.topP(0.95)
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.chat(request);
|
||||
```
|
||||
|
||||
### ChatResponse
|
||||
|
||||
**Purpose**: Access chat model responses and metadata.
|
||||
|
||||
```java
|
||||
String content = response.aiMessage().text();
|
||||
TokenUsage usage = response.tokenUsage();
|
||||
|
||||
System.out.println("Tokens: " + usage.totalTokenCount());
|
||||
System.out.println("Prompt tokens: " + usage.inputTokenCount());
|
||||
System.out.println("Completion tokens: " + usage.outputTokenCount());
|
||||
System.out.println("Finish reason: " + response.finishReason());
|
||||
```
|
||||
|
||||
## Query and Content
|
||||
|
||||
### Query
|
||||
|
||||
**Purpose**: Represent a user query in retrieval context.
|
||||
|
||||
```java
|
||||
// Query object contains:
|
||||
String text // The query text
|
||||
Metadata metadata() // Query metadata (e.g., userId)
|
||||
Object metadata(String key) // Get metadata value
|
||||
Object metadata(String key, Object defaultValue)
|
||||
```
|
||||
|
||||
### Content
|
||||
|
||||
**Purpose**: Retrieved content with metadata.
|
||||
|
||||
```java
|
||||
String textSegment() // Retrieved text
|
||||
double score() // Relevance score
|
||||
Metadata metadata() // Content metadata (e.g., source)
|
||||
Map<String, Object> source() // Original source data
|
||||
```
|
||||
|
||||
## Message Types
|
||||
|
||||
### SystemMessage
|
||||
```java
|
||||
SystemMessage message = SystemMessage.from("You are a code reviewer");
|
||||
```
|
||||
|
||||
### UserMessage
|
||||
```java
|
||||
UserMessage message = UserMessage.from("Review this code");
|
||||
// With images
|
||||
UserMessage message = UserMessage.from(
|
||||
TextContent.from("Analyze this"),
|
||||
ImageContent.from("http://...", "image/png")
|
||||
);
|
||||
```
|
||||
|
||||
### AiMessage
|
||||
```java
|
||||
AiMessage message = AiMessage.from("Here's my analysis");
|
||||
// With tool calls
|
||||
AiMessage message = AiMessage.from(
|
||||
"Let me calculate",
|
||||
ToolExecutionResultMessage.from(toolName, result)
|
||||
);
|
||||
```
|
||||
|
||||
## Configuration Patterns
|
||||
|
||||
### Chat Model Configuration
|
||||
|
||||
```java
|
||||
ChatModel model = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini") // Model selection
|
||||
.temperature(0.7) // Creativity (0-2)
|
||||
.topP(0.95) // Diversity (0-1)
|
||||
.topK(40) // Top K tokens
|
||||
.maxTokens(2000) // Max generation
|
||||
.frequencyPenalty(0.0) // Reduce repetition
|
||||
.presencePenalty(0.0) // Reduce topic switching
|
||||
.seed(42) // Reproducibility
|
||||
.logRequests(true) // Debug logging
|
||||
.logResponses(true) // Debug logging
|
||||
.build();
|
||||
```
|
||||
|
||||
### Embedding Model Configuration
|
||||
|
||||
```java
|
||||
EmbeddingModel embedder = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("text-embedding-3-small")
|
||||
.dimensions(512) // Custom dimensions
|
||||
.build();
|
||||
```
|
||||
|
||||
## Best Practices for API Usage
|
||||
|
||||
1. **Type Safety**: Always define typed interfaces for type safety at compile time
|
||||
2. **Separation of Concerns**: Use different interfaces for different domains
|
||||
3. **Error Handling**: Always implement error handlers for tools
|
||||
4. **Memory Management**: Choose appropriate memory implementation for use case
|
||||
5. **Token Optimization**: Use temperature=0 for deterministic tasks
|
||||
6. **Testing**: Mock ChatModel for unit tests
|
||||
7. **Logging**: Enable request/response logging in development
|
||||
8. **Rate Limiting**: Implement backoff strategies for API calls
|
||||
9. **Caching**: Cache responses for frequently asked questions
|
||||
10. **Monitoring**: Track token usage for cost management
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Factory Pattern for Multiple Assistants
|
||||
```java
|
||||
public class AssistantFactory {
|
||||
static JavaExpert createJavaExpert() {
|
||||
return AiServices.create(JavaExpert.class, chatModel);
|
||||
}
|
||||
|
||||
static PythonExpert createPythonExpert() {
|
||||
return AiServices.create(PythonExpert.class, chatModel);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Decorator Pattern for Enhanced Functionality
|
||||
```java
|
||||
public class LoggingAssistant implements Assistant {
|
||||
private final Assistant delegate;
|
||||
|
||||
public String chat(String message) {
|
||||
logger.info("User: " + message);
|
||||
String response = delegate.chat(message);
|
||||
logger.info("Assistant: " + response);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Builder Pattern for Complex Configurations
|
||||
```java
|
||||
var assistant = AiServices.builder(ComplexAssistant.class)
|
||||
.chatModel(getChatModel())
|
||||
.chatMemory(getMemory())
|
||||
.tools(getTool1(), getTool2())
|
||||
.contentRetriever(getRetriever())
|
||||
.build();
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [LangChain4j Documentation](https://docs.langchain4j.dev)
|
||||
- [OpenAI API Reference](https://platform.openai.com/docs)
|
||||
- [LangChain4j GitHub](https://github.com/langchain4j/langchain4j)
|
||||
- [LangChain4j Examples](https://github.com/langchain4j/langchain4j-examples)
|
||||
393
skills/langchain4j/langchain4j-mcp-server-patterns/SKILL.md
Normal file
393
skills/langchain4j/langchain4j-mcp-server-patterns/SKILL.md
Normal file
@@ -0,0 +1,393 @@
|
||||
---
|
||||
name: langchain4j-mcp-server-patterns
|
||||
description: Model Context Protocol (MCP) server implementation patterns with LangChain4j. Use when building MCP servers to extend AI capabilities with custom tools, resources, and prompt templates.
|
||||
category: ai-integration
|
||||
tags: [langchain4j, mcp, model-context-protocol, tools, resources, prompts, ai-services, java, spring-boot, enterprise]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash, WebFetch
|
||||
---
|
||||
|
||||
# LangChain4j MCP Server Implementation Patterns
|
||||
|
||||
Implement Model Context Protocol (MCP) servers with LangChain4j to extend AI capabilities with standardized tools, resources, and prompt templates.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when building:
|
||||
- AI applications requiring external tool integration
|
||||
- Enterprise MCP servers with multi-domain support (GitHub, databases, APIs)
|
||||
- Dynamic tool providers with context-aware filtering
|
||||
- Resource-based data access systems for AI models
|
||||
- Prompt template servers for standardized AI interactions
|
||||
- Scalable AI agents with resilient tool execution
|
||||
- Multi-modal AI applications with diverse data sources
|
||||
- Spring Boot applications with MCP integration
|
||||
- Production-ready MCP servers with security and monitoring
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic MCP Server
|
||||
|
||||
Create a simple MCP server with one tool:
|
||||
|
||||
```java
|
||||
MCPServer server = MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(new SimpleWeatherToolProvider())
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
```
|
||||
|
||||
### Spring Boot Integration
|
||||
|
||||
Configure MCP server in Spring Boot:
|
||||
|
||||
```java
|
||||
@Bean
|
||||
public MCPSpringConfig mcpServer(List<ToolProvider> tools) {
|
||||
return MCPSpringConfig.builder()
|
||||
.tools(tools)
|
||||
.server(new StdioServer.Builder())
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### MCP Architecture
|
||||
|
||||
MCP standardizes AI application connections:
|
||||
- **Tools**: Executable functions (database queries, API calls)
|
||||
- **Resources**: Data sources (files, schemas, documentation)
|
||||
- **Prompts**: Pre-configured templates for tasks
|
||||
- **Transport**: Communication layer (stdio, HTTP, WebSocket)
|
||||
|
||||
```
|
||||
AI Application ←→ MCP Client ←→ Transport ←→ MCP Server ←→ External Service
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
- **MCPServer**: Main server instance with configuration
|
||||
- **ToolProvider**: Tool specification and execution interface
|
||||
- **ResourceListProvider/ResourceReadHandler**: Resource access
|
||||
- **PromptListProvider/PromptGetHandler**: Template management
|
||||
- **Transport**: Communication mechanisms (stdio, HTTP)
|
||||
|
||||
## Implementation Patterns
|
||||
|
||||
### Tool Provider Pattern
|
||||
|
||||
Create tools with proper schema validation:
|
||||
|
||||
```java
|
||||
class WeatherToolProvider implements ToolProvider {
|
||||
|
||||
@Override
|
||||
public List<ToolSpecification> listTools() {
|
||||
return List.of(ToolSpecification.builder()
|
||||
.name("get_weather")
|
||||
.description("Get weather for a city")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"city", Map.of("type", "string", "description", "City name")
|
||||
),
|
||||
"required", List.of("city")
|
||||
))
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeTool(String name, String arguments) {
|
||||
// Parse arguments and execute tool logic
|
||||
return "Weather data result";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Resource Provider Pattern
|
||||
|
||||
Provide static and dynamic resources:
|
||||
|
||||
```java
|
||||
class CompanyResourceProvider
|
||||
implements ResourceListProvider, ResourceReadHandler {
|
||||
|
||||
@Override
|
||||
public List<McpResource> listResources() {
|
||||
return List.of(
|
||||
McpResource.builder()
|
||||
.uri("policies")
|
||||
.name("Company Policies")
|
||||
.mimeType("text/plain")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String readResource(String uri) {
|
||||
return loadResourceContent(uri);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Prompt Template Pattern
|
||||
|
||||
Create reusable prompt templates:
|
||||
|
||||
```java
|
||||
class PromptTemplateProvider
|
||||
implements PromptListProvider, PromptGetHandler {
|
||||
|
||||
@Override
|
||||
public List<Prompt> listPrompts() {
|
||||
return List.of(
|
||||
Prompt.builder()
|
||||
.name("code-review")
|
||||
.description("Review code for quality")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPrompt(String name, Map<String, String> args) {
|
||||
return applyTemplate(name, args);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Transport Configuration
|
||||
|
||||
### Stdio Transport
|
||||
|
||||
Local process communication:
|
||||
|
||||
```java
|
||||
McpTransport transport = new StdioMcpTransport.Builder()
|
||||
.command(List.of("npm", "exec", "@modelcontextprotocol/server-everything"))
|
||||
.logEvents(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### HTTP Transport
|
||||
|
||||
Remote server communication:
|
||||
|
||||
```java
|
||||
McpTransport transport = new HttpMcpTransport.Builder()
|
||||
.sseUrl("http://localhost:3001/sse")
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Client Integration
|
||||
|
||||
### MCP Client Setup
|
||||
|
||||
Connect to MCP servers:
|
||||
|
||||
```java
|
||||
McpClient client = new DefaultMcpClient.Builder()
|
||||
.key("my-client")
|
||||
.transport(transport)
|
||||
.cacheToolList(true)
|
||||
.build();
|
||||
|
||||
// List available tools
|
||||
List<ToolSpecification> tools = client.listTools();
|
||||
```
|
||||
|
||||
### Tool Provider Integration
|
||||
|
||||
Bridge MCP servers to LangChain4j AI services:
|
||||
|
||||
```java
|
||||
McpToolProvider provider = McpToolProvider.builder()
|
||||
.mcpClients(mcpClient)
|
||||
.failIfOneServerFails(false)
|
||||
.filter((client, tool) -> filterByPermissions(tool))
|
||||
.build();
|
||||
|
||||
// Integrate with AI service
|
||||
AIAssistant assistant = AiServices.builder(AIAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.toolProvider(provider)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Security & Best Practices
|
||||
|
||||
### Tool Security
|
||||
|
||||
Implement secure tool filtering:
|
||||
|
||||
```java
|
||||
McpToolProvider secureProvider = McpToolProvider.builder()
|
||||
.mcpClients(mcpClient)
|
||||
.filter((client, tool) -> {
|
||||
if (tool.name().startsWith("admin_") && !isAdmin()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.build();
|
||||
```
|
||||
|
||||
### Resource Security
|
||||
|
||||
Apply access controls to resources:
|
||||
|
||||
```java
|
||||
public boolean canAccessResource(String uri, User user) {
|
||||
return resourceService.hasAccess(uri, user);
|
||||
}
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
Implement robust error handling:
|
||||
|
||||
```java
|
||||
try {
|
||||
String result = mcpClient.executeTool(request);
|
||||
} catch (McpException e) {
|
||||
log.error("MCP execution failed: {}", e.getMessage());
|
||||
return fallbackResult();
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Multi-Server Configuration
|
||||
|
||||
Configure multiple MCP servers:
|
||||
|
||||
```java
|
||||
@Bean
|
||||
public List<McpClient> mcpClients(List<ServerConfig> configs) {
|
||||
return configs.stream()
|
||||
.map(this::createMcpClient)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Bean
|
||||
public McpToolProvider multiServerProvider(List<McpClient> clients) {
|
||||
return McpToolProvider.builder()
|
||||
.mcpClients(clients)
|
||||
.failIfOneServerFails(false)
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
### Dynamic Tool Discovery
|
||||
|
||||
Runtime tool filtering based on context:
|
||||
|
||||
```java
|
||||
McpToolProvider contextualProvider = McpToolProvider.builder()
|
||||
.mcpClients(clients)
|
||||
.filter((client, tool) -> isToolAllowed(user, tool, context))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Health Monitoring
|
||||
|
||||
Monitor MCP server health:
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class McpHealthChecker {
|
||||
|
||||
@Scheduled(fixedRate = 30000) // 30 seconds
|
||||
public void checkServers() {
|
||||
mcpClients.forEach(client -> {
|
||||
try {
|
||||
client.listTools();
|
||||
markHealthy(client.key());
|
||||
} catch (Exception e) {
|
||||
markUnhealthy(client.key(), e.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Application Properties
|
||||
|
||||
Configure MCP servers in application.yml:
|
||||
|
||||
```yaml
|
||||
mcp:
|
||||
servers:
|
||||
github:
|
||||
type: docker
|
||||
command: ["/usr/local/bin/docker", "run", "-e", "GITHUB_TOKEN", "-i", "mcp/github"]
|
||||
log-events: true
|
||||
database:
|
||||
type: stdio
|
||||
command: ["/usr/bin/npm", "exec", "@modelcontextprotocol/server-sqlite"]
|
||||
log-events: false
|
||||
```
|
||||
|
||||
### Spring Boot Configuration
|
||||
|
||||
Configure MCP with Spring Boot:
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(McpProperties.class)
|
||||
public class McpConfiguration {
|
||||
|
||||
@Bean
|
||||
public MCPServer mcpServer(List<ToolProvider> providers) {
|
||||
return MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(providers)
|
||||
.enableLogging(true)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
Refer to [examples.md](./references/examples.md) for comprehensive implementation examples including:
|
||||
- Basic MCP server setup
|
||||
- Multi-tool enterprise servers
|
||||
- Resource and prompt providers
|
||||
- Spring Boot integration
|
||||
- Error handling patterns
|
||||
- Security implementations
|
||||
|
||||
## API Reference
|
||||
|
||||
Complete API documentation is available in [api-reference.md](./references/api-reference.md) covering:
|
||||
- Core MCP classes and interfaces
|
||||
- Transport configuration
|
||||
- Client and server patterns
|
||||
- Error handling strategies
|
||||
- Configuration management
|
||||
- Testing and validation
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Resource Management**: Always close MCP clients properly using try-with-resources
|
||||
2. **Error Handling**: Implement graceful degradation when servers fail
|
||||
3. **Security**: Use tool filtering and resource access controls
|
||||
4. **Performance**: Enable caching and optimize tool execution
|
||||
5. **Monitoring**: Implement health checks and observability
|
||||
6. **Testing**: Create comprehensive test suites with mocks
|
||||
7. **Documentation**: Document tools, resources, and prompts clearly
|
||||
8. **Configuration**: Use structured configuration for maintainability
|
||||
|
||||
## References
|
||||
|
||||
- [LangChain4j Documentation](https://langchain4j.com/docs/)
|
||||
- [Model Context Protocol Specification](https://modelcontextprotocol.org/)
|
||||
- [API Reference](./references/api-reference.md)
|
||||
- [Examples](./references/examples.md)
|
||||
@@ -0,0 +1,315 @@
|
||||
package com.example.mcp;
|
||||
|
||||
import dev.langchain4j.mcp.*;
|
||||
import dev.langchain4j.mcp.transport.*;
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
// Helper imports
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Template for creating MCP servers with LangChain4j.
|
||||
*
|
||||
* This template provides a starting point for building MCP servers with:
|
||||
* - Tool providers
|
||||
* - Resource providers
|
||||
* - Prompt providers
|
||||
* - Spring Boot integration
|
||||
* - Configuration management
|
||||
*/
|
||||
@SpringBootApplication
|
||||
public class MCPServerTemplate {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(MCPServerTemplate.class, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure and build the main MCP server instance.
|
||||
*/
|
||||
@Bean
|
||||
public MCPServer mcpServer(
|
||||
List<ToolProvider> toolProviders,
|
||||
List<ResourceListProvider> resourceProviders,
|
||||
List<PromptListProvider> promptProviders) {
|
||||
|
||||
return MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(toolProviders)
|
||||
.addResourceProvider(resourceProviders)
|
||||
.addPromptProvider(promptProviders)
|
||||
.enableLogging(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure MCP clients for connecting to external MCP servers.
|
||||
*/
|
||||
@Bean
|
||||
public McpClient mcpClient() {
|
||||
StdioMcpTransport transport = new StdioMcpTransport.Builder()
|
||||
.command(List.of("npm", "exec", "@modelcontextprotocol/server-everything@0.6.2"))
|
||||
.logEvents(true)
|
||||
.build();
|
||||
|
||||
return new DefaultMcpClient.Builder()
|
||||
.key("template-client")
|
||||
.transport(transport)
|
||||
.cacheToolList(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure MCP tool provider for AI services integration.
|
||||
*/
|
||||
@Bean
|
||||
public McpToolProvider mcpToolProvider(McpClient mcpClient) {
|
||||
return McpToolProvider.builder()
|
||||
.mcpClients(mcpClient)
|
||||
.failIfOneServerFails(false)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Example tool provider implementing a simple calculator.
|
||||
*/
|
||||
class CalculatorToolProvider implements ToolProvider {
|
||||
|
||||
@Override
|
||||
public List<ToolSpecification> listTools() {
|
||||
return List.of(
|
||||
ToolSpecification.builder()
|
||||
.name("add")
|
||||
.description("Add two numbers")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"a", Map.of("type", "number", "description", "First number"),
|
||||
"b", Map.of("type", "number", "description", "Second number")
|
||||
),
|
||||
"required", List.of("a", "b")
|
||||
))
|
||||
.build(),
|
||||
ToolSpecification.builder()
|
||||
.name("multiply")
|
||||
.description("Multiply two numbers")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"a", Map.of("type", "number", "description", "First number"),
|
||||
"b", Map.of("type", "number", "description", "Second number")
|
||||
),
|
||||
"required", List.of("a", "b")
|
||||
))
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeTool(String name, String arguments) {
|
||||
try {
|
||||
// Parse JSON arguments
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode argsNode = mapper.readTree(arguments);
|
||||
double a = argsNode.get("a").asDouble();
|
||||
double b = argsNode.get("b").asDouble();
|
||||
|
||||
switch (name) {
|
||||
case "add":
|
||||
return String.valueOf(a + b);
|
||||
case "multiply":
|
||||
return String.valueOf(a * b);
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown tool: " + name);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
return "Error executing tool: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Example resource provider for static company information.
|
||||
*/
|
||||
class CompanyResourceProvider implements ResourceListProvider, ResourceReadHandler {
|
||||
|
||||
@Override
|
||||
public List<McpResource> listResources() {
|
||||
return List.of(
|
||||
McpResource.builder()
|
||||
.uri("company-info")
|
||||
.name("Company Information")
|
||||
.description("Basic company details and contact information")
|
||||
.mimeType("text/plain")
|
||||
.build(),
|
||||
McpResource.builder()
|
||||
.uri("policies")
|
||||
.name("Company Policies")
|
||||
.description("Company policies and procedures")
|
||||
.mimeType("text/markdown")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String readResource(String uri) {
|
||||
switch (uri) {
|
||||
case "company-info":
|
||||
return loadCompanyInfo();
|
||||
case "policies":
|
||||
return loadPolicies();
|
||||
default:
|
||||
throw new ResourceNotFoundException("Resource not found: " + uri);
|
||||
}
|
||||
}
|
||||
|
||||
private String loadCompanyInfo() {
|
||||
return """
|
||||
Company Information:
|
||||
===================
|
||||
|
||||
Name: Example Corporation
|
||||
Founded: 2020
|
||||
Industry: Technology
|
||||
Employees: 100+
|
||||
|
||||
Contact:
|
||||
- Email: info@example.com
|
||||
- Phone: +1-555-0123
|
||||
- Website: https://example.com
|
||||
|
||||
Mission: To deliver innovative AI solutions
|
||||
""";
|
||||
}
|
||||
|
||||
private String loadPolicies() {
|
||||
return """
|
||||
Company Policies:
|
||||
=================
|
||||
|
||||
1. Code of Conduct
|
||||
- Treat all team members with respect
|
||||
- Maintain professional communication
|
||||
- Report any concerns to management
|
||||
|
||||
2. Security Policy
|
||||
- Use strong passwords
|
||||
- Enable 2FA when available
|
||||
- Report security incidents immediately
|
||||
|
||||
3. Work Environment
|
||||
- Flexible working hours
|
||||
- Remote work options
|
||||
- Support for continuous learning
|
||||
""";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Example prompt template provider for common AI tasks.
|
||||
*/
|
||||
class PromptTemplateProvider implements PromptListProvider, PromptGetHandler {
|
||||
|
||||
@Override
|
||||
public List<Prompt> listPrompts() {
|
||||
return List.of(
|
||||
Prompt.builder()
|
||||
.name("code-review")
|
||||
.description("Review code for quality, security, and best practices")
|
||||
.build(),
|
||||
Prompt.builder()
|
||||
.name("documentation-generation")
|
||||
.description("Generate technical documentation from code")
|
||||
.build(),
|
||||
Prompt.builder()
|
||||
.name("bug-analysis")
|
||||
.description("Analyze and explain potential bugs in code")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPrompt(String name, Map<String, String> arguments) {
|
||||
switch (name) {
|
||||
case "code-review":
|
||||
return createCodeReviewPrompt(arguments);
|
||||
case "documentation-generation":
|
||||
return createDocumentationPrompt(arguments);
|
||||
case "bug-analysis":
|
||||
return createBugAnalysisPrompt(arguments);
|
||||
default:
|
||||
throw new PromptNotFoundException("Prompt not found: " + name);
|
||||
}
|
||||
}
|
||||
|
||||
private String createCodeReviewPrompt(Map<String, String> args) {
|
||||
String code = args.getOrDefault("code", "");
|
||||
String language = args.getOrDefault("language", "unknown");
|
||||
|
||||
return String.format("""
|
||||
Review the following %s code for quality, security, and best practices:
|
||||
|
||||
```%s
|
||||
%s
|
||||
```
|
||||
|
||||
Please analyze:
|
||||
1. Code quality and readability
|
||||
2. Security vulnerabilities
|
||||
3. Performance optimizations
|
||||
4. Best practices compliance
|
||||
5. Error handling
|
||||
|
||||
Provide specific recommendations for improvements.
|
||||
""", language, language, code);
|
||||
}
|
||||
|
||||
private String createDocumentationPrompt(Map<String, String> args) {
|
||||
String code = args.getOrDefault("code", "");
|
||||
String component = args.getOrDefault("component", "function");
|
||||
|
||||
return String.format("""
|
||||
Generate comprehensive documentation for the following %s:
|
||||
|
||||
```%s
|
||||
%s
|
||||
```
|
||||
|
||||
Include:
|
||||
1. Function/method signatures
|
||||
2. Parameters and return values
|
||||
3. Purpose and usage examples
|
||||
4. Dependencies and requirements
|
||||
5. Error conditions and handling
|
||||
""", component, "java", code);
|
||||
}
|
||||
|
||||
private String createBugAnalysisPrompt(Map<String, String> args) {
|
||||
String code = args.getOrDefault("code", "");
|
||||
|
||||
return String.format("""
|
||||
Analyze the following code for potential bugs and issues:
|
||||
|
||||
```java
|
||||
%s
|
||||
```
|
||||
|
||||
Look for:
|
||||
1. Null pointer exceptions
|
||||
2. Logic errors
|
||||
3. Resource leaks
|
||||
4. Race conditions
|
||||
5. Edge cases
|
||||
6. Type mismatches
|
||||
|
||||
Explain each issue found and suggest fixes.
|
||||
""", code);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,435 @@
|
||||
# LangChain4j MCP Server API Reference
|
||||
|
||||
This document provides comprehensive API documentation for implementing MCP servers with LangChain4j.
|
||||
|
||||
## Core MCP Classes
|
||||
|
||||
### McpClient Interface
|
||||
|
||||
Primary interface for communicating with MCP servers.
|
||||
|
||||
**Key Methods:**
|
||||
```java
|
||||
// Tool Management
|
||||
List<ToolSpecification> listTools();
|
||||
String executeTool(ToolExecutionRequest request);
|
||||
|
||||
// Resource Management
|
||||
List<McpResource> listResources();
|
||||
String getResource(String uri);
|
||||
List<McpResourceTemplate> listResourceTemplates();
|
||||
|
||||
// Prompt Management
|
||||
List<Prompt> listPrompts();
|
||||
String getPrompt(String name);
|
||||
|
||||
// Lifecycle Management
|
||||
void close();
|
||||
```
|
||||
|
||||
### DefaultMcpClient.Builder
|
||||
|
||||
Builder for creating MCP clients with configuration options.
|
||||
|
||||
**Configuration Methods:**
|
||||
```java
|
||||
McpClient client = new DefaultMcpClient.Builder()
|
||||
.key("unique-client-id") // Unique identifier
|
||||
.transport(transport) // Transport mechanism
|
||||
.cacheToolList(true) // Enable tool caching
|
||||
.logMessageHandler(handler) // Custom logging
|
||||
.build();
|
||||
```
|
||||
|
||||
### McpToolProvider.Builder
|
||||
|
||||
Builder for creating tool providers that bridge MCP servers to LangChain4j AI services.
|
||||
|
||||
**Configuration Methods:**
|
||||
```java
|
||||
McpToolProvider provider = McpToolProvider.builder()
|
||||
.mcpClients(client1, client2) // Add MCP clients
|
||||
.failIfOneServerFails(false) // Configure failure handling
|
||||
.filterToolNames("tool1", "tool2") // Filter by names
|
||||
.filter((client, tool) -> logic) // Custom filtering
|
||||
.build();
|
||||
```
|
||||
|
||||
## Transport Configuration
|
||||
|
||||
### StdioMcpTransport.Builder
|
||||
|
||||
For local process communication with npm packages or Docker containers.
|
||||
|
||||
```java
|
||||
McpTransport transport = new StdioMcpTransport.Builder()
|
||||
.command(List.of("npm", "exec", "@modelcontextprotocol/server-everything@0.6.2"))
|
||||
.logEvents(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### HttpMcpTransport.Builder
|
||||
|
||||
For HTTP-based communication with remote MCP servers.
|
||||
|
||||
```java
|
||||
McpTransport transport = new HttpMcpTransport.Builder()
|
||||
.sseUrl("http://localhost:3001/sse")
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### StreamableHttpMcpTransport.Builder
|
||||
|
||||
For streamable HTTP transport with enhanced performance.
|
||||
|
||||
```java
|
||||
McpTransport transport = new StreamableHttpMcpTransport.Builder()
|
||||
.url("http://localhost:3001/mcp")
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
## AI Service Integration
|
||||
|
||||
### AiServices.builder()
|
||||
|
||||
Create AI services integrated with MCP tool providers.
|
||||
|
||||
**Integration Methods:**
|
||||
```java
|
||||
AIAssistant assistant = AiServices.builder(AIAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.toolProvider(toolProvider)
|
||||
.chatMemoryProvider(memoryProvider)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Error Handling and Management
|
||||
|
||||
### Exception Handling
|
||||
|
||||
Handle MCP-specific exceptions gracefully:
|
||||
|
||||
```java
|
||||
try {
|
||||
String result = mcpClient.executeTool(request);
|
||||
} catch (McpException e) {
|
||||
log.error("MCP execution failed: {}", e.getMessage());
|
||||
// Implement fallback logic
|
||||
}
|
||||
```
|
||||
|
||||
### Retry and Resilience
|
||||
|
||||
Implement retry logic for unreliable MCP servers:
|
||||
|
||||
```java
|
||||
RetryTemplate retryTemplate = RetryTemplate.builder()
|
||||
.maxAttempts(3)
|
||||
.exponentialBackoff(1000, 2, 10000)
|
||||
.build();
|
||||
|
||||
String result = retryTemplate.execute(context ->
|
||||
mcpClient.executeTool(request));
|
||||
```
|
||||
|
||||
## Configuration Properties
|
||||
|
||||
### Application Configuration
|
||||
|
||||
```yaml
|
||||
mcp:
|
||||
fail-if-one-server-fails: false
|
||||
cache-tools: true
|
||||
servers:
|
||||
github:
|
||||
type: docker
|
||||
command: ["/usr/local/bin/docker", "run", "-e", "GITHUB_TOKEN", "-i", "mcp/github"]
|
||||
log-events: true
|
||||
database:
|
||||
type: stdio
|
||||
command: ["/usr/bin/npm", "exec", "@modelcontextprotocol/server-sqlite"]
|
||||
log-events: false
|
||||
```
|
||||
|
||||
### Spring Boot Configuration
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(McpProperties.class)
|
||||
public class McpConfiguration {
|
||||
|
||||
@Bean
|
||||
public List<McpClient> mcpClients(McpProperties properties) {
|
||||
return properties.getServers().entrySet().stream()
|
||||
.map(entry -> createMcpClient(entry.getKey(), entry.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Bean
|
||||
public McpToolProvider mcpToolProvider(List<McpClient> mcpClients, McpProperties properties) {
|
||||
return McpToolProvider.builder()
|
||||
.mcpClients(mcpClients)
|
||||
.failIfOneServerFails(properties.isFailIfOneServerFails())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Tool Specification and Execution
|
||||
|
||||
### Tool Specification
|
||||
|
||||
Define tools with proper schema:
|
||||
|
||||
```java
|
||||
ToolSpecification toolSpec = ToolSpecification.builder()
|
||||
.name("database_query")
|
||||
.description("Execute SQL queries against the database")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"sql", Map.of(
|
||||
"type", "string",
|
||||
"description", "SQL query to execute"
|
||||
)
|
||||
)
|
||||
))
|
||||
.build();
|
||||
```
|
||||
|
||||
### Tool Execution
|
||||
|
||||
Execute tools with structured requests:
|
||||
|
||||
```java
|
||||
ToolExecutionRequest request = ToolExecutionRequest.builder()
|
||||
.name("database_query")
|
||||
.arguments("{\"sql\": \"SELECT * FROM users LIMIT 10\"}")
|
||||
.build();
|
||||
|
||||
String result = mcpClient.executeTool(request);
|
||||
```
|
||||
|
||||
## Resource Handling
|
||||
|
||||
### Resource Access
|
||||
|
||||
Access and utilize MCP resources:
|
||||
|
||||
```java
|
||||
// List available resources
|
||||
List<McpResource> resources = mcpClient.listResources();
|
||||
|
||||
// Get specific resource content
|
||||
String content = mcpClient.getResource("resource://schema/database");
|
||||
|
||||
// Work with resource templates
|
||||
List<McpResourceTemplate> templates = mcpClient.listResourceTemplates();
|
||||
```
|
||||
|
||||
### Resource as Tools
|
||||
|
||||
Convert MCP resources to tools automatically:
|
||||
|
||||
```java
|
||||
DefaultMcpResourcesAsToolsPresenter presenter =
|
||||
new DefaultMcpResourcesAsToolsPresenter();
|
||||
mcpToolProvider.provideTools(presenter);
|
||||
|
||||
// Adds 'list_resources' and 'get_resource' tools automatically
|
||||
```
|
||||
|
||||
## Security and Filtering
|
||||
|
||||
### Tool Filtering
|
||||
|
||||
Implement security-conscious tool filtering:
|
||||
|
||||
```java
|
||||
McpToolProvider secureProvider = McpToolProvider.builder()
|
||||
.mcpClients(mcpClient)
|
||||
.filter((client, tool) -> {
|
||||
// Check user permissions
|
||||
if (tool.name().startsWith("admin_") && !currentUser.hasRole("ADMIN")) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.build();
|
||||
```
|
||||
|
||||
### Resource Security
|
||||
|
||||
Apply security controls to resource access:
|
||||
|
||||
```java
|
||||
public boolean canAccessResource(String uri, User user) {
|
||||
if (uri.contains("sensitive/") && !user.hasRole("ADMIN")) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Caching Strategy
|
||||
|
||||
Implement intelligent caching:
|
||||
|
||||
```java
|
||||
// Enable tool caching for performance
|
||||
McpClient client = new DefaultMcpClient.Builder()
|
||||
.transport(transport)
|
||||
.cacheToolList(true)
|
||||
.build();
|
||||
|
||||
// Periodic cache refresh
|
||||
@Scheduled(fixedRate = 300000) // 5 minutes
|
||||
public void refreshToolCache() {
|
||||
mcpClients.forEach(client -> {
|
||||
try {
|
||||
client.invalidateCache();
|
||||
client.listTools(); // Preload cache
|
||||
} catch (Exception e) {
|
||||
log.warn("Cache refresh failed: {}", e.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
Optimize connection management:
|
||||
|
||||
```java
|
||||
@Bean
|
||||
public Executor mcpExecutor() {
|
||||
return Executors.newFixedThreadPool(10); // Dedicated thread pool
|
||||
}
|
||||
```
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
### Mock Configuration
|
||||
|
||||
Setup for testing:
|
||||
|
||||
```java
|
||||
@TestConfiguration
|
||||
public class MockMcpConfiguration {
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public McpClient mockMcpClient() {
|
||||
McpClient mock = Mockito.mock(McpClient.class);
|
||||
|
||||
when(mock.listTools()).thenReturn(List.of(
|
||||
ToolSpecification.builder()
|
||||
.name("test_tool")
|
||||
.description("Test tool")
|
||||
.build()
|
||||
));
|
||||
|
||||
when(mock.executeTool(any(ToolExecutionRequest.class)))
|
||||
.thenReturn("Mock result");
|
||||
|
||||
return mock;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Testing
|
||||
|
||||
Test MCP integrations:
|
||||
|
||||
```java
|
||||
@SpringBootTest
|
||||
class McpIntegrationTest {
|
||||
|
||||
@Autowired
|
||||
private AIAssistant assistant;
|
||||
|
||||
@Test
|
||||
void shouldExecuteToolsSuccessfully() {
|
||||
String response = assistant.chat("Execute test tool");
|
||||
|
||||
assertThat(response).contains("Mock result");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
### Health Checks
|
||||
|
||||
Monitor MCP server health:
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class McpHealthChecker {
|
||||
|
||||
@EventListener
|
||||
@Async
|
||||
public void checkHealth() {
|
||||
mcpClients.forEach(client -> {
|
||||
try {
|
||||
client.listTools(); // Simple health check
|
||||
healthRegistry.markHealthy(client.key());
|
||||
} catch (Exception e) {
|
||||
healthRegistry.markUnhealthy(client.key(), e.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
Collect execution metrics:
|
||||
|
||||
```java
|
||||
@Bean
|
||||
public Counter toolExecutionCounter(MeterRegistry meterRegistry) {
|
||||
return meterRegistry.counter("mcp.tool.execution", "type", "total");
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Timer toolExecutionTimer(MeterRegistry meterRegistry) {
|
||||
return meterRegistry.timer("mcp.tool.execution.time");
|
||||
}
|
||||
```
|
||||
|
||||
## Migration and Versioning
|
||||
|
||||
### Version Compatibility
|
||||
|
||||
Handle version compatibility:
|
||||
|
||||
```java
|
||||
public class VersionedMcpClient {
|
||||
|
||||
public boolean isCompatible(String serverVersion) {
|
||||
return semanticVersionChecker.isCompatible(
|
||||
REQUIRED_MCP_VERSION, serverVersion);
|
||||
}
|
||||
|
||||
public McpClient createClient(ServerConfig config) {
|
||||
if (!isCompatible(config.getVersion())) {
|
||||
throw new IncompatibleVersionException(
|
||||
"Server version " + config.getVersion() +
|
||||
" is not compatible with required " + REQUIRED_MCP_VERSION);
|
||||
}
|
||||
|
||||
return new DefaultMcpClient.Builder()
|
||||
.transport(createTransport(config))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This API reference provides the complete foundation for implementing MCP servers and clients with LangChain4j, covering all major aspects from basic setup to advanced enterprise patterns.
|
||||
@@ -0,0 +1,592 @@
|
||||
# LangChain4j MCP Server Implementation Examples
|
||||
|
||||
This document provides comprehensive, production-ready examples for implementing MCP servers with LangChain4j.
|
||||
|
||||
## Basic MCP Server Setup
|
||||
|
||||
### Simple MCP Server Implementation
|
||||
|
||||
Create a basic MCP server with single tool functionality:
|
||||
|
||||
```java
|
||||
import dev.langchain4j.mcp.MCPServer;
|
||||
import dev.langchain4j.mcp.ToolProvider;
|
||||
import dev.langchain4j.mcp.server.StdioServer;
|
||||
|
||||
public class BasicMcpServer {
|
||||
|
||||
public static void main(String[] args) {
|
||||
MCPServer server = MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(new SimpleWeatherToolProvider())
|
||||
.build();
|
||||
|
||||
// Start the server
|
||||
server.start();
|
||||
}
|
||||
}
|
||||
|
||||
class SimpleWeatherToolProvider implements ToolProvider {
|
||||
|
||||
@Override
|
||||
public List<ToolSpecification> listTools() {
|
||||
return List.of(ToolSpecification.builder()
|
||||
.name("get_weather")
|
||||
.description("Get weather information for a city")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"city", Map.of(
|
||||
"type", "string",
|
||||
"description", "City name to get weather for"
|
||||
)
|
||||
),
|
||||
"required", List.of("city")
|
||||
))
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeTool(String name, String arguments) {
|
||||
if ("get_weather".equals(name)) {
|
||||
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
|
||||
String city = args.get("city").getAsString();
|
||||
|
||||
// Simulate weather API call
|
||||
return String.format("Weather in %s: Sunny, 22°C", city);
|
||||
}
|
||||
throw new UnsupportedOperationException("Unknown tool: " + name);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Spring Boot MCP Server Integration
|
||||
|
||||
Integrate MCP server with Spring Boot application:
|
||||
|
||||
```java
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
@SpringBootApplication
|
||||
public class McpSpringBootApplication {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(McpSpringBootApplication.class, args);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MCPServer mcpServer() {
|
||||
return MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(new DatabaseToolProvider())
|
||||
.addToolProvider(new FileToolProvider())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@Component
|
||||
class DatabaseToolProvider implements ToolProvider {
|
||||
|
||||
@Override
|
||||
public List<ToolSpecification> listTools() {
|
||||
return List.of(ToolSpecification.builder()
|
||||
.name("query_database")
|
||||
.description("Execute SQL queries against the database")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"sql", Map.of(
|
||||
"type", "string",
|
||||
"description", "SQL query to execute"
|
||||
)
|
||||
),
|
||||
"required", List.of("sql")
|
||||
))
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeTool(String name, String arguments) {
|
||||
if ("query_database".equals(name)) {
|
||||
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
|
||||
String sql = args.get("sql").getAsString();
|
||||
|
||||
// Execute database query
|
||||
return executeDatabaseQuery(sql);
|
||||
}
|
||||
throw new UnsupportedOperationException("Unknown tool: " + name);
|
||||
}
|
||||
|
||||
private String executeDatabaseQuery(String sql) {
|
||||
// Implementation using Spring Data JPA
|
||||
try {
|
||||
return jdbcTemplate.queryForObject(sql, String.class);
|
||||
} catch (Exception e) {
|
||||
return "Error executing query: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Multi-Tool MCP Server
|
||||
|
||||
### Enterprise MCP Server with Multiple Tools
|
||||
|
||||
Create a comprehensive MCP server with multiple tool providers:
|
||||
|
||||
```java
|
||||
@Component
|
||||
public class EnterpriseMcpServer {
|
||||
|
||||
@Bean
|
||||
public MCPServer enterpriseMcpServer(
|
||||
GitHubToolProvider githubToolProvider,
|
||||
DatabaseToolProvider databaseToolProvider,
|
||||
FileToolProvider fileToolProvider,
|
||||
EmailToolProvider emailToolProvider) {
|
||||
|
||||
return MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(githubToolProvider)
|
||||
.addToolProvider(databaseToolProvider)
|
||||
.addToolProvider(fileToolProvider)
|
||||
.addToolProvider(emailToolProvider)
|
||||
.enableLogging(true)
|
||||
.setLogHandler(new CustomLogHandler())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@Component
|
||||
class GitHubToolProvider implements ToolProvider {
|
||||
|
||||
@Override
|
||||
public List<ToolSpecification> listTools() {
|
||||
return List.of(
|
||||
ToolSpecification.builder()
|
||||
.name("get_issue")
|
||||
.description("Get GitHub issue details")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"owner", Map.of(
|
||||
"type", "string",
|
||||
"description", "Repository owner"
|
||||
),
|
||||
"repo", Map.of(
|
||||
"type", "string",
|
||||
"description", "Repository name"
|
||||
),
|
||||
"issue_number", Map.of(
|
||||
"type", "integer",
|
||||
"description", "Issue number"
|
||||
)
|
||||
),
|
||||
"required", List.of("owner", "repo", "issue_number")
|
||||
))
|
||||
.build(),
|
||||
ToolSpecification.builder()
|
||||
.name("list_issues")
|
||||
.description("List GitHub issues for a repository")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"owner", Map.of(
|
||||
"type", "string",
|
||||
"description", "Repository owner"
|
||||
),
|
||||
"repo", Map.of(
|
||||
"type", "string",
|
||||
"description", "Repository name"
|
||||
),
|
||||
"state", Map.of(
|
||||
"type", "string",
|
||||
"description", "Issue state: open, closed, all",
|
||||
"enum", List.of("open", "closed", "all")
|
||||
)
|
||||
),
|
||||
"required", List.of("owner", "repo")
|
||||
))
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeTool(String name, String arguments) {
|
||||
switch (name) {
|
||||
case "get_issue":
|
||||
return getIssueDetails(arguments);
|
||||
case "list_issues":
|
||||
return listRepositoryIssues(arguments);
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown tool: " + name);
|
||||
}
|
||||
}
|
||||
|
||||
private String getIssueDetails(String arguments) {
|
||||
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
|
||||
String owner = args.get("owner").getAsString();
|
||||
String repo = args.get("repo").getAsString();
|
||||
int issueNumber = args.get("issue_number").getAsInt();
|
||||
|
||||
// Call GitHub API
|
||||
GitHubIssue issue = githubService.getIssue(owner, repo, issueNumber);
|
||||
return "Issue #" + issueNumber + ": " + issue.getTitle() +
|
||||
"\nState: " + issue.getState() +
|
||||
"\nCreated: " + issue.getCreatedAt();
|
||||
}
|
||||
|
||||
private String listRepositoryIssues(String arguments) {
|
||||
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
|
||||
String owner = args.get("owner").getAsString();
|
||||
String repo = args.get("repo").getAsString();
|
||||
String state = args.has("state") ? args.get("state").getAsString() : "open";
|
||||
|
||||
List<GitHubIssue> issues = githubService.listIssues(owner, repo, state);
|
||||
|
||||
return issues.stream()
|
||||
.map(issue -> "#%d: %s (%s)".formatted(issue.getNumber(), issue.getTitle(), issue.getState()))
|
||||
.collect(Collectors.joining("\n"));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Resource Provider Implementation
|
||||
|
||||
### Static Resource Provider
|
||||
|
||||
Provide static resources for context enhancement:
|
||||
|
||||
```java
|
||||
@Component
|
||||
class StaticResourceProvider implements ResourceListProvider, ResourceReadHandler {
|
||||
|
||||
private final Map<String, String> resources = new HashMap<>();
|
||||
|
||||
public StaticResourceProvider() {
|
||||
// Initialize with static resources
|
||||
resources.put("company-policies", loadCompanyPolicies());
|
||||
resources.put("api-documentation", loadApiDocumentation());
|
||||
resources.put("best-practices", loadBestPractices());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<McpResource> listResources() {
|
||||
return resources.keySet().stream()
|
||||
.map(uri -> McpResource.builder()
|
||||
.uri(uri)
|
||||
.name(uri.replace("-", " "))
|
||||
.description("Documentation resource")
|
||||
.mimeType("text/plain")
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String readResource(String uri) {
|
||||
if (!resources.containsKey(uri)) {
|
||||
throw new ResourceNotFoundException("Resource not found: " + uri);
|
||||
}
|
||||
return resources.get(uri);
|
||||
}
|
||||
|
||||
private String loadCompanyPolicies() {
|
||||
// Load company policies from file or database
|
||||
return "Company Policies:\n1. Work hours: 9-5\n2. Security compliance\n3. Data privacy";
|
||||
}
|
||||
|
||||
private String loadApiDocumentation() {
|
||||
// Load API documentation
|
||||
return "API Documentation:\nGET /api/users - List users\nPOST /api/users - Create user";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Dynamic Resource Provider
|
||||
|
||||
Create dynamic resources that update in real-time:
|
||||
|
||||
```java
|
||||
@Component
|
||||
class DynamicResourceProvider implements ResourceListProvider, ResourceReadHandler {
|
||||
|
||||
@Autowired
|
||||
private MetricsService metricsService;
|
||||
|
||||
@Override
|
||||
public List<McpResource> listResources() {
|
||||
return List.of(
|
||||
McpResource.builder()
|
||||
.uri("system-metrics")
|
||||
.name("System Metrics")
|
||||
.description("Real-time system performance metrics")
|
||||
.mimeType("application/json")
|
||||
.build(),
|
||||
McpResource.builder()
|
||||
.uri("user-analytics")
|
||||
.name("User Analytics")
|
||||
.description("User behavior and usage statistics")
|
||||
.mimeType("application/json")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String readResource(String uri) {
|
||||
switch (uri) {
|
||||
case "system-metrics":
|
||||
return metricsService.getCurrentSystemMetrics();
|
||||
case "user-analytics":
|
||||
return metricsService.getUserAnalytics();
|
||||
default:
|
||||
throw new ResourceNotFoundException("Resource not found: " + uri);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Prompt Template Provider
|
||||
|
||||
### Prompt Template Server
|
||||
|
||||
Create prompt templates for common AI tasks:
|
||||
|
||||
```java
|
||||
@Component
|
||||
class PromptTemplateProvider implements PromptListProvider, PromptGetHandler {
|
||||
|
||||
private final Map<String, PromptTemplate> templates = new HashMap<>();
|
||||
|
||||
public PromptTemplateProvider() {
|
||||
templates.put("code-review", PromptTemplate.builder()
|
||||
.name("Code Review")
|
||||
.description("Review code for quality, security, and best practices")
|
||||
.template("Review the following code for:\n" +
|
||||
"1. Code quality and readability\n" +
|
||||
"2. Security vulnerabilities\n" +
|
||||
"3. Performance optimizations\n" +
|
||||
"4. Best practices compliance\n\n" +
|
||||
"Code:\n" +
|
||||
"```{code}```\n\n" +
|
||||
"Provide a detailed analysis with specific recommendations.")
|
||||
.build());
|
||||
|
||||
templates.put("documentation-generation", PromptTemplate.builder()
|
||||
.name("Documentation Generator")
|
||||
.description("Generate technical documentation from code")
|
||||
.template("Generate comprehensive documentation for the following code:\n" +
|
||||
"{code}\n\n" +
|
||||
"Include:\n" +
|
||||
"1. Function/method signatures\n" +
|
||||
"2. Parameters and return values\n" +
|
||||
"3. Purpose and usage examples\n" +
|
||||
"4. Dependencies and requirements")
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Prompt> listPrompts() {
|
||||
return templates.values().stream()
|
||||
.map(template -> Prompt.builder()
|
||||
.name(template.getName())
|
||||
.description(template.getDescription())
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPrompt(String name, Map<String, String> arguments) {
|
||||
PromptTemplate template = templates.get(name);
|
||||
if (template == null) {
|
||||
throw new PromptNotFoundException("Prompt not found: " + name);
|
||||
}
|
||||
|
||||
// Replace template variables
|
||||
String content = template.getTemplate();
|
||||
for (Map.Entry<String, String> entry : arguments.entrySet()) {
|
||||
content = content.replace("{" + entry.getKey() + "}", entry.getValue());
|
||||
}
|
||||
|
||||
return content;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling and Validation
|
||||
|
||||
### Robust Error Handling
|
||||
|
||||
Implement comprehensive error handling and validation:
|
||||
|
||||
```java
|
||||
@Component
|
||||
class RobustToolProvider implements ToolProvider {
|
||||
|
||||
@Override
|
||||
public List<ToolSpecification> listTools() {
|
||||
return List.of(ToolSpecification.builder()
|
||||
.name("secure_data_access")
|
||||
.description("Access sensitive data with proper validation")
|
||||
.inputSchema(Map.of(
|
||||
"type", "object",
|
||||
"properties", Map.of(
|
||||
"data_type", Map.of(
|
||||
"type", "string",
|
||||
"description", "Type of data to access",
|
||||
"enum", List.of("user_data", "system_data", "admin_data")
|
||||
),
|
||||
"user_id", Map.of(
|
||||
"type", "string",
|
||||
"description", "User ID requesting access"
|
||||
)
|
||||
),
|
||||
"required", List.of("data_type", "user_id")
|
||||
))
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeTool(String name, String arguments) {
|
||||
if ("secure_data_access".equals(name)) {
|
||||
try {
|
||||
JsonObject args = JsonParser.parseString(arguments).getAsJsonObject();
|
||||
String dataType = args.get("data_type").getAsString();
|
||||
String userId = args.get("user_id").getAsString();
|
||||
|
||||
// Validate user permissions
|
||||
if (!hasPermission(userId, dataType)) {
|
||||
return "Access denied: Insufficient permissions";
|
||||
}
|
||||
|
||||
// Get data securely
|
||||
return getSecureData(dataType, userId);
|
||||
|
||||
} catch (JsonParseException e) {
|
||||
return "Invalid JSON format: " + e.getMessage();
|
||||
} catch (Exception e) {
|
||||
return "Error accessing data: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
throw new UnsupportedOperationException("Unknown tool: " + name);
|
||||
}
|
||||
|
||||
private boolean hasPermission(String userId, String dataType) {
|
||||
// Implement permission checking
|
||||
if ("admin_data".equals(dataType)) {
|
||||
return userRepository.isAdmin(userId);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private String getSecureData(String dataType, String userId) {
|
||||
// Implement secure data retrieval
|
||||
if ("user_data".equals(dataType)) {
|
||||
return userDataService.getUserData(userId);
|
||||
}
|
||||
return "Data not available";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Server Configuration
|
||||
|
||||
### Multi-Transport Server Configuration
|
||||
|
||||
Configure MCP server with multiple transport options:
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class AdvancedMcpConfiguration {
|
||||
|
||||
@Bean
|
||||
public MCPServer advancedMcpServer(
|
||||
List<ToolProvider> toolProviders,
|
||||
List<ResourceListProvider> resourceProviders,
|
||||
List<PromptListProvider> promptProviders) {
|
||||
|
||||
return MCPServer.builder()
|
||||
.server(new StdioServer.Builder())
|
||||
.addToolProvider(toolProviders)
|
||||
.addResourceProvider(resourceProviders)
|
||||
.addPromptProvider(promptProviders)
|
||||
.enableLogging(true)
|
||||
.setLogHandler(new StructuredLogHandler())
|
||||
.enableHealthChecks(true)
|
||||
.setHealthCheckInterval(30) // seconds
|
||||
.setMaxConcurrentRequests(100)
|
||||
.setRequestTimeout(30) // seconds
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public HttpMcpTransport httpTransport() {
|
||||
return new HttpMcpTransport.Builder()
|
||||
.sseUrl("http://localhost:8080/mcp/sse")
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.setCorsEnabled(true)
|
||||
.setAllowedOrigins(List.of("http://localhost:3000"))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Client Integration Patterns
|
||||
|
||||
### Multi-Client MCP Integration
|
||||
|
||||
Integrate with multiple MCP servers for comprehensive functionality:
|
||||
|
||||
```java
|
||||
@Service
|
||||
public class MultiMcpIntegrationService {
|
||||
|
||||
private final List<McpClient> mcpClients;
|
||||
private final ChatModel chatModel;
|
||||
private final McpToolProvider toolProvider;
|
||||
|
||||
public MultiMcpIntegrationService(List<McpClient> mcpClients, ChatModel chatModel) {
|
||||
this.mcpClients = mcpClients;
|
||||
this.chatModel = chatModel;
|
||||
|
||||
// Create tool provider with multiple MCP clients
|
||||
this.toolProvider = McpToolProvider.builder()
|
||||
.mcpClients(mcpClients)
|
||||
.failIfOneServerFails(false) // Continue with available servers
|
||||
.filter((client, tool) -> {
|
||||
// Implement cross-server tool filtering
|
||||
return !tool.name().startsWith("deprecated_");
|
||||
})
|
||||
.build();
|
||||
}
|
||||
|
||||
public String processUserQuery(String userId, String query) {
|
||||
// Create AI service with multiple MCP integrations
|
||||
AIAssistant assistant = AiServices.builder(AIAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.toolProvider(toolProvider)
|
||||
.chatMemoryProvider(memoryProvider)
|
||||
.build();
|
||||
|
||||
return assistant.chat(userId, query);
|
||||
}
|
||||
|
||||
public List<ToolSpecification> getAvailableTools() {
|
||||
return mcpClients.stream()
|
||||
.flatMap(client -> {
|
||||
try {
|
||||
return client.listTools().stream();
|
||||
} catch (Exception e) {
|
||||
log.warn("Failed to list tools from client {}: {}", client.key(), e.getMessage());
|
||||
return Stream.empty();
|
||||
}
|
||||
})
|
||||
.distinct()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
These comprehensive examples provide a solid foundation for implementing MCP servers with LangChain4j, covering everything from basic setup to advanced enterprise patterns.
|
||||
@@ -0,0 +1,349 @@
|
||||
---
|
||||
name: langchain4j-rag-implementation-patterns
|
||||
description: Implement Retrieval-Augmented Generation (RAG) systems with LangChain4j. Build document ingestion pipelines, embedding stores, vector search strategies, and knowledge-enhanced AI applications. Use when creating question-answering systems over document collections or AI assistants with external knowledge bases.
|
||||
allowed-tools: Read, Write, Bash
|
||||
category: ai-development
|
||||
tags: [langchain4j, rag, retrieval-augmented-generation, embedding, vector-search, document-ingestion, java]
|
||||
version: 1.1.0
|
||||
---
|
||||
|
||||
# LangChain4j RAG Implementation Patterns
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Building knowledge-based AI applications requiring external document access
|
||||
- Implementing question-answering systems over large document collections
|
||||
- Creating AI assistants with access to company knowledge bases
|
||||
- Building semantic search capabilities for document repositories
|
||||
- Implementing chat systems that reference specific information sources
|
||||
- Creating AI applications requiring source attribution
|
||||
- Building domain-specific AI systems with curated knowledge
|
||||
- Implementing hybrid search combining vector similarity with traditional search
|
||||
- Creating AI applications requiring real-time document updates
|
||||
- Building multi-modal RAG systems with text, images, and other content types
|
||||
|
||||
## Overview
|
||||
|
||||
Implement complete Retrieval-Augmented Generation (RAG) systems with LangChain4j. RAG enhances language models by providing relevant context from external knowledge sources, improving accuracy and reducing hallucinations.
|
||||
|
||||
## Instructions
|
||||
|
||||
### Initialize RAG Project
|
||||
|
||||
Create a new Spring Boot project with required dependencies:
|
||||
|
||||
**pom.xml**:
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-spring-boot-starter</artifactId>
|
||||
<version>1.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>1.8.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Setup Document Ingestion
|
||||
|
||||
Configure document loading and processing:
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class RAGConfiguration {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("text-embedding-3-small")
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return new InMemoryEmbeddingStore<>();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Create document ingestion service:
|
||||
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class DocumentIngestionService {
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final EmbeddingStore<TextSegment> embeddingStore;
|
||||
|
||||
public void ingestDocument(String filePath, Map<String, Object> metadata) {
|
||||
Document document = FileSystemDocumentLoader.loadDocument(filePath);
|
||||
document.metadata().putAll(metadata);
|
||||
|
||||
DocumentSplitter splitter = DocumentSplitters.recursive(
|
||||
500, 50, new OpenAiTokenCountEstimator("text-embedding-3-small")
|
||||
);
|
||||
|
||||
List<TextSegment> segments = splitter.split(document);
|
||||
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
|
||||
embeddingStore.addAll(embeddings, segments);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Configure Content Retrieval
|
||||
|
||||
Setup content retrieval with filtering:
|
||||
|
||||
```java
|
||||
@Configuration
|
||||
public class ContentRetrieverConfiguration {
|
||||
|
||||
@Bean
|
||||
public ContentRetriever contentRetriever(
|
||||
EmbeddingStore<TextSegment> embeddingStore,
|
||||
EmbeddingModel embeddingModel) {
|
||||
|
||||
return EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(5)
|
||||
.minScore(0.7)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create RAG-Enabled AI Service
|
||||
|
||||
Define AI service with context retrieval:
|
||||
|
||||
```java
|
||||
interface KnowledgeAssistant {
|
||||
@SystemMessage("""
|
||||
You are a knowledgeable assistant with access to a comprehensive knowledge base.
|
||||
|
||||
When answering questions:
|
||||
1. Use the provided context from the knowledge base
|
||||
2. If information is not in the context, clearly state this
|
||||
3. Provide accurate, helpful responses
|
||||
4. When possible, reference specific sources
|
||||
5. If the context is insufficient, ask for clarification
|
||||
""")
|
||||
String answerQuestion(String question);
|
||||
}
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class KnowledgeService {
|
||||
|
||||
private final KnowledgeAssistant assistant;
|
||||
|
||||
public KnowledgeService(ChatModel chatModel, ContentRetriever contentRetriever) {
|
||||
this.assistant = AiServices.builder(KnowledgeAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.contentRetriever(contentRetriever)
|
||||
.build();
|
||||
}
|
||||
|
||||
public String answerQuestion(String question) {
|
||||
return assistant.answerQuestion(question);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Document Processing
|
||||
|
||||
```java
|
||||
public class BasicRAGExample {
|
||||
public static void main(String[] args) {
|
||||
var embeddingStore = new InMemoryEmbeddingStore<TextSegment>();
|
||||
|
||||
var embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("text-embedding-3-small")
|
||||
.build();
|
||||
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.build();
|
||||
|
||||
ingestor.ingest(Document.from("Spring Boot is a framework for building Java applications with minimal configuration."));
|
||||
|
||||
var retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Domain Assistant
|
||||
|
||||
```java
|
||||
interface MultiDomainAssistant {
|
||||
@SystemMessage("""
|
||||
You are an expert assistant with access to multiple knowledge domains:
|
||||
- Technical documentation
|
||||
- Company policies
|
||||
- Product information
|
||||
- Customer support guides
|
||||
|
||||
Tailor your response based on the type of question and available context.
|
||||
Always indicate which domain the information comes from.
|
||||
""")
|
||||
String answerQuestion(@MemoryId String userId, String question);
|
||||
}
|
||||
```
|
||||
|
||||
### Hierarchical RAG
|
||||
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class HierarchicalRAGService {
|
||||
|
||||
private final EmbeddingStore<TextSegment> chunkStore;
|
||||
private final EmbeddingStore<TextSegment> summaryStore;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
public String performHierarchicalRetrieval(String query) {
|
||||
List<EmbeddingMatch<TextSegment>> summaryMatches = searchSummaries(query);
|
||||
List<TextSegment> relevantChunks = new ArrayList<>();
|
||||
|
||||
for (EmbeddingMatch<TextSegment> summaryMatch : summaryMatches) {
|
||||
String documentId = summaryMatch.embedded().metadata().getString("documentId");
|
||||
List<EmbeddingMatch<TextSegment>> chunkMatches = searchChunksInDocument(query, documentId);
|
||||
chunkMatches.stream()
|
||||
.map(EmbeddingMatch::embedded)
|
||||
.forEach(relevantChunks::add);
|
||||
}
|
||||
|
||||
return generateResponseWithChunks(query, relevantChunks);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Document Segmentation
|
||||
|
||||
- Use recursive splitting with 500-1000 token chunks for most applications
|
||||
- Maintain 20-50 token overlap between chunks for context preservation
|
||||
- Consider document structure (headings, paragraphs) when splitting
|
||||
- Use token-aware splitters for optimal embedding generation
|
||||
|
||||
### Metadata Strategy
|
||||
|
||||
- Include rich metadata for filtering and attribution:
|
||||
- User and tenant identifiers for multi-tenancy
|
||||
- Document type and category classification
|
||||
- Creation and modification timestamps
|
||||
- Version and author information
|
||||
- Confidentiality and access level tags
|
||||
|
||||
### Query Processing
|
||||
|
||||
- Implement query preprocessing and cleaning
|
||||
- Consider query expansion for better recall
|
||||
- Apply dynamic filtering based on user context
|
||||
- Use re-ranking for improved result quality
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
- Cache embeddings for repeated queries
|
||||
- Use batch embedding generation for bulk operations
|
||||
- Implement pagination for large result sets
|
||||
- Consider asynchronous processing for long operations
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Simple RAG Pipeline
|
||||
|
||||
```java
|
||||
@RequiredArgsConstructor
|
||||
@Service
|
||||
public class SimpleRAGPipeline {
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final EmbeddingStore<TextSegment> embeddingStore;
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public String answerQuestion(String question) {
|
||||
Embedding queryEmbedding = embeddingModel.embed(question).content();
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(queryEmbedding)
|
||||
.maxResults(3)
|
||||
.build();
|
||||
|
||||
List<TextSegment> segments = embeddingStore.search(request).matches().stream()
|
||||
.map(EmbeddingMatch::embedded)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
String context = segments.stream()
|
||||
.map(TextSegment::text)
|
||||
.collect(Collectors.joining("\n\n"));
|
||||
|
||||
return chatModel.generate(context + "\n\nQuestion: " + question + "\nAnswer:");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Hybrid Search (Vector + Keyword)
|
||||
|
||||
```java
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class HybridSearchService {
|
||||
|
||||
private final EmbeddingStore<TextSegment> vectorStore;
|
||||
private final FullTextSearchEngine keywordEngine;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
public List<Content> hybridSearch(String query, int maxResults) {
|
||||
// Vector search
|
||||
List<Content> vectorResults = performVectorSearch(query, maxResults);
|
||||
|
||||
// Keyword search
|
||||
List<Content> keywordResults = performKeywordSearch(query, maxResults);
|
||||
|
||||
// Combine and re-rank using RRF algorithm
|
||||
return combineResults(vectorResults, keywordResults, maxResults);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Poor Retrieval Results**
|
||||
- Check document chunk size and overlap settings
|
||||
- Verify embedding model compatibility
|
||||
- Ensure metadata filters are not too restrictive
|
||||
- Consider adding re-ranking step
|
||||
|
||||
**Slow Performance**
|
||||
- Use cached embeddings for frequent queries
|
||||
- Optimize database indexing for vector stores
|
||||
- Implement pagination for large datasets
|
||||
- Consider async processing for bulk operations
|
||||
|
||||
**High Memory Usage**
|
||||
- Use disk-based embedding stores for large datasets
|
||||
- Implement proper pagination and filtering
|
||||
- Clean up unused embeddings periodically
|
||||
- Monitor and optimize chunk sizes
|
||||
|
||||
## References
|
||||
|
||||
- [API Reference](references/references.md) - Complete API documentation and interfaces
|
||||
- [Examples](references/examples.md) - Production-ready examples and patterns
|
||||
- [Official LangChain4j Documentation](https://docs.langchain4j.dev/)
|
||||
@@ -0,0 +1,482 @@
|
||||
# LangChain4j RAG Implementation - Practical Examples
|
||||
|
||||
Production-ready examples for implementing Retrieval-Augmented Generation (RAG) systems with LangChain4j.
|
||||
|
||||
## 1. Simple In-Memory RAG
|
||||
|
||||
**Scenario**: Quick RAG setup with documents in memory for development/testing.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
||||
|
||||
interface DocumentAssistant {
|
||||
String answer(String question);
|
||||
}
|
||||
|
||||
public class SimpleRagExample {
|
||||
public static void main(String[] args) {
|
||||
// Setup
|
||||
var embeddingStore = new InMemoryEmbeddingStore<TextSegment>();
|
||||
|
||||
var embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("text-embedding-3-small")
|
||||
.build();
|
||||
|
||||
var chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("gpt-4o-mini")
|
||||
.build();
|
||||
|
||||
// Ingest documents
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.build();
|
||||
|
||||
ingestor.ingest(Document.from("Spring Boot is a framework for building Java applications with minimal configuration."));
|
||||
ingestor.ingest(Document.from("Spring Data JPA provides data access abstraction using repositories."));
|
||||
ingestor.ingest(Document.from("Spring Cloud enables building distributed systems and microservices."));
|
||||
|
||||
// Create retriever and AI service
|
||||
var contentRetriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(3)
|
||||
.minScore(0.7)
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(DocumentAssistant.class)
|
||||
.chatModel(chatModel)
|
||||
.contentRetriever(contentRetriever)
|
||||
.build();
|
||||
|
||||
// Query with RAG
|
||||
System.out.println(assistant.answer("What is Spring Boot?"));
|
||||
System.out.println(assistant.answer("What does Spring Data JPA do?"));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2. Vector Database RAG (Pinecone)
|
||||
|
||||
**Scenario**: Production RAG with persistent vector database.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.store.embedding.pinecone.PineconeEmbeddingStore;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
|
||||
public class PineconeRagExample {
|
||||
public static void main(String[] args) {
|
||||
// Production vector store
|
||||
var embeddingStore = PineconeEmbeddingStore.builder()
|
||||
.apiKey(System.getenv("PINECONE_API_KEY"))
|
||||
.index("docs-index")
|
||||
.namespace("production")
|
||||
.build();
|
||||
|
||||
var embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.build();
|
||||
|
||||
// Ingest with metadata
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
.documentTransformer(doc -> {
|
||||
doc.metadata().put("source", "documentation");
|
||||
doc.metadata().put("date", LocalDate.now().toString());
|
||||
return doc;
|
||||
})
|
||||
.documentSplitter(DocumentSplitters.recursive(1000, 200))
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.build();
|
||||
|
||||
ingestor.ingest(Document.from("Your large document..."));
|
||||
|
||||
// Retrieve with filters
|
||||
var retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(5)
|
||||
.dynamicFilter(query ->
|
||||
new IsEqualTo("source", "documentation")
|
||||
)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Document Loading and Splitting
|
||||
|
||||
**Scenario**: Load documents from various sources and split intelligently.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.DocumentSplitter;
|
||||
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
|
||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenCountEstimator;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
|
||||
public class DocumentProcessingExample {
|
||||
public static void main(String[] args) {
|
||||
// Load from filesystem
|
||||
Path docPath = Paths.get("documents");
|
||||
List<Document> documents = FileSystemDocumentLoader.load(docPath);
|
||||
|
||||
// Smart recursive splitting with token counting
|
||||
DocumentSplitter splitter = DocumentSplitters.recursive(
|
||||
500, // Max tokens per segment
|
||||
50, // Overlap tokens
|
||||
new OpenAiTokenCountEstimator("gpt-4o-mini")
|
||||
);
|
||||
|
||||
// Process documents
|
||||
for (Document doc : documents) {
|
||||
List<TextSegment> segments = splitter.split(doc);
|
||||
System.out.println("Document split into " + segments.size() + " segments");
|
||||
|
||||
segments.forEach(segment -> {
|
||||
System.out.println("Text: " + segment.text());
|
||||
System.out.println("Metadata: " + segment.metadata());
|
||||
});
|
||||
}
|
||||
|
||||
// Alternative: Character-based splitting
|
||||
DocumentSplitter charSplitter = DocumentSplitters.recursive(
|
||||
1000, // Max characters
|
||||
100 // Overlap characters
|
||||
);
|
||||
|
||||
// Alternative: Paragraph-based splitting
|
||||
DocumentSplitter paraSplitter = DocumentSplitters.byParagraph(500, 50);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Metadata Filtering in RAG
|
||||
|
||||
**Scenario**: Search with complex metadata filters for multi-tenant RAG.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.store.embedding.filter.comparison.*;
|
||||
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
||||
|
||||
public class MetadataFilteringExample {
|
||||
public static void main(String[] args) {
|
||||
var retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
|
||||
// Single filter: user isolation
|
||||
.filter(new IsEqualTo("userId", "user123"))
|
||||
|
||||
// Complex AND filter
|
||||
.filter(new And(
|
||||
new IsEqualTo("department", "engineering"),
|
||||
new IsEqualTo("status", "active")
|
||||
))
|
||||
|
||||
// OR filter: multiple categories
|
||||
.filter(new Or(
|
||||
new IsEqualTo("category", "tutorial"),
|
||||
new IsEqualTo("category", "guide")
|
||||
))
|
||||
|
||||
// NOT filter: exclude deprecated
|
||||
.filter(new Not(
|
||||
new IsEqualTo("deprecated", "true")
|
||||
))
|
||||
|
||||
// Numeric filters
|
||||
.filter(new IsGreaterThan("relevance", 0.8))
|
||||
.filter(new IsLessThanOrEqualTo("createdDaysAgo", 30))
|
||||
|
||||
// Multiple conditions
|
||||
.dynamicFilter(query -> {
|
||||
String userId = extractUserFromQuery(query);
|
||||
return new And(
|
||||
new IsEqualTo("userId", userId),
|
||||
new IsGreaterThan("score", 0.7)
|
||||
);
|
||||
})
|
||||
|
||||
.build();
|
||||
}
|
||||
|
||||
private static String extractUserFromQuery(Object query) {
|
||||
// Extract user context
|
||||
return "user123";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 5. Document Transformation Pipeline
|
||||
|
||||
**Scenario**: Transform documents with custom metadata before ingestion.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import java.time.LocalDate;
|
||||
|
||||
public class DocumentTransformationExample {
|
||||
public static void main(String[] args) {
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
|
||||
// Add metadata to each document
|
||||
.documentTransformer(doc -> {
|
||||
doc.metadata().put("ingested_date", LocalDate.now().toString());
|
||||
doc.metadata().put("source_system", "internal");
|
||||
doc.metadata().put("version", "1.0");
|
||||
return doc;
|
||||
})
|
||||
|
||||
// Split documents intelligently
|
||||
.documentSplitter(DocumentSplitters.recursive(500, 50))
|
||||
|
||||
// Transform each segment (e.g., add filename)
|
||||
.textSegmentTransformer(segment -> {
|
||||
String fileName = segment.metadata().getString("file_name", "unknown");
|
||||
String enrichedText = "File: " + fileName + "\n" + segment.text();
|
||||
return TextSegment.from(enrichedText, segment.metadata());
|
||||
})
|
||||
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.build();
|
||||
|
||||
// Ingest with tracking
|
||||
IngestionResult result = ingestor.ingest(document);
|
||||
System.out.println("Tokens ingested: " + result.tokenUsage().totalTokenCount());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 6. Hybrid Search (Vector + Full-Text)
|
||||
|
||||
**Scenario**: Combine semantic search with keyword search for better recall.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.store.embedding.neo4j.Neo4jEmbeddingStore;
|
||||
|
||||
public class HybridSearchExample {
|
||||
public static void main(String[] args) {
|
||||
// Configure Neo4j for hybrid search
|
||||
var embeddingStore = Neo4jEmbeddingStore.builder()
|
||||
.withBasicAuth("bolt://localhost:7687", "neo4j", "password")
|
||||
.dimension(1536)
|
||||
|
||||
// Enable full-text search
|
||||
.fullTextIndexName("documents_fulltext")
|
||||
.autoCreateFullText(true)
|
||||
|
||||
// Query for full-text context
|
||||
.fullTextQuery("Spring OR Boot")
|
||||
|
||||
.build();
|
||||
|
||||
var retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(5)
|
||||
.build();
|
||||
|
||||
// Search combines both vector similarity and full-text keywords
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 7. Advanced RAG with Query Transformation
|
||||
|
||||
**Scenario**: Transform user queries before retrieval for better results.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
|
||||
import dev.langchain4j.rag.query.transformer.CompressingQueryTransformer;
|
||||
import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator;
|
||||
import dev.langchain4j.model.cohere.CohereScoringModel;
|
||||
|
||||
public class AdvancedRagExample {
|
||||
public static void main(String[] args) {
|
||||
// Scoring model for re-ranking
|
||||
var scoringModel = CohereScoringModel.builder()
|
||||
.apiKey(System.getenv("COHERE_API_KEY"))
|
||||
.build();
|
||||
|
||||
// Advanced retrieval augmentor
|
||||
var augmentor = DefaultRetrievalAugmentor.builder()
|
||||
|
||||
// Transform query for better context
|
||||
.queryTransformer(new CompressingQueryTransformer(chatModel))
|
||||
|
||||
// Retrieve relevant content
|
||||
.contentRetriever(EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(10)
|
||||
.minScore(0.6)
|
||||
.build())
|
||||
|
||||
// Re-rank results by relevance
|
||||
.contentAggregator(ReRankingContentAggregator.builder()
|
||||
.scoringModel(scoringModel)
|
||||
.minScore(0.8)
|
||||
.build())
|
||||
|
||||
.build();
|
||||
|
||||
// Use with AI Service
|
||||
var assistant = AiServices.builder(QuestionAnswering.class)
|
||||
.chatModel(chatModel)
|
||||
.retrievalAugmentor(augmentor)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 8. Multi-User RAG with Isolation
|
||||
|
||||
**Scenario**: Per-user vector stores for data isolation.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class MultiUserRagExample {
|
||||
private final Map<String, EmbeddingStore<TextSegment>> userStores = new HashMap<>();
|
||||
|
||||
public void ingestForUser(String userId, Document document) {
|
||||
var store = userStores.computeIfAbsent(userId,
|
||||
k -> new InMemoryEmbeddingStore<>());
|
||||
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(store)
|
||||
.build();
|
||||
|
||||
ingestor.ingest(document);
|
||||
}
|
||||
|
||||
public String askQuestion(String userId, String question) {
|
||||
var store = userStores.get(userId);
|
||||
|
||||
var retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(store)
|
||||
.embeddingModel(embeddingModel)
|
||||
.maxResults(3)
|
||||
.build();
|
||||
|
||||
var assistant = AiServices.builder(QuestionAnswering.class)
|
||||
.chatModel(chatModel)
|
||||
.contentRetriever(retriever)
|
||||
.build();
|
||||
|
||||
return assistant.answer(question);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 9. Streaming RAG with Content Access
|
||||
|
||||
**Scenario**: Stream RAG responses while accessing retrieved content.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.service.TokenStream;
|
||||
|
||||
interface StreamingRagAssistant {
|
||||
TokenStream streamAnswer(String question);
|
||||
}
|
||||
|
||||
public class StreamingRagExample {
|
||||
public static void main(String[] args) {
|
||||
var assistant = AiServices.builder(StreamingRagAssistant.class)
|
||||
.streamingChatModel(streamingModel)
|
||||
.contentRetriever(contentRetriever)
|
||||
.build();
|
||||
|
||||
assistant.streamAnswer("What is Spring Boot?")
|
||||
.onRetrieved(contents -> {
|
||||
System.out.println("=== Retrieved Content ===");
|
||||
contents.forEach(content ->
|
||||
System.out.println("Score: " + content.score() +
|
||||
", Text: " + content.textSegment().text()));
|
||||
})
|
||||
.onNext(token -> System.out.print(token))
|
||||
.onCompleteResponse(response ->
|
||||
System.out.println("\n=== Complete ==="))
|
||||
.onError(error -> System.err.println("Error: " + error))
|
||||
.start();
|
||||
|
||||
try {
|
||||
Thread.sleep(5000);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 10. Batch Document Ingestion
|
||||
|
||||
**Scenario**: Efficiently ingest large document collections.
|
||||
|
||||
```java
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class BatchIngestionExample {
|
||||
public static void main(String[] args) {
|
||||
var ingestor = EmbeddingStoreIngestor.builder()
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.documentSplitter(DocumentSplitters.recursive(500, 50))
|
||||
.build();
|
||||
|
||||
// Load batch of documents
|
||||
List<Document> documents = new ArrayList<>();
|
||||
for (int i = 1; i <= 100; i++) {
|
||||
documents.add(Document.from("Content " + i));
|
||||
}
|
||||
|
||||
// Ingest all at once
|
||||
IngestionResult result = ingestor.ingest(documents);
|
||||
|
||||
System.out.println("Documents ingested: " + documents.size());
|
||||
System.out.println("Total tokens: " + result.tokenUsage().totalTokenCount());
|
||||
|
||||
// Track progress
|
||||
long tokensPerDoc = result.tokenUsage().totalTokenCount() / documents.size();
|
||||
System.out.println("Average tokens per document: " + tokensPerDoc);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Batch Processing**: Ingest documents in batches to optimize embedding API calls
|
||||
2. **Document Splitting**: Use recursive splitting for better semantic chunks
|
||||
3. **Metadata**: Add minimal metadata to reduce embedding overhead
|
||||
4. **Vector DB**: Choose appropriate vector DB based on scale (in-memory for dev, Pinecone/Weaviate for prod)
|
||||
5. **Similarity Threshold**: Adjust minScore based on use case (0.7-0.85 typical)
|
||||
6. **Max Results**: Return top 3-5 results unless specific needs require more
|
||||
7. **Caching**: Cache frequently retrieved content to reduce API calls
|
||||
8. **Async Ingestion**: Use async ingestion for large datasets
|
||||
9. **Monitoring**: Track token usage and retrieval quality metrics
|
||||
10. **Testing**: Use in-memory store for unit tests, external DB for integration tests
|
||||
@@ -0,0 +1,506 @@
|
||||
# LangChain4j RAG Implementation - API References
|
||||
|
||||
Complete API reference for implementing RAG systems with LangChain4j.
|
||||
|
||||
## Document Loading
|
||||
|
||||
### Document Loaders
|
||||
|
||||
**FileSystemDocumentLoader**: Load from filesystem.
|
||||
```java
|
||||
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
|
||||
import java.nio.file.Path;
|
||||
|
||||
List<Document> documents = FileSystemDocumentLoader.load("documents");
|
||||
List<Document> single = FileSystemDocumentLoader.load("document.pdf");
|
||||
```
|
||||
|
||||
**ClassPathDocumentLoader**: Load from classpath resources.
|
||||
```java
|
||||
List<Document> resources = ClassPathDocumentLoader.load("documents");
|
||||
```
|
||||
|
||||
**UrlDocumentLoader**: Load from web URLs.
|
||||
```java
|
||||
Document webDoc = UrlDocumentLoader.load("https://example.com/doc.html");
|
||||
```
|
||||
|
||||
## Document Splitting
|
||||
|
||||
### DocumentSplitter Interface
|
||||
|
||||
```java
|
||||
interface DocumentSplitter {
|
||||
List<TextSegment> split(Document document);
|
||||
List<TextSegment> splitAll(Collection<Document> documents);
|
||||
}
|
||||
```
|
||||
|
||||
### DocumentSplitters Factory
|
||||
|
||||
**Recursive Split**: Smart recursive splitting by paragraphs, sentences, words.
|
||||
```java
|
||||
DocumentSplitter splitter = DocumentSplitters.recursive(
|
||||
500, // Max segment size (tokens or characters)
|
||||
50 // Overlap size
|
||||
);
|
||||
|
||||
// With token counting
|
||||
DocumentSplitter splitter = DocumentSplitters.recursive(
|
||||
500,
|
||||
50,
|
||||
new OpenAiTokenCountEstimator("gpt-4o-mini")
|
||||
);
|
||||
```
|
||||
|
||||
**Paragraph Split**: Split by paragraphs.
|
||||
```java
|
||||
DocumentSplitter splitter = DocumentSplitters.byParagraph(500, 50);
|
||||
```
|
||||
|
||||
**Sentence Split**: Split by sentences.
|
||||
```java
|
||||
DocumentSplitter splitter = DocumentSplitters.bySentence(500, 50);
|
||||
```
|
||||
|
||||
**Line Split**: Split by lines.
|
||||
```java
|
||||
DocumentSplitter splitter = DocumentSplitters.byLine(500, 50);
|
||||
```
|
||||
|
||||
## Embedding Models
|
||||
|
||||
### EmbeddingModel Interface
|
||||
|
||||
```java
|
||||
public interface EmbeddingModel {
|
||||
// Embed single text
|
||||
Response<Embedding> embed(String text);
|
||||
Response<Embedding> embed(TextSegment textSegment);
|
||||
|
||||
// Batch embedding
|
||||
Response<List<Embedding>> embedAll(List<TextSegment> textSegments);
|
||||
|
||||
// Model dimension
|
||||
int dimension();
|
||||
}
|
||||
```
|
||||
|
||||
### OpenAI Embedding Model
|
||||
|
||||
```java
|
||||
EmbeddingModel model = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName("text-embedding-3-small") // or text-embedding-3-large
|
||||
.dimensions(512) // Optional: reduce dimensions
|
||||
.timeout(Duration.ofSeconds(30))
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Other Embedding Models
|
||||
|
||||
```java
|
||||
// Google Vertex AI
|
||||
EmbeddingModel google = VertexAiEmbeddingModel.builder()
|
||||
.project("PROJECT_ID")
|
||||
.location("us-central1")
|
||||
.modelName("textembedding-gecko")
|
||||
.build();
|
||||
|
||||
// Ollama (local)
|
||||
EmbeddingModel ollama = OllamaEmbeddingModel.builder()
|
||||
.baseUrl("http://localhost:11434")
|
||||
.modelName("all-minilm")
|
||||
.build();
|
||||
|
||||
// AllMiniLmL6V2 (offline)
|
||||
EmbeddingModel offline = new AllMiniLmL6V2EmbeddingModel();
|
||||
```
|
||||
|
||||
## Vector Stores (EmbeddingStore)
|
||||
|
||||
### EmbeddingStore Interface
|
||||
|
||||
```java
|
||||
public interface EmbeddingStore<Embedded> {
|
||||
// Add embeddings
|
||||
String add(Embedding embedding);
|
||||
String add(String id, Embedding embedding);
|
||||
String add(Embedding embedding, Embedded embedded);
|
||||
List<String> addAll(List<Embedding> embeddings);
|
||||
List<String> addAll(List<Embedding> embeddings, List<Embedded> embeddeds);
|
||||
List<String> addAll(List<String> ids, List<Embedding> embeddings, List<Embedded> embeddeds);
|
||||
|
||||
// Search embeddings
|
||||
EmbeddingSearchResult<Embedded> search(EmbeddingSearchRequest request);
|
||||
|
||||
// Remove embeddings
|
||||
void remove(String id);
|
||||
void removeAll(Collection<String> ids);
|
||||
void removeAll(Filter filter);
|
||||
void removeAll();
|
||||
}
|
||||
```
|
||||
|
||||
### In-Memory Store
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = new InMemoryEmbeddingStore<>();
|
||||
|
||||
// Merge stores
|
||||
InMemoryEmbeddingStore<TextSegment> merged = InMemoryEmbeddingStore.merge(
|
||||
store1, store2, store3
|
||||
);
|
||||
```
|
||||
|
||||
### Pinecone
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = PineconeEmbeddingStore.builder()
|
||||
.apiKey(System.getenv("PINECONE_API_KEY"))
|
||||
.index("my-index")
|
||||
.namespace("production")
|
||||
.environment("gcp-starter") // or "aws-us-east-1"
|
||||
.build();
|
||||
```
|
||||
|
||||
### Weaviate
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = WeaviateEmbeddingStore.builder()
|
||||
.host("localhost")
|
||||
.port(8080)
|
||||
.scheme("http")
|
||||
.collectionName("Documents")
|
||||
.build();
|
||||
```
|
||||
|
||||
### Qdrant
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = QdrantEmbeddingStore.builder()
|
||||
.host("localhost")
|
||||
.port(6333)
|
||||
.collectionName("documents")
|
||||
.build();
|
||||
```
|
||||
|
||||
### Chroma
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = ChromaEmbeddingStore.builder()
|
||||
.baseUrl("http://localhost:8000")
|
||||
.collectionName("my-collection")
|
||||
.build();
|
||||
```
|
||||
|
||||
### Neo4j
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = Neo4jEmbeddingStore.builder()
|
||||
.withBasicAuth("bolt://localhost:7687", "neo4j", "password")
|
||||
.dimension(1536)
|
||||
.label("Document")
|
||||
.build();
|
||||
```
|
||||
|
||||
### MongoDB Atlas
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = MongoDbEmbeddingStore.builder()
|
||||
.databaseName("search")
|
||||
.collectionName("documents")
|
||||
.indexName("vector_index")
|
||||
.createIndex(true)
|
||||
.fromClient(mongoClient)
|
||||
.build();
|
||||
```
|
||||
|
||||
### PostgreSQL (pgvector)
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = PgVectorEmbeddingStore.builder()
|
||||
.host("localhost")
|
||||
.port(5432)
|
||||
.database("embeddings")
|
||||
.user("postgres")
|
||||
.password("password")
|
||||
.table("embeddings")
|
||||
.createTableIfNotExists(true)
|
||||
.build();
|
||||
```
|
||||
|
||||
### Milvus
|
||||
|
||||
```java
|
||||
EmbeddingStore<TextSegment> store = MilvusEmbeddingStore.builder()
|
||||
.host("localhost")
|
||||
.port(19530)
|
||||
.collectionName("documents")
|
||||
.dimension(1536)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Document Ingestion
|
||||
|
||||
### EmbeddingStoreIngestor
|
||||
|
||||
```java
|
||||
public class EmbeddingStoreIngestor {
|
||||
public static Builder builder();
|
||||
|
||||
public IngestionResult ingest(Document document);
|
||||
public IngestionResult ingest(Document... documents);
|
||||
public IngestionResult ingest(Collection<Document> documents);
|
||||
}
|
||||
```
|
||||
|
||||
### Building an Ingestor
|
||||
|
||||
```java
|
||||
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
|
||||
|
||||
// Document transformation
|
||||
.documentTransformer(doc -> {
|
||||
doc.metadata().put("source", "manual");
|
||||
return doc;
|
||||
})
|
||||
|
||||
// Document splitting strategy
|
||||
.documentSplitter(DocumentSplitters.recursive(500, 50))
|
||||
|
||||
// Text segment transformation
|
||||
.textSegmentTransformer(segment -> {
|
||||
String enhanced = "Category: Spring\n" + segment.text();
|
||||
return TextSegment.from(enhanced, segment.metadata());
|
||||
})
|
||||
|
||||
// Embedding model (required)
|
||||
.embeddingModel(embeddingModel)
|
||||
|
||||
// Embedding store (required)
|
||||
.embeddingStore(embeddingStore)
|
||||
|
||||
.build();
|
||||
```
|
||||
|
||||
### IngestionResult
|
||||
|
||||
```java
|
||||
IngestionResult result = ingestor.ingest(documents);
|
||||
|
||||
// Access results
|
||||
TokenUsage usage = result.tokenUsage();
|
||||
long totalTokens = usage.totalTokenCount();
|
||||
long inputTokens = usage.inputTokenCount();
|
||||
```
|
||||
|
||||
## Content Retrieval
|
||||
|
||||
### EmbeddingSearchRequest
|
||||
|
||||
```java
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding) // Required
|
||||
.maxResults(5) // Default: 3
|
||||
.minScore(0.7) // Threshold 0-1
|
||||
.filter(new IsEqualTo("category", "tutorial"))
|
||||
.build();
|
||||
```
|
||||
|
||||
### EmbeddingSearchResult
|
||||
|
||||
```java
|
||||
EmbeddingSearchResult<TextSegment> result = store.search(request);
|
||||
List<EmbeddingMatch<TextSegment>> matches = result.matches();
|
||||
|
||||
for (EmbeddingMatch<TextSegment> match : matches) {
|
||||
double score = match.score(); // Relevance 0-1
|
||||
TextSegment segment = match.embedded(); // Retrieved content
|
||||
String id = match.embeddingId(); // Store ID
|
||||
}
|
||||
```
|
||||
|
||||
### ContentRetriever Interface
|
||||
|
||||
```java
|
||||
public interface ContentRetriever {
|
||||
Content retrieve(Query query);
|
||||
List<Content> retrieveAll(List<Query> queries);
|
||||
}
|
||||
```
|
||||
|
||||
### EmbeddingStoreContentRetriever
|
||||
|
||||
```java
|
||||
ContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
|
||||
.embeddingStore(embeddingStore)
|
||||
.embeddingModel(embeddingModel)
|
||||
|
||||
// Static configuration
|
||||
.maxResults(5)
|
||||
.minScore(0.7)
|
||||
|
||||
// Dynamic configuration per query
|
||||
.dynamicMaxResults(query -> 10)
|
||||
.dynamicMinScore(query -> 0.8)
|
||||
.dynamicFilter(query ->
|
||||
new IsEqualTo("userId", extractUserId(query))
|
||||
)
|
||||
|
||||
.build();
|
||||
```
|
||||
|
||||
## Advanced RAG
|
||||
|
||||
### RetrievalAugmentor
|
||||
|
||||
```java
|
||||
public interface RetrievalAugmentor {
|
||||
AugmentationResult augment(UserMessage message);
|
||||
AugmentationResult augmentAll(List<UserMessage> messages);
|
||||
}
|
||||
```
|
||||
|
||||
### DefaultRetrievalAugmentor
|
||||
|
||||
```java
|
||||
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
|
||||
|
||||
// Query transformation
|
||||
.queryTransformer(new CompressingQueryTransformer(chatModel))
|
||||
|
||||
// Content retrieval
|
||||
.contentRetriever(contentRetriever)
|
||||
|
||||
// Content aggregation and re-ranking
|
||||
.contentAggregator(ReRankingContentAggregator.builder()
|
||||
.scoringModel(scoringModel)
|
||||
.minScore(0.8)
|
||||
.build())
|
||||
|
||||
// Parallelization
|
||||
.executor(customExecutor)
|
||||
|
||||
.build();
|
||||
```
|
||||
|
||||
### Use with AI Services
|
||||
|
||||
```java
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
.chatModel(chatModel)
|
||||
.retrievalAugmentor(augmentor)
|
||||
.build();
|
||||
```
|
||||
|
||||
## Metadata and Filtering
|
||||
|
||||
### Metadata Object
|
||||
|
||||
```java
|
||||
// Create from map
|
||||
Metadata meta = Metadata.from(Map.of(
|
||||
"userId", "user123",
|
||||
"category", "tutorial",
|
||||
"score", 0.95
|
||||
));
|
||||
|
||||
// Add entries
|
||||
meta.put("status", "active");
|
||||
meta.put("version", 2);
|
||||
|
||||
// Retrieve entries
|
||||
String userId = meta.getString("userId");
|
||||
int version = meta.getInt("version");
|
||||
double score = meta.getDouble("score");
|
||||
|
||||
// Check existence
|
||||
boolean has = meta.containsKey("userId");
|
||||
|
||||
// Remove entry
|
||||
meta.remove("userId");
|
||||
|
||||
// Merge
|
||||
Metadata other = Metadata.from(Map.of("source", "db"));
|
||||
meta.merge(other);
|
||||
```
|
||||
|
||||
### Filter Operations
|
||||
|
||||
```java
|
||||
import dev.langchain4j.store.embedding.filter.comparison.*;
|
||||
import dev.langchain4j.store.embedding.filter.logical.*;
|
||||
|
||||
// Equality
|
||||
Filter filter = new IsEqualTo("status", "active");
|
||||
Filter filter = new IsNotEqualTo("deprecated", "true");
|
||||
|
||||
// Comparison
|
||||
Filter filter = new IsGreaterThan("score", 0.8);
|
||||
Filter filter = new IsLessThanOrEqualTo("daysOld", 30);
|
||||
Filter filter = new IsGreaterThanOrEqualTo("priority", 5);
|
||||
Filter filter = new IsLessThan("errorRate", 0.01);
|
||||
|
||||
// Membership
|
||||
Filter filter = new IsIn("category", Arrays.asList("tech", "guide"));
|
||||
Filter filter = new IsNotIn("status", Arrays.asList("archived"));
|
||||
|
||||
// String operations
|
||||
Filter filter = new ContainsString("content", "Spring");
|
||||
|
||||
// Logical operations
|
||||
Filter filter = new And(
|
||||
new IsEqualTo("userId", "123"),
|
||||
new IsGreaterThan("score", 0.7)
|
||||
);
|
||||
|
||||
Filter filter = new Or(
|
||||
new IsEqualTo("type", "doc"),
|
||||
new IsEqualTo("type", "guide")
|
||||
);
|
||||
|
||||
Filter filter = new Not(new IsEqualTo("archived", "true"));
|
||||
```
|
||||
|
||||
## TextSegment
|
||||
|
||||
### Creating TextSegments
|
||||
|
||||
```java
|
||||
// Text only
|
||||
TextSegment segment = TextSegment.from("This is the content");
|
||||
|
||||
// With metadata
|
||||
Metadata metadata = Metadata.from(Map.of("source", "docs"));
|
||||
TextSegment segment = TextSegment.from("Content", metadata);
|
||||
|
||||
// Accessing
|
||||
String text = segment.text();
|
||||
Metadata meta = segment.metadata();
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Chunk Size**: Use 300-500 tokens per chunk for optimal balance
|
||||
2. **Overlap**: Use 10-50 token overlap for semantic continuity
|
||||
3. **Metadata**: Include source and timestamp for traceability
|
||||
4. **Batch Processing**: Ingest documents in batches when possible
|
||||
5. **Similarity Threshold**: Adjust minScore (0.7-0.85) based on precision/recall needs
|
||||
6. **Vector DB Selection**: In-memory for dev/test, Pinecone/Qdrant for production
|
||||
7. **Filtering**: Pre-filter by metadata to reduce search space
|
||||
8. **Re-ranking**: Use scoring models for better relevance in production
|
||||
9. **Monitoring**: Track retrieval quality metrics
|
||||
10. **Testing**: Use small in-memory stores for unit tests
|
||||
|
||||
## Performance Tips
|
||||
|
||||
- Use recursive splitting for semantic coherence
|
||||
- Enable batch processing for large datasets
|
||||
- Use dynamic max results based on query complexity
|
||||
- Cache embedding model for frequently accessed content
|
||||
- Implement async ingestion for large document collections
|
||||
- Monitor token usage for cost optimization
|
||||
- Use appropriate vector DB indexes for scale
|
||||
130
skills/langchain4j/langchain4j-spring-boot-integration/SKILL.md
Normal file
130
skills/langchain4j/langchain4j-spring-boot-integration/SKILL.md
Normal file
@@ -0,0 +1,130 @@
|
||||
---
|
||||
name: langchain4j-spring-boot-integration
|
||||
description: Integration patterns for LangChain4j with Spring Boot. Auto-configuration, dependency injection, and Spring ecosystem integration. Use when embedding LangChain4j into Spring Boot applications.
|
||||
category: ai-development
|
||||
tags: [langchain4j, spring-boot, ai, llm, rag, chatbot, integration, configuration, java]
|
||||
version: 1.1.0
|
||||
allowed-tools: Read, Write, Bash, Grep
|
||||
---
|
||||
|
||||
# LangChain4j Spring Boot Integration
|
||||
|
||||
To accomplish integration of LangChain4j with Spring Boot applications, follow this comprehensive guidance covering auto-configuration, declarative AI Services, chat models, embedding stores, and production-ready patterns for building AI-powered applications.
|
||||
|
||||
## When to Use
|
||||
|
||||
To accomplish integration of LangChain4j with Spring Boot when:
|
||||
- Integrating LangChain4j into existing Spring Boot applications
|
||||
- Building AI-powered microservices with Spring Boot
|
||||
- Setting up auto-configuration for AI models and services
|
||||
- Creating declarative AI Services with Spring dependency injection
|
||||
- Configuring multiple AI providers (OpenAI, Azure, Ollama, etc.)
|
||||
- Implementing RAG systems with Spring Boot
|
||||
- Setting up observability and monitoring for AI components
|
||||
- Building production-ready AI applications with Spring Boot
|
||||
|
||||
## Overview
|
||||
|
||||
LangChain4j Spring Boot integration provides declarative AI Services through Spring Boot starters, enabling automatic configuration of AI components based on properties. The integration combines the power of Spring dependency injection with LangChain4j's AI capabilities, allowing developers to create AI-powered applications using interface-based definitions with annotations.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
To accomplish basic setup of LangChain4j with Spring Boot:
|
||||
|
||||
**Add Dependencies:**
|
||||
```xml
|
||||
<!-- Core LangChain4j -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-spring-boot-starter</artifactId>
|
||||
<version>1.8.0</version> // Use latest version
|
||||
</dependency>
|
||||
|
||||
<!-- OpenAI Integration -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai-spring-boot-starter</artifactId>
|
||||
<version>1.8.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
**Configure Properties:**
|
||||
```properties
|
||||
# application.properties
|
||||
langchain4j.open-ai.chat-model.api-key=${OPENAI_API_KEY}
|
||||
langchain4j.open-ai.chat-model.model-name=gpt-4o-mini
|
||||
langchain4j.open-ai.chat-model.temperature=0.7
|
||||
```
|
||||
|
||||
**Create Declarative AI Service:**
|
||||
```java
|
||||
@AiService
|
||||
interface CustomerSupportAssistant {
|
||||
@SystemMessage("You are a helpful customer support agent for TechCorp.")
|
||||
String handleInquiry(String customerMessage);
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
To accomplish Spring Boot configuration for LangChain4j:
|
||||
|
||||
**Property-Based Configuration:** Configure AI models through application properties for different providers.
|
||||
|
||||
**Manual Bean Configuration:** For advanced configurations, define beans manually using @Configuration.
|
||||
|
||||
**Multiple Providers:** Support for multiple AI providers with explicit wiring when needed.
|
||||
|
||||
## Declarative AI Services
|
||||
|
||||
To accomplish interface-based AI service definitions:
|
||||
|
||||
**Basic AI Service:** Create interfaces with @AiService annotation and define methods with message templates.
|
||||
|
||||
**Streaming AI Service:** Implement streaming responses using Reactor or Project Reactor.
|
||||
|
||||
**Explicit Wiring:** Specify which model to use with @AiService(wiringMode = EXPLICIT, chatModel = "modelBeanName").
|
||||
|
||||
## RAG Implementation
|
||||
|
||||
To accomplish RAG system implementation:
|
||||
|
||||
**Embedding Stores:** Configure various embedding stores (PostgreSQL/pgvector, Neo4j, Pinecone, etc.).
|
||||
|
||||
**Document Ingestion:** Implement document processing and embedding generation.
|
||||
|
||||
**Content Retrieval:** Set up content retrieval mechanisms for knowledge augmentation.
|
||||
|
||||
## Tool Integration
|
||||
|
||||
To accomplish AI tool integration:
|
||||
|
||||
**Spring Component Tools:** Define tools as Spring components with @Tool annotations.
|
||||
|
||||
**Database Access Tools:** Create tools for database operations and business logic.
|
||||
|
||||
**Tool Registration:** Automatically register tools with AI services.
|
||||
|
||||
## Examples
|
||||
|
||||
To understand implementation patterns, refer to the comprehensive examples in [references/examples.md](references/examples.md).
|
||||
|
||||
## Best Practices
|
||||
|
||||
To accomplish production-ready AI applications:
|
||||
|
||||
- **Use Property-Based Configuration:** External configuration over hardcoded values
|
||||
- **Implement Proper Error Handling:** Graceful degradation and meaningful error responses
|
||||
- **Use Profiles for Different Environments:** Separate configurations for development, testing, and production
|
||||
- **Implement Proper Logging:** Debug AI service calls and monitor performance
|
||||
- **Secure API Keys:** Use environment variables and never commit to version control
|
||||
- **Handle Failures:** Implement retry mechanisms and fallback strategies
|
||||
- **Monitor Performance:** Add metrics and health checks for observability
|
||||
|
||||
## References
|
||||
|
||||
For detailed API references, advanced configurations, and additional patterns, refer to:
|
||||
|
||||
- [API Reference](references/references.md) - Complete API reference and configurations
|
||||
- [Examples](references/examples.md) - Comprehensive implementation examples
|
||||
- [Configuration Guide](references/configuration.md) - Deep dive into configuration options
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user