pytorch学习

pytorch已经用了很久了,但是之前并不怎么需要自己写,方向原因只需要跑一些现成的简单代码,能看懂网络结构和各个参数的意思就行,这次轮到我自己写了,发现里面有很多细节,好好学习一下吧。

tensor什么的肯定是会的,函数也可以现查,从dataset开始吧。

数据预处理

pytorch提供了torch.utils.data.DataLoader和torch.uutils.data.data用于预处理数据。

以FashionMNIST为例

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

root用于指向数据地址,train用于标注是否用于训练(bn层有影响),download
没下载的用true下载一下,tranform设置标签转化。

CustomImageDataset可以帮助我们使用自定义数据集,需要定义这三个函数__init__, len, and getitem.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

annotations_file长这样

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

__len__返回图片数量

__getitem__属于组合图片和标签并转化为张量。

训练前处理数据,通常是分批次(batchsize),打乱顺序重排(shuffle)等操作。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

对于图片和标签本身的性质,torchvision提供了transform和target-transform用于处理归一化等操作

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
dark
sans