Files
gh-tachyon-beep-skillpacks-…/skills/using-ml-production/model-serving-patterns.md
2025-11-30 08:59:57 +08:00

1668 lines
50 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Model Serving Patterns Skill
## When to Use This Skill
Use this skill when:
- Deploying ML models to production environments
- Building model serving APIs for real-time inference
- Optimizing model serving for throughput and latency
- Containerizing models for consistent deployment
- Implementing request batching for efficiency
- Choosing between serving frameworks and protocols
**When NOT to use:** Notebook prototyping, training jobs, or single-prediction scripts where serving infrastructure is premature.
## Core Principle
**Serving infrastructure is not one-size-fits-all. Pattern selection is context-dependent.**
Without proper serving infrastructure:
- model.pkl in repo (manual dependency hell)
- Wrong protocol choice (gRPC for simple REST use cases)
- No batching (1 req/sec instead of 100 req/sec)
- Not containerized (works on my machine syndrome)
- Static batching when dynamic needed (underutilized GPU)
**Formula:** Right framework (FastAPI vs TorchServe vs gRPC vs ONNX) + Request batching (dynamic > static) + Containerization (Docker + model) + Clear selection criteria = Production-ready serving.
## Serving Framework Decision Tree
```
┌────────────────────────────────────────┐
│ What's your primary requirement? │
└──────────────┬─────────────────────────┘
┌───────┴───────┐
▼ ▼
Flexibility Batteries Included
│ │
▼ ▼
FastAPI TorchServe
(Custom) (PyTorch)
│ │
│ ┌───────┴───────┐
│ ▼ ▼
│ Low Latency Cross-Framework
│ │ │
│ ▼ ▼
│ gRPC ONNX Runtime
│ │ │
└───────┴───────────────┘
┌───────────────────────┐
│ Add Request Batching │
│ Dynamic > Static │
└───────────┬────────────┘
┌───────────────────────┐
│ Containerize with │
│ Docker + Dependencies│
└────────────────────────┘
```
## Part 1: FastAPI for Custom Serving
**When to use:** Need flexibility, custom preprocessing, or non-standard workflows.
**Advantages:** Full control, easy debugging, Python ecosystem integration.
**Disadvantages:** Manual optimization, no built-in model management.
### Basic FastAPI Serving
```python
# serve_fastapi.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import torch
import numpy as np
from typing import List, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Model Serving API", version="1.0.0")
class PredictionRequest(BaseModel):
"""Request schema with validation."""
inputs: List[List[float]] = Field(..., description="Input features as 2D array")
return_probabilities: bool = Field(False, description="Return class probabilities")
class Config:
schema_extra = {
"example": {
"inputs": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
"return_probabilities": True
}
}
class PredictionResponse(BaseModel):
"""Response schema."""
predictions: List[int]
probabilities: Optional[List[List[float]]] = None
latency_ms: float
class ModelServer:
"""
Model server with lazy loading and caching.
WHY: Load model once at startup, reuse across requests.
WHY: Avoids 5-10 second model loading per request.
"""
def __init__(self, model_path: str):
self.model_path = model_path
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self):
"""Load model on first request (lazy loading)."""
if self.model is None:
logger.info(f"Loading model from {self.model_path}...")
self.model = torch.load(self.model_path, map_location=self.device)
self.model.eval() # WHY: Disable dropout, batchnorm for inference
logger.info("Model loaded successfully")
def predict(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Run inference.
Args:
inputs: Input array (batch_size, features)
Returns:
(predictions, probabilities)
"""
self.load_model()
# Convert to tensor
x = torch.tensor(inputs, dtype=torch.float32).to(self.device)
# WHY: torch.no_grad() disables gradient computation for inference
# WHY: Reduces memory usage by 50% and speeds up by 2×
with torch.no_grad():
logits = self.model(x)
probabilities = torch.softmax(logits, dim=1)
predictions = torch.argmax(probabilities, dim=1)
return predictions.cpu().numpy(), probabilities.cpu().numpy()
# Global model server instance
model_server = ModelServer(model_path="model.pth")
@app.on_event("startup")
async def startup_event():
"""Load model at startup for faster first request."""
model_server.load_model()
logger.info("Server startup complete")
@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers."""
return {
"status": "healthy",
"model_loaded": model_server.model is not None,
"device": str(model_server.device)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Prediction endpoint with validation and error handling.
WHY: Pydantic validates inputs automatically.
WHY: Returns 422 for invalid inputs, not 500.
"""
import time
start_time = time.time()
try:
inputs = np.array(request.inputs)
# Validate shape
if inputs.ndim != 2:
raise HTTPException(
status_code=400,
detail=f"Expected 2D array, got {inputs.ndim}D"
)
predictions, probabilities = model_server.predict(inputs)
latency_ms = (time.time() - start_time) * 1000
response = PredictionResponse(
predictions=predictions.tolist(),
probabilities=probabilities.tolist() if request.return_probabilities else None,
latency_ms=latency_ms
)
logger.info(f"Predicted {len(predictions)} samples in {latency_ms:.2f}ms")
return response
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Run with: uvicorn serve_fastapi:app --host 0.0.0.0 --port 8000 --workers 4
```
**Performance characteristics:**
| Metric | Value | Notes |
|--------|-------|-------|
| Cold start | 5-10s | Model loading time |
| Warm latency | 10-50ms | Per request |
| Throughput | 100-500 req/sec | Single worker |
| Memory | 2-8GB | Model + runtime |
### Advanced: Async FastAPI with Background Tasks
```python
# serve_fastapi_async.py
from fastapi import FastAPI, BackgroundTasks
from asyncio import Queue, create_task, sleep
import asyncio
from typing import Dict
import uuid
app = FastAPI()
class AsyncBatchPredictor:
"""
Async batch predictor with request queuing.
WHY: Collect multiple requests, predict as batch.
WHY: GPU utilization: 20% (1 req) → 80% (batch of 32).
"""
def __init__(self, model_server: ModelServer, batch_size: int = 32, wait_ms: int = 10):
self.model_server = model_server
self.batch_size = batch_size
self.wait_ms = wait_ms
self.queue: Queue = Queue()
self.pending_requests: Dict[str, asyncio.Future] = {}
async def start(self):
"""Start background batch processing loop."""
create_task(self._batch_processing_loop())
async def _batch_processing_loop(self):
"""
Continuously collect and process batches.
WHY: Wait for batch_size OR timeout, then process.
WHY: Balances throughput (large batch) and latency (timeout).
"""
while True:
batch_requests = []
batch_ids = []
# Collect batch
deadline = asyncio.get_event_loop().time() + (self.wait_ms / 1000)
while len(batch_requests) < self.batch_size:
timeout = max(0, deadline - asyncio.get_event_loop().time())
try:
request_id, inputs = await asyncio.wait_for(
self.queue.get(),
timeout=timeout
)
batch_requests.append(inputs)
batch_ids.append(request_id)
except asyncio.TimeoutError:
break # Timeout reached, process what we have
if not batch_requests:
await sleep(0.001) # Brief sleep before next iteration
continue
# Process batch
batch_array = np.array(batch_requests)
predictions, probabilities = self.model_server.predict(batch_array)
# Return results to waiting requests
for i, request_id in enumerate(batch_ids):
future = self.pending_requests.pop(request_id)
future.set_result((predictions[i], probabilities[i]))
async def predict_async(self, inputs: List[float]) -> tuple[int, np.ndarray]:
"""
Add request to queue and await result.
WHY: Returns immediately if batch ready, waits if not.
WHY: Client doesn't know about batching (transparent).
"""
request_id = str(uuid.uuid4())
future = asyncio.Future()
self.pending_requests[request_id] = future
await self.queue.put((request_id, inputs))
# Wait for batch processing to complete
prediction, probability = await future
return prediction, probability
# Global async predictor
async_predictor = None
@app.on_event("startup")
async def startup():
global async_predictor
model_server.load_model()
async_predictor = AsyncBatchPredictor(model_server, batch_size=32, wait_ms=10)
await async_predictor.start()
@app.post("/predict_async")
async def predict_async(request: PredictionRequest):
"""
Async prediction with automatic batching.
WHY: 10× better GPU utilization than synchronous.
WHY: Same latency, much higher throughput.
"""
# Single input for simplicity (extend for batch)
inputs = request.inputs[0]
prediction, probability = await async_predictor.predict_async(inputs)
return {
"prediction": int(prediction),
"probability": probability.tolist()
}
```
**Performance improvement:**
| Approach | Throughput | GPU Utilization | Latency P95 |
|----------|-----------|-----------------|-------------|
| Synchronous | 100 req/sec | 20% | 15ms |
| Async batching | 1000 req/sec | 80% | 25ms |
| Improvement | **10×** | **4×** | +10ms |
## Part 2: TorchServe for PyTorch Models
**When to use:** PyTorch models, want batteries-included solution with monitoring, metrics, and model management.
**Advantages:** Built-in batching, model versioning, A/B testing, metrics.
**Disadvantages:** PyTorch-only, less flexibility, steeper learning curve.
### Creating a TorchServe Handler
```python
# handler.py
import torch
import torch.nn.functional as F
from ts.torch_handler.base_handler import BaseHandler
import logging
logger = logging.getLogger(__name__)
class CustomClassifierHandler(BaseHandler):
"""
Custom TorchServe handler with preprocessing and batching.
WHY: TorchServe provides: model versioning, A/B testing, metrics, monitoring.
WHY: Built-in dynamic batching (no custom code needed).
"""
def initialize(self, context):
"""
Initialize handler (called once at startup).
Args:
context: TorchServe context with model artifacts
"""
self.manifest = context.manifest
properties = context.system_properties
# Set device
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else "cpu"
)
# Load model
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_path = f"{model_dir}/{serialized_file}"
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
logger.info(f"Model loaded successfully on {self.device}")
# WHY: Initialize preprocessing parameters
self.mean = torch.tensor([0.485, 0.456, 0.406]).to(self.device)
self.std = torch.tensor([0.229, 0.224, 0.225]).to(self.device)
self.initialized = True
def preprocess(self, data):
"""
Preprocess input data.
Args:
data: List of input requests
Returns:
Preprocessed tensor batch
WHY: TorchServe batches requests automatically.
WHY: This method receives multiple requests at once.
"""
inputs = []
for row in data:
# Get input from request (JSON or binary)
input_data = row.get("data") or row.get("body")
# Parse and convert
if isinstance(input_data, (bytes, bytearray)):
input_data = input_data.decode("utf-8")
# Convert to tensor
tensor = torch.tensor(eval(input_data), dtype=torch.float32)
# Normalize
tensor = (tensor - self.mean) / self.std
inputs.append(tensor)
# Stack into batch
batch = torch.stack(inputs).to(self.device)
return batch
def inference(self, batch):
"""
Run inference on batch.
Args:
batch: Preprocessed batch tensor
Returns:
Model output
WHY: torch.no_grad() for inference (faster, less memory).
"""
with torch.no_grad():
output = self.model(batch)
return output
def postprocess(self, inference_output):
"""
Postprocess inference output.
Args:
inference_output: Raw model output
Returns:
List of predictions (one per request in batch)
WHY: Convert tensors to JSON-serializable format.
WHY: Return predictions in same order as inputs.
"""
# Apply softmax
probabilities = F.softmax(inference_output, dim=1)
# Get predictions
predictions = torch.argmax(probabilities, dim=1)
# Convert to list (one entry per request)
results = []
for i in range(len(predictions)):
results.append({
"prediction": predictions[i].item(),
"probabilities": probabilities[i].tolist()
})
return results
```
### TorchServe Configuration
```python
# model_config.yaml
# WHY: Configuration controls batching, workers, timeouts
# WHY: Tune these for your latency/throughput requirements
minWorkers: 2 # WHY: Minimum workers (always ready)
maxWorkers: 4 # WHY: Maximum workers (scale up under load)
batchSize: 32 # WHY: Maximum batch size (GPU utilization)
maxBatchDelay: 10 # WHY: Max wait time for batch (ms)
# WHY: Trade-off: larger batch (better GPU util) vs latency
responseTimeout: 120 # WHY: Request timeout (seconds)
# WHY: Prevent hung requests
# Device assignment
deviceType: "gpu" # WHY: Use GPU if available
deviceIds: [0] # WHY: Specific GPU ID
# Metrics
metrics:
enable: true
prometheus: true # WHY: Export to Prometheus for monitoring
```
### Packaging and Serving
```bash
# Package model for TorchServe
# WHY: .mar file contains model + handler + config (portable)
torch-model-archiver \
--model-name classifier \
--version 1.0 \
--serialized-file model.pt \
--handler handler.py \
--extra-files "model_config.yaml" \
--export-path model_store/
# Start TorchServe
# WHY: Serves on 8080 (inference), 8081 (management), 8082 (metrics)
torchserve \
--start \
--ncs \
--model-store model_store \
--models classifier.mar \
--ts-config config.properties
# Register model (if not auto-loaded)
curl -X POST "http://localhost:8081/models?url=classifier.mar&batch_size=32&max_batch_delay=10"
# Make prediction
curl -X POST "http://localhost:8080/predictions/classifier" \
-H "Content-Type: application/json" \
-d '{"data": [[1.0, 2.0, 3.0]]}'
# Get metrics (for monitoring)
curl http://localhost:8082/metrics
# Unregister model (for updates)
curl -X DELETE "http://localhost:8081/models/classifier"
```
**TorchServe advantages:**
| Feature | Built-in? | Notes |
|---------|-----------|-------|
| Dynamic batching | ✓ | Automatic, configurable |
| Model versioning | ✓ | A/B testing support |
| Metrics/monitoring | ✓ | Prometheus integration |
| Multi-model serving | ✓ | Multiple models per server |
| GPU management | ✓ | Automatic device assignment |
| Custom preprocessing | ✓ | Via handler |
## Part 3: gRPC for Low-Latency Serving
**When to use:** Low latency critical (< 10ms), internal services, microservices architecture.
**Advantages:** 3-5× faster than REST, binary protocol, streaming support.
**Disadvantages:** More complex, requires proto definitions, harder debugging.
### Protocol Definition
```protobuf
// model_service.proto
syntax = "proto3";
package modelserving;
// WHY: Define service contract in .proto file
// WHY: Code generation for multiple languages (Python, Go, Java, etc.)
service ModelService {
// Unary RPC (one request, one response)
rpc Predict (PredictRequest) returns (PredictResponse);
// Server streaming (one request, stream responses)
rpc PredictStream (PredictRequest) returns (stream PredictResponse);
// Bidirectional streaming (stream requests and responses)
rpc PredictBidi (stream PredictRequest) returns (stream PredictResponse);
}
message PredictRequest {
// WHY: Repeated = array/list
repeated float features = 1; // WHY: Input features
bool return_probabilities = 2;
}
message PredictResponse {
int32 prediction = 1;
repeated float probabilities = 2;
float latency_ms = 3;
}
// Health check service (for load balancers)
service Health {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
}
ServingStatus status = 1;
}
```
### gRPC Server Implementation
```python
# serve_grpc.py
import grpc
from concurrent import futures
import time
import logging
import torch
import numpy as np
# Generated from proto file (run: python -m grpc_tools.protoc ...)
import model_service_pb2
import model_service_pb2_grpc
logger = logging.getLogger(__name__)
class ModelServicer(model_service_pb2_grpc.ModelServiceServicer):
"""
gRPC service implementation.
WHY: gRPC is 3-5× faster than REST (binary protocol, HTTP/2).
WHY: Use for low-latency internal services (< 10ms target).
"""
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.load(model_path, map_location=self.device)
self.model.eval()
logger.info(f"Model loaded on {self.device}")
def Predict(self, request, context):
"""
Unary RPC prediction.
WHY: Fastest for single predictions.
WHY: 3-5ms latency vs 10-15ms for REST.
"""
start_time = time.time()
try:
# Convert proto repeated field to numpy
features = np.array(request.features, dtype=np.float32)
# Reshape for model
x = torch.tensor(features).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
logits = self.model(x)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1)
latency_ms = (time.time() - start_time) * 1000
# Build response
response = model_service_pb2.PredictResponse(
prediction=int(pred.item()),
latency_ms=latency_ms
)
# WHY: Only include probabilities if requested (reduce bandwidth)
if request.return_probabilities:
response.probabilities.extend(probs[0].cpu().tolist())
return response
except Exception as e:
logger.error(f"Prediction error: {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return model_service_pb2.PredictResponse()
def PredictStream(self, request, context):
"""
Server streaming RPC.
WHY: Send multiple predictions over one connection.
WHY: Lower overhead for batch processing.
"""
# Stream multiple predictions (example: time series)
for i in range(10): # Simulate 10 predictions
response = self.Predict(request, context)
yield response
time.sleep(0.01) # Simulate processing delay
def PredictBidi(self, request_iterator, context):
"""
Bidirectional streaming RPC.
WHY: Real-time inference (send request, get response immediately).
WHY: Lowest latency for streaming use cases.
"""
for request in request_iterator:
response = self.Predict(request, context)
yield response
class HealthServicer(model_service_pb2_grpc.HealthServicer):
"""Health check service for load balancers."""
def Check(self, request, context):
# WHY: Load balancers need health checks to route traffic
return model_service_pb2.HealthCheckResponse(
status=model_service_pb2.HealthCheckResponse.SERVING
)
def serve():
"""
Start gRPC server.
WHY: ThreadPoolExecutor for concurrent request handling.
WHY: max_workers controls concurrency (tune based on CPU cores).
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
# WHY: These options optimize for low latency
('grpc.max_send_message_length', 10 * 1024 * 1024), # 10MB
('grpc.max_receive_message_length', 10 * 1024 * 1024),
('grpc.so_reuseport', 1), # WHY: Allows multiple servers on same port
('grpc.use_local_subchannel_pool', 1) # WHY: Better connection reuse
]
)
# Add services
model_service_pb2_grpc.add_ModelServiceServicer_to_server(
ModelServicer("model.pth"), server
)
model_service_pb2_grpc.add_HealthServicer_to_server(
HealthServicer(), server
)
# Bind to port
server.add_insecure_port('[::]:50051')
server.start()
logger.info("gRPC server started on port 50051")
server.wait_for_termination()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
serve()
```
### gRPC Client
```python
# client_grpc.py
import grpc
import model_service_pb2
import model_service_pb2_grpc
import time
def benchmark_grpc_vs_rest():
"""
Benchmark gRPC vs REST latency.
WHY: gRPC is faster, but how much faster?
"""
# gRPC client
channel = grpc.insecure_channel('localhost:50051')
stub = model_service_pb2_grpc.ModelServiceStub(channel)
# Warm up
request = model_service_pb2.PredictRequest(
features=[1.0, 2.0, 3.0],
return_probabilities=True
)
for _ in range(10):
stub.Predict(request)
# Benchmark
iterations = 1000
start = time.time()
for _ in range(iterations):
response = stub.Predict(request)
grpc_latency = ((time.time() - start) / iterations) * 1000
print(f"gRPC average latency: {grpc_latency:.2f}ms")
# Compare with REST (FastAPI)
import requests
rest_url = "http://localhost:8000/predict"
# Warm up
for _ in range(10):
requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]})
# Benchmark
start = time.time()
for _ in range(iterations):
requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]})
rest_latency = ((time.time() - start) / iterations) * 1000
print(f"REST average latency: {rest_latency:.2f}ms")
print(f"gRPC is {rest_latency/grpc_latency:.1f}× faster")
# Typical results:
# gRPC: 3-5ms
# REST: 10-15ms
# gRPC is 3-5× faster
if __name__ == "__main__":
benchmark_grpc_vs_rest()
```
**gRPC vs REST comparison:**
| Metric | gRPC | REST | Advantage |
|--------|------|------|-----------|
| Latency | 3-5ms | 10-15ms | **gRPC 3× faster** |
| Throughput | 10k req/sec | 3k req/sec | **gRPC 3× higher** |
| Payload size | Binary (smaller) | JSON (larger) | gRPC 30-50% smaller |
| Debugging | Harder | Easier | REST |
| Browser support | No (requires proxy) | Yes | REST |
| Streaming | Native | Complex (SSE/WebSocket) | gRPC |
## Part 4: ONNX Runtime for Cross-Framework Serving
**When to use:** Need cross-framework support (PyTorch, TensorFlow, etc.), want maximum performance, or deploying to edge devices.
**Advantages:** Framework-agnostic, highly optimized, smaller deployment size.
**Disadvantages:** Not all models convert easily, limited debugging.
### Converting PyTorch to ONNX
```python
# convert_to_onnx.py
import torch
import torch.onnx
def convert_pytorch_to_onnx(model_path: str, output_path: str):
"""
Convert PyTorch model to ONNX format.
WHY: ONNX is framework-agnostic (portable).
WHY: ONNX Runtime is 2-3× faster than native PyTorch inference.
WHY: Smaller deployment size (no PyTorch dependency).
"""
# Load PyTorch model
model = torch.load(model_path)
model.eval()
# Create dummy input (for tracing)
dummy_input = torch.randn(1, 3, 224, 224) # Example: image
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True, # WHY: Include model weights
opset_version=17, # WHY: Latest stable ONNX opset
do_constant_folding=True, # WHY: Optimize constants at export time
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'}, # WHY: Support variable batch size
'output': {0: 'batch_size'}
}
)
print(f"Model exported to {output_path}")
# Verify ONNX model
import onnx
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("ONNX model validation successful")
# Example usage
convert_pytorch_to_onnx("model.pth", "model.onnx")
```
### ONNX Runtime Serving
```python
# serve_onnx.py
import onnxruntime as ort
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
class ONNXModelServer:
"""
ONNX Runtime server with optimizations.
WHY: ONNX Runtime is 2-3× faster than PyTorch inference.
WHY: Smaller memory footprint (no PyTorch/TensorFlow).
WHY: Cross-platform (Windows, Linux, Mac, mobile, edge).
"""
def __init__(self, model_path: str):
self.model_path = model_path
self.session = None
def load_model(self):
"""Load ONNX model with optimizations."""
if self.session is None:
# Set execution providers (GPU > CPU)
# WHY: Tries GPU first, falls back to CPU
providers = [
'CUDAExecutionProvider', # NVIDIA GPU
'CPUExecutionProvider' # CPU fallback
]
# Session options for optimization
sess_options = ort.SessionOptions()
# WHY: Enable graph optimizations (fuse ops, constant folding)
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
# WHY: Intra-op parallelism (parallel ops within graph)
sess_options.intra_op_num_threads = 4
# WHY: Inter-op parallelism (parallel independent subgraphs)
sess_options.inter_op_num_threads = 2
# WHY: Enable memory pattern optimization
sess_options.enable_mem_pattern = True
# WHY: Enable CPU memory arena (reduces allocation overhead)
sess_options.enable_cpu_mem_arena = True
self.session = ort.InferenceSession(
self.model_path,
sess_options=sess_options,
providers=providers
)
# Get input/output metadata
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
logger.info(f"ONNX model loaded: {self.model_path}")
logger.info(f"Execution provider: {self.session.get_providers()[0]}")
def predict(self, inputs: np.ndarray) -> np.ndarray:
"""
Run ONNX inference.
WHY: ONNX Runtime automatically optimizes:
- Operator fusion (combine multiple ops)
- Constant folding (compute constants at load time)
- Memory reuse (reduce allocations)
"""
self.load_model()
# Run inference
outputs = self.session.run(
[self.output_name],
{self.input_name: inputs.astype(np.float32)}
)
return outputs[0]
def benchmark_vs_pytorch(self, num_iterations: int = 1000):
"""Compare ONNX vs PyTorch inference speed."""
import time
import torch
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
# Warm up
for _ in range(10):
self.predict(dummy_input)
# Benchmark ONNX
start = time.time()
for _ in range(num_iterations):
self.predict(dummy_input)
onnx_time = (time.time() - start) / num_iterations * 1000
# Benchmark PyTorch
pytorch_model = torch.load(self.model_path.replace('.onnx', '.pth'))
pytorch_model.eval()
dummy_tensor = torch.from_numpy(dummy_input)
# Warm up
with torch.no_grad():
for _ in range(10):
pytorch_model(dummy_tensor)
# Benchmark
start = time.time()
with torch.no_grad():
for _ in range(num_iterations):
pytorch_model(dummy_tensor)
pytorch_time = (time.time() - start) / num_iterations * 1000
print(f"ONNX Runtime: {onnx_time:.2f}ms")
print(f"PyTorch: {pytorch_time:.2f}ms")
print(f"ONNX is {pytorch_time/onnx_time:.1f}× faster")
# Typical results:
# ONNX: 5-8ms
# PyTorch: 12-20ms
# ONNX is 2-3× faster
# Global server
onnx_server = ONNXModelServer("model.onnx")
@app.on_event("startup")
async def startup():
onnx_server.load_model()
@app.post("/predict")
async def predict(request: PredictionRequest):
"""ONNX prediction endpoint."""
inputs = np.array(request.inputs, dtype=np.float32)
outputs = onnx_server.predict(inputs)
return {
"predictions": outputs.tolist()
}
```
**ONNX Runtime advantages:**
| Feature | Benefit | Measurement |
|---------|---------|-------------|
| Speed | Optimized operators | 2-3× faster than native |
| Size | No framework dependency | 10-50MB vs 500MB+ (PyTorch) |
| Portability | Framework-agnostic | PyTorch/TF/etc → ONNX |
| Edge deployment | Lightweight runtime | Mobile, IoT, embedded |
## Part 5: Request Batching Patterns
**Core principle:** Batch requests for GPU efficiency.
**Why batching matters:**
- GPU utilization: 20% (single request) → 80% (batch of 32)
- Throughput: 100 req/sec (unbatched) → 1000 req/sec (batched)
- Cost: 10× reduction in GPU cost per request
### Dynamic Batching (Adaptive)
```python
# dynamic_batching.py
import asyncio
from asyncio import Queue, Lock
from typing import List, Tuple
import numpy as np
import time
import logging
logger = logging.getLogger(__name__)
class DynamicBatcher:
"""
Dynamic batching with adaptive timeout.
WHY: Static batching waits for full batch (high latency at low load).
WHY: Dynamic batching adapts: full batch OR timeout (balanced).
Key parameters:
- max_batch_size: Maximum batch size (GPU memory limit)
- max_wait_ms: Maximum wait time (latency target)
Trade-off:
- Larger batch → better GPU utilization, higher throughput
- Shorter timeout → lower latency, worse GPU utilization
"""
def __init__(
self,
model_server,
max_batch_size: int = 32,
max_wait_ms: int = 10
):
self.model_server = model_server
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.request_queue: Queue = Queue()
self.batch_lock = Lock()
self.stats = {
"total_requests": 0,
"total_batches": 0,
"avg_batch_size": 0,
"gpu_utilization": 0
}
async def start(self):
"""Start batch processing loop."""
asyncio.create_task(self._batch_loop())
async def _batch_loop(self):
"""
Main batching loop.
Algorithm:
1. Wait for first request
2. Start timeout timer
3. Collect requests until:
- Batch full (max_batch_size reached)
- OR timeout expired (max_wait_ms)
4. Process batch
5. Return results to waiting requests
"""
while True:
batch = []
futures = []
# Wait for first request (no timeout)
request_data, future = await self.request_queue.get()
batch.append(request_data)
futures.append(future)
# Start deadline timer
deadline = asyncio.get_event_loop().time() + (self.max_wait_ms / 1000)
# Collect additional requests until batch full or timeout
while len(batch) < self.max_batch_size:
remaining_time = max(0, deadline - asyncio.get_event_loop().time())
try:
request_data, future = await asyncio.wait_for(
self.request_queue.get(),
timeout=remaining_time
)
batch.append(request_data)
futures.append(future)
except asyncio.TimeoutError:
# Timeout: process what we have
break
# Process batch
await self._process_batch(batch, futures)
async def _process_batch(
self,
batch: List[np.ndarray],
futures: List[asyncio.Future]
):
"""Process batch and return results."""
batch_size = len(batch)
# Convert to batch array
batch_array = np.array(batch)
# Run inference
start_time = time.time()
predictions, probabilities = self.model_server.predict(batch_array)
inference_time = (time.time() - start_time) * 1000
# Update stats
self.stats["total_requests"] += batch_size
self.stats["total_batches"] += 1
self.stats["avg_batch_size"] = (
self.stats["total_requests"] / self.stats["total_batches"]
)
self.stats["gpu_utilization"] = (
self.stats["avg_batch_size"] / self.max_batch_size * 100
)
logger.info(
f"Processed batch: size={batch_size}, "
f"inference_time={inference_time:.2f}ms, "
f"avg_batch_size={self.stats['avg_batch_size']:.1f}, "
f"gpu_util={self.stats['gpu_utilization']:.1f}%"
)
# Return results to waiting requests
for i, future in enumerate(futures):
if not future.done():
future.set_result((predictions[i], probabilities[i]))
async def predict(self, inputs: np.ndarray) -> Tuple[int, np.ndarray]:
"""
Add request to batch queue.
WHY: Transparent batching (caller doesn't see batching).
WHY: Returns when batch processed (might wait for other requests).
"""
future = asyncio.Future()
await self.request_queue.put((inputs, future))
# Wait for batch to be processed
prediction, probability = await future
return prediction, probability
def get_stats(self):
"""Get batching statistics."""
return self.stats
# Example usage with load simulation
async def simulate_load():
"""
Simulate varying load to demonstrate dynamic batching.
WHY: Shows how batcher adapts to load:
- High load: Fills batches quickly (high GPU util)
- Low load: Processes smaller batches (low latency)
"""
from serve_fastapi import ModelServer
model_server = ModelServer("model.pth")
model_server.load_model()
batcher = DynamicBatcher(
model_server,
max_batch_size=32,
max_wait_ms=10
)
await batcher.start()
# High load (32 concurrent requests)
print("Simulating HIGH LOAD (32 concurrent)...")
tasks = []
for i in range(32):
inputs = np.random.randn(10)
task = asyncio.create_task(batcher.predict(inputs))
tasks.append(task)
results = await asyncio.gather(*tasks)
print(f"High load results: {len(results)} predictions")
print(f"Stats: {batcher.get_stats()}")
# Expected: avg_batch_size ≈ 32, gpu_util ≈ 100%
await asyncio.sleep(0.1) # Reset
# Low load (1 request at a time)
print("\nSimulating LOW LOAD (1 at a time)...")
for i in range(10):
inputs = np.random.randn(10)
result = await batcher.predict(inputs)
await asyncio.sleep(0.02) # 20ms between requests
print(f"Stats: {batcher.get_stats()}")
# Expected: avg_batch_size ≈ 1-2, gpu_util ≈ 5-10%
# WHY: Timeout expires before batch fills (low latency maintained)
if __name__ == "__main__":
asyncio.run(simulate_load())
```
**Batching performance:**
| Load | Batch Size | GPU Util | Latency | Throughput |
|------|-----------|----------|---------|------------|
| High (100 req/sec) | 28-32 | 90% | 12ms | 1000 req/sec |
| Medium (20 req/sec) | 8-12 | 35% | 11ms | 200 req/sec |
| Low (5 req/sec) | 1-2 | 10% | 10ms | 50 req/sec |
**Key insight:** Dynamic batching adapts to load while maintaining latency target.
## Part 6: Containerization
**Why containerize:** "Works on my machine" → "Works everywhere"
**Benefits:**
- Reproducible builds (same dependencies, versions)
- Isolated environment (no conflicts)
- Portable deployment (dev, staging, prod identical)
- Easy scaling (K8s, Docker Swarm)
### Multi-Stage Docker Build
```dockerfile
# Dockerfile
# WHY: Multi-stage build reduces image size by 50-80%
# WHY: Build stage has compilers, runtime stage only has runtime deps
# ==================== Stage 1: Build ====================
FROM python:3.11-slim as builder
# WHY: Install build dependencies (needed for compilation)
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# WHY: Create virtual environment in builder stage
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# WHY: Copy only requirements first (layer caching)
# WHY: If requirements.txt unchanged, this layer is cached
COPY requirements.txt .
# WHY: Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# ==================== Stage 2: Runtime ====================
FROM python:3.11-slim
# WHY: Copy only virtual environment from builder (not build tools)
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# WHY: Set working directory
WORKDIR /app
# WHY: Copy application code
COPY serve_fastapi.py .
COPY model.pth .
# WHY: Non-root user for security
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
# WHY: Expose port (documentation, not enforcement)
EXPOSE 8000
# WHY: Health check for container orchestration
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# WHY: Run with uvicorn (production ASGI server)
CMD ["uvicorn", "serve_fastapi:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
```
### Docker Compose for Multi-Service
```yaml
# docker-compose.yml
# WHY: Docker Compose for local development and testing
# WHY: Defines multiple services (API, model, monitoring)
version: '3.8'
services:
# Model serving API
model-api:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
# WHY: Environment variables for configuration
- MODEL_PATH=/app/model.pth
- LOG_LEVEL=INFO
volumes:
# WHY: Mount model directory (for updates without rebuild)
- ./models:/app/models:ro
deploy:
resources:
# WHY: Limit resources to prevent resource exhaustion
limits:
cpus: '2'
memory: 4G
reservations:
# WHY: Reserve GPU (requires nvidia-docker)
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
# Redis for caching
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis-data:/data
command: redis-server --appendonly yes
# Prometheus for metrics
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus-data:/prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yml'
# Grafana for visualization
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana-data:/var/lib/grafana
volumes:
redis-data:
prometheus-data:
grafana-data:
```
### Build and Deploy
```bash
# Build image
# WHY: Tag with version for rollback capability
docker build -t model-api:1.0.0 .
# Run container
docker run -d \
--name model-api \
-p 8000:8000 \
--gpus all \
model-api:1.0.0
# Check logs
docker logs -f model-api
# Test API
curl http://localhost:8000/health
# Start all services with docker-compose
docker-compose up -d
# Scale API service (multiple instances)
# WHY: Load balancer distributes traffic across instances
docker-compose up -d --scale model-api=3
# View logs
docker-compose logs -f model-api
# Stop all services
docker-compose down
```
**Container image sizes:**
| Stage | Size | Contents |
|-------|------|----------|
| Full build | 2.5 GB | Python + build tools + deps + model |
| Multi-stage | 800 MB | Python + runtime deps + model |
| Optimized | 400 MB | Minimal Python + deps + model |
| Savings | **84%** | From 2.5 GB → 400 MB |
## Part 7: Framework Selection Guide
### Decision Matrix
```python
# framework_selector.py
from enum import Enum
from typing import List
class Requirement(Enum):
FLEXIBILITY = "flexibility" # Custom preprocessing, business logic
BATTERIES_INCLUDED = "batteries" # Minimal setup, built-in features
LOW_LATENCY = "low_latency" # < 10ms target
CROSS_FRAMEWORK = "cross_framework" # PyTorch + TensorFlow support
EDGE_DEPLOYMENT = "edge" # Mobile, IoT, embedded
EASE_OF_DEBUG = "debug" # Development experience
HIGH_THROUGHPUT = "throughput" # > 1000 req/sec
class Framework(Enum):
FASTAPI = "fastapi"
TORCHSERVE = "torchserve"
GRPC = "grpc"
ONNX = "onnx"
# Framework capabilities (0-5 scale)
FRAMEWORK_SCORES = {
Framework.FASTAPI: {
Requirement.FLEXIBILITY: 5, # Full control
Requirement.BATTERIES_INCLUDED: 2, # Manual implementation
Requirement.LOW_LATENCY: 3, # 10-20ms
Requirement.CROSS_FRAMEWORK: 4, # Any Python model
Requirement.EDGE_DEPLOYMENT: 2, # Heavyweight
Requirement.EASE_OF_DEBUG: 5, # Excellent debugging
Requirement.HIGH_THROUGHPUT: 3 # 100-500 req/sec
},
Framework.TORCHSERVE: {
Requirement.FLEXIBILITY: 3, # Customizable via handlers
Requirement.BATTERIES_INCLUDED: 5, # Everything built-in
Requirement.LOW_LATENCY: 4, # 5-15ms
Requirement.CROSS_FRAMEWORK: 1, # PyTorch only
Requirement.EDGE_DEPLOYMENT: 2, # Heavyweight
Requirement.EASE_OF_DEBUG: 3, # Learning curve
Requirement.HIGH_THROUGHPUT: 5 # 1000+ req/sec with batching
},
Framework.GRPC: {
Requirement.FLEXIBILITY: 4, # Binary protocol, custom logic
Requirement.BATTERIES_INCLUDED: 2, # Manual implementation
Requirement.LOW_LATENCY: 5, # 3-8ms
Requirement.CROSS_FRAMEWORK: 4, # Any model
Requirement.EDGE_DEPLOYMENT: 3, # Moderate size
Requirement.EASE_OF_DEBUG: 2, # Binary protocol harder
Requirement.HIGH_THROUGHPUT: 5 # 1000+ req/sec
},
Framework.ONNX: {
Requirement.FLEXIBILITY: 3, # Limited to ONNX ops
Requirement.BATTERIES_INCLUDED: 3, # Runtime provided
Requirement.LOW_LATENCY: 5, # 2-6ms (optimized)
Requirement.CROSS_FRAMEWORK: 5, # Any framework → ONNX
Requirement.EDGE_DEPLOYMENT: 5, # Lightweight runtime
Requirement.EASE_OF_DEBUG: 2, # Conversion can be tricky
Requirement.HIGH_THROUGHPUT: 4 # 500-1000 req/sec
}
}
def select_framework(
requirements: List[Requirement],
weights: List[float] = None
) -> Framework:
"""
Select best framework based on requirements.
Args:
requirements: List of requirements
weights: Importance weight for each requirement (0-1)
Returns:
Best framework
"""
if weights is None:
weights = [1.0] * len(requirements)
scores = {}
for framework in Framework:
score = 0
for req, weight in zip(requirements, weights):
score += FRAMEWORK_SCORES[framework][req] * weight
scores[framework] = score
best_framework = max(scores, key=scores.get)
print(f"\nFramework Selection:")
print(f"Requirements: {[r.value for r in requirements]}")
print(f"\nScores:")
for framework, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
print(f" {framework.value}: {score:.1f}")
return best_framework
# Example use cases
print("=" * 60)
print("Use Case 1: Prototyping with flexibility")
print("=" * 60)
selected = select_framework([
Requirement.FLEXIBILITY,
Requirement.EASE_OF_DEBUG
])
print(f"\nRecommendation: {selected.value}")
# Expected: FASTAPI
print("\n" + "=" * 60)
print("Use Case 2: Production PyTorch with minimal setup")
print("=" * 60)
selected = select_framework([
Requirement.BATTERIES_INCLUDED,
Requirement.HIGH_THROUGHPUT
])
print(f"\nRecommendation: {selected.value}")
# Expected: TORCHSERVE
print("\n" + "=" * 60)
print("Use Case 3: Low-latency microservice")
print("=" * 60)
selected = select_framework([
Requirement.LOW_LATENCY,
Requirement.HIGH_THROUGHPUT
])
print(f"\nRecommendation: {selected.value}")
# Expected: GRPC or ONNX
print("\n" + "=" * 60)
print("Use Case 4: Edge deployment (mobile/IoT)")
print("=" * 60)
selected = select_framework([
Requirement.EDGE_DEPLOYMENT,
Requirement.CROSS_FRAMEWORK,
Requirement.LOW_LATENCY
])
print(f"\nRecommendation: {selected.value}")
# Expected: ONNX
print("\n" + "=" * 60)
print("Use Case 5: Multi-framework ML platform")
print("=" * 60)
selected = select_framework([
Requirement.CROSS_FRAMEWORK,
Requirement.HIGH_THROUGHPUT,
Requirement.BATTERIES_INCLUDED
])
print(f"\nRecommendation: {selected.value}")
# Expected: ONNX or TORCHSERVE (depending on weights)
```
### Quick Reference Guide
| Scenario | Framework | Why |
|----------|-----------|-----|
| **Prototyping** | FastAPI | Fast iteration, easy debugging |
| **PyTorch production** | TorchServe | Built-in batching, metrics, management |
| **Internal microservices** | gRPC | Lowest latency, high throughput |
| **Multi-framework** | ONNX Runtime | Framework-agnostic, optimized |
| **Edge/mobile** | ONNX Runtime | Lightweight, cross-platform |
| **Custom preprocessing** | FastAPI | Full flexibility |
| **High throughput batch** | TorchServe + batching | Dynamic batching built-in |
| **Real-time streaming** | gRPC | Bidirectional streaming |
## Summary
**Model serving is pattern matching, not one-size-fits-all.**
**Core patterns:**
1. **FastAPI:** Flexibility, custom logic, easy debugging
2. **TorchServe:** PyTorch batteries-included, built-in batching
3. **gRPC:** Low latency (3-5ms), high throughput, microservices
4. **ONNX Runtime:** Cross-framework, optimized, edge deployment
5. **Dynamic batching:** Adaptive batch size, balances latency and throughput
6. **Containerization:** Reproducible, portable, scalable
**Selection checklist:**
- ✓ Identify primary requirement (flexibility, latency, throughput, etc.)
- ✓ Match requirement to framework strengths
- ✓ Consider deployment environment (cloud, edge, on-prem)
- ✓ Evaluate trade-offs (development speed vs performance)
- ✓ Implement batching if GPU-based (10× better utilization)
- ✓ Containerize for reproducibility
- ✓ Monitor metrics (latency, throughput, GPU util)
- ✓ Iterate based on production data
**Anti-patterns to avoid:**
- ✗ model.pkl in repo (dependency hell)
- ✗ gRPC for simple REST use cases (over-engineering)
- ✗ No batching with GPU (wasted 80% capacity)
- ✗ Not containerized (deployment inconsistency)
- ✗ Static batching (poor latency at low load)
Production-ready model serving requires matching infrastructure pattern to requirements.