Week 21 ํ•™์Šต ์ •๋ฆฌ

๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™”: ์ž์› ํšจ์œจ์„ฑ์„ ๋†’์ด๋Š” AI ์ตœ์ ํ™” ๊ธฐ์ˆ 

๊ฑฐ๋Œ€ AI ๋ชจ๋ธ์€ ์ˆ˜๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์–ด, ํ•™์Šต ์‹œ ๋งŽ์€ GPU, ์ „๋ ฅ, ์‹œ๊ฐ„์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋Œ€๋ถ€๋ถ„์˜ ํ™˜๊ฒฝ์—์„œ๋Š” ์ด๋Ÿฌํ•œ ์ž์›๊ณผ ์‹œ๊ฐ„์„ ์ถฉ๋ถ„ํžˆ ํ™•๋ณดํ•˜๊ธฐ ์–ด๋ ต๊ธฐ ๋•Œ๋ฌธ์—, ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•˜๋ฉด์„œ๋„ ํฌ๊ธฐ์™€ ๊ณ„์‚ฐ ๋น„์šฉ์„ ์ค„์ด๋Š” ๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™” ๊ธฐ์ˆ ์ด ์ค‘์š”ํ•ด์กŒ์Šต๋‹ˆ๋‹ค. ๊ฒฝ๋Ÿ‰ํ™”๋œ ๋ชจ๋ธ์€ ์ถ”๋ก  ์‹œ๊ฐ„๋„ ๋‹จ์ถ•๋˜์–ด ์ž์œจ์ฃผํ–‰๊ณผ ๊ฐ™์ด ์‹ค์‹œ๊ฐ„ ์ฒ˜๋ฆฌ๊ฐ€ ์š”๊ตฌ๋˜๋Š” ํƒœ์Šคํฌ์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.


๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™” ์ฃผ์š” ๊ธฐ๋ฒ•

Pasted image 20250311150951.png

1. Pruning (๊ฐ€์ง€์น˜๊ธฐ)

2. Knowledge Distillation (์ง€์‹ ์ฆ๋ฅ˜)

3. Quantization (์–‘์žํ™”)


๊ฒฐ๋ก 

๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™” ๊ธฐ์ˆ ์€ ์ž์›๊ณผ ์‹œ๊ฐ„์ด ์ œํ•œ๋œ ํ™˜๊ฒฝ์—์„œ ๊ณ ์„ฑ๋Šฅ AI ๋ชจ๋ธ์„ ์‹ค์ œ ์„œ๋น„์Šค์— ์ ์šฉํ•˜๊ธฐ ์œ„ํ•œ ํ•„์ˆ˜ ์ „๋žต์ž…๋‹ˆ๋‹ค. Pruning, Knowledge Distillation, Quantization๊ณผ ๊ฐ™์€ ๊ธฐ๋ฒ•์„ ํ†ตํ•ด ๋ชจ๋ธ์˜ ํฌ๊ธฐ๋ฅผ ์ค„์ด๋ฉด์„œ๋„, ์ตœ์ข… ์„ฑ๋Šฅ์€ ์ตœ๋Œ€ํ•œ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์‹ค์‹œ๊ฐ„ ์ฒ˜๋ฆฌ๊ฐ€ ์š”๊ตฌ๋˜๋Š” ๋‹ค์–‘ํ•œ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜, ์˜ˆ๋ฅผ ๋“ค์–ด ์ž์œจ์ฃผํ–‰, ๋ชจ๋ฐ”์ผ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๋“ฑ์— ํšจ๊ณผ์ ์œผ๋กœ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™”๋ฅผ ์œ„ํ•œ Pruning ๊ธฐ์ˆ 

