自編碼器應(yīng)用
自編碼器(autoencoder, AE)已在機(jī)器學(xué)習(xí)和人工神經(jīng)網(wǎng)絡(luò)引起了許多人的關(guān)注,事實(shí)上,自編碼器已在醫(yī)療醫(yī)藥、圖像去噪、神經(jīng)機(jī)器翻譯等領(lǐng)域產(chǎn)生了很可觀的成績(jī)。
自編碼器的組成
與大多數(shù)神經(jīng)網(wǎng)絡(luò)一樣,自編碼器通過(guò)向后傳播梯度以?xún)?yōu)化一組數(shù)據(jù)的權(quán)重進(jìn)行學(xué)習(xí),但自編碼器的體系結(jié)構(gòu)與大多數(shù)神經(jīng)網(wǎng)絡(luò)間最顯著差異在于瓶頸層[編者注:bottleneck,簡(jiǎn)單翻譯就是瓶頸層,一般在深度較高的網(wǎng)絡(luò)(如resnet101)中使用]。自編碼器的瓶頸層是將數(shù)據(jù)壓縮為小尺寸表示的一種方法。
而自編碼器的另外兩個(gè)重要部分是編碼器和解碼器。這三個(gè)組件融合在一起,便可形成“原始”自編碼器,當(dāng)然更復(fù)雜的自編碼器可能還具有一些其他組件。
接下來(lái),讓我們單獨(dú)看一下這些組件。
1、編碼器
是將數(shù)據(jù)壓縮和重建的第一階段,編碼器負(fù)責(zé)了數(shù)據(jù)壓縮階段。編碼器是一個(gè)前饋神經(jīng)網(wǎng)絡(luò),它將接收數(shù)據(jù)特征(例如在圖像壓縮的情況下為像素),并輸出小于數(shù)據(jù)特征大小的潛向量。
為了使數(shù)據(jù)重建更強(qiáng)有力,編碼器在訓(xùn)練過(guò)程中優(yōu)化了權(quán)重,以保證將輸入數(shù)據(jù)最重要的特征壓縮到小尺寸的潛在向量中。這確保了解碼器具足夠的輸入數(shù)據(jù)的相關(guān)信息,可以在最小的損失下重建數(shù)據(jù)。
2、潛在向量(瓶頸層)
自編碼器的潛在向量組件(瓶頸層)是最關(guān)鍵的部分,當(dāng)需要選擇潛在向量的大小時(shí),它就會(huì)變得更加關(guān)鍵。
編碼器的輸出為我們提供了潛在向量,而這個(gè)向量則是輸入數(shù)據(jù)最重要的特征,最終將被解碼器解碼,并將有效信息傳播到解碼器以進(jìn)行重建。
選擇較小的潛在向量意味著可以用較少的數(shù)據(jù)數(shù)據(jù)信息來(lái)表示數(shù)據(jù)特征,選擇較大的潛在向量則與自編碼器壓縮的思路相違背,還會(huì)增加計(jì)算成本。
3、解碼器
總結(jié)一下數(shù)據(jù)壓縮和重建的過(guò)程,就如同編碼器一樣,這個(gè)組件也是前饋神經(jīng)網(wǎng)絡(luò),但在結(jié)構(gòu)上和編碼器略有不同。差異主要來(lái)自于解碼器會(huì)將比解碼器輸出小的潛在向量作為輸入。
解碼器的功能是從潛在向量生成一個(gè)非常接近輸入的輸出。
訓(xùn)練自編碼器
在訓(xùn)練自編碼器時(shí),通常會(huì)將組件組合在一起,并不是單獨(dú)構(gòu)建的,最終使用諸如梯度下降或ADAM優(yōu)化器之類(lèi)的優(yōu)化算法對(duì)它們進(jìn)行端到端訓(xùn)練。
1、損失函數(shù)
自編碼器訓(xùn)練過(guò)程中最值得討論的一部分就是損失函數(shù)。數(shù)據(jù)重構(gòu)是一項(xiàng)生成數(shù)據(jù)的任務(wù),與其他機(jī)器學(xué)習(xí)任務(wù)不同(目標(biāo)是最大程度預(yù)測(cè)正確類(lèi)別的可能性),這會(huì)驅(qū)動(dòng)網(wǎng)絡(luò)產(chǎn)生接近輸入的輸出。
可通過(guò)幾個(gè)損失函數(shù)(例如l1、l2、均方誤差等)來(lái)實(shí)現(xiàn)該目標(biāo),這些損失函數(shù)的共同之處在于可以測(cè)量輸入和輸出之間的差異,因此使用任何一個(gè)函數(shù)都是可以的。
2、自編碼器網(wǎng)絡(luò)
一直以來(lái),大都使用多層感知器來(lái)設(shè)計(jì)編碼器和解碼器,但事實(shí)證明,仍然還有更加專(zhuān)業(yè)的框架,例如卷積神經(jīng)網(wǎng)絡(luò)(CNN)來(lái)捕獲有關(guān)輸入數(shù)據(jù)的更多空間信息來(lái)進(jìn)行圖像數(shù)據(jù)壓縮。
令人驚訝的是,研究表明,用作文本數(shù)據(jù)的自編碼器的遞歸網(wǎng)絡(luò)效果非常好,但這并不在本文的范圍內(nèi),不再一一贅述。多層感知器中使用的編碼器的潛在向量解碼器的概念仍然適用于卷積自編碼器。唯一的區(qū)別就是卷積層需要設(shè)計(jì)專(zhuān)用的解碼器和編碼器。
所有這些自編碼器網(wǎng)絡(luò)都可以很好地完成壓縮任務(wù),但是存在一個(gè)問(wèn)題,就是這些網(wǎng)絡(luò)的并不具備什么創(chuàng)造力。這里所指的創(chuàng)造力是這些自編碼器只能產(chǎn)生見(jiàn)過(guò)或訓(xùn)練過(guò)的東西。
不過(guò),通過(guò)稍微調(diào)整一下整體架構(gòu)的設(shè)計(jì),就可以引出一定水平的創(chuàng)造力,這種調(diào)整過(guò)的東西被稱(chēng)之為可變自編碼器。
可變自編碼器
可變自編碼器引入了兩個(gè)主要的設(shè)計(jì)更改:
1、不再將輸入轉(zhuǎn)換為潛在向量,而是輸出兩個(gè)向量參數(shù):均值和方差。
2、KL散度損失的附加損失項(xiàng)會(huì)添加到初始損失函數(shù)中。
可變自編碼器背后的思想是,希望解碼器使用由編碼器生成的均值向量和方差向量?jī)蓚€(gè)參數(shù)化的分布中,采樣出潛在向量來(lái)重構(gòu)數(shù)據(jù)。
這種采樣特征,賦予了編碼器一個(gè)受控空間。而當(dāng)可變自編碼器被訓(xùn)練后,每當(dāng)對(duì)輸入數(shù)據(jù)執(zhí)行前向傳遞時(shí),編碼器都會(huì)生成均值和方差向量,該均值和方差向量負(fù)責(zé)確定從中采樣潛在向量的分布。
均值向量確定了輸入數(shù)據(jù)編碼的中心位置,方差向量確定了要從中選擇編碼以生成真實(shí)輸出的徑向空間或圓。這意味著,對(duì)于具有相同輸入數(shù)據(jù)的前向傳遞,可變自編碼器可以生成以均值向量為中心,以方差向量空間內(nèi)為中心的輸出的不同變體。
為了進(jìn)行比較,在查看標(biāo)準(zhǔn)自編碼器時(shí),當(dāng)嘗試生成尚未經(jīng)過(guò)網(wǎng)絡(luò)訓(xùn)練的輸出時(shí),由于編碼器產(chǎn)生的潛在向量空間的不連續(xù)性,它會(huì)生成一些并不真實(shí)的輸出。
現(xiàn)在我們對(duì)可變自編碼器有了直觀的了解,讓我們看看如何在TensorFlow中構(gòu)建一個(gè)。
計(jì)算梯度的重新參數(shù)化
在上面的過(guò)程中,我們并未提及重參數(shù)化功能,但是它解決了可變自編碼器中一個(gè)非常關(guān)鍵的問(wèn)題。
回想一下,在解碼階段,由編碼器生成的均值和方差向量控制的分布采樣潛在向量編碼,當(dāng)通過(guò)的網(wǎng)絡(luò)前向數(shù)據(jù)傳播時(shí),這不會(huì)產(chǎn)生任何問(wèn)題,但是假若解碼器到編碼器反向傳播時(shí),由于采樣操作是不可逆的,會(huì)引起很大的問(wèn)題。
簡(jiǎn)而言之,就是無(wú)法通過(guò)采樣操作來(lái)計(jì)算梯度。
應(yīng)用重參數(shù)化技巧是解決這個(gè)問(wèn)題的一個(gè)不錯(cuò)的方法,首先通過(guò)生成均值0和方差1的標(biāo)準(zhǔn)高斯分布,再使用編碼器生成的均值和方差對(duì)該分布執(zhí)行微分加法和乘法運(yùn)算,就可完成這項(xiàng)工作。
注意,在將方差轉(zhuǎn)換為代碼中的對(duì)數(shù)空間,是為了確保數(shù)值的穩(wěn)定性。引入額外的損失項(xiàng),即Kullback-Leibler散度損失,以確保生成的分布盡可能接近均值0和方差1的標(biāo)準(zhǔn)高斯分布。
將分布的均值驅(qū)動(dòng)為0,可以確保生成的分布彼此非常接近,防止分布之間的不連續(xù)性。接近1的方差意味著我們有一個(gè)更適度(既不是很大也不是很?。┑目臻g來(lái)生成編碼。
在執(zhí)行重參數(shù)化之后,通過(guò)將方差向量與標(biāo)準(zhǔn)高斯分布相乘,并將結(jié)果添加到均值向量而獲得的分布,與由均值和方差向量立即控制的分布非常相似。
構(gòu)建可變自編碼器的簡(jiǎn)單步驟
最后,進(jìn)行一下技術(shù)總結(jié):
1、建立編碼器和解碼器網(wǎng)絡(luò)。
2、在編碼器和解碼器之間應(yīng)用重參數(shù)化技巧,以允許反向傳播。
3、端到端訓(xùn)練兩個(gè)網(wǎng)絡(luò)。
上文所使用的的完整代碼可在TensorFlow官方網(wǎng)站上找到。