解决AttributeError Can't get attribute on module

之前本地跑大作业时遇上的神秘问题,有点抽象,遂记录一下。

问题

本地跑大作业的模型训练时一直卡着不动,看jupyter命令行发现报错:AttributeError: Can't get attribute 'FashionDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>
在网上搜索一番后在这篇文章发现了问题原因:

实际上这个地方:在linux中使用的是fork的方式,可以复制公共变量和方法到新开启的进程中,但是windows中是重建一个新的进程(全新),在进程中并不存在公共的变量和方法,所以会出现这个问题,未能导入模块中定义的对象

解决

把出问题的代码(我这里是FashionDataset类)打包成一个单独的.py文件再import进来就行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# dataset.py
import os
import gzip
import numpy as np
import torch
from torch.utils.data import Dataset

class FashionDataset(Dataset):
def __init__(self, datadir, transform, is_train=True):
super().__init__()
self.datadir = datadir
self.transform = transform
self.img, self.label = self.load_data(datadir, is_train)
self.len_data = len(self.img)

def __getitem__(self, index):
return self.transform(self.img[index]), self.label[index]

def __len__(self):
return self.len_data

def load_data(self, datadir, is_train):
dirname = os.path.join(datadir)
files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
paths = [os.path.join(dirname, f) for f in files]

if is_train:
with gzip.open(paths[0], 'rb') as lbpath:
label = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
img = np.frombuffer(imgpath.read(), np.uint8,
offset=16).reshape(len(label), 28, 28)
else:
with gzip.open(paths[2], 'rb') as lbpath:
label = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
img = np.frombuffer(imgpath.read(), np.uint8,
offset=16).reshape(len(label), 28, 28)
return img, label

然后在主文件中from dataset import FashionDataset就行