1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| import os import gzip import numpy as np import torch from torch.utils.data import Dataset
class FashionDataset(Dataset): def __init__(self, datadir, transform, is_train=True): super().__init__() self.datadir = datadir self.transform = transform self.img, self.label = self.load_data(datadir, is_train) self.len_data = len(self.img)
def __getitem__(self, index): return self.transform(self.img[index]), self.label[index]
def __len__(self): return self.len_data
def load_data(self, datadir, is_train): dirname = os.path.join(datadir) files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'] paths = [os.path.join(dirname, f) for f in files]
if is_train: with gzip.open(paths[0], 'rb') as lbpath: label = np.frombuffer(lbpath.read(), np.uint8, offset=8) with gzip.open(paths[1], 'rb') as imgpath: img = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(label), 28, 28) else: with gzip.open(paths[2], 'rb') as lbpath: label = np.frombuffer(lbpath.read(), np.uint8, offset=8) with gzip.open(paths[3], 'rb') as imgpath: img = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(label), 28, 28) return img, label
|