# Model Serving Patterns Skill ## When to Use This Skill Use this skill when: - Deploying ML models to production environments - Building model serving APIs for real-time inference - Optimizing model serving for throughput and latency - Containerizing models for consistent deployment - Implementing request batching for efficiency - Choosing between serving frameworks and protocols **When NOT to use:** Notebook prototyping, training jobs, or single-prediction scripts where serving infrastructure is premature. ## Core Principle **Serving infrastructure is not one-size-fits-all. Pattern selection is context-dependent.** Without proper serving infrastructure: - model.pkl in repo (manual dependency hell) - Wrong protocol choice (gRPC for simple REST use cases) - No batching (1 req/sec instead of 100 req/sec) - Not containerized (works on my machine syndrome) - Static batching when dynamic needed (underutilized GPU) **Formula:** Right framework (FastAPI vs TorchServe vs gRPC vs ONNX) + Request batching (dynamic > static) + Containerization (Docker + model) + Clear selection criteria = Production-ready serving. ## Serving Framework Decision Tree ``` ┌────────────────────────────────────────┐ │ What's your primary requirement? │ └──────────────┬─────────────────────────┘ │ ┌───────┴───────┐ ▼ ▼ Flexibility Batteries Included │ │ ▼ ▼ FastAPI TorchServe (Custom) (PyTorch) │ │ │ ┌───────┴───────┐ │ ▼ ▼ │ Low Latency Cross-Framework │ │ │ │ ▼ ▼ │ gRPC ONNX Runtime │ │ │ └───────┴───────────────┘ │ ▼ ┌───────────────────────┐ │ Add Request Batching │ │ Dynamic > Static │ └───────────┬────────────┘ │ ▼ ┌───────────────────────┐ │ Containerize with │ │ Docker + Dependencies│ └────────────────────────┘ ``` ## Part 1: FastAPI for Custom Serving **When to use:** Need flexibility, custom preprocessing, or non-standard workflows. **Advantages:** Full control, easy debugging, Python ecosystem integration. **Disadvantages:** Manual optimization, no built-in model management. ### Basic FastAPI Serving ```python # serve_fastapi.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field import torch import numpy as np from typing import List, Optional import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Model Serving API", version="1.0.0") class PredictionRequest(BaseModel): """Request schema with validation.""" inputs: List[List[float]] = Field(..., description="Input features as 2D array") return_probabilities: bool = Field(False, description="Return class probabilities") class Config: schema_extra = { "example": { "inputs": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "return_probabilities": True } } class PredictionResponse(BaseModel): """Response schema.""" predictions: List[int] probabilities: Optional[List[List[float]]] = None latency_ms: float class ModelServer: """ Model server with lazy loading and caching. WHY: Load model once at startup, reuse across requests. WHY: Avoids 5-10 second model loading per request. """ def __init__(self, model_path: str): self.model_path = model_path self.model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self): """Load model on first request (lazy loading).""" if self.model is None: logger.info(f"Loading model from {self.model_path}...") self.model = torch.load(self.model_path, map_location=self.device) self.model.eval() # WHY: Disable dropout, batchnorm for inference logger.info("Model loaded successfully") def predict(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Run inference. Args: inputs: Input array (batch_size, features) Returns: (predictions, probabilities) """ self.load_model() # Convert to tensor x = torch.tensor(inputs, dtype=torch.float32).to(self.device) # WHY: torch.no_grad() disables gradient computation for inference # WHY: Reduces memory usage by 50% and speeds up by 2× with torch.no_grad(): logits = self.model(x) probabilities = torch.softmax(logits, dim=1) predictions = torch.argmax(probabilities, dim=1) return predictions.cpu().numpy(), probabilities.cpu().numpy() # Global model server instance model_server = ModelServer(model_path="model.pth") @app.on_event("startup") async def startup_event(): """Load model at startup for faster first request.""" model_server.load_model() logger.info("Server startup complete") @app.get("/health") async def health_check(): """Health check endpoint for load balancers.""" return { "status": "healthy", "model_loaded": model_server.model is not None, "device": str(model_server.device) } @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): """ Prediction endpoint with validation and error handling. WHY: Pydantic validates inputs automatically. WHY: Returns 422 for invalid inputs, not 500. """ import time start_time = time.time() try: inputs = np.array(request.inputs) # Validate shape if inputs.ndim != 2: raise HTTPException( status_code=400, detail=f"Expected 2D array, got {inputs.ndim}D" ) predictions, probabilities = model_server.predict(inputs) latency_ms = (time.time() - start_time) * 1000 response = PredictionResponse( predictions=predictions.tolist(), probabilities=probabilities.tolist() if request.return_probabilities else None, latency_ms=latency_ms ) logger.info(f"Predicted {len(predictions)} samples in {latency_ms:.2f}ms") return response except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Run with: uvicorn serve_fastapi:app --host 0.0.0.0 --port 8000 --workers 4 ``` **Performance characteristics:** | Metric | Value | Notes | |--------|-------|-------| | Cold start | 5-10s | Model loading time | | Warm latency | 10-50ms | Per request | | Throughput | 100-500 req/sec | Single worker | | Memory | 2-8GB | Model + runtime | ### Advanced: Async FastAPI with Background Tasks ```python # serve_fastapi_async.py from fastapi import FastAPI, BackgroundTasks from asyncio import Queue, create_task, sleep import asyncio from typing import Dict import uuid app = FastAPI() class AsyncBatchPredictor: """ Async batch predictor with request queuing. WHY: Collect multiple requests, predict as batch. WHY: GPU utilization: 20% (1 req) → 80% (batch of 32). """ def __init__(self, model_server: ModelServer, batch_size: int = 32, wait_ms: int = 10): self.model_server = model_server self.batch_size = batch_size self.wait_ms = wait_ms self.queue: Queue = Queue() self.pending_requests: Dict[str, asyncio.Future] = {} async def start(self): """Start background batch processing loop.""" create_task(self._batch_processing_loop()) async def _batch_processing_loop(self): """ Continuously collect and process batches. WHY: Wait for batch_size OR timeout, then process. WHY: Balances throughput (large batch) and latency (timeout). """ while True: batch_requests = [] batch_ids = [] # Collect batch deadline = asyncio.get_event_loop().time() + (self.wait_ms / 1000) while len(batch_requests) < self.batch_size: timeout = max(0, deadline - asyncio.get_event_loop().time()) try: request_id, inputs = await asyncio.wait_for( self.queue.get(), timeout=timeout ) batch_requests.append(inputs) batch_ids.append(request_id) except asyncio.TimeoutError: break # Timeout reached, process what we have if not batch_requests: await sleep(0.001) # Brief sleep before next iteration continue # Process batch batch_array = np.array(batch_requests) predictions, probabilities = self.model_server.predict(batch_array) # Return results to waiting requests for i, request_id in enumerate(batch_ids): future = self.pending_requests.pop(request_id) future.set_result((predictions[i], probabilities[i])) async def predict_async(self, inputs: List[float]) -> tuple[int, np.ndarray]: """ Add request to queue and await result. WHY: Returns immediately if batch ready, waits if not. WHY: Client doesn't know about batching (transparent). """ request_id = str(uuid.uuid4()) future = asyncio.Future() self.pending_requests[request_id] = future await self.queue.put((request_id, inputs)) # Wait for batch processing to complete prediction, probability = await future return prediction, probability # Global async predictor async_predictor = None @app.on_event("startup") async def startup(): global async_predictor model_server.load_model() async_predictor = AsyncBatchPredictor(model_server, batch_size=32, wait_ms=10) await async_predictor.start() @app.post("/predict_async") async def predict_async(request: PredictionRequest): """ Async prediction with automatic batching. WHY: 10× better GPU utilization than synchronous. WHY: Same latency, much higher throughput. """ # Single input for simplicity (extend for batch) inputs = request.inputs[0] prediction, probability = await async_predictor.predict_async(inputs) return { "prediction": int(prediction), "probability": probability.tolist() } ``` **Performance improvement:** | Approach | Throughput | GPU Utilization | Latency P95 | |----------|-----------|-----------------|-------------| | Synchronous | 100 req/sec | 20% | 15ms | | Async batching | 1000 req/sec | 80% | 25ms | | Improvement | **10×** | **4×** | +10ms | ## Part 2: TorchServe for PyTorch Models **When to use:** PyTorch models, want batteries-included solution with monitoring, metrics, and model management. **Advantages:** Built-in batching, model versioning, A/B testing, metrics. **Disadvantages:** PyTorch-only, less flexibility, steeper learning curve. ### Creating a TorchServe Handler ```python # handler.py import torch import torch.nn.functional as F from ts.torch_handler.base_handler import BaseHandler import logging logger = logging.getLogger(__name__) class CustomClassifierHandler(BaseHandler): """ Custom TorchServe handler with preprocessing and batching. WHY: TorchServe provides: model versioning, A/B testing, metrics, monitoring. WHY: Built-in dynamic batching (no custom code needed). """ def initialize(self, context): """ Initialize handler (called once at startup). Args: context: TorchServe context with model artifacts """ self.manifest = context.manifest properties = context.system_properties # Set device self.device = torch.device( "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" ) # Load model model_dir = properties.get("model_dir") serialized_file = self.manifest["model"]["serializedFile"] model_path = f"{model_dir}/{serialized_file}" self.model = torch.jit.load(model_path, map_location=self.device) self.model.eval() logger.info(f"Model loaded successfully on {self.device}") # WHY: Initialize preprocessing parameters self.mean = torch.tensor([0.485, 0.456, 0.406]).to(self.device) self.std = torch.tensor([0.229, 0.224, 0.225]).to(self.device) self.initialized = True def preprocess(self, data): """ Preprocess input data. Args: data: List of input requests Returns: Preprocessed tensor batch WHY: TorchServe batches requests automatically. WHY: This method receives multiple requests at once. """ inputs = [] for row in data: # Get input from request (JSON or binary) input_data = row.get("data") or row.get("body") # Parse and convert if isinstance(input_data, (bytes, bytearray)): input_data = input_data.decode("utf-8") # Convert to tensor tensor = torch.tensor(eval(input_data), dtype=torch.float32) # Normalize tensor = (tensor - self.mean) / self.std inputs.append(tensor) # Stack into batch batch = torch.stack(inputs).to(self.device) return batch def inference(self, batch): """ Run inference on batch. Args: batch: Preprocessed batch tensor Returns: Model output WHY: torch.no_grad() for inference (faster, less memory). """ with torch.no_grad(): output = self.model(batch) return output def postprocess(self, inference_output): """ Postprocess inference output. Args: inference_output: Raw model output Returns: List of predictions (one per request in batch) WHY: Convert tensors to JSON-serializable format. WHY: Return predictions in same order as inputs. """ # Apply softmax probabilities = F.softmax(inference_output, dim=1) # Get predictions predictions = torch.argmax(probabilities, dim=1) # Convert to list (one entry per request) results = [] for i in range(len(predictions)): results.append({ "prediction": predictions[i].item(), "probabilities": probabilities[i].tolist() }) return results ``` ### TorchServe Configuration ```python # model_config.yaml # WHY: Configuration controls batching, workers, timeouts # WHY: Tune these for your latency/throughput requirements minWorkers: 2 # WHY: Minimum workers (always ready) maxWorkers: 4 # WHY: Maximum workers (scale up under load) batchSize: 32 # WHY: Maximum batch size (GPU utilization) maxBatchDelay: 10 # WHY: Max wait time for batch (ms) # WHY: Trade-off: larger batch (better GPU util) vs latency responseTimeout: 120 # WHY: Request timeout (seconds) # WHY: Prevent hung requests # Device assignment deviceType: "gpu" # WHY: Use GPU if available deviceIds: [0] # WHY: Specific GPU ID # Metrics metrics: enable: true prometheus: true # WHY: Export to Prometheus for monitoring ``` ### Packaging and Serving ```bash # Package model for TorchServe # WHY: .mar file contains model + handler + config (portable) torch-model-archiver \ --model-name classifier \ --version 1.0 \ --serialized-file model.pt \ --handler handler.py \ --extra-files "model_config.yaml" \ --export-path model_store/ # Start TorchServe # WHY: Serves on 8080 (inference), 8081 (management), 8082 (metrics) torchserve \ --start \ --ncs \ --model-store model_store \ --models classifier.mar \ --ts-config config.properties # Register model (if not auto-loaded) curl -X POST "http://localhost:8081/models?url=classifier.mar&batch_size=32&max_batch_delay=10" # Make prediction curl -X POST "http://localhost:8080/predictions/classifier" \ -H "Content-Type: application/json" \ -d '{"data": [[1.0, 2.0, 3.0]]}' # Get metrics (for monitoring) curl http://localhost:8082/metrics # Unregister model (for updates) curl -X DELETE "http://localhost:8081/models/classifier" ``` **TorchServe advantages:** | Feature | Built-in? | Notes | |---------|-----------|-------| | Dynamic batching | ✓ | Automatic, configurable | | Model versioning | ✓ | A/B testing support | | Metrics/monitoring | ✓ | Prometheus integration | | Multi-model serving | ✓ | Multiple models per server | | GPU management | ✓ | Automatic device assignment | | Custom preprocessing | ✓ | Via handler | ## Part 3: gRPC for Low-Latency Serving **When to use:** Low latency critical (< 10ms), internal services, microservices architecture. **Advantages:** 3-5× faster than REST, binary protocol, streaming support. **Disadvantages:** More complex, requires proto definitions, harder debugging. ### Protocol Definition ```protobuf // model_service.proto syntax = "proto3"; package modelserving; // WHY: Define service contract in .proto file // WHY: Code generation for multiple languages (Python, Go, Java, etc.) service ModelService { // Unary RPC (one request, one response) rpc Predict (PredictRequest) returns (PredictResponse); // Server streaming (one request, stream responses) rpc PredictStream (PredictRequest) returns (stream PredictResponse); // Bidirectional streaming (stream requests and responses) rpc PredictBidi (stream PredictRequest) returns (stream PredictResponse); } message PredictRequest { // WHY: Repeated = array/list repeated float features = 1; // WHY: Input features bool return_probabilities = 2; } message PredictResponse { int32 prediction = 1; repeated float probabilities = 2; float latency_ms = 3; } // Health check service (for load balancers) service Health { rpc Check (HealthCheckRequest) returns (HealthCheckResponse); } message HealthCheckRequest { string service = 1; } message HealthCheckResponse { enum ServingStatus { UNKNOWN = 0; SERVING = 1; NOT_SERVING = 2; } ServingStatus status = 1; } ``` ### gRPC Server Implementation ```python # serve_grpc.py import grpc from concurrent import futures import time import logging import torch import numpy as np # Generated from proto file (run: python -m grpc_tools.protoc ...) import model_service_pb2 import model_service_pb2_grpc logger = logging.getLogger(__name__) class ModelServicer(model_service_pb2_grpc.ModelServiceServicer): """ gRPC service implementation. WHY: gRPC is 3-5× faster than REST (binary protocol, HTTP/2). WHY: Use for low-latency internal services (< 10ms target). """ def __init__(self, model_path: str): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = torch.load(model_path, map_location=self.device) self.model.eval() logger.info(f"Model loaded on {self.device}") def Predict(self, request, context): """ Unary RPC prediction. WHY: Fastest for single predictions. WHY: 3-5ms latency vs 10-15ms for REST. """ start_time = time.time() try: # Convert proto repeated field to numpy features = np.array(request.features, dtype=np.float32) # Reshape for model x = torch.tensor(features).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): logits = self.model(x) probs = torch.softmax(logits, dim=1) pred = torch.argmax(probs, dim=1) latency_ms = (time.time() - start_time) * 1000 # Build response response = model_service_pb2.PredictResponse( prediction=int(pred.item()), latency_ms=latency_ms ) # WHY: Only include probabilities if requested (reduce bandwidth) if request.return_probabilities: response.probabilities.extend(probs[0].cpu().tolist()) return response except Exception as e: logger.error(f"Prediction error: {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(e)) return model_service_pb2.PredictResponse() def PredictStream(self, request, context): """ Server streaming RPC. WHY: Send multiple predictions over one connection. WHY: Lower overhead for batch processing. """ # Stream multiple predictions (example: time series) for i in range(10): # Simulate 10 predictions response = self.Predict(request, context) yield response time.sleep(0.01) # Simulate processing delay def PredictBidi(self, request_iterator, context): """ Bidirectional streaming RPC. WHY: Real-time inference (send request, get response immediately). WHY: Lowest latency for streaming use cases. """ for request in request_iterator: response = self.Predict(request, context) yield response class HealthServicer(model_service_pb2_grpc.HealthServicer): """Health check service for load balancers.""" def Check(self, request, context): # WHY: Load balancers need health checks to route traffic return model_service_pb2.HealthCheckResponse( status=model_service_pb2.HealthCheckResponse.SERVING ) def serve(): """ Start gRPC server. WHY: ThreadPoolExecutor for concurrent request handling. WHY: max_workers controls concurrency (tune based on CPU cores). """ server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=[ # WHY: These options optimize for low latency ('grpc.max_send_message_length', 10 * 1024 * 1024), # 10MB ('grpc.max_receive_message_length', 10 * 1024 * 1024), ('grpc.so_reuseport', 1), # WHY: Allows multiple servers on same port ('grpc.use_local_subchannel_pool', 1) # WHY: Better connection reuse ] ) # Add services model_service_pb2_grpc.add_ModelServiceServicer_to_server( ModelServicer("model.pth"), server ) model_service_pb2_grpc.add_HealthServicer_to_server( HealthServicer(), server ) # Bind to port server.add_insecure_port('[::]:50051') server.start() logger.info("gRPC server started on port 50051") server.wait_for_termination() if __name__ == "__main__": logging.basicConfig(level=logging.INFO) serve() ``` ### gRPC Client ```python # client_grpc.py import grpc import model_service_pb2 import model_service_pb2_grpc import time def benchmark_grpc_vs_rest(): """ Benchmark gRPC vs REST latency. WHY: gRPC is faster, but how much faster? """ # gRPC client channel = grpc.insecure_channel('localhost:50051') stub = model_service_pb2_grpc.ModelServiceStub(channel) # Warm up request = model_service_pb2.PredictRequest( features=[1.0, 2.0, 3.0], return_probabilities=True ) for _ in range(10): stub.Predict(request) # Benchmark iterations = 1000 start = time.time() for _ in range(iterations): response = stub.Predict(request) grpc_latency = ((time.time() - start) / iterations) * 1000 print(f"gRPC average latency: {grpc_latency:.2f}ms") # Compare with REST (FastAPI) import requests rest_url = "http://localhost:8000/predict" # Warm up for _ in range(10): requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]}) # Benchmark start = time.time() for _ in range(iterations): requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]}) rest_latency = ((time.time() - start) / iterations) * 1000 print(f"REST average latency: {rest_latency:.2f}ms") print(f"gRPC is {rest_latency/grpc_latency:.1f}× faster") # Typical results: # gRPC: 3-5ms # REST: 10-15ms # gRPC is 3-5× faster if __name__ == "__main__": benchmark_grpc_vs_rest() ``` **gRPC vs REST comparison:** | Metric | gRPC | REST | Advantage | |--------|------|------|-----------| | Latency | 3-5ms | 10-15ms | **gRPC 3× faster** | | Throughput | 10k req/sec | 3k req/sec | **gRPC 3× higher** | | Payload size | Binary (smaller) | JSON (larger) | gRPC 30-50% smaller | | Debugging | Harder | Easier | REST | | Browser support | No (requires proxy) | Yes | REST | | Streaming | Native | Complex (SSE/WebSocket) | gRPC | ## Part 4: ONNX Runtime for Cross-Framework Serving **When to use:** Need cross-framework support (PyTorch, TensorFlow, etc.), want maximum performance, or deploying to edge devices. **Advantages:** Framework-agnostic, highly optimized, smaller deployment size. **Disadvantages:** Not all models convert easily, limited debugging. ### Converting PyTorch to ONNX ```python # convert_to_onnx.py import torch import torch.onnx def convert_pytorch_to_onnx(model_path: str, output_path: str): """ Convert PyTorch model to ONNX format. WHY: ONNX is framework-agnostic (portable). WHY: ONNX Runtime is 2-3× faster than native PyTorch inference. WHY: Smaller deployment size (no PyTorch dependency). """ # Load PyTorch model model = torch.load(model_path) model.eval() # Create dummy input (for tracing) dummy_input = torch.randn(1, 3, 224, 224) # Example: image # Export to ONNX torch.onnx.export( model, dummy_input, output_path, export_params=True, # WHY: Include model weights opset_version=17, # WHY: Latest stable ONNX opset do_constant_folding=True, # WHY: Optimize constants at export time input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, # WHY: Support variable batch size 'output': {0: 'batch_size'} } ) print(f"Model exported to {output_path}") # Verify ONNX model import onnx onnx_model = onnx.load(output_path) onnx.checker.check_model(onnx_model) print("ONNX model validation successful") # Example usage convert_pytorch_to_onnx("model.pth", "model.onnx") ``` ### ONNX Runtime Serving ```python # serve_onnx.py import onnxruntime as ort import numpy as np from fastapi import FastAPI from pydantic import BaseModel from typing import List import logging logger = logging.getLogger(__name__) app = FastAPI() class ONNXModelServer: """ ONNX Runtime server with optimizations. WHY: ONNX Runtime is 2-3× faster than PyTorch inference. WHY: Smaller memory footprint (no PyTorch/TensorFlow). WHY: Cross-platform (Windows, Linux, Mac, mobile, edge). """ def __init__(self, model_path: str): self.model_path = model_path self.session = None def load_model(self): """Load ONNX model with optimizations.""" if self.session is None: # Set execution providers (GPU > CPU) # WHY: Tries GPU first, falls back to CPU providers = [ 'CUDAExecutionProvider', # NVIDIA GPU 'CPUExecutionProvider' # CPU fallback ] # Session options for optimization sess_options = ort.SessionOptions() # WHY: Enable graph optimizations (fuse ops, constant folding) sess_options.graph_optimization_level = ( ort.GraphOptimizationLevel.ORT_ENABLE_ALL ) # WHY: Intra-op parallelism (parallel ops within graph) sess_options.intra_op_num_threads = 4 # WHY: Inter-op parallelism (parallel independent subgraphs) sess_options.inter_op_num_threads = 2 # WHY: Enable memory pattern optimization sess_options.enable_mem_pattern = True # WHY: Enable CPU memory arena (reduces allocation overhead) sess_options.enable_cpu_mem_arena = True self.session = ort.InferenceSession( self.model_path, sess_options=sess_options, providers=providers ) # Get input/output metadata self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name logger.info(f"ONNX model loaded: {self.model_path}") logger.info(f"Execution provider: {self.session.get_providers()[0]}") def predict(self, inputs: np.ndarray) -> np.ndarray: """ Run ONNX inference. WHY: ONNX Runtime automatically optimizes: - Operator fusion (combine multiple ops) - Constant folding (compute constants at load time) - Memory reuse (reduce allocations) """ self.load_model() # Run inference outputs = self.session.run( [self.output_name], {self.input_name: inputs.astype(np.float32)} ) return outputs[0] def benchmark_vs_pytorch(self, num_iterations: int = 1000): """Compare ONNX vs PyTorch inference speed.""" import time import torch dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32) # Warm up for _ in range(10): self.predict(dummy_input) # Benchmark ONNX start = time.time() for _ in range(num_iterations): self.predict(dummy_input) onnx_time = (time.time() - start) / num_iterations * 1000 # Benchmark PyTorch pytorch_model = torch.load(self.model_path.replace('.onnx', '.pth')) pytorch_model.eval() dummy_tensor = torch.from_numpy(dummy_input) # Warm up with torch.no_grad(): for _ in range(10): pytorch_model(dummy_tensor) # Benchmark start = time.time() with torch.no_grad(): for _ in range(num_iterations): pytorch_model(dummy_tensor) pytorch_time = (time.time() - start) / num_iterations * 1000 print(f"ONNX Runtime: {onnx_time:.2f}ms") print(f"PyTorch: {pytorch_time:.2f}ms") print(f"ONNX is {pytorch_time/onnx_time:.1f}× faster") # Typical results: # ONNX: 5-8ms # PyTorch: 12-20ms # ONNX is 2-3× faster # Global server onnx_server = ONNXModelServer("model.onnx") @app.on_event("startup") async def startup(): onnx_server.load_model() @app.post("/predict") async def predict(request: PredictionRequest): """ONNX prediction endpoint.""" inputs = np.array(request.inputs, dtype=np.float32) outputs = onnx_server.predict(inputs) return { "predictions": outputs.tolist() } ``` **ONNX Runtime advantages:** | Feature | Benefit | Measurement | |---------|---------|-------------| | Speed | Optimized operators | 2-3× faster than native | | Size | No framework dependency | 10-50MB vs 500MB+ (PyTorch) | | Portability | Framework-agnostic | PyTorch/TF/etc → ONNX | | Edge deployment | Lightweight runtime | Mobile, IoT, embedded | ## Part 5: Request Batching Patterns **Core principle:** Batch requests for GPU efficiency. **Why batching matters:** - GPU utilization: 20% (single request) → 80% (batch of 32) - Throughput: 100 req/sec (unbatched) → 1000 req/sec (batched) - Cost: 10× reduction in GPU cost per request ### Dynamic Batching (Adaptive) ```python # dynamic_batching.py import asyncio from asyncio import Queue, Lock from typing import List, Tuple import numpy as np import time import logging logger = logging.getLogger(__name__) class DynamicBatcher: """ Dynamic batching with adaptive timeout. WHY: Static batching waits for full batch (high latency at low load). WHY: Dynamic batching adapts: full batch OR timeout (balanced). Key parameters: - max_batch_size: Maximum batch size (GPU memory limit) - max_wait_ms: Maximum wait time (latency target) Trade-off: - Larger batch → better GPU utilization, higher throughput - Shorter timeout → lower latency, worse GPU utilization """ def __init__( self, model_server, max_batch_size: int = 32, max_wait_ms: int = 10 ): self.model_server = model_server self.max_batch_size = max_batch_size self.max_wait_ms = max_wait_ms self.request_queue: Queue = Queue() self.batch_lock = Lock() self.stats = { "total_requests": 0, "total_batches": 0, "avg_batch_size": 0, "gpu_utilization": 0 } async def start(self): """Start batch processing loop.""" asyncio.create_task(self._batch_loop()) async def _batch_loop(self): """ Main batching loop. Algorithm: 1. Wait for first request 2. Start timeout timer 3. Collect requests until: - Batch full (max_batch_size reached) - OR timeout expired (max_wait_ms) 4. Process batch 5. Return results to waiting requests """ while True: batch = [] futures = [] # Wait for first request (no timeout) request_data, future = await self.request_queue.get() batch.append(request_data) futures.append(future) # Start deadline timer deadline = asyncio.get_event_loop().time() + (self.max_wait_ms / 1000) # Collect additional requests until batch full or timeout while len(batch) < self.max_batch_size: remaining_time = max(0, deadline - asyncio.get_event_loop().time()) try: request_data, future = await asyncio.wait_for( self.request_queue.get(), timeout=remaining_time ) batch.append(request_data) futures.append(future) except asyncio.TimeoutError: # Timeout: process what we have break # Process batch await self._process_batch(batch, futures) async def _process_batch( self, batch: List[np.ndarray], futures: List[asyncio.Future] ): """Process batch and return results.""" batch_size = len(batch) # Convert to batch array batch_array = np.array(batch) # Run inference start_time = time.time() predictions, probabilities = self.model_server.predict(batch_array) inference_time = (time.time() - start_time) * 1000 # Update stats self.stats["total_requests"] += batch_size self.stats["total_batches"] += 1 self.stats["avg_batch_size"] = ( self.stats["total_requests"] / self.stats["total_batches"] ) self.stats["gpu_utilization"] = ( self.stats["avg_batch_size"] / self.max_batch_size * 100 ) logger.info( f"Processed batch: size={batch_size}, " f"inference_time={inference_time:.2f}ms, " f"avg_batch_size={self.stats['avg_batch_size']:.1f}, " f"gpu_util={self.stats['gpu_utilization']:.1f}%" ) # Return results to waiting requests for i, future in enumerate(futures): if not future.done(): future.set_result((predictions[i], probabilities[i])) async def predict(self, inputs: np.ndarray) -> Tuple[int, np.ndarray]: """ Add request to batch queue. WHY: Transparent batching (caller doesn't see batching). WHY: Returns when batch processed (might wait for other requests). """ future = asyncio.Future() await self.request_queue.put((inputs, future)) # Wait for batch to be processed prediction, probability = await future return prediction, probability def get_stats(self): """Get batching statistics.""" return self.stats # Example usage with load simulation async def simulate_load(): """ Simulate varying load to demonstrate dynamic batching. WHY: Shows how batcher adapts to load: - High load: Fills batches quickly (high GPU util) - Low load: Processes smaller batches (low latency) """ from serve_fastapi import ModelServer model_server = ModelServer("model.pth") model_server.load_model() batcher = DynamicBatcher( model_server, max_batch_size=32, max_wait_ms=10 ) await batcher.start() # High load (32 concurrent requests) print("Simulating HIGH LOAD (32 concurrent)...") tasks = [] for i in range(32): inputs = np.random.randn(10) task = asyncio.create_task(batcher.predict(inputs)) tasks.append(task) results = await asyncio.gather(*tasks) print(f"High load results: {len(results)} predictions") print(f"Stats: {batcher.get_stats()}") # Expected: avg_batch_size ≈ 32, gpu_util ≈ 100% await asyncio.sleep(0.1) # Reset # Low load (1 request at a time) print("\nSimulating LOW LOAD (1 at a time)...") for i in range(10): inputs = np.random.randn(10) result = await batcher.predict(inputs) await asyncio.sleep(0.02) # 20ms between requests print(f"Stats: {batcher.get_stats()}") # Expected: avg_batch_size ≈ 1-2, gpu_util ≈ 5-10% # WHY: Timeout expires before batch fills (low latency maintained) if __name__ == "__main__": asyncio.run(simulate_load()) ``` **Batching performance:** | Load | Batch Size | GPU Util | Latency | Throughput | |------|-----------|----------|---------|------------| | High (100 req/sec) | 28-32 | 90% | 12ms | 1000 req/sec | | Medium (20 req/sec) | 8-12 | 35% | 11ms | 200 req/sec | | Low (5 req/sec) | 1-2 | 10% | 10ms | 50 req/sec | **Key insight:** Dynamic batching adapts to load while maintaining latency target. ## Part 6: Containerization **Why containerize:** "Works on my machine" → "Works everywhere" **Benefits:** - Reproducible builds (same dependencies, versions) - Isolated environment (no conflicts) - Portable deployment (dev, staging, prod identical) - Easy scaling (K8s, Docker Swarm) ### Multi-Stage Docker Build ```dockerfile # Dockerfile # WHY: Multi-stage build reduces image size by 50-80% # WHY: Build stage has compilers, runtime stage only has runtime deps # ==================== Stage 1: Build ==================== FROM python:3.11-slim as builder # WHY: Install build dependencies (needed for compilation) RUN apt-get update && apt-get install -y \ gcc \ g++ \ && rm -rf /var/lib/apt/lists/* # WHY: Create virtual environment in builder stage RUN python -m venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" # WHY: Copy only requirements first (layer caching) # WHY: If requirements.txt unchanged, this layer is cached COPY requirements.txt . # WHY: Install Python dependencies RUN pip install --no-cache-dir -r requirements.txt # ==================== Stage 2: Runtime ==================== FROM python:3.11-slim # WHY: Copy only virtual environment from builder (not build tools) COPY --from=builder /opt/venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" # WHY: Set working directory WORKDIR /app # WHY: Copy application code COPY serve_fastapi.py . COPY model.pth . # WHY: Non-root user for security RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app USER appuser # WHY: Expose port (documentation, not enforcement) EXPOSE 8000 # WHY: Health check for container orchestration HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD curl -f http://localhost:8000/health || exit 1 # WHY: Run with uvicorn (production ASGI server) CMD ["uvicorn", "serve_fastapi:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"] ``` ### Docker Compose for Multi-Service ```yaml # docker-compose.yml # WHY: Docker Compose for local development and testing # WHY: Defines multiple services (API, model, monitoring) version: '3.8' services: # Model serving API model-api: build: context: . dockerfile: Dockerfile ports: - "8000:8000" environment: # WHY: Environment variables for configuration - MODEL_PATH=/app/model.pth - LOG_LEVEL=INFO volumes: # WHY: Mount model directory (for updates without rebuild) - ./models:/app/models:ro deploy: resources: # WHY: Limit resources to prevent resource exhaustion limits: cpus: '2' memory: 4G reservations: # WHY: Reserve GPU (requires nvidia-docker) devices: - driver: nvidia count: 1 capabilities: [gpu] healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/health"] interval: 30s timeout: 10s retries: 3 # Redis for caching redis: image: redis:7-alpine ports: - "6379:6379" volumes: - redis-data:/data command: redis-server --appendonly yes # Prometheus for metrics prometheus: image: prom/prometheus:latest ports: - "9090:9090" volumes: - ./prometheus.yml:/etc/prometheus/prometheus.yml - prometheus-data:/prometheus command: - '--config.file=/etc/prometheus/prometheus.yml' # Grafana for visualization grafana: image: grafana/grafana:latest ports: - "3000:3000" environment: - GF_SECURITY_ADMIN_PASSWORD=admin volumes: - grafana-data:/var/lib/grafana volumes: redis-data: prometheus-data: grafana-data: ``` ### Build and Deploy ```bash # Build image # WHY: Tag with version for rollback capability docker build -t model-api:1.0.0 . # Run container docker run -d \ --name model-api \ -p 8000:8000 \ --gpus all \ model-api:1.0.0 # Check logs docker logs -f model-api # Test API curl http://localhost:8000/health # Start all services with docker-compose docker-compose up -d # Scale API service (multiple instances) # WHY: Load balancer distributes traffic across instances docker-compose up -d --scale model-api=3 # View logs docker-compose logs -f model-api # Stop all services docker-compose down ``` **Container image sizes:** | Stage | Size | Contents | |-------|------|----------| | Full build | 2.5 GB | Python + build tools + deps + model | | Multi-stage | 800 MB | Python + runtime deps + model | | Optimized | 400 MB | Minimal Python + deps + model | | Savings | **84%** | From 2.5 GB → 400 MB | ## Part 7: Framework Selection Guide ### Decision Matrix ```python # framework_selector.py from enum import Enum from typing import List class Requirement(Enum): FLEXIBILITY = "flexibility" # Custom preprocessing, business logic BATTERIES_INCLUDED = "batteries" # Minimal setup, built-in features LOW_LATENCY = "low_latency" # < 10ms target CROSS_FRAMEWORK = "cross_framework" # PyTorch + TensorFlow support EDGE_DEPLOYMENT = "edge" # Mobile, IoT, embedded EASE_OF_DEBUG = "debug" # Development experience HIGH_THROUGHPUT = "throughput" # > 1000 req/sec class Framework(Enum): FASTAPI = "fastapi" TORCHSERVE = "torchserve" GRPC = "grpc" ONNX = "onnx" # Framework capabilities (0-5 scale) FRAMEWORK_SCORES = { Framework.FASTAPI: { Requirement.FLEXIBILITY: 5, # Full control Requirement.BATTERIES_INCLUDED: 2, # Manual implementation Requirement.LOW_LATENCY: 3, # 10-20ms Requirement.CROSS_FRAMEWORK: 4, # Any Python model Requirement.EDGE_DEPLOYMENT: 2, # Heavyweight Requirement.EASE_OF_DEBUG: 5, # Excellent debugging Requirement.HIGH_THROUGHPUT: 3 # 100-500 req/sec }, Framework.TORCHSERVE: { Requirement.FLEXIBILITY: 3, # Customizable via handlers Requirement.BATTERIES_INCLUDED: 5, # Everything built-in Requirement.LOW_LATENCY: 4, # 5-15ms Requirement.CROSS_FRAMEWORK: 1, # PyTorch only Requirement.EDGE_DEPLOYMENT: 2, # Heavyweight Requirement.EASE_OF_DEBUG: 3, # Learning curve Requirement.HIGH_THROUGHPUT: 5 # 1000+ req/sec with batching }, Framework.GRPC: { Requirement.FLEXIBILITY: 4, # Binary protocol, custom logic Requirement.BATTERIES_INCLUDED: 2, # Manual implementation Requirement.LOW_LATENCY: 5, # 3-8ms Requirement.CROSS_FRAMEWORK: 4, # Any model Requirement.EDGE_DEPLOYMENT: 3, # Moderate size Requirement.EASE_OF_DEBUG: 2, # Binary protocol harder Requirement.HIGH_THROUGHPUT: 5 # 1000+ req/sec }, Framework.ONNX: { Requirement.FLEXIBILITY: 3, # Limited to ONNX ops Requirement.BATTERIES_INCLUDED: 3, # Runtime provided Requirement.LOW_LATENCY: 5, # 2-6ms (optimized) Requirement.CROSS_FRAMEWORK: 5, # Any framework → ONNX Requirement.EDGE_DEPLOYMENT: 5, # Lightweight runtime Requirement.EASE_OF_DEBUG: 2, # Conversion can be tricky Requirement.HIGH_THROUGHPUT: 4 # 500-1000 req/sec } } def select_framework( requirements: List[Requirement], weights: List[float] = None ) -> Framework: """ Select best framework based on requirements. Args: requirements: List of requirements weights: Importance weight for each requirement (0-1) Returns: Best framework """ if weights is None: weights = [1.0] * len(requirements) scores = {} for framework in Framework: score = 0 for req, weight in zip(requirements, weights): score += FRAMEWORK_SCORES[framework][req] * weight scores[framework] = score best_framework = max(scores, key=scores.get) print(f"\nFramework Selection:") print(f"Requirements: {[r.value for r in requirements]}") print(f"\nScores:") for framework, score in sorted(scores.items(), key=lambda x: x[1], reverse=True): print(f" {framework.value}: {score:.1f}") return best_framework # Example use cases print("=" * 60) print("Use Case 1: Prototyping with flexibility") print("=" * 60) selected = select_framework([ Requirement.FLEXIBILITY, Requirement.EASE_OF_DEBUG ]) print(f"\nRecommendation: {selected.value}") # Expected: FASTAPI print("\n" + "=" * 60) print("Use Case 2: Production PyTorch with minimal setup") print("=" * 60) selected = select_framework([ Requirement.BATTERIES_INCLUDED, Requirement.HIGH_THROUGHPUT ]) print(f"\nRecommendation: {selected.value}") # Expected: TORCHSERVE print("\n" + "=" * 60) print("Use Case 3: Low-latency microservice") print("=" * 60) selected = select_framework([ Requirement.LOW_LATENCY, Requirement.HIGH_THROUGHPUT ]) print(f"\nRecommendation: {selected.value}") # Expected: GRPC or ONNX print("\n" + "=" * 60) print("Use Case 4: Edge deployment (mobile/IoT)") print("=" * 60) selected = select_framework([ Requirement.EDGE_DEPLOYMENT, Requirement.CROSS_FRAMEWORK, Requirement.LOW_LATENCY ]) print(f"\nRecommendation: {selected.value}") # Expected: ONNX print("\n" + "=" * 60) print("Use Case 5: Multi-framework ML platform") print("=" * 60) selected = select_framework([ Requirement.CROSS_FRAMEWORK, Requirement.HIGH_THROUGHPUT, Requirement.BATTERIES_INCLUDED ]) print(f"\nRecommendation: {selected.value}") # Expected: ONNX or TORCHSERVE (depending on weights) ``` ### Quick Reference Guide | Scenario | Framework | Why | |----------|-----------|-----| | **Prototyping** | FastAPI | Fast iteration, easy debugging | | **PyTorch production** | TorchServe | Built-in batching, metrics, management | | **Internal microservices** | gRPC | Lowest latency, high throughput | | **Multi-framework** | ONNX Runtime | Framework-agnostic, optimized | | **Edge/mobile** | ONNX Runtime | Lightweight, cross-platform | | **Custom preprocessing** | FastAPI | Full flexibility | | **High throughput batch** | TorchServe + batching | Dynamic batching built-in | | **Real-time streaming** | gRPC | Bidirectional streaming | ## Summary **Model serving is pattern matching, not one-size-fits-all.** **Core patterns:** 1. **FastAPI:** Flexibility, custom logic, easy debugging 2. **TorchServe:** PyTorch batteries-included, built-in batching 3. **gRPC:** Low latency (3-5ms), high throughput, microservices 4. **ONNX Runtime:** Cross-framework, optimized, edge deployment 5. **Dynamic batching:** Adaptive batch size, balances latency and throughput 6. **Containerization:** Reproducible, portable, scalable **Selection checklist:** - ✓ Identify primary requirement (flexibility, latency, throughput, etc.) - ✓ Match requirement to framework strengths - ✓ Consider deployment environment (cloud, edge, on-prem) - ✓ Evaluate trade-offs (development speed vs performance) - ✓ Implement batching if GPU-based (10× better utilization) - ✓ Containerize for reproducibility - ✓ Monitor metrics (latency, throughput, GPU util) - ✓ Iterate based on production data **Anti-patterns to avoid:** - ✗ model.pkl in repo (dependency hell) - ✗ gRPC for simple REST use cases (over-engineering) - ✗ No batching with GPU (wasted 80% capacity) - ✗ Not containerized (deployment inconsistency) - ✗ Static batching (poor latency at low load) Production-ready model serving requires matching infrastructure pattern to requirements.