歡迎光臨
每天分享高質量文章

詳解深度學習中的Normalization,不只是BN(1)

 深度神經網路模型訓練之難眾所周知,其中一個重要的現象就是 Internal Covariate Shift. Batch Normalization 大法自 2015 年由Google 提出之後,就成為深度學習必備之神器。自 BN 之後, Layer Norm / Weight Norm / Cosine Norm 等也橫空出世。本文從 Normalization 的背景講起,用一個公式概括 Normalization 的基本思想與通用框架,將各大主流方法一一對號入座進行深入的對比分析,並從引數和資料的伸縮不變性的角度探討 Normalization 有效的深層原因。本文是該系列的第一篇。


01

為什麼需要 Normalization

1.1  獨立同分佈與白化

機器學習界的煉丹師們最喜歡的資料有什麼特點?竊以為,莫過於“獨立同分佈”了,即 independent and identically distributed,簡稱為 i.i.d. 獨立同分佈並非所有機器學習模型的必然要求(比如 Naive Bayes 模型就建立在特徵彼此獨立的基礎之上,而Logistic Regression 和 神經網路 則在非獨立的特徵資料上依然可以訓練出很好的模型),但獨立同分佈的資料可以簡化常規機器學習模型的訓練、提升機器學習模型的預測能力,已經是一個共識。

因此,在把資料喂給機器學習模型之前,“白化(whitening)”是一個重要的資料預處理步驟。白化一般包含兩個目的:

(1)去除特徵之間的相關性 —> 獨立;

(2)使得所有特徵具有相同的均值和方差 —> 同分佈。

白化最典型的方法就是PCA,本文不再展開。

1.2 深度學習中的 Internal Covariate Shift

深度神經網路模型的訓練為什麼會很困難?其中一個重要的原因是,深度神經網路涉及到很多層的疊加,而每一層的引數更新會導致上層的輸入資料分佈發生變化,透過層層疊加,高層的輸入分佈變化會非常劇烈,這就使得高層需要不斷去重新適應底層的引數更新。為了訓好模型,我們需要非常謹慎地去設定學習率、初始化權重、以及盡可能細緻的引數更新策略。

Google 將這一現象總結為 Internal Covariate Shift,簡稱 ICS. 什麼是 ICS 呢?@魏秀參 在一個回答中做出了一個很好的解釋:

大家都知道在統計機器學習中的一個經典假設是“源空間(source domain)和標的空間(target domain)的資料分佈(distribution)是一致的”。如果不一致,那麼就出現了新的機器學習問題,如 transfer learning / domain adaptation 等。而 covariate shift 就是分佈不一致假設之下的一個分支問題,它是指源空間和標的空間的條件機率是一致的,但是其邊緣機率不同,即:對所有有:


但是


大家細想便會發現,的確,對於神經網路的各層輸出,由於它們經過了層內操作作用,其分佈顯然與各層對應的輸入訊號分佈不同,而且差異會隨著網路深度增大而增大,可是它們所能“指示”的樣本標記(label)仍然是不變的,這便符合了 covariate shift 的定義。由於是對層間訊號的分析,也即是 “internal”的來由。

1.3 ICS 會導致什麼問題?

簡而言之,每個神經元的輸入資料不再是“獨立同分佈”。

其一,上層引數需要不斷適應新的輸入資料分佈,降低學習速度。

其二,下層輸入的變化可能趨向於變大或者變小,導致上層落入飽和區,使得學習過早停止。

其三,每層的更新都會影響到其它層,因此每層的引數更新策略需要盡可能的謹慎。

02

Normalization 的基本思想與框架

我們以神經網路中的一個普通神經元為例。神經元接收一組輸入向量


透過某種運算後,輸出一個標量值:


由於 ICS 問題的存在, x 的分佈可能相差很大。要解決獨立同分佈的問題,“理論正確”的方法就是對每一層的資料都進行白化操作。然而標準的白化操作代價高昂,特別是我們還希望白化操作是可微的,保證白化操作可以透過反向傳播來更新梯度。