๊ฑฐ๋Œ€ AI ๋ชจ๋ธ์€ ์ˆ˜๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์–ด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰๊ณผ ์—ฐ์‚ฐ ๋น„์šฉ์ด ํฝ๋‹ˆ๋‹ค. ์ด๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•ด ์ค‘์š”๋„๊ฐ€ ๋‚ฎ์€ ๋‰ด๋Ÿฐ์ด๋‚˜ ์—ฐ๊ฒฐ(์‹œ๋ƒ…์Šค)์„ ์ œ๊ฑฐํ•˜๋Š” pruning ๊ธฐ๋ฒ•์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. Pruning์„ ํ†ตํ•ด ๋ชจ๋ธ์˜ ํฌ๊ธฐ๋ฅผ ์ค„์ด๊ณ , ๊ณ„์‚ฐ ์†๋„๋ฅผ ๋†’์ด๋ฉฐ, ์ถ”๋ก  ์‹œ๊ฐ„๋„ ๋‹จ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Pruning ๊ธฐ๋ฒ•์€ ํฌ๊ฒŒ ๊ตฌ์กฐ(structure), ์Šค์ฝ”์–ด๋ง(scoring), ์Šค์ผ€์ค„๋ง(scheduling), ์ดˆ๊ธฐํ™”(initialization) ๋„ค ๊ฐ€์ง€ ๊ด€์ ์—์„œ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


1. Pruning ๊ธฐ๋ฒ•์˜ ๋ถ„๋ฅ˜

1-1. Structure: ๋ชจ๋ธ ๊ตฌ์กฐ ๋ณ€ํ™” ์—ฌ๋ถ€

Pasted image 20250311151120.png

๋ฐฉ๋ฒ• ๋‹จ์œ„ ๊ตฌ์กฐ ๋ณ€๊ฒฝ ์—ฌ๋ถ€ ์žฅ์  ๋‹จ์ 
Unstructured ๊ฐœ๋ณ„ ํŒŒ๋ผ๋ฏธํ„ฐ ์—†์Œ ๊ตฌํ˜„์ด ์‰ฌ์›€ ์—ฐ์‚ฐ ์†๋„ ํ–ฅ์ƒ ๋ฏธํก
Structured ๋‰ด๋Ÿฐ/์ฑ„๋„/๋ ˆ์ด์–ด ์žˆ์Œ ์—ฐ์‚ฐ ์†๋„ ํ–ฅ์ƒ ๊ฐ€๋Šฅ ๊ตฌํ˜„์ด ์–ด๋ ต๊ฑฐ๋‚˜ ์ œํ•œ์ ์ผ ์ˆ˜ ์žˆ์Œ

1-2. Scoring: ๊ฐ€์ง€์น˜๊ธฐํ•  ํŒŒ๋ผ๋ฏธํ„ฐ ์„ ์ •

์ค‘์š”๋„ ๊ณ„์‚ฐ ๋ฐฉ๋ฒ•

์ค‘์š”๋„๋ฅผ ๋ฐ˜์˜ํ•˜๋Š” ๋‹จ์œ„

1-3. Scheduling: ๊ฐ€์ง€์น˜๊ธฐ ์ง„ํ–‰ ๋ฐฉ์‹

1-4. Initialization: Fine-tuning ์‹œ์ž‘์ 

๊ฐ€์ง€์น˜๊ธฐ ํ›„ ๋ชจ๋ธ์„ ์žฌํ•™์Šตํ•  ๋•Œ, ์ดˆ๊ธฐ ์ƒํƒœ์— ๋”ฐ๋ผ ๋‘ ๊ฐ€์ง€ ๋ฐฉ์‹์ด ์žˆ์Šต๋‹ˆ๋‹ค.

Iterative Magnitude Pruning (IMP):
๊ฐ€์žฅ ๊ธฐ๋ณธ์ ์ธ pruning ๋ฐฉ๋ฒ•์œผ๋กœ,

  1. unstructured ๋ฐฉ์‹,
  2. global (ํŒŒ๋ผ๋ฏธํ„ฐ๋ณ„ ์ ˆ๋Œ€๊ฐ’ ๊ธฐ๋ฐ˜),
  3. recursive (iterative) ๋ฐฉ์‹,
  4. rewinding์„ ๊ฒฐํ•ฉํ•˜์—ฌ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

