Deep Learning Frameworks

Purpose

Deep learning frameworks provide:

  • Automatic differentiation — compute gradients of any computation without manual derivation
  • Hardware abstraction — run the same code on CPU, GPU, or TPU
  • Optimized primitives — BLAS, cuDNN, and XLA kernels for tensor operations
  • Ecosystem — dataloaders, model zoos, deployment tooling

The three dominant frameworks are PyTorch, TensorFlow/Keras, and JAX.

Architecture

All frameworks build on the same core abstraction: a tensor (n-dimensional array on a device) + a computation graph that records operations for gradient computation via reverse-mode autodiff.

PyTorchTensorFlow 2 / KerasJAX
Graph styleDynamic (eager by default)Eager + @tf.function staticFunctional, JIT via jit
Autodiffautograd via .backward()GradientTapejax.grad, jax.value_and_grad
Primary abstractionnn.Modulekeras.ModelPure functions + pytree
ParallelismDataParallel, DistributedDataParallelMirroredStrategyjax.pmap, jax.vmap
DeploymentTorchScript, ONNXSavedModel, TFLite, TF ServingXLA-compiled HLO

Dynamic vs static graphs: PyTorch builds the graph on-the-fly each forward pass (easy to debug with standard Python tools); TensorFlow’s @tf.function and JAX’s jit compile a static graph for performance at the cost of tracing overhead and constraints on Python control flow.

Implementation Notes

PyTorch — core patterns

Define a model:

import torch
import torch.nn as nn
 
class MLP(nn.Module):
    def __init__(self, in_dim, hidden, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )
 
    def forward(self, x):
        return self.net(x)

Standard training loop:

model = MLP(784, 256, 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
 
for epoch in range(num_epochs):
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()          # populate .grad on all parameters
        optimizer.step()         # update parameters

Key APIs:

TaskAPI
Move to GPUtensor.to(device) / model.to(device)
Disable gradient trackingtorch.no_grad() context manager
Save / load checkpointtorch.save(model.state_dict(), path) / model.load_state_dict(...)
Freeze layersparam.requires_grad = False
Inspect parameter countsum(p.numel() for p in model.parameters())
Data pipelineDataset + DataLoader with num_workers
Compile for speedtorch.compile(model) (PyTorch 2+, wraps Triton/Inductor)

Gradient utilities:

# Gradient clipping (important for RNNs)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
# Manual gradient inspection
for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.norm())

TensorFlow / Keras — core patterns

model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10),
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_ds, epochs=10, validation_data=val_ds)

Use GradientTape for custom training loops:

with tf.GradientTape() as tape:
    logits = model(x, training=True)
    loss = loss_fn(y, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

JAX — core patterns

JAX is functional: models are plain Python functions; state is explicit.

import jax
import jax.numpy as jnp
 
def loss_fn(params, x, y):
    logits = model.apply(params, x)   # Flax / Haiku style
    return cross_entropy(logits, y)
 
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(params, x, y)
params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

Use jax.jit for compilation, jax.vmap for batching, jax.pmap for multi-device.

Trade-offs

CriterionPyTorchTensorFlow 2JAX
Ease of debugging✅ Native Python⚠️ Graph tracing errors can be opaque⚠️ Functional style has a learning curve
Research flexibility✅ Industry standard✅ Good✅ Best for custom autodiff
Production deployment✅ TorchScript/ONNX✅ TF Serving, TFLite⚠️ Less mature tooling
Performancetorch.compile competitive✅ XLA strong✅ XLA best-in-class on TPU
Ecosystem / model zoos✅ Largest (HuggingFace, timm)✅ Large⚠️ Smaller but growing (Flax, Optax)
Multi-device parallelism✅ DDP mature✅ Strategy APIpmap elegant

Framework selection heuristic:

  • Default to PyTorch — largest research community, HuggingFace ecosystem, strong production story
  • Use TensorFlow/Keras when deploying to mobile/edge (TFLite) or when TF Serving is already in production
  • Use JAX for research requiring custom autodiff, TPU-heavy workloads, or functional programming style (e.g., DeepMind, Google Brain work)

References