pytorch已经用了很久了,但是之前并不怎么需要自己写,方向原因只需要跑一些现成的简单代码,能看懂网络结构和各个参数的意思就行,这次轮到我自己写了,发现里面有很多细节,好好学习一下吧。
tensor什么的肯定是会的,函数也可以现查,从dataset开始吧。
数据预处理
pytorch提供了torch.utils.data.DataLoader和torch.uutils.data.data用于预处理数据。
以FashionMNIST为例
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
root用于指向数据地址,train用于标注是否用于训练(bn层有影响),download
没下载的用true下载一下,tranform设置标签转化。
CustomImageDataset可以帮助我们使用自定义数据集,需要定义这三个函数__init__, len, and getitem.
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
annotations_file长这样
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
__len__返回图片数量
__getitem__属于组合图片和标签并转化为张量。
训练前处理数据,通常是分批次(batchsize),打乱顺序重排(shuffle)等操作。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
对于图片和标签本身的性质,torchvision提供了transform和target-transform用于处理归一化等操作
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)