Chapter 7: Initialization
This is how our folder structure currently looks like. In this chapter we will work inside babygrad/init.py.
project/
├─ .venv/
├─ babygrad/
│ ├─ __init__.py
│ ├─ init.py
│ ├─ ops.py
│ ├─ tensor.py
│ ├─ nn.py
│ └─ optim.py
├─ examples/
│ └─ simple_mnist.py
└─ tests/
Up until now, before starting training, we've been randomly initializing our weights. This seems perfectly reasonable after all, we know that by the end of the training process, gradient descent will find the optimal weights we need.
So why should we even care about how we initialize the weights? Won't training just fix any poor initial choices anyway?
This is a critical question. The answer is that while training finds the destination, initialization determines the starting point of the journey. A poor start can make that journey impossibly slow, unstable, or even prevent it from beginning at all.
If we don't take into consideration the seriousness of good initialization, we are almost certain to face one of these three crippling problems:
The Symmetry Problem (Similar Weights): If all weights start with the same value, every neuron in a layer learns the exact same feature. It's like having a team where every member is a perfect clone doing the same job completely defeating the purpose of having a team in the first place.
The Vanishing Gradient Problem (Smaller Weights): When weights are consistently too small, the signal (and its gradient) shrinks as it passes through each layer. By the time it reaches the early layers, it's so faint that they receive no meaningful updates. The model simply refuses to learn anything.
The Exploding Gradient Problem (Bigger Weights): The opposite occurs when weights are too large. The signal grows exponentially with each layer until it becomes massive. This results in huge, unstable updates during training that cause the loss to fly towards infinity, often seen as NaN values in your output.
What is Uniform Distribution?
A distribution where every number between a minimum value (a)
and a maximum value (b) has an equal and
constant probability of being chosen.
What is Normal Distribution?
A distribution shaped like a "bell curve" where values are centered around an average (the mean). Values become progressively less likely the further they are from the mean. The width or "spread" of the curve is controlled by the standard deviation.
So what shall we do?
We need to choose initializations that are healthy and not too big or too small.
But how?
7.1 Xavier
7.1.1 Xavier Uniform
W ~ U(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out)))
What are fan_in and fan_out?
These terms refer to the number of input and output connections to a neuron in a layer. For a standard Linear (fully-connected) layer, the definitions are straightforward:
fan_in: The number of input features to the layer.
fan_out: The number of output features from the layer.
For example, in a layer defined as Linear(in_features=784, out_features=256), the fan_in is 784 and the fan_out is 256.
File: babygrad/init.py
Exercise 7.1: Implement xavier_uniform function.
def xavier_uniform(fan_in: int, fan_out: int, gain: float = 1.0, **kwargs):
"""
The weights are drawn from a uniform distribution U[-a, a], where
a = gain * sqrt(6 / (fan_in + fan_out)).
"""
# (-a,a) : find `a` . return Tensor
#use Tensor.rand
#your code
7.1.2 Xavier Normal
For a normal distribution, the goal is the same, but instead of defining a hard limit, we define the standard deviation (std) of the distribution
W ~ N(0, 2 / (fan_in + fan_out))
File: babygrad/init.py
Exercise 7.2: Implement xavier_normal function.
def xavier_normal(fan_in: int, fan_out: int, gain: float = 1.0, **kwargs):
"""
Xavier normal initialization.
Calls Tensor.randn() with the correct standard deviation.
"""
# std = the formula above
#return Tensor
7.2 Kaiming
While Xavier initialization works well for tanh and sigmoid activations, a different approach is needed for ReLU activation function. Kaiming initialization adjusts the formulas to account for the properties of ReLU.
7.2.1 Kaiming Uniform
W ~ U(-sqrt(6 / fan_in), sqrt(6 / fan_in))
Exercise 7.3: Implement kaiming_uniform function.
def kaiming_uniform(fan_in: int, fan_out: int **kwargs):
"""
Kaiming uniform initialization.
"""
#your code
#Use Tensor.rand
7.2.2 Kaiming Normal
W ~ N(0, 2 / fan_in)
Exercise 7.4: Implement kaiming_normal function.
def kaiming_normal(fan_in: int, fan_out: int **kwargs):
"""
Kaiming uniform initialization.
"""
#your code
#use Tensor.randn
Original: zekcrates/initialization