2. ์ถ”๊ฐ€ ๊ณ ๋ ค ์‚ฌํ•ญ

2-1. Matrix Sparsity ๋ฌธ์ œ

Pasted image 20250311151234.png

2-2. Sensitivity (๋ฏผ๊ฐ๋„)


3. In Practice: CNN๊ณผ BERT์—์„œ์˜ Pruning

3-1. CNN์—์„œ์˜ Pruning

3-2. BERT์—์„œ์˜ Pruning

Pasted image 20250311151254.png


4. ์˜ˆ์ œ ์ฝ”๋“œ (PyTorch ๊ธฐ๋ฐ˜)

๋‹ค์Œ์€ ๊ฐ„๋‹จํ•œ 2-layer MLP ๋ชจ๋ธ์— ๋Œ€ํ•ด unstructured pruning์„ ์ ์šฉํ•˜๋Š” ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# 2-layer MLP ๋ชจ๋ธ ์ •์˜
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# ๋ชจ๋ธ ์„ ์–ธ ๋ฐ ํ•™์Šต/ํ…Œ์ŠคํŠธ ํ•จ์ˆ˜ ์ •์˜ (train, test๋Š” ๋ณ„๋„ ๊ตฌํ˜„)
model = Model()

# ํ•™์Šต ๋ฐ ํ…Œ์ŠคํŠธ (pruning ์ „)
train(model, train_data)
test(model, test_data)

# ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ํ™•์ธ
total_params = sum(p.numel() for p in model.parameters())
total_params_nz = sum((p != 0.0).sum().item() for p in model.parameters())
print("์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ:", total_params)
print("0์ด ์•„๋‹Œ ํŒŒ๋ผ๋ฏธํ„ฐ:", total_params_nz)

# Pruning ๋Œ€์ƒ ๋ ˆ์ด์–ด ์„ ํƒ (์˜ˆ: ์ฒซ ๋ฒˆ์งธ Fully Connected ๋ ˆ์ด์–ด)
layer = model.fc1

# 1) ๋žœ๋ค์œผ๋กœ ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด์—์„œ 50%๋ฅผ ์ œ๊ฑฐ (unstructured)
prune.random_unstructured(layer, name='weight', amount=0.5)

# 2) ํŒŒ๋ผ๋ฏธํ„ฐ ์ ˆ๋Œ“๊ฐ’ ๊ธฐ์ค€ ํ•˜์œ„ 50%๋ฅผ ์ œ๊ฑฐ (L1-based pruning)
prune.l1_unstructured(layer, name='weight', amount=0.5)

# pruning ํ›„ fine-tuning์„ ์œ„ํ•ด ๋‹ค์‹œ ํ•™์Šต
train(model, train_data)

# pruning ๋ฐ fine-tuning ํ›„ ํ…Œ์ŠคํŠธ
test(model, test_data)

์„ค๋ช…:


๊ฒฐ๋ก 

Pruning์€ ๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™”์˜ ํ•ต์‹ฌ ๊ธฐ๋ฒ• ์ค‘ ํ•˜๋‚˜๋กœ,

์ถ”๊ฐ€๋กœ, matrix sparsity, ์ „์šฉ ํ•˜๋“œ์›จ์–ด ํ™œ์šฉ, ๊ทธ๋ฆฌ๊ณ  ๊ฐ ๋ ˆ์ด์–ด์˜ sensitivity ๋“ฑ์„ ๊ณ ๋ คํ•˜์—ฌ ์ตœ์ ์˜ pruning ์ „๋žต์„ ์„ค๊ณ„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. CNN์ด๋‚˜ BERT์™€ ๊ฐ™์€ ๋ชจ๋ธ์— pruning์„ ์ ์šฉํ•˜๋Š” ์‹ค์ œ ์‚ฌ๋ก€๋ฅผ ํ†ตํ•ด, ๋ชจ๋ธ์˜ ํฌ๊ธฐ์™€ ์—ฐ์‚ฐ ๋น„์šฉ์„ ํšจ๊ณผ์ ์œผ๋กœ ์ค„์ด๋ฉด์„œ๋„ ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


