출처
- Using LoRA for Efficient Stable Diffusion Fine-Tuning (huggingface.co)
- LoRA는 LLM을 파인튜닝할 수 있는 기법이다
- 사전 학습된 모델 가중치는 Freeze 하고, 학습 가능한 '행렬 분해 행렬'을 삽입한다.
- 이는 학습 파라미터 수와 GPU 요구사항을 대폭 줄였다.
- MS 연구원들은 Transformer Attention 블록에 초점을 맞춰, LoRA를 사용한 미세 조정 품질이 전체 모델 미세 조정과 동등하면서도 훨씬 빠르고 컴퓨팅이 덜 필요하다는 사실을 발견
LoRA for Diffusers
- LoRA는 Transformer 블록을 적용하는 어느 어플리케이션에서도 적용 가능
- Stable Diffusion의 경우 '이미지 표현'을 '그것을 설명하는 프롬프트'와 연관 시키는 Cross-Attention 레이어에 LoRA를 적용할 수 있다.
- cloneofsimo/lora: Using Low-rank adaptation to quickly fine-tune diffusion models. (github.com) 에서 처음 LoRA를 Diffusion에서 적용하였다.
- Cross-Attention 레이어를 조작할 수 있는 유연한 방법을 제공하면, xFormers와 같은 최적화 기술을 쉽게 채택하는 등 이점이 있을 수 있다.
- Prompt-to-Prompt와 같은 프로젝트는 해당 레이어에 더 쉽게 접근할 수 있는 방법을 제공하므로, 우리는 일반적인 방법을 제공하기로 하였다. ==> 결과물
- 학습이 빠르며
- Full-FT가 2080Ti(11GB) 에서 가능하며
- 새 레이어의 가중치(~3MB) 단일파일로 저장 가능하다 (이는 UNet 대비 1/1000)
- Dreambooth Concepts 라이브러리를 등록하여, 사용자의 결과들을 공유 가능하다
LoRA Fine-Tuning
- SD Full-FT는 매우 느리고 어렵기 때문에, 가벼운 모델 DreamBooth, Textual-Inversion이 유명하다.
- Diffusers는 LoRA 파인튜닝 스크립트를 제공한다. (train_text_to_image_lora.py)
- 11GB GPU RAM에서 8bit 최적화와 같은 트릭에 의존하지 않고 실행 가능
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
export HUB_MODEL_ID="pokemon-lora"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=15000 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR} \
--push_to_hub \
--hub_model_id=${HUB_MODEL_ID} \
--report_to=wandb \
--checkpointing_steps=500 \
--validation_prompt="Totoro" \
--seed=1337
- 1e-4로 lr이 기존 1e-6 대비 크다. 2080Ti(11GB)로 약 5시간 소요되었다.
- T4(16GB RAM)에 돌린 버전 모델 ("sayakpaul/sd-model-finetuned-lora-t4")
추론
- 앞서 논의한 것 처럼 LoRA의 주요 장점은 훨씬 작은 가중치 훈련으로 우수한 결과를 얻을 수 있다는 것
- SD 모델 가중치에 이것의 수정 없이 추가 가중치를 로드할 수 있는 추론 프로세스를 설계하였다.
- 먼저 BaseModel을 정의한다.
from huggingface_hub import model_info
# LoRA weights ~3 MB
model_path = "sayakpaul/sd-model-finetuned-lora-t4"
info = model_info(model_path)
model_base = info.cardData["base_model"]
print(model_base) # CompVis/stable-diffusion-v1-4
- -push_to_hub 옵션을 선택할 경우, BaseModel 정보는 자동으로 FT 스크립트에 의해 수집된다.
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")
image = pipe("Green pokemon with menacing face", num_inference_steps=25).images[0]
image.save("green_pokemon.png")
- BaseModel이 결정되면, LoRA를 이용해 FT를 한다. 우리는 여기서 빠른 추론을 위해 DPMSolverMultistepScheduler 를 사용한다.
- pipe.unet.load_attn_procs(model_path) 하는 것 만으로 LoRA 가중치를 업로드 할 수 있다.
Dreamboothing with LoRA
- DreamBooth는 SD에 새로운 개념을 가리키는 것이다. LoRA와도 호환이 가능하며 FT 하는 것과 유사하다
- diffusers/examples/dreambooth at main · huggingface/diffusers (github.com)
- Training Stable Diffusion with Dreambooth using Diffusers (huggingface.co)
- LoRA+DreamBooth를 이용해 싸고, 쉬운 학습 방법을 원한다면 다음을 살펴본다(https://huggingface.co/spaces/lora-library/LoRA-DreamBooth-Training-UI)
Other Methods
- DreamBooth 외에도 Textual_Inversion은 SD에 새로운 개념을 학습시키는 바업ㅂ
- Textual_Inversion을 쓰는 이유는 학습된 가중치가 작고, 공유하기 쉽기 때문
- 하지만, 단일 물체 (혹은 소수)에만 작동하는 반면 LoRA는 범용 미세조정에 사용할 수 있습니다. 즉, 새로운 도메인이나 데이터셋에서 적용할 수 있습니다.
- Pivotal_Tuning은 Textual_Inversion을 LoRA와 결합한 방법입니다.
Diffusers는 3가지 구성요소가 있다
- SOTA 파이프라인 ( Pipelines (huggingface.co) )
- 생성 속도 <-> 품질 간의 균형을 맞추기 위한 교체 가능한 노이즈 스케줄러
- 사전 학습된 모델 ( Models (huggingface.co) )
Quicktour에서는 추론시 DiffusionPipeline을 어떻게 사용하는지 보여주고, 모델과 스케줄러를 결합하여, DiffusionPipeline 내부에서 일어나는 일을 설명한다.
Cross-Attention이란?
- Cross attention은 self attention과 mechanism은 동일하나 input의 출처가 다름
- query, key, value 기반의 attention 구조의 장점을 활용하는 cross attention이 있음
- key value값을 통해 어떻게 attention 해줄지를 설정
− Key, value를 어떻게 설정하느냐에 따라서 특정 inductive bias를 가해줄 수 있음
- Multi scale 또는 multi modal 등을 fusion 하는 연구에 사용
class CrossAttention(Module):
def __init__(
self,
dim,
hidden_dim,
heads = 8,
dim_head = 64,
flash = False
):
super().__init__()
self.attn = Attention(
dim = hidden_dim,
dim_context = dim,
norm_context = True,
num_null_kv = 1,
dim_head = dim_head,
heads = heads,
flash = flash
)
def forward(
self,
condition,
hiddens,
mask = None
):
return self.attn(hiddens, condition, mask = mask) + hiddens
class Attention(Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
dim_context = None,
norm_context = False,
num_null_kv = 0,
flash = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
dim_context = default(dim_context, dim)
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()
self.attend = Attend(flash = flash)
self.num_null_kv = num_null_kv
self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(
self,
x,
context = None,
mask = None
):
b = x.shape[0]
if exists(context):
context = self.context_norm(context)
kv_input = default(context, x)
x = self.norm(x)
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
if self.num_null_kv > 0:
null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
k = torch.cat((null_k, k), dim = -2)
v = torch.cat((null_v, v), dim = -2)
if exists(mask):
mask = F.pad(mask, (self.num_null_kv, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
out = self.attend(q, k, v, mask = mask)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
In dalle2_pytorch.py
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False,
cosine_sim = False,
cosine_sim_scale = 16
):
super().__init__()
self.cosine_sim = cosine_sim
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
'논문 리뷰(Paper Review)' 카테고리의 다른 글
"Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters" (0) | 2024.12.31 |
---|---|
DDPM(Denosing Diffusion Probabilistic Model) 개념 정리 (0) | 2024.02.27 |
Dall-e 2 및 주변 기술 리뷰 (1) | 2024.02.14 |
StyleTrasfer 복습 (0) | 2024.02.13 |
Meta Pseudo Labels(CVPR 2021) (0) | 2024.01.02 |