教育行業(yè)A股IPO第一股(股票代碼 003032)

全國(guó)咨詢/投訴熱線:400-618-4000

yolo算法:構(gòu)造訓(xùn)練樣本和設(shè)計(jì)損失函數(shù)

更新時(shí)間:2022年12月08日09時(shí)36分 來(lái)源:傳智教育 瀏覽次數(shù):

在進(jìn)行模型訓(xùn)練時(shí),我們需要構(gòu)造訓(xùn)練樣本和設(shè)計(jì)損失函數(shù),才能利用梯度下降對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練。

訓(xùn)練樣本的構(gòu)建

將一幅圖片輸入到y(tǒng)olo模型中,對(duì)應(yīng)的輸出是一個(gè)7x7x30張量,構(gòu)建標(biāo)簽label時(shí)對(duì)于原圖像中的每一個(gè)網(wǎng)格grid都需要構(gòu)建一個(gè)30維的向量。對(duì)照下圖我們來(lái)構(gòu)建目標(biāo)向量:

1670392915921_28.png

20個(gè)對(duì)象分類的概率

對(duì)于輸入圖像中的每個(gè)對(duì)象,先找到其中心點(diǎn)。比如上圖中自行車,其中心點(diǎn)在黃色圓點(diǎn)位置,中心點(diǎn)落在黃色網(wǎng)格內(nèi),所以這個(gè)黃色網(wǎng)格對(duì)應(yīng)的30維向量中,自行車的概率是1,其它對(duì)象的概率是0。所有其它48個(gè)網(wǎng)格的30維向量中,該自行車的概率都是0。這就是所謂的"中心點(diǎn)所在的網(wǎng)格對(duì)預(yù)測(cè)該對(duì)象負(fù)責(zé)"。狗和汽車的分類概率也是同樣的方法填寫

2個(gè)bounding box的位置

訓(xùn)練樣本的bbox位置應(yīng)該填寫對(duì)象真實(shí)的位置bbox,但一個(gè)對(duì)象對(duì)應(yīng)了2個(gè)bounding box,該填哪一個(gè)呢?需要根據(jù)網(wǎng)絡(luò)輸出的bbox與對(duì)象實(shí)際bbox的IOU來(lái)選擇,所以要在訓(xùn)練過(guò)程中動(dòng)態(tài)決定到底填哪一個(gè)bbox。

2個(gè)bounding box的置信度

預(yù)測(cè)置信度的公式為:

1670393030360_29.png

利用網(wǎng)絡(luò)輸出的2個(gè)bounding box與對(duì)象真實(shí)bounding box計(jì)算出來(lái)。然后看這2個(gè)bounding box的IOU,哪個(gè)比較大,就由哪個(gè)bounding box來(lái)負(fù)責(zé)預(yù)測(cè)該對(duì)象是否存在,即該bounding box的Pr(Object)=1,同時(shí)對(duì)象真實(shí)bounding box的位置也就填入該bounding box。另一個(gè)不負(fù)責(zé)預(yù)測(cè)的bounding box的Pr(Object)=0。

上圖中自行車所在的grid對(duì)應(yīng)的結(jié)果如下圖所示:

樣本標(biāo)簽

損失函數(shù)

損失就是網(wǎng)絡(luò)實(shí)際輸出值與樣本標(biāo)簽值之間的偏差:

損失函數(shù)

yolo給出的損失函數(shù):

損失函數(shù)

模型訓(xùn)練

Yolo先使用ImageNet數(shù)據(jù)集對(duì)前20層卷積網(wǎng)絡(luò)進(jìn)行預(yù)訓(xùn)練,然后使用完整的網(wǎng)絡(luò),在PASCAL VOC數(shù)據(jù)集上進(jìn)行對(duì)象識(shí)別和定位的訓(xùn)練。

Yolo的最后一層采用線性激活函數(shù),其它層都是Leaky ReLU。訓(xùn)練中采用了drop out和數(shù)據(jù)增強(qiáng)(data augmentation)來(lái)防止過(guò)擬合。

模型預(yù)測(cè)

將圖片resize成448x448的大小,送入到y(tǒng)olo網(wǎng)絡(luò)中,輸出一個(gè) 7x7x30 的張量(tensor)來(lái)表示圖片中所有網(wǎng)格包含的對(duì)象(概率)以及該對(duì)象可能的2個(gè)位置(bounding box)和可信程度(置信度)。在采用NMS(Non-maximal suppression,非極大值抑制)算法選出最有可能是目標(biāo)的結(jié)果。

總結(jié):yolo模型預(yù)測(cè)速度非??欤幚硭俣瓤梢赃_(dá)到45fps,其快速版本(網(wǎng)絡(luò)較小)甚至可以達(dá)到155fps。訓(xùn)練和預(yù)測(cè)可以端到端的進(jìn)行,非常簡(jiǎn)便。準(zhǔn)確率會(huì)打折扣對(duì)于小目標(biāo)和靠的很近的目標(biāo)檢測(cè)效果并不好。

0 分享到:
和我們?cè)诰€交談!