Knowledge Distillation (์ง€์‹ ์ฆ๋ฅ˜)

Knowledge Distillation(KD)์€ ๊ณ ์„ฑ๋Šฅ์˜ Teacher ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ์ง€์‹์„ ์ „๋‹ฌ๋ฐ›์•„, ์ƒ๋Œ€์ ์œผ๋กœ ๊ฒฝ๋Ÿ‰ํ™”๋œ Student ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
Teacher ๋ชจ๋ธ์€ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ๋งŽ์•„ ์„ฑ๋Šฅ์€ ์ข‹์ง€๋งŒ ์—ฐ์‚ฐ ๋น„์šฉ๊ณผ ์ถ”๋ก  ์†๋„๊ฐ€ ๋Š๋ฆฐ ๋ฐ˜๋ฉด, Student ๋ชจ๋ธ์€ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ์ ์–ด ๋น ๋ฅธ ์ถ”๋ก  ์†๋„๋ฅผ ์ž๋ž‘ํ•ฉ๋‹ˆ๋‹ค. KD๋ฅผ ํ†ตํ•ด ์„ฑ๋Šฅ์€ ์ตœ๋Œ€ํ•œ ์œ ์ง€ํ•˜๋ฉด์„œ๋„ ๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™”์™€ ์—ฐ์‚ฐ ํšจ์œจ์„ฑ์„ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


1. KD ๊ธฐ๋ฒ•์˜ ๋ถ„๋ฅ˜

Pasted image 20250311151531.png

1-1. Knowledge ๊ด€์ 

Response-based KD

Feature-based KD

1-2. Transparency ๊ด€์ 


2. KD ์ ์šฉ ๋‹จ๊ณ„

  1. ๋ชจ๋ฐฉ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘:
    Teacher ๋ชจ๋ธ์— ํŠน์ • ์งˆ๋ฌธ(seed ์งˆ๋ฌธ)์„ ์ž…๋ ฅํ•˜์—ฌ ์‘๋‹ต ๋ฐ์ดํ„ฐ๋ฅผ ์ˆ˜์ง‘ํ•ฉ๋‹ˆ๋‹ค.
    • ์˜ˆ: "๋ฐ˜ํ’ˆ ์ •์ฑ…์€ ์–ด๋–ป๊ฒŒ ๋˜๋‚˜์š”?" โ†’ "๋ฐ˜ํ’ˆ ์ ˆ์ฐจ๊ฐ€ ๋ช‡ ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ๋Š”์ง€ ๊ฐ ๋‹จ๊ณ„๋ณ„๋กœ ์ž์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."
  2. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ:
    ์ˆ˜์ง‘๋œ ๋ฐ์ดํ„ฐ์—์„œ ๋ถˆํ•„์š”ํ•˜๊ฑฐ๋‚˜ ๋…ธ์ด์ฆˆ๊ฐ€ ๋งŽ์€ ๋ถ€๋ถ„(์˜๋ฏธ ์—†๋Š” ๋Œ€ํ™”, ์ง€๋‚˜์น˜๊ฒŒ ์งง์€ ๋‹ต๋ณ€, hallucination ๋“ฑ)์„ ์ œ๊ฑฐํ•˜๊ณ , ์งˆ๋ฌธ-๋‹ต๋ณ€์˜ ๊ท ํ˜•์„ ๋งž์ถฅ๋‹ˆ๋‹ค.
  3. Student ๋ชจ๋ธ ํ•™์Šต:
    ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•ด Student ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค.

