본문 바로가기

논문 리뷰(Paper Review)

LoRA for Efficient Stable Diffusion Fine-Tuning

출처

- Using LoRA for Efficient Stable Diffusion Fine-Tuning (huggingface.co)

 

Using LoRA for Efficient Stable Diffusion Fine-Tuning

Using LoRA for Efficient Stable Diffusion Fine-Tuning LoRA: Low-Rank Adaptation of Large Language Models is a novel technique introduced by Microsoft researchers to deal with the problem of fine-tuning large-language models. Powerful models with billions o

huggingface.co

 


 

- LoRA는 LLM을 파인튜닝할 수 있는 기법이다

- 사전 학습된 모델 가중치는 Freeze 하고, 학습 가능한 '행렬 분해 행렬'을 삽입한다. 

- 이는 학습 파라미터 수와 GPU 요구사항을 대폭 줄였다.

- MS 연구원들은 Transformer Attention 블록에 초점을 맞춰, LoRA를 사용한 미세 조정 품질이 전체 모델 미세 조정과 동등하면서도 훨씬 빠르고 컴퓨팅이 덜 필요하다는 사실을 발견

 

LoRA for Diffusers

- LoRA는 Transformer 블록을 적용하는 어느 어플리케이션에서도 적용 가능

- Stable Diffusion의 경우 '이미지 표현'을 '그것을 설명하는 프롬프트'와 연관 시키는 Cross-Attention 레이어에 LoRA를 적용할 수 있다.

- cloneofsimo/lora: Using Low-rank adaptation to quickly fine-tune diffusion models. (github.com) 에서 처음 LoRA를 Diffusion에서 적용하였다.

- Cross-Attention 레이어를 조작할 수 있는 유연한 방법을 제공하면, xFormers와 같은 최적화 기술을 쉽게 채택하는 등 이점이 있을 수 있다. 

- Prompt-to-Prompt와 같은 프로젝트는 해당 레이어에 더 쉽게 접근할 수 있는 방법을 제공하므로, 우리는 일반적인 방법을 제공하기로 하였다. ==> 결과물

  - 학습이 빠르며

  - Full-FT가 2080Ti(11GB) 에서 가능하며

  - 새 레이어의 가중치(~3MB) 단일파일로 저장 가능하다 (이는 UNet 대비 1/1000)

  - Dreambooth Concepts 라이브러리를 등록하여, 사용자의 결과들을 공유 가능하다

 

LoRA Fine-Tuning

- SD Full-FT는 매우 느리고 어렵기 때문에, 가벼운 모델 DreamBooth, Textual-Inversion이 유명하다.

- Diffusers는 LoRA 파인튜닝 스크립트를 제공한다. (train_text_to_image_lora.py)

- 11GB GPU RAM에서 8bit 최적화와 같은 트릭에 의존하지 않고 실행 가능

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
export HUB_MODEL_ID="pokemon-lora"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16"  train_text_to_image_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$DATASET_NAME \
  --dataloader_num_workers=8 \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-04 \
  --max_grad_norm=1 \
  --lr_scheduler="cosine" --lr_warmup_steps=0 \
  --output_dir=${OUTPUT_DIR} \
  --push_to_hub \
  --hub_model_id=${HUB_MODEL_ID} \
  --report_to=wandb \
  --checkpointing_steps=500 \
  --validation_prompt="Totoro" \
  --seed=1337

 

- 1e-4로 lr이 기존 1e-6 대비 크다. 2080Ti(11GB)로 약 5시간 소요되었다

- T4(16GB RAM)에 돌린 버전 모델 ("sayakpaul/sd-model-finetuned-lora-t4")

 

추론

- 앞서 논의한 것 처럼 LoRA의 주요 장점은 훨씬 작은 가중치 훈련으로 우수한 결과를 얻을 수 있다는 것

- SD 모델 가중치에 이것의 수정 없이 추가 가중치를 로드할 수 있는 추론 프로세스를 설계하였다.

- 먼저 BaseModel을 정의한다. 

from huggingface_hub import model_info

# LoRA weights ~3 MB
model_path = "sayakpaul/sd-model-finetuned-lora-t4"

info = model_info(model_path)
model_base = info.cardData["base_model"]
print(model_base)   # CompVis/stable-diffusion-v1-4

 

