Chapter 6: Data Handling
So far in our simple_mnist example, we've loaded our data and sliced it into batches manually. This is fine for a simple script, but it's not a flexible or reusable solution. What if our data is a huge collection of text files, or a massive CSV? We need a simple way to handle data.
This is where the Dataset and DataLoader classes come in. They separate the problem into two parts.
6.1 Dataset
What should the Dataset give us?
- Length of the dataset.
- Retrieve data item at a given index.
__len____getitem__
FILE : babygrad/data.py
class Dataset:
""" Base class representing a dataset.
This is the base class for all datasets. `__len__` method
and the `__getitem__` method
(which supports fetching a data sample at a given index).
Args:
transforms (list, optional): A list of functions that take
a data sample and return a transformed version. Applied
in the order they are provided. Defaults to None.
Example:
>>> class MyNumberDataset(Dataset):
... def __init__(self, numbers):
... super().__init__()
... self.numbers = numbers
... def __len__(self):
... return len(self.numbers)
... def __getitem__(self, index):
... x = self.numbers[index]
y = x ** 2
return self.apply_transform(x), y
...
>>> dataset = MyNumberDataset([1, 2, 3, 4])
>>> print(f"Dataset size: {len(dataset)}")
Dataset size: 4
>>> print(f"Third sample: {dataset[2]}")
Third sample: (3, 9)
"""
def __init__(self, transforms=None):
self.transforms = transforms
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def apply_transform(self, x):
if not self.transforms:
return x
for t in self.transforms:
x = t(x)
return x
Dataset is the Base class that we will use to create other Datasets.
What are transforms here?
Transforms are functions that modify data samples before they're used for training. They're applied on-the-fly when you fetch an item from the dataset.
Why do we need them?
- Normalization : Converting pixel values from [0, 255] to [0, 1] .
- Data augmentation : Random crops, flips, rotations for images.
- Type Conversion : Converting PIL images to tensors.
- Preprocessing : Tokenizing text, applying filters, resizing images.
def normalize(x):
"""Normalize to [0, 1] range."""
return x / 255.0
def add_noise(x):
"""Add random noise for augmentation."""
return x + np.random.randn(*x.shape) * 0.01
# Create dataset with transforms
dataset = MyDataset(transforms=[normalize, add_noise])
6.2 MNIST Dataset
Lets create our MNIST Dataset using the classes above.
The images on a given index return an array of size (784). We need to reshape it into (28,28,1).
The index can be :
- Number: Return as (28,28,1)
- Slice: Return as (slice_length,28,28,1)
FILE : babygrad/data.py
Exercise 6.1: Implement __getitem__ method.
class MNISTDataset(Dataset):
def __init__(
self,
image_filename: str,
label_filename: str,
transforms: Optional[List] = None,
):
self.images, self.labels = parse_mnist(image_filename=
image_filename,label_filename=label_filename)
self.transforms = transforms
def __getitem__(self, index) -> object:
#get the image and labels
# convert to np array
# reshape image (28,28,1) If single index else (slice_length,28,28,1)
#apply transforms if applicable.
# return (sample_image, sample_label)
def __len__(self) -> int:
return len(self.images)
Get the parse_mnist from examples.
Use reshape.
6.3 Dataloader
The Dataset nicely gives us a single pair (x,y) of samples from the dataset when given an index.
But we need samples in batches and we also need to decide if we need them in order of their occurence in the dataset or randomly. Random samples are important for model efficiency.
So We need a Dataloader that will give us batches of samples from the dataset.
What should Dataloader contain?
- Load a dataset of type
Dataset. - To shuffle or not.
- Batch size .
for (x,y) in dataloader(yourdataset,shuffle=True, batchsize=8):
# do something
#forward
#loss
# backward
FILE : babygrad/data.py
__iter__: Should initialize__next__: Should return next sample from dataset.
Exercise 5.1
class DataLoader:
"""Provides an iterator for easy batching, shuffling, and loading of
data.
Args:
dataset (Dataset):
batch_size (int, optional):
shuffle (bool, optional):
def __init__(self,
dataset: Dataset,
batch_size: int = 1,
shuffle: bool = True):
self.dataset = dataset
self.shuffle = shuffle
self.batch_size = batch_size
def __iter__(self):
self.indices = np.arange(len(self.dataset))
# shuffle the indices if shuffle is true
# initialize batch idx =0
# get the number of batches (dataset/batch_size)
# return self
def __next__(self):
if self.batch_idx >= self.num_batches:
raise StopIteration
start = self.batch_idx * self.batch_size
batch_indices = self.indices[start: start+self.batch_size]
samples = [self.dataset[i] for i in batch_indices]
#your solution
# unzip the samples
# stack them
# wrap around Tensor and return .
self.batch_idx+=1
Use np.stack for stacking and convert the samples to Tensor before returning.
Original: zekcrates/data