Pytorch中的数据加载(3)
目标知道数据加载的目的知道如何使用pytorch中的Dataset知道如何使用pytorch中的DataLoader知道如何获取pytorch中内置的数据集1.深入使用数据加载器的目的学习
为了对数据进行批量操作Pytorch中的数据加载(3),将整个数据集分成若干个批次;
1. 批处理是指将多个数据点的张量组合成一张单张Pytorch中的数据加载(3),整体放入一个网络中。以图像处理为例,网络的输入形状为(b,c,w,h),合成张量为batch
2. 为什么我们需要批处理?需要加快计算速度;
2.数据基类2.1 数据集基类
torch.util.data.Dataset:我们定义的所有数据集所有类都应该继承这个基类数据加载,因为在这个基类中自动实现了三个重要的方法:
def __init()__; #初始化数据集对象时,会调用函数__init__();一般会设置train、val,test等数据集的路径
def __len()__; #返回数据集的长度
def __getitem()__; # data[i],对数据集类进行遍历操作时,会自动调用该函数,该函数会提供一个数据点(batch),
# 通常在getitem()里,实现数据的加载,转换,生成groundTruth
我们看一下torch.util.data.Dataset的源码:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
2.数据加载2例
以髋关节数据集,关键点检测Task为例
数据:jpg格式数据加载,
label:csv文件格式,如下图:
class JointsDataset(Dataset):
def __init__(self, cfg, root, state, is_train):
# 是否是训练状态
self.is_train = is_train
# 图片的根路径
self.root = root
# 状态(train,test1,val)
self.state = state
# 对应图片的索引编号
self.indexes = self._get_index()
# 图片类型
self.data_format = cfg.DATASET.DATA_FORMAT
# 生成热力图的类型,默认是高斯
self.target_type = cfg.MODEL.TARGET_TYPE
# 训练输入图片大小
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE)
# 输出热力图的大小
self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE)
# 热力图的sigma
self.sigma = cfg.MODEL.SIGMA
# 是否使用数据增强
self.aug = cfg.DATASET.AUG
# 标签文件
self.annotation = self._get_annotation()
flip = slc.SelectiveStream([ #这个没有用到,数据增强
slt.RandomFlip(p=0.5, axis=1)
# slt.Flip(p=0.5, axis=1)
])
if (self.state == 'train' or self.state== "val" or self.state== "test1") and self.target_type == "segment":
self.my_transform = transform.Compose([
# slt.ImageColorTransform(mode='rgb2gs'),
# w, h,由于opencv中h, w, c = img.shape
slt.ResizeTransform((self.image_size[1], self.image_size[0])),
# partial(get_fpn_heatmap, levels=cfg.MODEL.EXTRA.LEVEL, convert=True, sigma=cfg.MODEL.SIGMA)
partial(solt2seghm, downsample=cfg.DATASET.DOWNSAMPLE, sigma=self.sigma, scale_ld=cfg.DATASET.LANDMARK_NORMAL)
])
if self.state == 'train' and self.target_type != "segment":
self.my_transform = transform.Compose([
# slt.ImageColorTransform(mode='rgb2gs'),
# w, h,由于opencv中h, w, c = img.shape
slt.ResizeTransform((self.image_size[1], self.image_size[0])),
# partial(get_fpn_heatmap, levels=cfg.MODEL.EXTRA.LEVEL, convert=True, sigma=cfg.MODEL.SIGMA)
partial(solt2torchhm, downsample=cfg.DATASET.DOWNSAMPLE, sigma=self.sigma, scale_ld=cfg.DATASET.LANDMARK_NORMAL)
])
if (self.state == 'test1' or self.state == 'val')and self.target_type != "segment":
self.my_transform = transform.Compose([
# slt.ImageColorTransform(mode='rgb2gs'),
# w, h
slt.ResizeTransform((self.image_size[1], self.image_size[0])),
# partial(get_fpn_heatmap, levels=cfg.MODEL.EXTRA.LEVEL, convert=True, sigma=cfg.MODEL.SIGMA)
partial(solt2torchhm, downsample=cfg.DATASET.DOWNSAMPLE, sigma=self.sigma,
scale_ld=cfg.DATASET.LANDMARK_NORMAL)
])
# 加载标签
def _get_annotation(self):
# annotation = pd.read_csv(os.path.join(self.root, 'landmark123.csv'))
# annotation={}
with open(os.path.join(self.root,'landmark_dict') ,'rb') as f:
annotation = pickle.load(f)
return annotation
# 获取文件的索引
def _get_index(self):
img_indexes = []
index_file = os.path.join(self.root, str(self.state) + '.txt')
file = open(index_file, 'r')
for line in file.readlines():
img_indexes.append(line.strip('n'))
return img_indexes
def __len__(self):
return len(self.indexes)
def __getitem__(self, item):
img_index = self.indexes[item] #indexes中保存的是图片的索引
kpts = self.annotation[img_index]
path = os.path.join(self.root, img_index + self.data_format)
img_index = int(img_index)
ori_img = cv2.imread(path)
h, w, c = ori_img.shape
# 将训练图片翻转为同一个方向
type1='R'
if type1 == "L":
ori_img = ori_img[:, ::-1, :]
kpts[:, 0] = w - kpts[:, 0]
kpts_wrapped = sld.KeyPoints(kpts, ori_img.shape[0], ori_img.shape[1])
dc = sld.DataContainer((ori_img, kpts_wrapped), 'IP')
res = self.my_transform(dc)
seg_hm = []
if self.target_type == "segment":
img, target_hm,seg_hm,target_kp = res
else:
img, target_hm,target_kp = res
img = img - img.mean()
img /= img.std()
img /= img.max()
if self.target_type == 'landmark':
meta = {
'img_index': img_index,
'target_kp': target_kp,
'ori_kp': kpts,
'ori_h': h,
'ori_w': w,
'type': type1
}
return img, target_kp, meta
elif self.target_type == "segment":
meta= {
'img_index': img_index,
'target_kp': target_kp,
'ori_kp': kpts,
'ori_h': h,
'seg_hm':seg_hm,
'ori_w': w,
'type': type1
}
return img, target_hm,meta
else:
meta = {
'img_index': img_index,
'target_kp': target_kp,
'ori_kp': kpts,
'ori_h': h,
'ori_w': w,
'type': type1
}
return img, target_hm, meta
3. 迭代数据集
为了完成数据的读取,我们需要用到DataLoader;火炬在 pytorch 中提供。 util.data.DataLoader等方法
DataLoader 使用示例:
#cfg为参数文件
valid_dataset = JointsDataset(cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False)
valid_loader = torch.utils.data.DataLoader( valid_dataset,
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
shuffle=False,
num_workers=cfg.WORKERS,
pin_memory=True)
model.train()
for i, (input, target, meta) in enumerate(train_loader):
output = model(input.cuda())
target = target.cuda(non_blocking=True)
loss = criterion(output, target)
评论前必须登录!
注册