| import os |
|
|
| from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase |
| from dassl.utils import listdir_nohidden |
|
|
| from .imagenet import ImageNet |
|
|
|
|
| @DATASET_REGISTRY.register() |
| class ImageNetV2(DatasetBase): |
| """ImageNetV2. |
| |
| This dataset is used for testing only. |
| """ |
|
|
| dataset_dir = "imagenetv2" |
|
|
| def __init__(self, cfg): |
| root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) |
| self.dataset_dir = os.path.join(root, self.dataset_dir) |
| image_dir = "imagenetv2-matched-frequency-format-val" |
| self.image_dir = os.path.join(self.dataset_dir, image_dir) |
|
|
| text_file = os.path.join(self.dataset_dir, "classnames.txt") |
| classnames = ImageNet.read_classnames(text_file) |
|
|
| data = self.read_data(classnames) |
| |
| _,self.all_classnames = self.get_lab2cname(data) |
| super().__init__(train_x=data, test=data) |
|
|
| def read_data(self, classnames): |
| image_dir = self.image_dir |
| folders = list(classnames.keys()) |
| items = [] |
|
|
| for label in range(1000): |
| class_dir = os.path.join(image_dir, str(label)) |
| imnames = listdir_nohidden(class_dir) |
| folder = folders[label] |
| classname = classnames[folder] |
| for imname in imnames: |
| impath = os.path.join(class_dir, imname) |
| item = Datum(impath=impath, label=label, classname=classname) |
| items.append(item) |
|
|
| return items |
|
|