快速入门 - PyTorch 教程 2.5.0+cu124 文档 - PyTorch 机器学习库
1. 一、基础 1.1. 使用数据1 PyTorch 有两个 用于处理数据的原语 :torch.utils.data.DataLoader 和 torch.utils.data.Dataset。Dataset 存储样本及其对应的标签,而 DataLoader 在 Dataset 周围包装一个可迭代对象。import torchfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST( root="data" , train=True , download=True , transform=ToTensor(), ) test_data = datasets.FashionMNIST( root="data" , train=False , download=True , transform=ToTensor(), )
batch_size = 64 train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader: print (f"Shape of X [N, C, H, W]: {X.shape} " ) print (f"Shape of y: {y.shape} {y.dtype} " ) break
device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" )print (f"Using {device} device" )class NeuralNetwork (nn.Module): def __init__ (self ): super ().__init__() self .flatten = nn.Flatten() self .linear_relu_stack = nn.Sequential( nn.Linear(28 *28 , 512 ), nn.ReLU(), nn.Linear(512 , 512 ), nn.ReLU(), nn.Linear(512 , 10 ) ) def forward (self, x ): x = self .flatten(x) logits = self .linear_relu_stack(x) return logits model = NeuralNetwork().to(device)print (model)