雀恰营销
专注中国网络营销推广

Pytorch中的数据加载(3)

Pytorch中的数据加载(3)

目标知道数据加载的目的知道如何使用pytorch中的Dataset知道如何使用pytorch中的DataLoader知道如何获取pytorch中内置的数据集1.深入使用数据加载器的目的学习

为了对数据进行批量操作Pytorch中的数据加载(3),将整个数据集分成若干个批次;

1. 批处理是指将多个数据点的张量组合成一张单张Pytorch中的数据加载(3),整体放入一个网络中。以图像处理为例,网络的输入形状为(b,c,w,h),合成张量为batch

Pytorch中的数据加载(3)

2. 为什么我们需要批处理?需要加快计算速度;

2.数据基类2.1 数据集基类

torch.util.data.Dataset:我们定义的所有数据集所有类都应该继承这个基类数据加载,因为在这个基类中自动实现了三个重要的方法:

Pytorch中的数据加载(3)

def __init()__; #初始化数据集对象时,会调用函数__init__();一般会设置train、val,test等数据集的路径
def __len()__; #返回数据集的长度
def __getitem()__; # data[i],对数据集类进行遍历操作时,会自动调用该函数,该函数会提供一个数据点(batch),
# 通常在getitem()里,实现数据的加载,转换,生成groundTruth

我们看一下torch.util.data.Dataset的源码:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

2.数据加载2例

以髋关节数据集,关键点检测Task为例

数据:jpg格式数据加载

Pytorch中的数据加载(3)

label:csv文件格式,如下图:

