Chapter 9: Trainer
For any training loop we will write something like this:
model = Mymodel()
optimizer = myOptimizer()
dataloader = MyDataloader()
for epoch in range(10):
for batch in dataloader:
optimizer.zero_grad()
pred = model(batch.x)
loss = loss_fn(pred, batch.y)
loss.backward()
optimizer.step()
This works fine, but it becomes repetitive. Every model and every experiment requires writing the same boilerplate. To make our workflow cleaner, we can wrap this logic inside a Trainer class.
Note: This is how our folder structure currently looks like. In this chapter we will work inside babygrad/trainer.py.
project/
├─ .venv/
├─ babygrad/
| ├─ trainer.py
│ ├─ __init__.py
│ ├─ data.py
│ ├─ init.py
│ ├─ ops.py
│ ├─ tensor.py
│ ├─ nn.py
│ └─ optim.py
├─ examples/
│ └─ simple_mnist.py
└─ tests/
What does the Trainer class need? It needs everything that is used in the training loops.
- Model
- Loss function
- Optimizer
- Dataloader
trainer=Trainer(model,optimizer,loss_fn,train_loader,val_loader=test_loader)
print("Starting Training...")
trainer.fit(EPOCHS)
File : babygrad/trainer.py
Exercise 9.1
Lets write the fit method inside Trainer class.
from babygrad.tensor import Tensor
class Trainer:
def __init__(self, model, optimizer, loss_fn, train_loader,
val_loader=None):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.train_loader = train_loader
self.val_loader = val_loader
def fit(self, epochs: int):
"""
Runs the training loop for the specified number of epochs.
"""
for epoch in range(epochs):
self.model.train() # Set mode to training
total_loss = 0
# Your solution here:
# 1. Iterate over self.train_loader
# 2. Get batch data (x, y)
# 3. Zero Gradients
# 4. Forward Pass
# 5. Compute Loss
# 6. Backward Pass
# 7. Optimizer Step
print(f"Epoch {epoch+1} Done.")
def evaluate(self, loader=None):
"""
Calculates accuracy on the validation set.
"""
target_loader = loader if loader is not None else self.val_loader
if target_loader is None:
return 0.0
# Hint:
# 1. Set model to evaluation mode: self.model.eval()
# 2. Loop over self.val_loader
# 3. Forward pass only (no backward)
# 4. Compare predictions to true labels(use argmax(axis=1))
# 5. Sum the correct predictions and calculate the average.
pass
Original: zekcrates/trainer