- -push_to_hub 옵션을 선택할 경우, BaseModel 정보는 자동으로 FT 스크립트에 의해 수집된다.

 

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")

image = pipe("Green pokemon with menacing face", num_inference_steps=25).images[0]
image.save("green_pokemon.png")

 

- BaseModel이 결정되면, LoRA를 이용해 FT를 한다. 우리는 여기서 빠른 추론을 위해 DPMSolverMultistepScheduler 를 사용한다.

- pipe.unet.load_attn_procs(model_path) 하는 것 만으로 LoRA 가중치를 업로드 할 수 있다.

Dreamboothing with LoRA

Other Methods

  • DreamBooth 외에도 Textual_Inversion SD 새로운 개념을 학습시키는 바업ㅂ
  • Textual_Inversion 쓰는 이유는 학습된 가중치가 작고, 공유하기 쉽기 때문
  • 하지만, 단일 물체 (혹은 소수)에만 작동하는 반면 LoRA 범용 미세조정에 사용할 있습니다. , 새로운 도메인이나 데이터셋에서 적용할 있습니다.
  • Pivotal_Tuning Textual_Inversion LoRA 결합한 방법입니다.

Diffusers는 3가지 구성요소가 있다

- SOTA 파이프라인 ( Pipelines (huggingface.co) )

- 생성 속도 <-> 품질 간의 균형을 맞추기 위한 교체 가능한 노이즈 스케줄러

- 사전 학습된 모델 ( Models (huggingface.co) )

 

Quicktour에서는 추론시 DiffusionPipeline을 어떻게 사용하는지 보여주고, 모델과 스케줄러를 결합하여, DiffusionPipeline 내부에서 일어나는 일을 설명한다.

 

 


Cross-Attention이란?

-  Cross attention은 self attention과 mechanism은 동일하나 input의 출처가 다름

-  query, key, value 기반의 attention 구조의 장점을 활용하는 cross attention이 있음

-  key value값을 통해 어떻게 attention 해줄지를 설정

    − Key, value를 어떻게 설정하느냐에 따라서 특정 inductive bias를 가해줄 수 있음

-  Multi scale 또는 multi modal 등을 fusion 하는 연구에 사용

 

 

 

 

classifier-free-guidance-pytorch/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py at a9edb94d44c61af38dd7f2971bbce7d40e4ea7d0 · lucidrains/classifier-free-guidance-pytorch (github.com)

class CrossAttention(Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        heads = 8,
        dim_head = 64,
        flash = False
    ):
        super().__init__()
        self.attn = Attention(
            dim = hidden_dim,
            dim_context = dim,
            norm_context = True,
            num_null_kv = 1,
            dim_head = dim_head,
            heads = heads,
            flash = flash
        )

    def forward(
        self,
        condition,
        hiddens,
        mask = None
    ):
        return self.attn(hiddens, condition, mask = mask) + hiddens
        
   class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dim_context = None,
        norm_context = False,
        num_null_kv = 0,
        flash = False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()

        self.attend = Attend(flash = flash)        

        self.num_null_kv = num_null_kv
        self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        mask = None
    ):
        b = x.shape[0]

        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        x = self.norm(x)

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        if self.num_null_kv > 0:
            null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
            k = torch.cat((null_k, k), dim = -2)
            v = torch.cat((null_v, v), dim = -2)

        if exists(mask):
            mask = F.pad(mask, (self.num_null_kv, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        out = self.attend(q, k, v, mask = mask)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

 

 

In dalle2_pytorch.py

DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py at 680dfc4d93b70f9ab23c814a22ca18017a738ef6 · lucidrains/DALLE2-pytorch (github.com)

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        norm_context = False,
        cosine_sim = False,
        cosine_sim_scale = 16
    ):
        super().__init__()
        self.cosine_sim = cosine_sim
        self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
        self.heads = heads
        inner_dim = dim_head * heads

        context_dim = default(context_dim, dim)

        self.norm = LayerNorm(dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
        self.dropout = nn.Dropout(dropout)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(self, x, context, mask = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        context = self.norm_context(context)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # add null key / value for classifier free guidance in prior net

        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        if self.cosine_sim:
            q, k = map(l2norm, (q, k))

        q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.type(sim.dtype)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
반응형