Week 21 ํ์ต ์ ๋ฆฌ
๋ชจ๋ธ ๊ฒฝ๋ํ: ์์ ํจ์จ์ฑ์ ๋์ด๋ AI ์ต์ ํ ๊ธฐ์
๊ฑฐ๋ AI ๋ชจ๋ธ์ ์๋ง์ ํ๋ผ๋ฏธํฐ๋ก ๊ตฌ์ฑ๋์ด ์์ด, ํ์ต ์ ๋ง์ GPU, ์ ๋ ฅ, ์๊ฐ์ด ํ์ํฉ๋๋ค. ๊ทธ๋ฌ๋ ๋๋ถ๋ถ์ ํ๊ฒฝ์์๋ ์ด๋ฌํ ์์๊ณผ ์๊ฐ์ ์ถฉ๋ถํ ํ๋ณดํ๊ธฐ ์ด๋ ต๊ธฐ ๋๋ฌธ์, ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์๋ ํฌ๊ธฐ์ ๊ณ์ฐ ๋น์ฉ์ ์ค์ด๋ ๋ชจ๋ธ ๊ฒฝ๋ํ ๊ธฐ์ ์ด ์ค์ํด์ก์ต๋๋ค. ๊ฒฝ๋ํ๋ ๋ชจ๋ธ์ ์ถ๋ก ์๊ฐ๋ ๋จ์ถ๋์ด ์์จ์ฃผํ๊ณผ ๊ฐ์ด ์ค์๊ฐ ์ฒ๋ฆฌ๊ฐ ์๊ตฌ๋๋ ํ์คํฌ์ ์ ํฉํฉ๋๋ค.
๋ชจ๋ธ ๊ฒฝ๋ํ ์ฃผ์ ๊ธฐ๋ฒ
1. Pruning (๊ฐ์ง์น๊ธฐ)
- ๊ฐ๋
:
ํ์ต๋ ๋ชจ๋ธ์์ ์ค์๋๊ฐ ๋ฎ์ ๋ด๋ฐ์ด๋ ์ฐ๊ฒฐ(์๋ ์ค)์ ์ ๊ฑฐํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. - ํจ๊ณผ:
๋ชจ๋ธ์ ํฌ๊ธฐ์ ๊ณ์ฐ ๋น์ฉ์ ์ค์ฌ ๊ฒฝ๋ํ์ ์๋ ํฅ์์ ๋์์ ๋๋ชจํฉ๋๋ค.
2. Knowledge Distillation (์ง์ ์ฆ๋ฅ)
- ๊ฐ๋
:
๊ณ ์ฑ๋ฅ์ Teacher ๋ชจ๋ธ๋ก๋ถํฐ ์ง์์ ์ ๋ฌ๋ฐ์, ๋ ๊ฒฝ๋ํ๋ Student ๋ชจ๋ธ์ ํ์ต์ํค๋ ๊ธฐ๋ฒ์ ๋๋ค. - ํจ๊ณผ:
Teacher ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ต๋ํ ์ ์งํ๋ฉด์, ์์ ์๋ชจ๋ฅผ ์ค์ธ Student ๋ชจ๋ธ์ ๊ตฌ์ถํ ์ ์์ต๋๋ค.
3. Quantization (์์ํ)
- ๊ฐ๋
:
๋ชจ๋ธ์ ๊ฐ์ค์น์ ํ์ฑํ๋ฅผ ๋ฎ์ ๋นํธ ์ ๋ฐ๋๋ก ๋ณํํ์ฌ ์ ์ฅ ๋ฐ ๊ณ์ฐ ํจ์จ์ฑ์ ๋์ด๋ ๋ฐฉ๋ฒ์ ๋๋ค. - ํจ๊ณผ:
์ฐ์ฐ ์๋๋ฅผ ํฅ์์ํค๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๊ฐ์์์ผ, ์ถ๋ก ๋จ๊ณ์์ ํนํ ์ ๋ฆฌํ ์ฑ๋ฅ ๊ฐ์ ์ ์ด๋ฃน๋๋ค.
๊ฒฐ๋ก
๋ชจ๋ธ ๊ฒฝ๋ํ ๊ธฐ์ ์ ์์๊ณผ ์๊ฐ์ด ์ ํ๋ ํ๊ฒฝ์์ ๊ณ ์ฑ๋ฅ AI ๋ชจ๋ธ์ ์ค์ ์๋น์ค์ ์ ์ฉํ๊ธฐ ์ํ ํ์ ์ ๋ต์ ๋๋ค. Pruning, Knowledge Distillation, Quantization๊ณผ ๊ฐ์ ๊ธฐ๋ฒ์ ํตํด ๋ชจ๋ธ์ ํฌ๊ธฐ๋ฅผ ์ค์ด๋ฉด์๋, ์ต์ข ์ฑ๋ฅ์ ์ต๋ํ ์ ์งํ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด ์ค์๊ฐ ์ฒ๋ฆฌ๊ฐ ์๊ตฌ๋๋ ๋ค์ํ ์ ํ๋ฆฌ์ผ์ด์ , ์๋ฅผ ๋ค์ด ์์จ์ฃผํ, ๋ชจ๋ฐ์ผ ์ ํ๋ฆฌ์ผ์ด์ ๋ฑ์ ํจ๊ณผ์ ์ผ๋ก ํ์ฉํ ์ ์์ต๋๋ค.
๋ชจ๋ธ ๊ฒฝ๋ํ๋ฅผ ์ํ Pruning ๊ธฐ์
๊ฑฐ๋ AI ๋ชจ๋ธ์ ์๋ง์ ํ๋ผ๋ฏธํฐ๋ก ์ด๋ฃจ์ด์ ธ ์์ด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๊ณผ ์ฐ์ฐ ๋น์ฉ์ด ํฝ๋๋ค. ์ด๋ฅผ ์ค์ด๊ธฐ ์ํด ์ค์๋๊ฐ ๋ฎ์ ๋ด๋ฐ์ด๋ ์ฐ๊ฒฐ(์๋ ์ค)์ ์ ๊ฑฐํ๋ pruning ๊ธฐ๋ฒ์ด ์ฌ์ฉ๋ฉ๋๋ค. Pruning์ ํตํด ๋ชจ๋ธ์ ํฌ๊ธฐ๋ฅผ ์ค์ด๊ณ , ๊ณ์ฐ ์๋๋ฅผ ๋์ด๋ฉฐ, ์ถ๋ก ์๊ฐ๋ ๋จ์ถํ ์ ์์ต๋๋ค.
Pruning ๊ธฐ๋ฒ์ ํฌ๊ฒ ๊ตฌ์กฐ(structure), ์ค์ฝ์ด๋ง(scoring), ์ค์ผ์ค๋ง(scheduling), ์ด๊ธฐํ(initialization) ๋ค ๊ฐ์ง ๊ด์ ์์ ์ ๊ทผํ ์ ์์ต๋๋ค.
1. Pruning ๊ธฐ๋ฒ์ ๋ถ๋ฅ
1-1. Structure: ๋ชจ๋ธ ๊ตฌ์กฐ ๋ณํ ์ฌ๋ถ
-
Unstructured Pruning
๊ฐ๋ณ ํ๋ผ๋ฏธํฐ ๋จ์๋ก ์ค์๋๊ฐ ๋ฎ์ ๊ฐ์ 0์ผ๋ก ๋ณ๊ฒฝํ์ฌ ์ ๊ฑฐํฉ๋๋ค.- ์ฅ์ : ๊ตฌํ์ด ์๋์ ์ผ๋ก ์ฝ์ต๋๋ค.
- ๋จ์ : ๋ชจ๋ธ ๊ตฌ์กฐ ์์ฒด๋ ๋ณํ์ง ์์ผ๋ฏ๋ก, 0์ผ๋ก ์ฑ์์ง ๋ถ๋ถ์ ๊ทธ๋๋ก ์ฐ์ฐ์ ์ฌ์ฉํ๊ฒ ๋์ด ์ฐ์ฐ ์๋ ํฅ์์ ๋ฏธ๋ฏธํ ์ ์์ต๋๋ค.
-
Structured Pruning
๋ด๋ฐ, ์ฑ๋, ํน์ ๋ ์ด์ด ์ ์ฒด๋ฅผ ์ ๊ฑฐํ์ฌ ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ๋ณ๊ฒฝํฉ๋๋ค.- ์ฅ์ : ๊ตฌ์กฐ ๋ณ๊ฒฝ์ผ๋ก ์ธํด ์ค์ ์ฐ์ฐ ์๋ ํฅ์์ ๊ธฐ๋ํ ์ ์์ต๋๋ค.
- ๋จ์ : ๊ตฌํ์ด ๋ณต์กํ๊ฑฐ๋ ๊ฒฝ์ฐ์ ๋ฐ๋ผ ๋ถ๊ฐ๋ฅํ ์๋ ์์ต๋๋ค.
๋ฐฉ๋ฒ | ๋จ์ | ๊ตฌ์กฐ ๋ณ๊ฒฝ ์ฌ๋ถ | ์ฅ์ | ๋จ์ |
---|---|---|---|---|
Unstructured | ๊ฐ๋ณ ํ๋ผ๋ฏธํฐ | ์์ | ๊ตฌํ์ด ์ฌ์ | ์ฐ์ฐ ์๋ ํฅ์ ๋ฏธํก |
Structured | ๋ด๋ฐ/์ฑ๋/๋ ์ด์ด | ์์ | ์ฐ์ฐ ์๋ ํฅ์ ๊ฐ๋ฅ | ๊ตฌํ์ด ์ด๋ ต๊ฑฐ๋ ์ ํ์ ์ผ ์ ์์ |
1-2. Scoring: ๊ฐ์ง์น๊ธฐํ ํ๋ผ๋ฏธํฐ ์ ์
์ค์๋ ๊ณ์ฐ ๋ฐฉ๋ฒ
- ๊ฐ๋ณ ํ๋ผ๋ฏธํฐ ๊ธฐ์ค: ๊ฐ ํ๋ผ๋ฏธํฐ์ ์ ๋๊ฐ์ ๊ธฐ์ค์ผ๋ก ์ค์๋๋ฅผ ํ๊ฐํ์ฌ, ์ ๋๊ฐ์ด ์์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๊ฑฐํฉ๋๋ค.
- ๋ ์ด์ด ๋ณ Lp-norm ๊ธฐ์ค: ๋ ์ด์ด๋ง๋ค
(์: )์ ๊ณ์ฐํด, ํด๋น ๊ฐ์ด ์์ ๋ ์ด์ด์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๊ฑฐํฉ๋๋ค.
์ค์๋๋ฅผ ๋ฐ์ํ๋ ๋จ์
- Global Pruning:
์ ์ฒด ๋ชจ๋ธ์์ ์ค์๋๊ฐ ๋ฎ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ํํด ์ ๊ฑฐํฉ๋๋ค.- ์ฅ์ : ์ค์ํ ๋ ์ด์ด์ ํ๋ผ๋ฏธํฐ๋ ์๋์ ์ผ๋ก ๋ณด์กด๋ฉ๋๋ค.
- ๋จ์ : ๊ณ์ฐ๋์ด ๋ง์์ง ์ ์์ต๋๋ค.
- Local Pruning:
๊ฐ ๋ ์ด์ด๋ณ๋ก ์ค์๋ ํ์ ์ผ์ ๋น์จ(์: ํ์ 50%)์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๊ฑฐํฉ๋๋ค.- ์ฅ์ : ํน์ ๋ ์ด์ด์ ๊ณผ๋ํ๊ฒ ์ง์ค๋์ง ์์ ๊ท ํ ์๊ฒ ๊ฐ์ง์น๊ธฐํ ์ ์์ต๋๋ค.
- ๋จ์ : ์ค์ํ ๋ ์ด์ด์์ ๋ถํ์ํ๊ฒ ๋ง์ ํ๋ผ๋ฏธํฐ๊ฐ ์ ๊ฑฐ๋ ์ํ์ด ์์ต๋๋ค.
1-3. Scheduling: ๊ฐ์ง์น๊ธฐ ์งํ ๋ฐฉ์
- One-shot Pruning:
ํ ๋ฒ์ ๊ฐ์ง์น๊ธฐ๋ฅผ ์ํํฉ๋๋ค. ๋น ๋ฅด์ง๋ง ์ฑ๋ฅ์ด ๋ถ์์ ํ ์ ์์ต๋๋ค. - Recursive (Iterative) Pruning:
์ฌ๋ฌ ๋ฒ์ ๊ฑธ์ณ ์กฐ๊ธ์ฉ ๊ฐ์ง์น๊ธฐ๋ฅผ ์งํํฉ๋๋ค. ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ์ง๋ง ์ฑ๋ฅ ์์ ์ฑ์ด ๋์ต๋๋ค.
1-4. Initialization: Fine-tuning ์์์
๊ฐ์ง์น๊ธฐ ํ ๋ชจ๋ธ์ ์ฌํ์ตํ ๋, ์ด๊ธฐ ์ํ์ ๋ฐ๋ผ ๋ ๊ฐ์ง ๋ฐฉ์์ด ์์ต๋๋ค.
- Weight-Preserving (Classic):
๊ฐ์ง์น๊ธฐ ์งํ์ ๋ชจ๋ธ ์ํ์์ ๋ฐ๋ก fine-tuning์ ์งํํฉ๋๋ค. ํ์ต๊ณผ ์๋ ด์ด ๋น ๋ฅด์ง๋ง ์ฑ๋ฅ์ด ๋ถ์์ ํ ์ ์์ต๋๋ค. - Weight-Reinitializing (Rewinding):
๊ฐ์ง์น๊ธฐ ํ ๋ชจ๋ธ์ ์ผ๋ถ๋ฅผ ๋๋ค ๊ฐ์ผ๋ก ์ด๊ธฐํํ ํ fine-tuning์ ์งํํฉ๋๋ค. ํ์ต ์๊ฐ์ด ๋ ๊ฑธ๋ฆฌ์ง๋ง, ์ฑ๋ฅ ์์ ์ฑ์ด ํฅ์๋ ์ ์์ต๋๋ค.
Iterative Magnitude Pruning (IMP):
๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ pruning ๋ฐฉ๋ฒ์ผ๋ก,
- unstructured ๋ฐฉ์,
- global (ํ๋ผ๋ฏธํฐ๋ณ ์ ๋๊ฐ ๊ธฐ๋ฐ),
- recursive (iterative) ๋ฐฉ์,
- rewinding์ ๊ฒฐํฉํ์ฌ ์ํํฉ๋๋ค.
2. ์ถ๊ฐ ๊ณ ๋ ค ์ฌํญ
2-1. Matrix Sparsity ๋ฌธ์
-
์ ์:
matrix์ ๋๋ถ๋ถ ์์๊ฐ 0์ธ ์ํ๋ฅผ ์๋ฏธํฉ๋๋ค.-
Density:
density=0์ดย ์๋ย ์์์ย ๊ฐ์์ ์ฒดย ์์์ย ๊ฐ์\text{density} = \frac{\text{0์ด ์๋ ์์์ ๊ฐ์}}{\text{์ ์ฒด ์์์ ๊ฐ์}}
-
Sparsity:
sparsity=0์ธย ์์์ย ๊ฐ์์ ์ฒดย ์์์ย ๊ฐ์=1โdensity\text{sparsity} = \frac{\text{0์ธ ์์์ ๊ฐ์}}{\text{์ ์ฒด ์์์ ๊ฐ์}} = 1 - \text
-
-
๋ฌธ์ ์ :
Unstructured pruning์ผ๋ก ์์ฑ๋ 0 ๊ฐ์ด ์ฌ์ ํ ์ฐ์ฐ์ ํฌํจ๋๋ฉด, ๊ณ์ฐ ์๋ ํฅ์์ด ๋ฏธํกํ ์ ์์ต๋๋ค. -
ํด๊ฒฐ ๋ฐฉ์:
- Sparse Matrix Representation:
0์ด ์๋ ์์๋ค์ ์ขํ๋ฅผ ์ ์ฅํ์ฌ ์ฐ์ฐ ์ ํ์ฉ (sparsity๊ฐ ๋งค์ฐ ๋์ ๊ฒฝ์ฐ). - ์ ์ฉ ํ๋์จ์ด:
NVIDIA Tensor Core์ ๊ฐ์ด, sparse ์ฐ์ฐ์ ์ต์ ํํ ํ๋์จ์ด๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- Sparse Matrix Representation:
2-2. Sensitivity (๋ฏผ๊ฐ๋)
- ๊ฐ๋
:
๊ฐ์ง์น๊ธฐํ ํ๋ผ๋ฏธํฐ๋ ๋ ์ด์ด๊ฐ ๋ชจ๋ธ ์ ์ฒด ์ฑ๋ฅ์ ๋ฏธ์น๋ ์ํฅ์ ํ๊ฐํ๋ ์งํ์ ๋๋ค. - ์ค๋ฌด:
์ผ๋ฐ์ ์ผ๋ก ์์ชฝ ๋ ์ด์ด๊ฐ ๋ฏผ๊ฐ๋๊ฐ ๋๊ณ , ๋ค์ชฝ ๋ ์ด์ด๊ฐ ๋ ๋ฏผ๊ฐํ ๊ฒฝํฅ์ด ์์ต๋๋ค. ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก empiricalํ๊ฒ pruning ๋น์จ์ ๊ฒฐ์ ํฉ๋๋ค.
3. In Practice: CNN๊ณผ BERT์์์ Pruning
3-1. CNN์์์ Pruning
- CNN ๊ตฌ์ฑ:
๋ณดํต Convolutional Layer์ Fully-Connected Layer(FC)๋ก ๊ตฌ์ฑ๋ฉ๋๋ค. - ํน์ง:
- ๋๋ถ๋ถ์ ํ๋ผ๋ฏธํฐ๋ FC layer์ ์ง์ค๋์ด ์์ง๋ง, ์ฐ์ฐ ์๋ ๋ณ๋ชฉ์ CNN ๋ ์ด์ด์์ ๋ฐ์ํฉ๋๋ค.
- ํจ์จ์ ์ธ pruning์ ์ํด CNN์ ํํฐ(์ปค๋)์ FC layer ๋ชจ๋๋ฅผ ๋์์ผ๋ก ๊ฐ์ง์น๊ธฐ๋ฅผ ์ํํฉ๋๋ค.
- ๋ฐฉ๋ฒ:
CNN ๋ ์ด์ด์์๋ ์ค์๋๊ฐ ๋ฎ์ ํํฐ๋ฅผ ์ ๊ฑฐํ๋ฉฐ, ์ค์๋๋ ์ฃผ๋ก sparsity ๋๋ Lโ-norm ๊ฐ ๊ธฐ์ค์ผ๋ก ํ๋จํฉ๋๋ค.
3-2. BERT์์์ Pruning
- BERT ๊ตฌ์ฑ:
12๊ฐ์ Transformer ๋ ์ด์ด๋ก ์ด๋ฃจ์ด์ง ๋ค์ฉ๋ ์ธ์ด ๋ชจ๋ธ์ ๋๋ค. - ํน์ง:
- BERT๋ ์์ชฝ ๋ ์ด์ด์์ ๋จ์ด ๋จ์์ ์์ ์ ๋ณด๋ฅผ, ๋ค์ชฝ ๋ ์ด์ด์์ ๋ฌธ์ฅ๊ณผ ๊ฐ์ ํฐ ์ ๋ณด๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- ๋ ์ด์ด๋ณ sparsity๊ฐ ์ผ์ ํ์ง ์์, global ๋๋ structured pruning ์ ์ฑ๋ฅ ์ ํ ์ํ์ด ์์ต๋๋ค.
- ์ ๋๊ฐ ๊ธฐ๋ฐ์ pruning์ด ํจ๊ณผ์ ์ด๋ฉฐ, ์์ชฝ ๋ ์ด์ด์ ๋ํด์๋ local pruning์ ์ ์ฉํด ์ฑ๋ฅ ์ ํ๋ฅผ ์ต์ํํ ์ ์์ต๋๋ค.
4. ์์ ์ฝ๋ (PyTorch ๊ธฐ๋ฐ)
๋ค์์ ๊ฐ๋จํ 2-layer MLP ๋ชจ๋ธ์ ๋ํด unstructured pruning์ ์ ์ฉํ๋ ์์ ์ ๋๋ค.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# 2-layer MLP ๋ชจ๋ธ ์ ์
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 10)
self.fc2 = nn.Linear(10, 2)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# ๋ชจ๋ธ ์ ์ธ ๋ฐ ํ์ต/ํ
์คํธ ํจ์ ์ ์ (train, test๋ ๋ณ๋ ๊ตฌํ)
model = Model()
# ํ์ต ๋ฐ ํ
์คํธ (pruning ์ )
train(model, train_data)
test(model, test_data)
# ์ ์ฒด ํ๋ผ๋ฏธํฐ ์ ํ์ธ
total_params = sum(p.numel() for p in model.parameters())
total_params_nz = sum((p != 0.0).sum().item() for p in model.parameters())
print("์ ์ฒด ํ๋ผ๋ฏธํฐ:", total_params)
print("0์ด ์๋ ํ๋ผ๋ฏธํฐ:", total_params_nz)
# Pruning ๋์ ๋ ์ด์ด ์ ํ (์: ์ฒซ ๋ฒ์งธ Fully Connected ๋ ์ด์ด)
layer = model.fc1
# 1) ๋๋ค์ผ๋ก ์ฒซ ๋ฒ์งธ ๋ ์ด์ด์์ 50%๋ฅผ ์ ๊ฑฐ (unstructured)
prune.random_unstructured(layer, name='weight', amount=0.5)
# 2) ํ๋ผ๋ฏธํฐ ์ ๋๊ฐ ๊ธฐ์ค ํ์ 50%๋ฅผ ์ ๊ฑฐ (L1-based pruning)
prune.l1_unstructured(layer, name='weight', amount=0.5)
# pruning ํ fine-tuning์ ์ํด ๋ค์ ํ์ต
train(model, train_data)
# pruning ๋ฐ fine-tuning ํ ํ
์คํธ
test(model, test_data)
์ค๋ช :
- ๋จผ์ ๋ชจ๋ธ์ ํ์ต ๋ฐ ํ ์คํธํ ํ, ์ฒซ ๋ฒ์งธ ๋ ์ด์ด์ weight tensor์ ๋ํด unstructured pruning์ ์ ์ฉํฉ๋๋ค.
prune.random_unstructured
๋ ๋ฌด์์๋ก 50%์ ํ๋ผ๋ฏธํฐ๋ฅผ 0์ผ๋ก ๋ง๋ญ๋๋ค.prune.l1_unstructured
๋ ๊ฐ ํ๋ผ๋ฏธํฐ์ ์ ๋๊ฐ์ ๊ธฐ์ค์ผ๋ก ํ์ 50%๋ฅผ 0์ผ๋ก ๋ง๋ญ๋๋ค.- ์ดํ fine-tuning์ ํตํด pruning์ผ๋ก ์ธํ ์ฑ๋ฅ ์ ํ๋ฅผ ๋ณด์ํฉ๋๋ค.
๊ฒฐ๋ก
Pruning์ ๋ชจ๋ธ ๊ฒฝ๋ํ์ ํต์ฌ ๊ธฐ๋ฒ ์ค ํ๋๋ก,
- ๊ตฌ์กฐ์ ์ธ ์ ๊ทผ(structure): unstructured์ structured ๋ฐฉ์ ์ ํ
- ์ค์ฝ์ด๋ง(scoring): ํ๋ผ๋ฏธํฐ ์ค์๋ ํ๊ฐ ๋ฐฉ๋ฒ (์ ๋๊ฐ, Lp-norm) ๋ฐ global vs. local ๋ฐฉ์
- ์ค์ผ์ค๋ง(scheduling): ํ ๋ฒ ํน์ ์ฌ๋ฌ ๋ฒ์ iterative ๋ฐฉ์
- ์ด๊ธฐํ(initialization): fine-tuning ์ weight-preserving ๋๋ reinitializing ์ ํ
์ถ๊ฐ๋ก, matrix sparsity, ์ ์ฉ ํ๋์จ์ด ํ์ฉ, ๊ทธ๋ฆฌ๊ณ ๊ฐ ๋ ์ด์ด์ sensitivity ๋ฑ์ ๊ณ ๋ คํ์ฌ ์ต์ ์ pruning ์ ๋ต์ ์ค๊ณํ ์ ์์ต๋๋ค. CNN์ด๋ BERT์ ๊ฐ์ ๋ชจ๋ธ์ pruning์ ์ ์ฉํ๋ ์ค์ ์ฌ๋ก๋ฅผ ํตํด, ๋ชจ๋ธ์ ํฌ๊ธฐ์ ์ฐ์ฐ ๋น์ฉ์ ํจ๊ณผ์ ์ผ๋ก ์ค์ด๋ฉด์๋ ์ฑ๋ฅ์ ์ ์งํ ์ ์์์ ํ์ธํ ์ ์์ต๋๋ค.
Knowledge Distillation (์ง์ ์ฆ๋ฅ)
Knowledge Distillation(KD)์ ๊ณ ์ฑ๋ฅ์ Teacher ๋ชจ๋ธ๋ก๋ถํฐ ์ง์์ ์ ๋ฌ๋ฐ์, ์๋์ ์ผ๋ก ๊ฒฝ๋ํ๋ Student ๋ชจ๋ธ์ ํ์ต์ํค๋ ๊ธฐ๋ฒ์
๋๋ค.
Teacher ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ ์๊ฐ ๋ง์ ์ฑ๋ฅ์ ์ข์ง๋ง ์ฐ์ฐ ๋น์ฉ๊ณผ ์ถ๋ก ์๋๊ฐ ๋๋ฆฐ ๋ฐ๋ฉด, Student ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ ์๊ฐ ์ ์ด ๋น ๋ฅธ ์ถ๋ก ์๋๋ฅผ ์๋ํฉ๋๋ค. KD๋ฅผ ํตํด ์ฑ๋ฅ์ ์ต๋ํ ์ ์งํ๋ฉด์๋ ๋ชจ๋ธ ๊ฒฝ๋ํ์ ์ฐ์ฐ ํจ์จ์ฑ์ ๋์ผ ์ ์์ต๋๋ค.
1. KD ๊ธฐ๋ฒ์ ๋ถ๋ฅ
1-1. Knowledge ๊ด์
Response-based KD
- Logit-based KD:
Teacher ๋ชจ๋ธ์ logit ๊ฐ(์ฆ, ์ถ๋ ฅ ํ๋ฅ ๋ถํฌ)์ Student ๋ชจ๋ธ์ด ๋ชจ๋ฐฉํ๋๋ก ํ์ตํฉ๋๋ค.- ๋ฐฉ๋ฒ:
Teacher๊ฐ ์์ธกํ ํด๋์ค ํ๋ฅ ๋ถํฌ(์: cat=0.8, cow=0.07, dog=0.13)๋ฅผ Student๊ฐ ์์ธกํ๋๋ก KL divergence๋ฅผ loss๋ก ์ฌ์ฉํฉ๋๋ค. - Temperature T:
- T < 1: ํ๋ฅ ๋ถํฌ๊ฐ ๋ ๋ ์นด๋กญ๊ฒ(๋์ contrast)
- T > 1: ํ๋ฅ ๋ถํฌ๊ฐ ์๋งํ๊ฒ(๋ฎ์ contrast)
์ ์ ํ T ๊ฐ ์ค์ ์ ์ฆ๋ฅ ์ฑ๋ฅ์ ํฐ ์ํฅ์ ๋ฏธ์นฉ๋๋ค.
- ๋ฐฉ๋ฒ:
Feature-based KD
- Teacher ๋ชจ๋ธ์ ์ค๊ฐ ๋ ์ด์ด์์ ์ถ์ถํ feature ๋๋ ํํ(representation)์ Student ๋ชจ๋ธ์ด ๋ชจ๋ฐฉํ๋๋ก ํฉ๋๋ค.
- ๋ณดํต Teacher์ Student์ ์ค๊ฐ ๋ ์ด์ด ์ฐจ์์ด ๋ค๋ฅด๊ธฐ ๋๋ฌธ์, Student ๋ชจ๋ธ์๋ regressor layer(๋๋ projection layer)๋ฅผ ์ถ๊ฐํ์ฌ ์ฐจ์์ ๋ง์ถ๊ณ , ๋ feature map ๊ฐ์ ์ฐจ์ด๋ฅผ MSE loss ๋ฑ์ผ๋ก ์ค์ ๋๋ค.
1-2. Transparency ๊ด์
- White-box KD:
Teacher ๋ชจ๋ธ์ ๋ด๋ถ ๊ตฌ์กฐ์ ํ๋ผ๋ฏธํฐ ๋ฑ์ ์์ ํ ์ด๋ํ ์ ์๋ ๊ฒฝ์ฐ. - Gray-box KD:
Teacher ๋ชจ๋ธ์ output ๋ฐ ์ต์ข logit ๊ฐ ๋ฑ ์ ํ๋ ์ ๋ณด๋ง ์ด๋ ๊ฐ๋ฅํ ๊ฒฝ์ฐ. - Black-box KD (Imitation Learning):
Teacher ๋ชจ๋ธ์ ๋ด๋ถ ๊ตฌ์กฐ๋ ํ๋ผ๋ฏธํฐ๋ ์ ์ ์๊ณ , ์ ๋ ฅ์ ๋ฐ๋ฅธ ๊ฒฐ๊ณผ๋ง์ ๋ฐํ์ผ๋ก Student ๋ชจ๋ธ์ด ๋ชจ๋ฐฉ ํ์ตํ๋ ๋ฐฉ์.- ์ฅ์ : ๋ฐ์ดํฐ ์์ง ๋น์ฉ์ด ๋ฎ๊ณ , ์ธ๊ฐ์ด ํด์ ๊ฐ๋ฅํ ํํ์ ์ง์์ ์ ๋ฌ๋ฐ์ ์ ์์ต๋๋ค.
- ๋จ์ : Teacher๊ฐ ์ค๋ฅ๊ฐ ์๋ ์์ธก์ ํ๋ค๋ฉด ๊ทธ ์ํฅ์ ๊ทธ๋๋ก ๋ฐ์ ์ํ์ด ์์ต๋๋ค.
2. KD ์ ์ฉ ๋จ๊ณ
- ๋ชจ๋ฐฉ ๋ฐ์ดํฐ ์์ง:
Teacher ๋ชจ๋ธ์ ํน์ ์ง๋ฌธ(seed ์ง๋ฌธ)์ ์ ๋ ฅํ์ฌ ์๋ต ๋ฐ์ดํฐ๋ฅผ ์์งํฉ๋๋ค.- ์: "๋ฐํ ์ ์ฑ ์ ์ด๋ป๊ฒ ๋๋์?" โ "๋ฐํ ์ ์ฐจ๊ฐ ๋ช ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ ธ ์๋์ง ๊ฐ ๋จ๊ณ๋ณ๋ก ์์ธํ ์ค๋ช ํด์ฃผ์ธ์."
- ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ:
์์ง๋ ๋ฐ์ดํฐ์์ ๋ถํ์ํ๊ฑฐ๋ ๋ ธ์ด์ฆ๊ฐ ๋ง์ ๋ถ๋ถ(์๋ฏธ ์๋ ๋ํ, ์ง๋์น๊ฒ ์งง์ ๋ต๋ณ, hallucination ๋ฑ)์ ์ ๊ฑฐํ๊ณ , ์ง๋ฌธ-๋ต๋ณ์ ๊ท ํ์ ๋ง์ถฅ๋๋ค. - Student ๋ชจ๋ธ ํ์ต:
์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํด Student ๋ชจ๋ธ์ ํ์ต์ํต๋๋ค.
3. Logit-based KD ์์ (PyTorch ์ฝ๋)
์๋๋ Teacher ๋ชจ๋ธ๊ณผ Student ๋ชจ๋ธ์ ์ ์ํ๊ณ , Teacher์ ์ง์์ ๊ธฐ๋ฐ์ผ๋ก Student ๋ชจ๋ธ์ ํ์ต์ํค๋ ๊ฐ๋จํ ์์ ์ฝ๋์ ๋๋ค.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Teacher ๋ชจ๋ธ ์ ์ (์์)
class Teacher(nn.Module):
def __init__(self, num_classes=10):
super(Teacher, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x) # Logits ์ถ๋ ฅ
teacher = Teacher(num_classes=10)
print("Teacher ํ๋ผ๋ฏธํฐ ์:", sum(p.numel() for p in teacher.parameters()))
# Student ๋ชจ๋ธ ์ ์ (๋ ์์ ๋คํธ์ํฌ)
class Student(nn.Module):
def __init__(self, num_classes=10):
super(Student, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
student = Student(num_classes=10)
print("Student ํ๋ผ๋ฏธํฐ ์:", sum(p.numel() for p in student.parameters()))
# Teacher ๋ชจ๋ธ ํ์ต (์ผ๋ฐ์ ์ธ Cross-Entropy Loss ์ฌ์ฉ)
def train_teacher(model, train_data, optimizer, epochs=10):
model.train()
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for x, y in train_data:
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
# Teacher ๋ชจ๋ธ ํ๊ฐ
def test_model(model, test_data):
model.eval()
total, correct = 0, 0
with torch.no_grad():
for x, y in test_data:
logits = model(x)
preds = logits.argmax(dim=1)
total += y.size(0)
correct += (preds == y).sum().item()
print("Accuracy: {:.2f}%".format(100 * correct / total))
# Teacher ๋ชจ๋ธ ํ์ต ๋ฐ ํ๊ฐ
optimizer_teacher = torch.optim.Adam(teacher.parameters(), lr=1e-3)
train_teacher(teacher, train_data, optimizer_teacher, epochs=10)
test_model(teacher, test_data)
# Student ๋ชจ๋ธ ํ์ต (Hard label + Soft label with KD)
def train_student(teacher, student, train_data, optimizer, epochs=10, T=2.0, alpha=0.5):
teacher.eval() # Teacher๋ ๊ณ ์
criterion = nn.CrossEntropyLoss()
kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
student.train()
for epoch in range(epochs):
for x, y in train_data:
optimizer.zero_grad()
# Teacher ์์ธก (soft target)
with torch.no_grad():
teacher_logits = teacher(x)
soft_target = F.log_softmax(teacher_logits / T, dim=1)
# Student ์์ธก
student_logits = student(x)
# Hard label loss (Cross-Entropy)
loss_hard = criterion(student_logits, y)
# Soft label loss (KL divergence)
loss_soft = kl_loss_fn(F.log_softmax(student_logits / T, dim=1), soft_target)
# Total loss: alpha ์กฐํฉ
loss = alpha * loss_hard + (1 - alpha) * (T * T) * loss_soft
loss.backward()
optimizer.step()
# Student ๋ชจ๋ธ ํ๊ฐ
optimizer_student = torch.optim.Adam(student.parameters(), lr=1e-3)
train_student(teacher, student, train_data, optimizer_student, epochs=10, T=2.0, alpha=0.5)
test_model(student, test_data)
์ค๋ช :
- Teacher ๋ชจ๋ธ์ ๋ ํฐ ๋คํธ์ํฌ๋ก, Cross-Entropy Loss๋ฅผ ์ฌ์ฉํด ํ์ต๋ฉ๋๋ค.
- Student ๋ชจ๋ธ์ Teacher๋ณด๋ค ์์ ๋คํธ์ํฌ๋ก, ํ์ต ์ hard label (์๋ ์ ๋ต)๊ณผ soft label (Teacher์ logit ๋ถํฌ)์ ๋ชจ๋ ์ฌ์ฉํฉ๋๋ค.
- Temperature TT์ ํผํฉ ๊ณ์ alphaalpha๋ฅผ ํตํด ๋ loss ๊ฐ์ ๊ฐ์ค์น๋ฅผ ์กฐ์ ํฉ๋๋ค.
๊ฒฐ๋ก
Knowledge Distillation์ ๊ณ ์ฑ๋ฅ Teacher ๋ชจ๋ธ์ ์ ๋ณด๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์์ถํ์ฌ, ๊ฒฝ๋ํ๋ Student ๋ชจ๋ธ์ ํ์ต์ํค๋ ๊ธฐ๋ฒ์ ๋๋ค.
- Response-based KD (Logit-based KD): Teacher์ ์ถ๋ ฅ ํ๋ฅ ๋ถํฌ๋ฅผ Student๊ฐ ๋ชจ๋ฐฉํ๋๋ก ํ์ฌ, ํด๋์ค ๊ฐ์ ์ ์ฌ๋ ์ ๋ณด๋ฅผ ํจ๊ป ํ์ตํฉ๋๋ค.
- Feature-based KD: ์ค๊ฐ ๋ ์ด์ด์ feature๋ฅผ ๋ชจ๋ฐฉํ๋ ๋ฐฉ๋ฒ์ผ๋ก, ๋ณดํต regressor๋ฅผ ํตํด ์ฐจ์ ์กฐ์ ํ MSE loss๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- Transparency ๊ด์ : Teacher ๋ชจ๋ธ์ ๋ด๋ถ ๊ตฌ์กฐ๋ฅผ ์ผ๋ง๋ ์ด๋ํ ์ ์๋์ง์ ๋ฐ๋ผ white-box, gray-box, black-box KD๋ก ๊ตฌ๋ถํ ์ ์์ต๋๋ค.
๋ชจ๋ธ ๊ฒฝ๋ํ๋ฅผ ์ํ Quantization(์์ํ) ๊ธฐ๋ฒ ์ดํดํ๊ธฐ
๊ฑฐ๋ AI ๋ชจ๋ธ์ ๊ณ ์ ๋ฐ FP32 ํ์์ ๊ฐ์ค์น์ ํ์ฑํ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๊ณผ ์ฐ์ฐ ๋น์ฉ์ด ๋งค์ฐ ํฝ๋๋ค. Quantization(์์ํ)์ ์ด๋ฌํ ๋ชจ๋ธ์ ๊ฐ์ค์น์ ํ์ฑํ๋ฅผ ๋ฎ์ ๋นํธ ์ ๋ฐ๋๋ก ๋ณํํ์ฌ ์ ์ฅ ๋ฐ ๊ณ์ฐ ํจ์จ์ฑ์ ๋์ด๋ ๊ธฐ๋ฒ์ ๋๋ค. ์๋ฅผ ๋ค์ด, ฯ๋ฅผ 3.141592์ฒ๋ผ ๊ณ ์ ๋ฐํ๊ฒ ํํํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ํฌ์ง๋ง, 3๊ณผ ๊ฐ์ด ๋ฎ์ ์ ๋ฐ๋๋ก ํํํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ์ค์ด๋ค์ง๋ง ์ค์ฐจ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค. Quantization์ ํต์ฌ์ ์ค์ฐจ๋ฅผ ์ต์ํํ๋ฉด์๋ ํจ์จ์ ์ธ ๋ฎ์ ์ ๋ฐ๋๋ฅผ ์ฐพ๋ ๊ฒ์ ๋๋ค.
Quantization Mapping
Quantization์ ๋์ ์ ๋ฐ๋์ ๊ฐ์ ๋ฎ์ ๋นํธ ์ ๋ฐ๋์ ๊ฐ์ ๋งคํํ๋ ๊ณผ์ ์ ํฌํจํฉ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ธ์ ๋ฐ์ดํฐ๋ ํ๋ผ๋ฏธํฐ๋ FP32 ํ์์ผ๋ก ํํ๋๋๋ฐ, ์ด๋ฅผ FP16, INT8 ๋ฑ์ผ๋ก ๋ณํํฉ๋๋ค. ๋ค๋ง, ๊ฐ ์๋ฃํ๋ง๋ค ํํ ๊ฐ๋ฅํ ๊ฐ์ ๋ฒ์๊ฐ ๋ค๋ฆ
๋๋ค.
ํ์ | ํํ ๊ฐ๋ฅ ์ต์ | ํํ ๊ฐ๋ฅ ์ต๋ |
---|---|---|
INT8 | -128 | 127 |
INT16 | -32768 | 32767 |
INT32 | -2147483648 | 2147483647 |
FP16 | -65504 | 65504 |
FP32 | -3.4028235 ร 10^38 | 3.4028235 ร 10^38 |
์๋ฅผ ๋ค์ด, FP32๋ก ํํ๋ 350.5๋ INT8์ ์ต๋ ํํ ๋ฒ์(127)๋ฅผ ์ด๊ณผํ๊ธฐ ๋๋ฌธ์ ์ง์ ์ ์ธ ๋ณํ์ด ๋ถ๊ฐ๋ฅํฉ๋๋ค.
Quantization Mapping์ ๋ณดํต ์๋์ ๊ฐ์ ์์์ ์ฌ์ฉํ์ฌ ์งํ๋ฉ๋๋ค.
: ์๋ณธ ๊ฐ : ์์ํ๋ ๊ฐ : scale factor (๊ธฐ์ธ๊ธฐ) : zero-point (์์ํ ํ 0์ด ๋งคํ๋๋ ์์น)
์์ํ ์, ์ด
Quantization ๊ธฐ๋ฒ์ ์ข ๋ฅ
1. Absmax Quantization
-
๊ฐ๋ :
๋ฐ์ดํฐ ๋ถํฌ์ ์ ๋๊ฐ ์ต๋์น์ ๊ธฐ๋ฐํด scale factor๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.์ด ๋ฐฉ์์์๋
๋ก ๊ณ ์ ๋ฉ๋๋ค. -
์ ์ฉ:
๋ฐ์ดํฐ ๋ถํฌ๊ฐ ๋์นญ์ ์ด๊ฑฐ๋ ํ๊ท ์ด 0์ธ ๊ฒฝ์ฐ(์: tanh ํจ์์ ์ถ๋ ฅ)์ ํจ๊ณผ์ ์ ๋๋ค. -
๋จ์ :
๊ทน๋จ์ ์ธ ๊ฐ(Outlier)์ ๋ฏผ๊ฐํ์ฌ, ์ด๋ฌํ ๊ฐ๋ค์ด scale factor์ ํฐ ์ํฅ์ ์ค ์ ์์ต๋๋ค.
2. Zero-point Quantization
-
๊ฐ๋ :
๋ฐ์ดํฐ ๋ถํฌ๊ฐ ๋น๋์นญ์ ์ด๊ฑฐ๋ ํ๊ท ์ด 0์ด ์๋ ๊ฒฝ์ฐ์ ์ฌ์ฉํฉ๋๋ค.
์ ์ฒด ๋ฒ์๊ฐ ์ผ์ ํ๊ฒ ๋งคํ๋๋๋ก scale factor์ zero-point๋ฅผ ๊ณ์ฐํฉ๋๋ค.
-
์ ์ฉ:
์ฃผ๋ก ReLU์ ๊ฐ์ด ์ถ๋ ฅ์ด 0 ์ด์์ธ ๊ฒฝ์ฐ์ ์ ๋ฆฌํฉ๋๋ค. -
๋จ์ :
๊ธฐ์ค์ (z)์ด ๋น์ ์์ ์ผ๋ก ์ค์ ๋๋ฉด ์ฑ๋ฅ ์ ํ๋ก ์ด์ด์ง ์ ์์ต๋๋ค.
Clipping๊ณผ Calibration
๋ฐ์ดํฐ์ outlier๊ฐ ์กด์ฌํ๋ฉด, ์์ํ mapping์ด ํจ๊ณผ์ ์ด์ง ์์ ์ ์์ต๋๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด clipping ๊ธฐ๋ฒ์ ์ฌ์ฉํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ฐ์ดํฐ ๊ฐ์ ๋ฒ์ฃผ๋ฅผ [โ5,5][-5, 5]๋ก ์ ํํ๊ณ , ์ด ๋ฒ์๋ฅผ ๋์ด์๋ ๊ฐ์ ๋ชจ๋ ๊ฐ์ ๊ฐ์ผ๋ก ์ทจ๊ธํฉ๋๋ค. ์ด ๊ณผ์ ์์ ์ ์ ํ ๋ฒ์ฃผ๋ฅผ ์ ํํ๋ ๊ฒ์ calibration์ด๋ผ๊ณ ํฉ๋๋ค.
๊ฒฐ๋ก
Quantization์ ๋ชจ๋ธ์ ๊ฐ์ค์น์ ํ์ฑํ๋ฅผ ๋ฎ์ ๋นํธ ์ ๋ฐ๋๋ก ๋ณํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๊ณ ์ฐ์ฐ ํจ์จ์ฑ์ ๋์ด๋ ์ค์ํ ๋ชจ๋ธ ๊ฒฝ๋ํ ๊ธฐ๋ฒ์ ๋๋ค.
- Absmax Quantization์ ๋ฐ์ดํฐ ๋ถํฌ๊ฐ ๋์นญ์ ์ผ ๋ ์ ๋ฆฌํ๋ฉฐ,
- Zero-point Quantization์ ๋น๋์นญ์ ์ธ ๋ถํฌ์ ํจ๊ณผ์ ์
๋๋ค.
๋ํ, clipping๊ณผ calibration์ ํตํด outlier์ ์ํฅ์ ์ต์ํํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค.
Parameter-Efficient Fine-Tuning (PEFT): ๊ฑฐ๋ ๋ชจ๋ธ์ ํจ์จ์ ๋ฏธ์ธ์กฐ์
๊ฑฐ๋ AI ๋ชจ๋ธ์ ํ์ต์ ์๋ง์ ํ๋ผ๋ฏธํฐ๋ก ์ธํด ๋ง๋ํ ์์๊ณผ ์๊ฐ์ด ์์๋ฉ๋๋ค. ์ด๋ฌํ ์ํฉ์์ ์ ์ฒด ๋ชจ๋ธ์ ์ฌํ์ตํ๋ ๋์ , ๋ชจ๋ธ์ ์ผ๋ถ๋ถ๋ง ๋ฏธ์ธ์กฐ์ ํ์ฌ ํจ์จ์ฑ์ ๊ทน๋ํํ๋ Parameter-Efficient Fine-Tuning (PEFT) ๊ธฐ๋ฒ์ด ์ฃผ๋ชฉ๋ฐ๊ณ ์์ต๋๋ค. PEFT๋ ์ ์ด ํ์ต(Transfer Learning)์ Fine-tuning ๋จ๊ณ์์ ์ ์ฒด ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ์ง ์๊ณ , ์ผ๋ถ ํ๋ผ๋ฏธํฐ๋ง ํ์ต์์ผ ๋น ๋ฅด๊ณ ๋น์ฉ ํจ์จ์ ์ธ ๋ชจ๋ธ ๊ฐ์ ์ ๋ชฉํ๋ก ํฉ๋๋ค.
1. ์ ์ด ํ์ต๊ณผ PEFT
- ์ ์ด ํ์ต(Transfer Learning):
์ด๋ฏธ ํ์ต๋ ๋ชจ๋ธ(Pre-trained Model)์ ์๋ก์ด ์์ ์ ์์์ ์ผ๋ก ํ์ฉํ๋ ๋ฐฉ๋ฒ์ผ๋ก,- Pre-training: ๋ฐฉ๋ํ ์์ ๋ฐ์ดํฐ๋ก ๋ชจ๋ธ์ ์ฌ์ ํ์ต
- Fine-tuning: ํน์ ์์ ์ ๋ง์ถฐ ๋ชจ๋ธ์ ์ฌํ์ต
- Fine-tuning์ ํ์์ฑ:
์ฌ์ ํ์ต ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ธ ๋ฌธ์ ํด๊ฒฐ ๋ฅ๋ ฅ์ด ๋ถ์กฑํ๊ธฐ ๋๋ฌธ์, ์๋ก์ด ํ์คํฌ์ ๋ง์ถฐ ์ฌํ์ตํ๋ ๊ณผ์ ์ด ํ์ํฉ๋๋ค.
PEFT๋ Fine-tuning ์ ์ ์ฒด ๋ชจ๋ธ์ ์ ๋ฐ์ดํธํ์ง ์๊ณ ์ผ๋ถ ํ๋ผ๋ฏธํฐ๋ง ํ์ตํ์ฌ ํจ์จ์ฑ์ ๋์ด๋ ์ ๊ทผ๋ฒ์ ๋๋ค.
2. PEFT์ ์ ๊ทผ ๋ฐฉ์
PEFT๋ ํฌ๊ฒ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ๋ก ์ผ๋ก ๋๋ ์ ์์ต๋๋ค.
2-1. Prompt Tuning ๋ฐฉ์
- ๊ฐ๋
:
๋ชจ๋ธ์ ๊ธฐ์กด ํ๋ผ๋ฏธํฐ๋ฅผ ๋ณ๊ฒฝํ์ง ์๊ณ , ์ ๋ ฅ prompt๋ context๋ฅผ ์กฐ์ ํ์ฌ ์ํ๋ ์ถ๋ ฅ์ผ๋ก ์ ๋ํฉ๋๋ค. - ์์:
Prompt Tuning, Prefix Tuning, P-Tuning ๋ฑ
2-2. ํ๋ผ๋ฏธํฐ ์ฝ์ ๋ฐฉ์
- ๊ฐ๋
:
๋ชจ๋ธ์ ํน์ ์์น์ ์ถ๊ฐ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ(๋ชจ๋)๋ฅผ ์ฝ์ ํ์ฌ ๋ฏธ์ธ์กฐ์ ํฉ๋๋ค. - ์์:
Adapter, LoRA, Compacter ๋ฑ
3. ํ๋ผ๋ฏธํฐ ์ฝ์ ๋ฐฉ์ ์์ธ
3-1. Adapter
- ๊ตฌ์กฐ:
- ๊ธฐ์กด ๋ชจ๋ธ์ ๊ฐ ๋ ์ด์ด ์ฌ์ด์ ์์ adapter ๋ชจ๋์ ์ฝ์
- Bottleneck ๊ตฌ์กฐ: ๋จผ์ ๋ค์ด-ํ๋ก์ ์ (์ฐจ์ ์ถ์) โ ๋น์ ํ ํ์ฑํ โ ์ -ํ๋ก์ ์ (์๋ ์ฐจ์ ๋ณต์)
- Skip-connection์ ํตํด ์๋ ์ ๋ ฅ์ ๋ณต์ ํ ํฉ์ฐ
- ํน์ง:
๊ธฐ์กด ํ๋ผ๋ฏธํฐ๋ ๊ทธ๋๋ก ๋๊ณ adapter๋ง ํ์ตํ์ฌ ํ์คํฌ์ ๋ง๊ฒ ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ์ ํํ ์ ์์ต๋๋ค.
3-2. Low-Rank Decomposition
- ๊ฐ๋
:
๊ณ ์ฐจ์ weight matrix๋ฅผ ๋ ๊ฐ์ ์ ์ฐจ์ ํ๋ ฌ(์: mรrm \times r์ rรmr \times m, r<mr < m)์ ๊ณฑ์ผ๋ก ๊ทผ์ฌํ์ฌ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ค์ด๋ ๋ฐฉ๋ฒ - ์ ์ฉ:
์๋ฅผ ๋ค์ด, 300ร300300 \times 300 ํ๋ ฌ์ 300ร10300 \times 10์ 10ร30010 \times 300 ํ๋ ฌ๋ก ๋ถํดํ์ฌ ๊ฒฝ๋ํ
3-3. LoRA (Low-Rank Adaptation)
- ๊ฐ๋
:
์ฌ์ ํ์ต๋ ๋ชจ๋ธ ๊ฐ์ค์น๋ ๊ณ ์ ํ ์ฑ, ๊ฐ ์ธต์ low-rank ๋ถํด๋ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ณ๋ ฌ์ ์ผ๋ก ์ฝ์ ํ์ฌ ํ์ตํฉ๋๋ค. - ํน์ง ๋น๊ต (Adapter vs. LoRA):
ํญ๋ชฉ | Adapter | LoRA |
---|---|---|
์ฐ์ฐ ๋ฐฉ์ | Sequential (์์ฐจ์ ) | Parallel (๋ณ๋ ฌ์ ) |
๋น์ ํ ํจ์ ์ฌ์ฉ ์ฌ๋ถ | ์ฌ์ฉํจ | ์ฌ์ฉํ์ง ์์ |
ํ์ต ํ๋ผ๋ฏธํฐ | Weight ๋ฐ bias | Weight๋ง |
์ฐ์ฐ ์ง์ฐ | ๋ชจ๋ ์ฐ์ฐ์ ๋น๋ก | ๊ฑฐ์ ๋ฐ์ํ์ง ์์ |
LoRA๋ Adapter์ ์ ์ฌํ low-rank ๊ธฐ๋ฒ์ ์ฌ์ฉํ์ง๋ง, ๋ณ๋ ฌ์ ์ผ๋ก ๊ณ์ฐํ์ฌ ์๋ ์ธก๋ฉด์์ ๋ ํจ์จ์ ์ ๋๋ค.
3-4. AdapterFusion
- ๊ฐ๋
:
์ฌ๋ฌ ํ์คํฌ์ ๋ํด ๊ฐ๊ฐ ํ์ตํ adapter ๋ชจ๋๋ค์ ๊ฒฐํฉํ์ฌ ํ๋์ ๋ชจ๋ธ๋ก ๊ตฌ์ฑํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. - ๋์ ๋ฐฉ์:
- Knowledge Extraction: ๊ฐ ํ์คํฌ๋ณ๋ก ๊ฐ๋ณ adapter๋ฅผ ํ์ต
- Knowledge Composition: ์ ๋ ฅ์ ๋ฐ๋ผ ์ฌ๋ฌ adapter์ ์ถ๋ ฅ์ attention ๊ธฐ๋ฐ์ผ๋ก ์ทจํฉํ์ฌ ์ต์ ์ ๊ฒฐ๊ณผ๋ฅผ ์์ฑ
- ์ฅ์ :
๋จ์ผ ๋ชจ๋ธ๋ก ๋ค์ํ ํ์คํฌ๋ฅผ ๋น ๋ฅด๊ฒ ์ ํํ ์ ์์ผ๋ฉฐ, adapter ๋ชจ๋๋ง ๋ณ๊ฒฝํ๋ฉด ๋๋ฏ๋ก ํจ์จ์ ์ ๋๋ค.
3-5. QLoRA
- ๊ฐ๋
:
LoRA์ Quantization ๊ธฐ๋ฒ์ ์ถ๊ฐ ์ ์ฉํ์ฌ, ๋ฉ๋ชจ๋ฆฌ์ ์ฐ์ฐ ํจ์จ์ฑ์ ๊ทน๋ํํ ๋ฐฉ๋ฒ์ ๋๋ค. - ๋ฐฉ๋ฒ:
- ์ฌ์ ํ์ต ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ 16-bit์์ 4-bit Normal-Float Quantization์ผ๋ก ๋ณํ
- ์์ํ ์์(Scale๊ณผ Zero-point)๋ ๋์ ์ ๋ฐ๋๋ก ์ ์ฅํ์ฌ, double quantization์ ์ํ
4. In Practice: LoRA ์ ์ฉ ์ฌ๋ก
์๋ฅผ ๋ค์ด, BERT_base ๋ชจ๋ธ์ LoRA๋ฅผ ์ ์ฉํ์ฌ SQuAD ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ ๊ฒฝ์ฐ๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
# ๊ธฐ์กด ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ๊ณ ์ (Fine-tuning ์ ๊ณ ์ )
for param in model.parameters():
param.requires_grad = False
# LoRA ์ค์ (์์)
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=2, # Low-rank factor (์: 1-4 ์ ๋ ์ฌ์ฉ)
lora_alpha=16, # Low-rank matrix์ ๊ณฑํด์ง๋ ๊ฐ์ผ๋ก, ์๋ ด์ ๋์ต๋๋ค.
target_modules=["query", "value"], # ์ ์ฉํ ๋ชจ๋ ์ง์
lora_dropout=0.1 # Dropout์ผ๋ก ๊ณผ์ ํฉ ๋ฐฉ์ง
)
# PEFT ๋ชจ๋ธ ์์ฑ
model_with_lora = get_peft_model(model, lora_config)
- ์ค๋ช
:
- ๊ธฐ์กด BERT_base ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ ๊ณ ์ ํ๊ณ , ํน์ ๋ชจ๋(query, value)์ ๋ํด์๋ง low-rank ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฝ์ ํฉ๋๋ค.
- LoRA ์ค์ ์ ๋ฐ๋ผ, ์ถ๊ฐ๋๋ ํ๋ผ๋ฏธํฐ๋ ํจ์ฌ ์์ผ๋ฉด์๋ ์ถฉ๋ถํ ์ฑ๋ฅ์ ์ ์งํ ์ ์์ต๋๋ค.
๊ฒฐ๋ก
PEFT ๊ธฐ๋ฒ์ ๊ฑฐ๋ ๋ชจ๋ธ์ ์ ์ฒด ์ฌํ์ต ์์ด๋, ๋ชจ๋ธ์ ์ผ๋ถ๋ง ๋ฏธ์ธ์กฐ์ ํ์ฌ ํ์ต ๋น์ฉ๊ณผ ์๊ฐ์ ํฌ๊ฒ ์ค์ผ ์ ์๋ ํจ์จ์ ์ธ ๋ฐฉ๋ฒ์ ๋๋ค.
- Prompt Tuning์ ์ ๋ ฅ๋ง ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ด๋ฉฐ,
- ํ๋ผ๋ฏธํฐ ์ฝ์ ๋ฐฉ์์ Adapter, LoRA, AdapterFusion, QLoRA ๋ฑ ๋ค์ํ ๊ธฐ๋ฒ์ผ๋ก ๊ตฌ๋ถ๋ฉ๋๋ค.