3. Logit-based KD ์˜ˆ์ œ (PyTorch ์ฝ”๋“œ)

์•„๋ž˜๋Š” Teacher ๋ชจ๋ธ๊ณผ Student ๋ชจ๋ธ์„ ์ •์˜ํ•˜๊ณ , Teacher์˜ ์ง€์‹์„ ๊ธฐ๋ฐ˜์œผ๋กœ Student ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๊ฐ„๋‹จํ•œ ์˜ˆ์ œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.

import torch
import torch.nn as nn
import torch.nn.functional as F

# Teacher ๋ชจ๋ธ ์ •์˜ (์˜ˆ์‹œ)
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # Logits ์ถœ๋ ฅ

teacher = Teacher(num_classes=10)
print("Teacher ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜:", sum(p.numel() for p in teacher.parameters()))

# Student ๋ชจ๋ธ ์ •์˜ (๋” ์–•์€ ๋„คํŠธ์›Œํฌ)
class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

student = Student(num_classes=10)
print("Student ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜:", sum(p.numel() for p in student.parameters()))

# Teacher ๋ชจ๋ธ ํ•™์Šต (์ผ๋ฐ˜์ ์ธ Cross-Entropy Loss ์‚ฌ์šฉ)
def train_teacher(model, train_data, optimizer, epochs=10):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for x, y in train_data:
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

# Teacher ๋ชจ๋ธ ํ‰๊ฐ€
def test_model(model, test_data):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in test_data:
            logits = model(x)
            preds = logits.argmax(dim=1)
            total += y.size(0)
            correct += (preds == y).sum().item()
    print("Accuracy: {:.2f}%".format(100 * correct / total))

# Teacher ๋ชจ๋ธ ํ•™์Šต ๋ฐ ํ‰๊ฐ€
optimizer_teacher = torch.optim.Adam(teacher.parameters(), lr=1e-3)
train_teacher(teacher, train_data, optimizer_teacher, epochs=10)
test_model(teacher, test_data)

# Student ๋ชจ๋ธ ํ•™์Šต (Hard label + Soft label with KD)
def train_student(teacher, student, train_data, optimizer, epochs=10, T=2.0, alpha=0.5):
    teacher.eval()  # Teacher๋Š” ๊ณ ์ •
    criterion = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
    
    student.train()
    for epoch in range(epochs):
        for x, y in train_data:
            optimizer.zero_grad()
            # Teacher ์˜ˆ์ธก (soft target)
            with torch.no_grad():
                teacher_logits = teacher(x)
                soft_target = F.log_softmax(teacher_logits / T, dim=1)
            # Student ์˜ˆ์ธก
            student_logits = student(x)
            # Hard label loss (Cross-Entropy)
            loss_hard = criterion(student_logits, y)
            # Soft label loss (KL divergence)
            loss_soft = kl_loss_fn(F.log_softmax(student_logits / T, dim=1), soft_target)
            # Total loss: alpha ์กฐํ•ฉ
            loss = alpha * loss_hard + (1 - alpha) * (T * T) * loss_soft
            loss.backward()
            optimizer.step()

# Student ๋ชจ๋ธ ํ‰๊ฐ€
optimizer_student = torch.optim.Adam(student.parameters(), lr=1e-3)
train_student(teacher, student, train_data, optimizer_student, epochs=10, T=2.0, alpha=0.5)
test_model(student, test_data)

์„ค๋ช…:


๊ฒฐ๋ก 

Knowledge Distillation์€ ๊ณ ์„ฑ๋Šฅ Teacher ๋ชจ๋ธ์˜ ์ •๋ณด๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ์••์ถ•ํ•˜์—ฌ, ๊ฒฝ๋Ÿ‰ํ™”๋œ Student ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.


๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™”๋ฅผ ์œ„ํ•œ Quantization(์–‘์žํ™”) ๊ธฐ๋ฒ• ์ดํ•ดํ•˜๊ธฐ

๊ฑฐ๋Œ€ AI ๋ชจ๋ธ์€ ๊ณ ์ •๋ฐ€ FP32 ํ˜•์‹์˜ ๊ฐ€์ค‘์น˜์™€ ํ™œ์„ฑํ™”๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰๊ณผ ์—ฐ์‚ฐ ๋น„์šฉ์ด ๋งค์šฐ ํฝ๋‹ˆ๋‹ค. Quantization(์–‘์žํ™”)์€ ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ํ™œ์„ฑํ™”๋ฅผ ๋‚ฎ์€ ๋น„ํŠธ ์ •๋ฐ€๋„๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ €์žฅ ๋ฐ ๊ณ„์‚ฐ ํšจ์œจ์„ฑ์„ ๋†’์ด๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ฯ€๋ฅผ 3.141592์ฒ˜๋Ÿผ ๊ณ ์ •๋ฐ€ํ•˜๊ฒŒ ํ‘œํ˜„ํ•˜๋ฉด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ํฌ์ง€๋งŒ, 3๊ณผ ๊ฐ™์ด ๋‚ฎ์€ ์ •๋ฐ€๋„๋กœ ํ‘œํ˜„ํ•˜๋ฉด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ์€ ์ค„์–ด๋“ค์ง€๋งŒ ์˜ค์ฐจ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Quantization์˜ ํ•ต์‹ฌ์€ ์˜ค์ฐจ๋ฅผ ์ตœ์†Œํ™”ํ•˜๋ฉด์„œ๋„ ํšจ์œจ์ ์ธ ๋‚ฎ์€ ์ •๋ฐ€๋„๋ฅผ ์ฐพ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.


Quantization Mapping

Pasted image 20250311151907.png
Quantization์€ ๋†’์€ ์ •๋ฐ€๋„์˜ ๊ฐ’์„ ๋‚ฎ์€ ๋น„ํŠธ ์ •๋ฐ€๋„์˜ ๊ฐ’์— ๋งคํ•‘ํ•˜๋Š” ๊ณผ์ •์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ๋ชจ๋ธ์˜ ๋ฐ์ดํ„ฐ๋‚˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” FP32 ํ˜•์‹์œผ๋กœ ํ‘œํ˜„๋˜๋Š”๋ฐ, ์ด๋ฅผ FP16, INT8 ๋“ฑ์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค๋งŒ, ๊ฐ ์ž๋ฃŒํ˜•๋งˆ๋‹ค ํ‘œํ˜„ ๊ฐ€๋Šฅํ•œ ๊ฐ’์˜ ๋ฒ”์œ„๊ฐ€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค.

ํƒ€์ž… ํ‘œํ˜„ ๊ฐ€๋Šฅ ์ตœ์†Œ ํ‘œํ˜„ ๊ฐ€๋Šฅ ์ตœ๋Œ€
INT8 -128 127
INT16 -32768 32767
INT32 -2147483648 2147483647
FP16 -65504 65504
FP32 -3.4028235 ร— 10^38 3.4028235 ร— 10^38

์˜ˆ๋ฅผ ๋“ค์–ด, FP32๋กœ ํ‘œํ˜„๋œ 350.5๋Š” INT8์˜ ์ตœ๋Œ€ ํ‘œํ˜„ ๋ฒ”์œ„(127)๋ฅผ ์ดˆ๊ณผํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ง์ ‘์ ์ธ ๋ณ€ํ™˜์ด ๋ถˆ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

Quantization Mapping์€ ๋ณดํ†ต ์•„๋ž˜์™€ ๊ฐ™์€ ์ˆ˜์‹์„ ์‚ฌ์šฉํ•˜์—ฌ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.

Xquant=round(sร—X+z)
Xdequant=sร—(Xquantโˆ’z)

