PyTorch 导入数据集
对于任何机器学习模型的训练过程,导入数据都是最基础的一步。在 PyTorch 中,可以使用一些 Python 的标准库将数据导入为 numpy array,然后再转换为 torch.*Tensor
。
- For images, packages such as Pillow, OpenCV are useful
- For audio, packages such as scipy and librosa
- For text, either raw Python or Cython based loading, or NLTK and SpaCy are useful
PyTorch 还专门为视觉定制了一个包torchvision
:
torchvision.datasets
可以直接导入一些常用的图形数据库,如 Imagenet, CIFAR10, MNIST 等。torchvision.transforms
还包含了很多图形数据转换器(data transformers for images)。
另外,torchaudio 提供了一个简单的音频数据的 I/O。但它只提供了专用于两个数据集 VCTK, YesNo 的数据导入操作。如果想使用其他数据集,还是要定制自己的torch.utils.data.Dataset
,然后结合torch.utils.data.DataLoader
来使用。举例:
1 | #! usr/bin/env python |
Tips:
- PyTorch 所有的数据集对象都是
torch.utils.data.Dataset
的子类。在继承它的时候必须要重写其__len__
和__getitem__
方法; - 为了方便数据的存储和读入,可以将数据存为
.pt
文件(PyTorch 的标准数据文件); - 对于多通道输入对象(如,3通道的RGB图片、2通道的MP3音乐),应该存储为 nAttributes * nChannels,即
torchaudio.load
导入的信号需要转置再存储; - 标签要转换为标签组的下标后再写入文件,以方便训练和计算损失。
torchaudio 报错:in audio_open(): NoBackendError。–> 缺少解码器:apt install libav-tools
参考资料
[1] 这里把音频转换为了频谱图片然后再导入计算,有点意思。