Jared AI Hub
Published on

FastAPI for ML Engineers: Building Production APIs

Authors
  • avatar
    Name
    Jared Chung
    Twitter

Introduction

You've trained a great model. Now what? Most ML value comes from deploying models where they can serve predictions. FastAPI has become the go-to framework for ML APIs fast, modern, and designed with async support that's perfect for I/O-heavy ML workloads.

This guide covers everything you need to build production-ready ML APIs with FastAPI.

FastAPI ML API Architecture

Why FastAPI for ML?

FeatureBenefit for ML
Async nativeHandle concurrent requests while model processes
Automatic docsOpenAPI docs generated from type hints
Pydantic validationValidate inputs before hitting your model
Streaming supportStream LLM tokens in real-time
Background tasksLong-running inference without blocking
Easy testingBuilt-in test client

Project Structure

A well-organized ML API project:

ml_api/
├── app/
│   ├── __init__.py
│   ├── main.py              # FastAPI app
│   ├── config.py            # Settings
│   ├── models/
│   │   ├── __init__.py
│   │   ├── schemas.py       # Pydantic models
│   │   └── ml_models.py     # ML model loading
│   ├── routers/
│   │   ├── __init__.py
│   │   ├── predictions.py   # Prediction endpoints
│   │   └── health.py        # Health checks
│   ├── services/
│   │   ├── __init__.py
│   │   └── inference.py     # Business logic
│   └── middleware/
│       ├── __init__.py
│       └── logging.py       # Request logging
├── tests/
├── Dockerfile
├── requirements.txt
└── pyproject.toml

Basic ML API

Let's start with a simple image classification API:

# app/main.py
from fastapi import FastAPI
from contextlib import asynccontextmanager
from app.models.ml_models import load_model

ml_models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load model on startup
    ml_models["classifier"] = load_model("resnet50")
    yield
    # Cleanup on shutdown
    ml_models.clear()

app = FastAPI(
    title="ML Prediction API",
    version="1.0.0",
    lifespan=lifespan
)
# app/models/schemas.py
from pydantic import BaseModel, Field
from typing import List

class PredictionRequest(BaseModel):
    image_url: str = Field(..., description="URL of image to classify")

class Prediction(BaseModel):
    label: str
    confidence: float = Field(..., ge=0, le=1)

class PredictionResponse(BaseModel):
    predictions: List[Prediction]
    model_version: str
    inference_time_ms: float
# app/routers/predictions.py
from fastapi import APIRouter, HTTPException
from app.models.schemas import PredictionRequest, PredictionResponse
from app.services.inference import classify_image
import time

router = APIRouter(prefix="/api/v1", tags=["predictions"])

@router.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    start = time.perf_counter()

    try:
        predictions = await classify_image(request.image_url)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

    inference_time = (time.perf_counter() - start) * 1000

    return PredictionResponse(
        predictions=predictions,
        model_version="1.0.0",
        inference_time_ms=round(inference_time, 2)
    )

Handling Long-Running Inference

ML inference can be slow. Handle it properly.

Background Tasks

For non-blocking inference with callbacks:

from fastapi import BackgroundTasks
from uuid import uuid4

# In-memory store (use Redis in production)
job_store = {}

class JobStatus(BaseModel):
    job_id: str
    status: str  # pending, processing, completed, failed
    result: Optional[dict] = None

@router.post("/predict/async")
async def predict_async(
    request: PredictionRequest,
    background_tasks: BackgroundTasks
):
    job_id = str(uuid4())
    job_store[job_id] = {"status": "pending", "result": None}

    background_tasks.add_task(
        run_inference_job,
        job_id,
        request
    )

    return {"job_id": job_id, "status_url": f"/api/v1/jobs/{job_id}"}

async def run_inference_job(job_id: str, request: PredictionRequest):
    job_store[job_id]["status"] = "processing"
    try:
        result = await classify_image(request.image_url)
        job_store[job_id] = {"status": "completed", "result": result}
    except Exception as e:
        job_store[job_id] = {"status": "failed", "result": str(e)}

