| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 
 | import os
 import gzip
 import numpy as np
 import torch
 from torch.utils.data import Dataset
 
 class FashionDataset(Dataset):
 def __init__(self, datadir, transform, is_train=True):
 super().__init__()
 self.datadir = datadir
 self.transform = transform
 self.img, self.label = self.load_data(datadir, is_train)
 self.len_data = len(self.img)
 
 def __getitem__(self, index):
 return self.transform(self.img[index]), self.label[index]
 
 def __len__(self):
 return self.len_data
 
 def load_data(self, datadir, is_train):
 dirname = os.path.join(datadir)
 files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
 paths = [os.path.join(dirname, f) for f in files]
 
 if is_train:
 with gzip.open(paths[0], 'rb') as lbpath:
 label = np.frombuffer(lbpath.read(), np.uint8, offset=8)
 with gzip.open(paths[1], 'rb') as imgpath:
 img = np.frombuffer(imgpath.read(), np.uint8,
 offset=16).reshape(len(label), 28, 28)
 else:
 with gzip.open(paths[2], 'rb') as lbpath:
 label = np.frombuffer(lbpath.read(), np.uint8, offset=8)
 with gzip.open(paths[3], 'rb') as imgpath:
 img = np.frombuffer(imgpath.read(), np.uint8,
 offset=16).reshape(len(label), 28, 28)
 return img, label
 
 
 |