Concept: Why not model.forward() but model()
Core Intuition
- Calling
model(x)invokesnn.Module.__call__, a wrapper aroundforward()that runs hooks, manages state, sets up autograd, and handles distributed sync. - Calling
model.forward(x)directly bypasses all of this infrastructure.
Analogy
Think of __call__ as the interface designed for the user, and forward() as the implementation detail meant for the developer. Using model(x) ensures the entire ecosystem of PyTorch features remains fully functional.
Component of
Insights
- Never call
model.forward(x)directly in user code; always usemodel(x). - Hooks will not fire: Any pre- or post-forward hooks registered will be ignored, breaking debugging or logging tools.
- State inconsistency: Necessary checks or updates to internal module states that occur in the
__call__wrapper may be bypassed. - Distributed issues: In distributed training (e.g., DDP), skipping
__call__means the necessary synchronization logic between GPUs will not be triggered.
Connections
- PyTorch Hooks todo (forward_pre_hooks, forward_hooks)
- Autograd and Computation Graph todo
- DistributedDataParallel (DDP) todo
- model.train() and model.eval() todo
- torch.compile todo
Implementation Notes
Simple example to demonstrate __call__ vs. forward (illustrative plain class, not actual nn.Module):
class MyModel:
def __init__(self, name):
self.name = name
def forward(self, x):
"""
Defines the computation performed at every call.
Should be overridden by all subclasses.
"""
return x * 2
def __call__(self, x):
"""
The entry point for module execution.
Handles hooks, pre-processing, and post-processing logic
before invoking the forward pass.
"""
print(f"Running pre-forward hooks for {self.name}")
# Invoke the core implementation: forward().
result = self.forward(x)
print(f"Running post-forward hooks for {self.name}")
return result
# Simple usage example
model = MyModel("TestModel")
output = model(10) # This triggers __call__ internallynn.Module.__call__ in torch/nn/modules/module.py performs several critical steps before and after invoking forward():
forward_pre_hooks: Executed beforeforward()— useful for logging, modifying inputs, or inspecting data shapes.forward_hooks: Executed afterforward()— commonly used to capture activations, perform feature extraction, or debug layers.- State Management: Handles the
trainingflag, ensuring BatchNorm running statistics and Dropout behavior correctly switch between training and evaluation modes. - Autograd and Graph Construction: Ensures correct interaction with the C++ autograd engine for backpropagation graph construction.
- Distributed Processing: Handles communication primitives (e.g.,
AllReduce) to synchronize gradients or parameters across devices. - Profiling and Tracing: PyTorch Profiler and TorchScript rely on these wrappers to intercept calls for recording execution times, memory usage, or tracing graph structure.