因此,以 BN 為代表的 Normalization 方法退而求其次,進行了簡化的白化操作。基本思想是:在將 x 送給神經元之前,先對其做平移和伸縮變換, 將 x 的分佈規範化成在固定區間範圍的標準分佈。

通用變換框架就如下所示:


我們來看看這個公式中的各個引數。

(1) μ 是平移引數(shift parameter), σ縮放引數(scale parameter)。透過這兩個引數進行 shift 和 scale 變換: 

得到的資料符合均值為 0、方差為 1 的標準分佈。

(2)b再平移引數(re-shift parameter),b再縮放引數(re-scale parameter)。將 上一步得到的 \hat{x} 進一步變換為: 

最終得到的資料符合均值為 b 、方差為 g^2 的分佈。

奇不奇怪?奇不奇怪?

說好的處理 ICS,第一步都已經得到了標準分佈,第二步怎麼又給變走了?

答案是——為了保證模型的表達能力不因為規範化而下降

我們可以看到,第一步的變換將輸入資料限制到了一個全域性統一的確定範圍(均值為 0、方差為 1)。下層神經元可能很努力地在學習,但不論其如何變化,其輸出的結果在交給上層神經元進行處理之前,將被粗暴地重新調整到這一固定範圍。

沮不沮喪?沮不沮喪?

難道我們底層神經元人民就在做無用功嗎?

所以,為了尊重底層神經網路的學習結果,我們將規範化後的資料進行再平移和再縮放,使得每個神經元對應的輸入範圍是針對該神經元量身定製的一個確定範圍(均值為 b 、方差為 g^2 )。rescale 和 reshift 的引數都是可學習的,這就使得 Normalization 層可以學習如何去尊重底層的學習結果。

除了充分利用底層學習的能力,另一方面的重要意義在於保證獲得非線性的表達能力。Sigmoid 等啟用函式在神經網路中有著重要作用,透過區分飽和區和非飽和區,使得神經網路的資料變換具有了非線性計算能力。而第一步的規範化會將幾乎所有資料對映到啟用函式的非飽和區(線性區),僅利用到了線性變化能力,從而降低了神經網路的表達能力。而進行再變換,則可以將資料從線性區變換到非線性區,恢復模型的表達能力。

那麼問題又來了——

經過這麼的變回來再變過去,會不會跟沒變一樣?

不會。因為,再變換引入的兩個新引數 g 和 b,可以表示舊引數作為輸入的同一族函式,但是新引數有不同的學習動態。在舊引數中, x 的均值取決於下層神經網路的複雜關聯;但在新引數中, 僅由 來確定,去除了與下層計算的密切耦合。新引數很容易透過梯度下降來學習,簡化了神經網路的訓練。

那麼還有一個問題(問題怎麼這麼多!)——

這樣的 Normalization 離標準的白化還有多遠?

標準白化操作的目的是“獨立同分佈”。獨立就不說了,暫不考慮。變換為均值為 b 、方差為 g^2 的分佈,也並不是嚴格的同分佈,只是對映到了一個確定的區間範圍而已。(所以,這個坑還有得研究呢!)



篇幅所限,這篇推送中我們就先談到這裡。

在下一篇中,我將推送本文的第二部分,歡迎繼續關註。先預告一下:

03. 主流 Normalization 方法梳理

——結合本文所述框架,將 BatchNorm / LayerNorm / WeightNorm / CosineNorm 對號入座,梳理各種方法之間的關係與差別。

04. Normalization 為什麼會有效?

——從引數和資料的伸縮不變性探討Normalization有效的深層原因。


@Julius

PhD 畢業於 THU 計算機系。

現在 Tencent AI 從事機器學習和個性化推薦研究與 AI 平臺開發工作。


關於PaperWeekly


PaperWeekly 是一個推薦、解讀、討論、報道人工智慧前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號後臺點選「交流群」,小助手將把你帶入 PaperWeekly 的交流群裡。


▽ 點選 | 閱讀原文 | 加入社群

贊(0)

分享創造快樂