์–‘์žํ™” ์‹œ, ์ด s์™€ z๋ฅผ ์ €์žฅํ•ด๋‘์–ด ๋‚˜์ค‘์— ๋ณต์›(de-quantization)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


Quantization ๊ธฐ๋ฒ•์˜ ์ข…๋ฅ˜

1. Absmax Quantization

2. Zero-point Quantization


Clipping๊ณผ Calibration

๋ฐ์ดํ„ฐ์— outlier๊ฐ€ ์กด์žฌํ•˜๋ฉด, ์–‘์žํ™” mapping์ด ํšจ๊ณผ์ ์ด์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด clipping ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋ฐ์ดํ„ฐ ๊ฐ’์˜ ๋ฒ”์ฃผ๋ฅผ [โˆ’5,5][-5, 5]๋กœ ์ œํ•œํ•˜๊ณ , ์ด ๋ฒ”์œ„๋ฅผ ๋„˜์–ด์„œ๋Š” ๊ฐ’์€ ๋ชจ๋‘ ๊ฐ™์€ ๊ฐ’์œผ๋กœ ์ทจ๊ธ‰ํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์—์„œ ์ ์ ˆํ•œ ๋ฒ”์ฃผ๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ์„ calibration์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.


๊ฒฐ๋ก 

Quantization์€ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ํ™œ์„ฑํ™”๋ฅผ ๋‚ฎ์€ ๋น„ํŠธ ์ •๋ฐ€๋„๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๊ณ  ์—ฐ์‚ฐ ํšจ์œจ์„ฑ์„ ๋†’์ด๋Š” ์ค‘์š”ํ•œ ๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.


Parameter-Efficient Fine-Tuning (PEFT): ๊ฑฐ๋Œ€ ๋ชจ๋ธ์˜ ํšจ์œจ์  ๋ฏธ์„ธ์กฐ์ •

๊ฑฐ๋Œ€ AI ๋ชจ๋ธ์˜ ํ•™์Šต์€ ์ˆ˜๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ธํ•ด ๋ง‰๋Œ€ํ•œ ์ž์›๊ณผ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ƒํ™ฉ์—์„œ ์ „์ฒด ๋ชจ๋ธ์„ ์žฌํ•™์Šตํ•˜๋Š” ๋Œ€์‹ , ๋ชจ๋ธ์˜ ์ผ๋ถ€๋ถ„๋งŒ ๋ฏธ์„ธ์กฐ์ •ํ•˜์—ฌ ํšจ์œจ์„ฑ์„ ๊ทน๋Œ€ํ™”ํ•˜๋Š” Parameter-Efficient Fine-Tuning (PEFT) ๊ธฐ๋ฒ•์ด ์ฃผ๋ชฉ๋ฐ›๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. PEFT๋Š” ์ „์ด ํ•™์Šต(Transfer Learning)์˜ Fine-tuning ๋‹จ๊ณ„์—์„œ ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•˜์ง€ ์•Š๊ณ , ์ผ๋ถ€ ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ ํ•™์Šต์‹œ์ผœ ๋น ๋ฅด๊ณ  ๋น„์šฉ ํšจ์œจ์ ์ธ ๋ชจ๋ธ ๊ฐœ์„ ์„ ๋ชฉํ‘œ๋กœ ํ•ฉ๋‹ˆ๋‹ค.


1. ์ „์ด ํ•™์Šต๊ณผ PEFT

PEFT๋Š” Fine-tuning ์‹œ ์ „์ฒด ๋ชจ๋ธ์„ ์—…๋ฐ์ดํŠธํ•˜์ง€ ์•Š๊ณ  ์ผ๋ถ€ ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ ํ•™์Šตํ•˜์—ฌ ํšจ์œจ์„ฑ์„ ๋†’์ด๋Š” ์ ‘๊ทผ๋ฒ•์ž…๋‹ˆ๋‹ค.