class JointsDataset(Dataset):
    def __init__(self, cfg, root, state, is_train):
        # 是否是训练状态
        self.is_train = is_train
        # 图片的根路径
        self.root = root
        # 状态(train,test1,val)
        self.state = state
        # 对应图片的索引编号
        self.indexes = self._get_index()
        # 图片类型
        self.data_format = cfg.DATASET.DATA_FORMAT
        # 生成热力图的类型,默认是高斯
        self.target_type = cfg.MODEL.TARGET_TYPE
        # 训练输入图片大小
        self.image_size = np.array(cfg.MODEL.IMAGE_SIZE)
        # 输出热力图的大小
        self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE)
        # 热力图的sigma
        self.sigma = cfg.MODEL.SIGMA
        # 是否使用数据增强
        self.aug = cfg.DATASET.AUG
        # 标签文件
        self.annotation = self._get_annotation()
        flip = slc.SelectiveStream([        #这个没有用到,数据增强
            slt.RandomFlip(p=0.5, axis=1)
            # slt.Flip(p=0.5, axis=1)
        ])
        if  (self.state == 'train' or self.state== "val" or self.state== "test1") and self.target_type == "segment":
            self.my_transform = transform.Compose([
                # slt.ImageColorTransform(mode='rgb2gs'),
                # w, h,由于opencv中h, w, c = img.shape
                slt.ResizeTransform((self.image_size[1], self.image_size[0])),
                # partial(get_fpn_heatmap, levels=cfg.MODEL.EXTRA.LEVEL, convert=True, sigma=cfg.MODEL.SIGMA)
                partial(solt2seghm, downsample=cfg.DATASET.DOWNSAMPLE, sigma=self.sigma, scale_ld=cfg.DATASET.LANDMARK_NORMAL)
            ])
            
        if self.state == 'train' and self.target_type != "segment":
            self.my_transform = transform.Compose([
                # slt.ImageColorTransform(mode='rgb2gs'),
                # w, h,由于opencv中h, w, c = img.shape
                slt.ResizeTransform((self.image_size[1], self.image_size[0])),
                # partial(get_fpn_heatmap, levels=cfg.MODEL.EXTRA.LEVEL, convert=True, sigma=cfg.MODEL.SIGMA)
                partial(solt2torchhm, downsample=cfg.DATASET.DOWNSAMPLE, sigma=self.sigma, scale_ld=cfg.DATASET.LANDMARK_NORMAL)
            ])
        if (self.state == 'test1' or self.state == 'val')and self.target_type != "segment":
            self.my_transform = transform.Compose([
                # slt.ImageColorTransform(mode='rgb2gs'),
                # w, h
                slt.ResizeTransform((self.image_size[1], self.image_size[0])),
                # partial(get_fpn_heatmap, levels=cfg.MODEL.EXTRA.LEVEL, convert=True, sigma=cfg.MODEL.SIGMA)
                partial(solt2torchhm, downsample=cfg.DATASET.DOWNSAMPLE, sigma=self.sigma,
                        scale_ld=cfg.DATASET.LANDMARK_NORMAL)
            ])
    # 加载标签
    def _get_annotation(self):
        # annotation = pd.read_csv(os.path.join(self.root, 'landmark123.csv'))
       # annotation={}
        with open(os.path.join(self.root,'landmark_dict') ,'rb')  as f:
            annotation = pickle.load(f)
        return annotation
    # 获取文件的索引
    def _get_index(self):
        img_indexes = []
        index_file = os.path.join(self.root, str(self.state) + '.txt')
        file = open(index_file, 'r')
        for line in file.readlines():
            img_indexes.append(line.strip('n'))
        return img_indexes
    def __len__(self):
        return len(self.indexes)
    def __getitem__(self, item):   
        img_index = self.indexes[item] #indexes中保存的是图片的索引
        kpts = self.annotation[img_index]
        path = os.path.join(self.root, img_index + self.data_format)
        img_index = int(img_index)
        ori_img = cv2.imread(path)
        h, w, c = ori_img.shape
        # 将训练图片翻转为同一个方向
        type1='R'
        if type1 == "L":
            ori_img = ori_img[:, ::-1, :]
            kpts[:, 0] = w - kpts[:, 0]
        kpts_wrapped = sld.KeyPoints(kpts, ori_img.shape[0], ori_img.shape[1])
     
        dc = sld.DataContainer((ori_img, kpts_wrapped), 'IP')
        res = self.my_transform(dc)
        seg_hm = []
        if self.target_type == "segment":
            img, target_hm,seg_hm,target_kp = res
        else:
            img, target_hm,target_kp = res
        img = img - img.mean()
        img /= img.std()
        img /= img.max()
        if self.target_type == 'landmark':
            meta = {
                'img_index': img_index,
                'target_kp': target_kp,
                'ori_kp': kpts,
                'ori_h': h,
                'ori_w': w,
                'type': type1
            }
            return img, target_kp, meta
        elif self.target_type == "segment":
            meta= {
                'img_index': img_index,
                'target_kp': target_kp,
                'ori_kp': kpts,
                'ori_h': h,
                'seg_hm':seg_hm,
                'ori_w': w,
                'type': type1
            }
            return img, target_hm,meta
        else:
            meta = {
                'img_index': img_index,
                'target_kp': target_kp,
                'ori_kp': kpts,
                'ori_h': h,
                'ori_w': w,
                'type': type1
            }
            return img, target_hm, meta

3. 迭代数据集

为了完成数据的读取,我们需要用到DataLoader;火炬在 pytorch 中提供。 util.data.DataLoader等方法

DataLoader 使用示例:

#cfg为参数文件 
valid_dataset = JointsDataset(cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False)
valid_loader = torch.utils.data.DataLoader( valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True)
model.train()
for i, (input, target, meta) in enumerate(train_loader):
    output = model(input.cuda())
    target = target.cuda(non_blocking=True)
    loss = criterion(output, target)

赞(0) 打赏
未经允许不得转载:雀恰营销 » Pytorch中的数据加载(3)
分享到: 更多 (0)

评论 抢沙发

评论前必须登录!

 

文章对你有帮助就赞助我一下吧

支付宝扫一扫打赏

微信扫一扫打赏