@router.get("/jobs/{job_id}")
async def get_job_status(job_id: str):
    if job_id not in job_store:
        raise HTTPException(status_code=404, detail="Job not found")
    return job_store[job_id]

Streaming Responses (LLMs)

For LLM token streaming:

from fastapi.responses import StreamingResponse
import asyncio

@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
    async def generate():
        async for token in llm.stream(request.message):
            yield f"data: {token}\n\n"
        yield "data: [DONE]\n\n"

    return StreamingResponse(
        generate(),
        media_type="text/event-stream"
    )

For real-world LLM streaming with proper error handling:

from openai import AsyncOpenAI
import json

client = AsyncOpenAI()

@router.post("/chat")
async def chat(request: ChatRequest):
    async def event_generator():
        try:
            stream = await client.chat.completions.create(
                model="gpt-4",
                messages=[{"role": "user", "content": request.message}],
                stream=True
            )

            async for chunk in stream:
                if chunk.choices[0].delta.content:
                    content = chunk.choices[0].delta.content
                    yield f"data: {json.dumps({'content': content})}\n\n"

            yield f"data: {json.dumps({'done': True})}\n\n"

        except Exception as e:
            yield f"data: {json.dumps({'error': str(e)})}\n\n"

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
        }
    )

Model Loading Patterns

Lazy Loading

Load models only when first requested:

from functools import lru_cache

class ModelManager:
    def __init__(self):
        self._models = {}

    def get_model(self, model_name: str):
        if model_name not in self._models:
            self._models[model_name] = self._load_model(model_name)
        return self._models[model_name]

    def _load_model(self, model_name: str):
        # Load based on name
        if model_name == "classifier":
            return load_classifier()
        elif model_name == "embedder":
            return load_embedder()
        raise ValueError(f"Unknown model: {model_name}")

model_manager = ModelManager()

# In endpoint
def get_model_manager():
    return model_manager

@router.post("/embed")
async def embed(
    request: EmbedRequest,
    manager: ModelManager = Depends(get_model_manager)
):
    model = manager.get_model("embedder")
    return model.embed(request.text)

Multiple Model Versions

Support A/B testing with model versions:

class ModelRegistry:
    def __init__(self):
        self.models = {
            "v1": load_model_v1(),
            "v2": load_model_v2(),
        }
        self.default_version = "v2"

    def predict(self, input_data, version: str = None):
        version = version or self.default_version
        model = self.models.get(version)
        if not model:
            raise ValueError(f"Model version {version} not found")
        return model.predict(input_data)

@router.post("/predict")
async def predict(
    request: PredictionRequest,
    model_version: str = Query(default=None)
):
    return registry.predict(request.data, model_version)

Request Validation

Pydantic makes input validation easy:

from pydantic import BaseModel, Field, validator
from typing import List, Optional

class TextGenerationRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=4096)
    max_tokens: int = Field(default=256, ge=1, le=4096)
    temperature: float = Field(default=0.7, ge=0, le=2)
    stop_sequences: Optional[List[str]] = Field(default=None, max_items=4)

    @validator("prompt")
    def prompt_not_empty(cls, v):
        if not v.strip():
            raise ValueError("Prompt cannot be empty or whitespace")
        return v.strip()

    @validator("stop_sequences", each_item=True)
    def stop_sequence_valid(cls, v):
        if len(v) > 20:
            raise ValueError("Stop sequence too long")
        return v

Error Handling

Consistent error responses:

from fastapi import Request
from fastapi.responses import JSONResponse

class MLException(Exception):
    def __init__(self, message: str, error_code: str, status_code: int = 500):
        self.message = message
        self.error_code = error_code
        self.status_code = status_code

class ModelNotLoadedError(MLException):
    def __init__(self, model_name: str):
        super().__init__(
            message=f"Model {model_name} is not loaded",
            error_code="MODEL_NOT_LOADED",
            status_code=503
        )