2. PEFT์˜ ์ ‘๊ทผ ๋ฐฉ์‹

PEFT๋Š” ํฌ๊ฒŒ ๋‘ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•๋ก ์œผ๋กœ ๋‚˜๋ˆŒ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

2-1. Prompt Tuning ๋ฐฉ์‹

2-2. ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ฝ์ž… ๋ฐฉ์‹


3. ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ฝ์ž… ๋ฐฉ์‹ ์ƒ์„ธ

3-1. Adapter

Pasted image 20250311152433.png

3-2. Low-Rank Decomposition

Pasted image 20250311152445.png

3-3. LoRA (Low-Rank Adaptation)

Pasted image 20250311152454.png

ํ•ญ๋ชฉ Adapter LoRA
์—ฐ์‚ฐ ๋ฐฉ์‹ Sequential (์ˆœ์ฐจ์ ) Parallel (๋ณ‘๋ ฌ์ )
๋น„์„ ํ˜• ํ•จ์ˆ˜ ์‚ฌ์šฉ ์—ฌ๋ถ€ ์‚ฌ์šฉํ•จ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ
ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ Weight ๋ฐ bias Weight๋งŒ
์—ฐ์‚ฐ ์ง€์—ฐ ๋ชจ๋“ˆ ์—ฐ์‚ฐ์— ๋น„๋ก€ ๊ฑฐ์˜ ๋ฐœ์ƒํ•˜์ง€ ์•Š์Œ

LoRA๋Š” Adapter์™€ ์œ ์‚ฌํ•œ low-rank ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•˜์ง€๋งŒ, ๋ณ‘๋ ฌ์ ์œผ๋กœ ๊ณ„์‚ฐํ•˜์—ฌ ์†๋„ ์ธก๋ฉด์—์„œ ๋” ํšจ์œจ์ ์ž…๋‹ˆ๋‹ค.

3-4. AdapterFusion

Pasted image 20250311152506.png

3-5. QLoRA

Pasted image 20250311152519.png


4. In Practice: LoRA ์ ์šฉ ์‚ฌ๋ก€

์˜ˆ๋ฅผ ๋“ค์–ด, BERT_base ๋ชจ๋ธ์— LoRA๋ฅผ ์ ์šฉํ•˜์—ฌ SQuAD ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•œ ๊ฒฝ์šฐ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

# ๊ธฐ์กด ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณ ์ • (Fine-tuning ์‹œ ๊ณ ์ •)
for param in model.parameters():
    param.requires_grad = False

# LoRA ์„ค์ • (์˜ˆ์‹œ)
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=2,               # Low-rank factor (์˜ˆ: 1-4 ์ •๋„ ์‚ฌ์šฉ)
    lora_alpha=16,     # Low-rank matrix์— ๊ณฑํ•ด์ง€๋Š” ๊ฐ’์œผ๋กœ, ์ˆ˜๋ ด์„ ๋•์Šต๋‹ˆ๋‹ค.
    target_modules=["query", "value"],  # ์ ์šฉํ•  ๋ชจ๋“ˆ ์ง€์ •
    lora_dropout=0.1   # Dropout์œผ๋กœ ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
)

# PEFT ๋ชจ๋ธ ์ƒ์„ฑ
model_with_lora = get_peft_model(model, lora_config)

๊ฒฐ๋ก 

PEFT ๊ธฐ๋ฒ•์€ ๊ฑฐ๋Œ€ ๋ชจ๋ธ์˜ ์ „์ฒด ์žฌํ•™์Šต ์—†์ด๋„, ๋ชจ๋ธ์˜ ์ผ๋ถ€๋งŒ ๋ฏธ์„ธ์กฐ์ •ํ•˜์—ฌ ํ•™์Šต ๋น„์šฉ๊ณผ ์‹œ๊ฐ„์„ ํฌ๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ๋Š” ํšจ์œจ์ ์ธ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.