PyTorch基礎語法:數(shù)據(jù)預處理transforms模塊機制
來源:投稿 作者:阿克西
編輯:學姐
建議搭配視頻食用↓
視頻鏈接:https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6
1.transforms運行機制
torchvision是pytorch的計算機視覺工具包,主要有以下三個模塊:
torchvision.transforms
:提供了常用的一系列圖像預處理方法,例如數(shù)據(jù)的標準化,中心化,旋轉(zhuǎn),翻轉(zhuǎn)等。torchvision.datasets
:定義了一系列常用的公開數(shù)據(jù)集的datasets,比如MNIST,CIFAR-10,ImageNet等。torchvision.model
:提供了常用的預訓練模型,例如AlexNet,VGG,ResNet,GoogLeNet等。
torchvision.transforms:常用的圖像預處理方法
數(shù)據(jù)中心化,數(shù)據(jù)標準化
縮放,裁剪,旋轉(zhuǎn),翻轉(zhuǎn),填充
噪聲添加,灰度變換,線性變換,仿射變換
亮度、飽和度及對比度變換
深度學習是由數(shù)據(jù)驅(qū)動的,數(shù)據(jù)的數(shù)量以及分布對模型的優(yōu)劣起到?jīng)Q定性作用,所以需要對數(shù)據(jù)進行一定的預處理以及數(shù)據(jù)增強,用來提升模型的泛化能力。

上圖是1張原始圖片經(jīng)過數(shù)據(jù)增強之后生成的一系列數(shù)據(jù),一共有64張圖片。對圖片進行數(shù)據(jù)增強可以豐富訓練數(shù)據(jù),提高模型的泛化能力。因為如果數(shù)據(jù)增強生成了與測試樣本很相似的圖片,那么模型的泛化能力自然可以得到提高。
使用上一節(jié)中介紹的人民幣二分類實驗的代碼的數(shù)據(jù)預處理部分:

2.斷點調(diào)試
同樣,在模型訓練樣本讀取位置設置斷點,進行debug:

點擊step into按鍵,在跳轉(zhuǎn)后的代碼中進行一個是否采用多進程的判斷:

點擊step over,選擇單進程的運行機制,再點擊step into按鍵,進入dataloader.py界面:

光標設置在index = self._next_index() # may raise StopIteration這一行,點擊Run to Cursor,程序就會運行到光標所在的行。這一步的作用是獲取Index,也就是要讀取哪些數(shù)據(jù)。點擊step over,得到Index就可以進入dataset_fetcher.fetch(index),根據(jù)索引去獲取數(shù)據(jù)。點擊step into進入到fetch函數(shù):

在fetch函數(shù)中,代碼data = [self.dataset[idx] for idx in possibly_batched_index]使用了列表生成式,調(diào)用了dataset,接著點擊step over與step into進入dataset所在的代碼位置,dataset代碼位于類RMBDataset(Dataset)中的__getitem__()函數(shù):

在getitem()中根據(jù)索引去獲取圖片的路徑以及標簽,然后采用代碼img = Image.open(path_img).convert('RGB') # 0~255打開圖片,讀取進來的圖片是一個PIL的數(shù)據(jù)類型,然后在getitem中調(diào)用transform()進行圖像預處理操作,在代碼處img = self.transform(img)通過step into進入transforms.py中的def 「call」()函數(shù)

「call」()函數(shù)是一個for循環(huán),也就是依次有序地從compose中去調(diào)用預處理方法,第一個預處理方法是t(img),其功能是是Resize縮放;第二個功能是裁剪,第三個功能是進行張量操作,第四個功能是進行歸一化;對compose的四個功能循環(huán)結(jié)束之后,就會返回代碼處img = self.transform(img)。
transform是在__getitem__()中調(diào)用,并且在__getitem__()中實現(xiàn)數(shù)據(jù)預處理,然后通過__getitem__返回一個樣本。
執(zhí)行step out操作返回fetch()函數(shù),接著就是不斷地循環(huán)index獲取一個batch_size大小的數(shù)據(jù),最后在return的時候調(diào)用collate_fn()函數(shù),將數(shù)據(jù)整理成一個batch_data的形式。

然后執(zhí)行step out操作返回到dataloader.py中的__next__()函數(shù)中,然后再執(zhí)行執(zhí)行step out操作回到訓練代碼中,接著數(shù)據(jù)就讀取進來了。這就是pytorch數(shù)據(jù)讀取和transforms的運行機制。

回顧上面的數(shù)據(jù)讀取流程圖,transforms是在getitem中使用的,在getitem中讀取一張圖片,然后對這一張圖片進行一系列預處理,返回圖片以及標簽。
了解了transforms的機制,現(xiàn)在學習一個比較常用的預處理方法,數(shù)據(jù)的標準化transforms.Normalize。
3.數(shù)據(jù)標準化transforms.normalize
3.1 定義
功能:逐channel的對圖像進行標準化,即數(shù)據(jù)的均值變?yōu)?,標準差變?yōu)?。
計算公式:
mean:各通道的均值
std:各通道的標準差
inplace:是否原位操作
3.2 斷點調(diào)試
回到代碼中看一下normalize的具體實現(xiàn)方法,transform是在dataset的getitem中實現(xiàn)的,所以可以直接去dataset的getitem函數(shù)中設置斷點:

進行debug操作,點擊step into進入詳細代碼環(huán)境,進入了transforms.py中的call()函數(shù)中,在call函數(shù)中循環(huán)transforms。

點擊step over執(zhí)行多次,到normalize實現(xiàn)

接著點擊step into查看normalize的實現(xiàn),來到了normalize()類中的__call__()函數(shù)中,代碼只有一行,實際上這行代碼是調(diào)用了pytorch中的function中normalize方法。pytorch的function提供了很多常用的函數(shù)。

接著使用step into查看normalize中的具體實現(xiàn)。
首先是輸入的合法性判斷,輸入的是tensor,也就是原始的圖像,接著判斷是否要原地操作,如果不是inplace就需要將張量復制一份到新的內(nèi)存空間中。下面的代碼就是獲取數(shù)據(jù)的均值和標準差,并將數(shù)據(jù)轉(zhuǎn)換為張量。注意在sub_和div_后面有下劃線,意思是進行原位操作,這樣就完成了數(shù)據(jù)標準化的操作。
3.3 標準化作用
對數(shù)據(jù)進行標準化之后可以加快模型的收斂。
之前的邏輯回歸代碼bias=1,發(fā)現(xiàn)迭代次數(shù)360次即可得到99%的準確率,損失loss=0.05。

當修改bias=5時,發(fā)現(xiàn)需要迭代960次模型才能收斂,loss=0.14,得到99%的準確率。

原因:模型初始化一般有0均值,需要逐漸靠近最優(yōu)分類平面。
bias=5的初始化距離分類平面較遠


可以看出,如果訓練數(shù)據(jù)有良好的分布或者權(quán)重有良好的初始化,可以加速模型的訓練。
關注“學姐帶你玩AI”公眾號
回復“500”領取AI必讀論文合集(含視頻)