class InvalidInputError(MLException):
    def __init__(self, detail: str):
        super().__init__(
            message=detail,
            error_code="INVALID_INPUT",
            status_code=400
        )

@app.exception_handler(MLException)
async def ml_exception_handler(request: Request, exc: MLException):
    return JSONResponse(
        status_code=exc.status_code,
        content={
            "error": exc.error_code,
            "message": exc.message,
            "path": str(request.url)
        }
    )

Health Checks

Essential for production deployments:

from enum import Enum

class HealthStatus(str, Enum):
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"

class HealthResponse(BaseModel):
    status: HealthStatus
    model_loaded: bool
    gpu_available: bool
    version: str

@router.get("/health", response_model=HealthResponse)
async def health_check():
    model_loaded = "classifier" in ml_models
    gpu_available = torch.cuda.is_available()

    if model_loaded and gpu_available:
        status = HealthStatus.HEALTHY
    elif model_loaded:
        status = HealthStatus.DEGRADED
    else:
        status = HealthStatus.UNHEALTHY

    return HealthResponse(
        status=status,
        model_loaded=model_loaded,
        gpu_available=gpu_available,
        version=app.version
    )

# Kubernetes-style probes
@router.get("/health/live")
async def liveness():
    return {"status": "alive"}

@router.get("/health/ready")
async def readiness():
    if "classifier" not in ml_models:
        raise HTTPException(status_code=503, detail="Model not ready")
    return {"status": "ready"}

Rate Limiting

Protect your API from abuse:

from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

@router.post("/predict")
@limiter.limit("10/minute")
async def predict(request: Request, data: PredictionRequest):
    # Rate limited to 10 requests per minute per IP
    return await run_prediction(data)

Middleware

Add observability:

import time
import logging
from uuid import uuid4

logger = logging.getLogger(__name__)

@app.middleware("http")
async def logging_middleware(request: Request, call_next):
    request_id = str(uuid4())
    start_time = time.perf_counter()

    # Add request ID to state
    request.state.request_id = request_id

    response = await call_next(request)

    duration = time.perf_counter() - start_time

    logger.info(
        "Request completed",
        extra={
            "request_id": request_id,
            "method": request.method,
            "path": request.url.path,
            "status_code": response.status_code,
            "duration_ms": round(duration * 1000, 2)
        }
    )

    response.headers["X-Request-ID"] = request_id
    return response

Deployment

Docker

FROM python:3.11-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application
COPY app/ app/

# Non-root user
RUN useradd -m appuser && chown -R appuser:appuser /app
USER appuser

# Run with uvicorn
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

Production Configuration

# app/config.py
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    app_name: str = "ML API"
    debug: bool = False
    model_path: str = "/models/classifier.pt"
    max_batch_size: int = 32
    request_timeout: int = 30
    workers: int = 4

    class Config:
        env_file = ".env"

settings = Settings()

Run with Gunicorn for production:

gunicorn app.main:app \
    --workers 4 \
    --worker-class uvicorn.workers.UvicornWorker \
    --bind 0.0.0.0:8000 \
    --timeout 120 \
    --keep-alive 5

Testing

from fastapi.testclient import TestClient
from app.main import app
import pytest

client = TestClient(app)

def test_health_check():
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json()["status"] == "healthy"

def test_prediction():
    response = client.post(
        "/api/v1/predict",
        json={"image_url": "https://example.com/cat.jpg"}
    )
    assert response.status_code == 200
    assert "predictions" in response.json()

def test_invalid_input():
    response = client.post(
        "/api/v1/predict",
        json={"image_url": ""}
    )
    assert response.status_code == 422  # Validation error

Conclusion

FastAPI provides everything you need for production ML APIs. Start simple, add complexity as needed, and always prioritize reliability and observability. Your users don't care how clever your API is they care that it works, every time.