SENet ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (2017)
1. ์๊ฐ
ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง(CNN) ์ํคํ ์ฒ์ ๊ฐ์ ์ ์ํ "Squeeze-and-Excitation (SE) ๋ธ๋ก" ์ ์ํ๋ค.
- SE ๋ธ๋ก: ์ปจ๋ณผ๋ฃจ์ ํน์ง์ ์ฑ๋ ๊ฐ ์ํธ ์์กด์ฑ์ ๋ชจ๋ธ๋งํ์ฌ ๋คํธ์ํฌ์ ํํ๋ ฅ์ ํฅ์์ํจ๋ค.
- ๊ตฌ์ฑ ์์:
- Squeeze ๋จ๊ณ: ์ฑ๋ ์ค๋ช ์(์ฑ๋๋ณ ์ค์๋) ์์ฑํ๋ค.
- Excitation ๋จ๊ณ: ์ฑ๋๋ณ ๊ฐ์ค์น ์ ์ฉํ๋ค.
์ฅ์
- ๊ธฐ์กด CNN ์ํคํ ์ฒ(SOTA๊ธ)์ ์ฝ๊ฒ ํตํฉ ๊ฐ๋ฅํ๋ค.
- ๊ณ์ฐ์ ์ผ๋ก ๊ฐ๋ฒผ์ ๋ชจ๋ธ ๋ณต์ก๋๋ฅผ ํฌ๊ฒ ์ฆ๊ฐ์ํค์ง ์๋๋ค.
๊ฒฐ๊ณผ ๋ฐ ์ฑ๊ณผ
- ImageNet ๋ฐ์ดํฐ์ ์์ SE ๋ธ๋ก์ ์ฌ์ฉํ SENet์ ์ฐ์ํ ์ฑ๋ฅ ๊ฒ์ฆ
- ILSVRC 2017 ๋ถ๋ฅ ๋ํ 1์ ๋ฌ์ฑ
2. ๊ธฐ์กด ๋ฐฉ๋ฒ์ ๋ฌธ์ ์ (SE ๋ธ๋ญ์ด ์๋)
์ ๊ทธ๋ฆผ์ฒ๋ผ ์ฌ๋ฌ ์ฑ๋์ ํตํด ์ฌ๋ฌ ์ ๋ณด๋ฅผ ์ป์ ์ ์๋ค. ์ฌ๋ฌ ์ ๋ณด๋ฅผ ํ์ธํด๋ณด๋ฉด 1, 2๋ฒ์งธ๋ ์์ค์ํ๊ณ 3๋ฒ์งธ ์ฑ๋์ ์ค์ํ ๊ฒ์ ์ ์ ์๋ค.
์ด๋ฌํ ์ํฉ์ด๋ฉด 1, 2๋ฒ์งธ ์ฑ๋์ ๋น์ค์ ์ค์ด๊ณ 3๋ฒ์งธ ์ฑ๋์ ๋น์ค์ ๋๋ฆฌ๋ ๊ฒ์ด ์ ๋ฆฌํ ๊ฒ์ด๋ค.
ํ์ง๋ง ๊ธฐ์กด CNN์ ์ด๋ฌํ ๊ธฐ๋ฅ์ด ์์์ผ๋ฉฐ โSE ๋ธ๋ญโ์ด ์ด ๊ธฐ๋ฅ์ ์ถ๊ฐํ๋ค.
3. ๊ตฌ์กฐ
์ ์ฒด์ ์ธ ๊ตฌ์กฐ
3.1 Squeeze: ์ ๋ณด ์์ถ ๋จ๊ณ
z ๋ ๊ฐ ์ฑ๋์ ๊ฐ์ ์๋ฏธํจ
๊ฐ ์ฑ๋๋ณ ์ค์๋๋ฅผ ํ์ธํ๊ธฐ์ํด ๊ธ๋ก๋ฒ ํ๊ท ํ๋ง์ ์ฌ์ฉํฉ๋๋ค.
์ด๋ ๊ฐ ์ฑ๋์ ๊ฐ๋ค์ ํ๊ท ์ ๋ด์ ๊ฐ ์ฑ๋์ ์ ๋ฐ์ ์ธ ๊ฐ์ ํฌ๊ธฐ๋ฅผ ๊ตฌํฉ๋๋ค.
Input์ ํฌ๊ธฐ (H,W,C)์ด๋ผ๋ฉด ๊ธ๋ก๋ฒ ํ๊ท ํ๋ง์ ์งํํ๋ฉด (1, 1, C)์ ํฌ๊ธฐ๋ก ์ค์ด๋ค๊ฒ ๋ฉ๋๋ค.
3.2 Excitation: ์ค์๋ ๊ณ์ฐ ๋จ๊ณ
s ๋ ์ต์ข ์ฑ๋๋ณ ๊ฐ์ค์น, W1๊ณผ W2๋ FC Layer๋ฅผ ์๋ฏธํ๊ณ ฮด๋ ReLU ฯ๋ Sigmoid์ด๋ค.
Squeeze ๋จ๊ณ์์ ์ง๊ณ๋ ์ ๋ณด๋ ์ด๋ค ์ฑ๋์ด ์ค์ํ์ง์ ๋ํ ์ ๋ณด๋ ๋ฐ์๋์์ง ์์ ์ํ์ ๋๋ค.
๋ฐ๋ผ์ ์ด๋ฅผ ๋ฐ์ํ๊ธฐ ์ํด ํ์ต์ด ์งํ๋์ด์ผํฉ๋๋ค. ์ด๋ฅผ Fully Connected -> ReLU -> Fully Connected -> Sigmoid ์์๋ก ๊ตฌ์ฑํ์ฌ ํ์ตํฉ๋๋ค.
๋ง์ง๋ง ํ์ฑํ ํจ์๋ก Sigmoid๋ฅผ ์ฌ์ฉํด 0~1 ์ฌ์ด์ ๊ฐ์ ๊ฐ์ ธ ์ฑ๋๋ณ ์ค์๋(Attention Score)๋ฅผ ์ฌ์ฉํ ์ ์๋๋ก ํฉ๋๋ค.
3.3 Scale: ์ค์๋ ์ ์ฉ
์ด์ ์ ์ฐ์ฐ์ ํตํด ์ป์ ์ฑ๋๋ณ ์ค์๋(=
์ด ์ฐ์ฐ์ ํตํด ์ค์ํ ์ฑ๋์ ๊ฐ์ ์ ์งํ๊ณ ์ ์ค์ํ ์ฑ๋์ ๊ฐ์ ๋ฌด์ํ๋ฉฐ ํ์ต์ด ์งํ๋๋๋ก ํฉ๋๋ค.
ResNet์ ์ ์ฉํ SE ๋ธ๋ญ์ ๋ชจ์ต์ด๋ค. ์ฌ๊ธฐ์ ์ฒ์ ๋ํ๋๋ r์ ๊ฐ์ FC ๊ตฌ์กฐ๊ฐ BottleNeck ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋ฉฐ ์ด๋ ์ฌ์ฉํ๋ ์ถ์ ๋น์จ์ ๊ฐ์ด๋ค.
4. ๋ชจ๋ธ ๋ฐ ๊ณ์ฐ ๋ณต์ก๋
SE ๋ธ๋ก์ ์ ์ฉํ๊ณ ์ํ๊ณ ์ ๊ณ์ฐ ๋ณต์ก๋ ์ฐจ์ด๋ฅผ ๋น๊ตํ๊ธฐ์ํด ResNet-50๊ณผ SE-ResNet-50์ ์์๋ก ์ฌ์ฉํฉ๋๋ค.
๊ณ์ฐ ๋น์ฉ ๋น๊ต
- ResNet-50: 224 ร 224 ์ ๋ ฅ ์ด๋ฏธ์ง์ ๋ํด ๋จ์ผ ์๋ฐฉํฅ ํจ์ค ์ ์ฝ 3.86G FLOPs ํ์.
- SE-ResNet-50: ๋์ผํ ์ ๋ ฅ์ ๋ํด ์ฝ 3.87G FLOPs ํ์, ์ด๋ ResNet-50 ๋๋น 0.26%์ ๊ณ์ฐ ๋น์ฉ ์ฆ๊ฐ.
์คํ ์๊ฐ ๋น๊ต
- ResNet-50: ์๋ฐฉํฅ ๋ฐ ์ญ๋ฐฉํฅ ํจ์ค์ ์ฝ 190ms ์์.
- SE-ResNet-50: ์ฝ 209ms ์์(256๊ฐ์ ์ด๋ฏธ์ง ๋ฏธ๋๋ฐฐ์น, 8๊ฐ์ NVIDIA Titan X GPU ์ฌ์ฉ).
- CPU ์ถ๋ก ์๊ฐ: ResNet-50์ 164ms, SE-ResNet-50์ 167ms.
- ๊ฒฐ๋ก ์ ์์ผ๋ก ์ถ๊ฐ ๋น์ฉ์ ๋ฏธ๋ฏธํฉ๋๋ค.
์ถ๊ฐ์ ์ธ ๋งค๊ฐ๋ณ์์ ๋ชจ๋ธ ์ฉ๋
- ResNet-50๊ณผ SE-ResNet-50์ ์ถ๊ฐ์ ์ธ ๋งค๊ฐ๋ณ์ ์ฐจ์ด๋ ๋๊ฐ์ FC ๊ณ์ธต์์๋ง ๋ฐ์ํฉ๋๋ค. ์ด๋ ์ ์ฒด์ ์ธ ๋ชจ๋ธ ์ฉ๋์์ ์์ฃผ ์์ ๋ถ๋ถ์ ์ฐจ์งํ๋ฉฐ ์ฝ 10% ์ฆ๊ฐํ ์ฉ๋์ ๋ณด์ฌ์ค๋๋ค.
5. ์ฑ๋ฅํ๊ฐ
SE ๋ธ๋ญ์ CNN ๋ชจ๋ธ ์ค๊ฐ์ ์ฌ์ฉ๋ ์ ์๋ค๊ณ ํ์ต๋๋ค. ๋ฐ๋ผ์ image Classification ๋ฟ๋ง ์๋๋ผ Scene Classification, Object Detection์์ ์ฑ๋ฅ ํ๊ฐ๋ฅผ ์งํํ์ต๋๋ค.
5.1 Image Classification
Imagenet ๋ฐ์ดํฐ์
์คํ ๊ฒฐ๊ณผ
Imagenet ๋ฐ์ดํฐ์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด SE๋ฅผ ์ ์ฉํ ๋ชจ๋ธ์ด ์ฝ๊ฐ์ ์ฐ์ฐ๋ ์ฆ๊ฐ๋ฅผ ๋ณด์ด๊ณ ๋ ๋ฎ์ Error์จ์ ๋ณด์ด๋ ๊ฒ์ ํ์ธ ํ ์ ์์ต๋๋ค.
CIFAR-10, CIFAR-100 ๋ฐ์ดํฐ์
์คํ๊ฒฐ๊ณผ (Error% ๋น๊ต)
๋ค๋ฅธ ๋ ๊ฐ์ง ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ์คํํ ๊ฒฐ๊ณผ SE ๋ธ๋ญ์ ์ถ๊ฐํ ๋ชจ๋ธ์ Error์จ์ด ๋ ๋ฎ์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
5.2 Scene Classification
Places365 ๋ฐ์ดํฐ ์
์คํ๊ฒฐ๊ณผ
์ฅ๋ฉด ๋ถ๋ฅ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ์คํํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด SE ๋ธ๋ญ์ ์ถ๊ฐํ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ ์ข์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
5.3 Object detection
COCO ๋ฐ์ดํฐ์
์คํ ๊ฒฐ๊ณผ
๊ฐ์ฒด ํ์ง ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ์คํํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด SE ๋ธ๋ญ์ ์ถ๊ฐํ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ ์ข์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
์ฌ๋ฌ ๋ฌธ์ ์์ ์ ์ฉํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด ๋ฌธ์ ์ ๊ด๊ณ์์ด ๊ธฐ์กด CNN ๊ตฌ์กฐ์ SE ๋ธ๋ญ์ ์ถ๊ฐํ๋ฉด ์ฑ๋ฅ์ด ๊ฐ์ ๋จ์ ์ ์ ์์์ต๋๋ค.