ML-Workshop/dataset.py
2025-03-20 12:45:57 -04:00

24 lines
750 B
Python

import torchvision
from consts import CIFAR_DIR
#optional transformations:
# https://pytorch.org/vision/0.11/transforms.html
#training data using torchvision cifar.
cifar_data_train = torchvision.datasets.CIFAR10(root = CIFAR_DIR, train = True, transform = None, download = True)
#example of cifar data sample. It is an image, class example.
# here, the image is the image (PIL, or pillow) and the corresponding label, frog. I've chopped the dataset to only include cats
# and dogs, so we can apply a different form of classification so it's easier to perform
example_data = cifar_data_train[0]
print(f'items in an instance of cifar10: {len(example_data)}')
example_data[0].show()
print(f'class corresponding to image: {example_data[1]}')