Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | ||||
4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | 31 |
Tags
- Gemma
- llama-4-scout-17b-16e-instruct
- gemma-3
- ViT
- PEFT
- backbone
- langchain
- llama-4
- error: mkl-service + intel(r)
- glibcxx
- multi-gpu
- Mac
- Text-to-Image
- llm
- prompt
- ubuntu
- Fine-tuning
- CPT
- vLLM
- transformer
- lora+
- diffusion
- instruction tuning
- gemma2
- domain-adapted pre-training
- nccl
- tensor-parallel
- sfttrainer
- gemma-3-27b-it
- Lora
Archives
- Today
- Total
꾸준하게
[PEFT] QLoRA Quantization 적용 대상 본문
지금까지 QLoRA에서 Q가 LoRA에 붙어있으니 당연히 LoRA에 적용되는줄 알았다..
디버깅 해보니, LoRA는 fp16, base model layer들은 uint8로 찍힌다.
이때, 4bit가 아닌 8bit로 보이는 이유는, 겉으로는 8bit로 보이지만 내부적으로 2개의 weight를 하나의 8bit로 합쳐서 저장하기 때문으로, 실제로는 4bit로 저장되는게 맞다고 한다.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
device_map = {"": device_string} # force data-parallel training
model_name = 'NCSOFT/Llama-VARCO-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path = model_name,
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # 'eager' for gemma2
)
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.0,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type='CAUSAL_LM'
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path = model_name,
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # 'eager' for gemma2
)
model = get_peft_model(model, peft_config)
print(model.base_model.model.model.layers[0].self_attn.q_proj.base_layer.state_dict()['weight'].dtype)
print(model.base_model.model.model.layers[0].self_attn.q_proj.lora_A.state_dict()['default.weight'].dtype)
# Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00, 1.08s/it]
# Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00, 1.09s/it]
# torch.uint8
# torch.float32
model.half()
print(model.base_model.model.model.layers[0].self_attn.q_proj.base_layer.state_dict()['weight'].dtype)
# torch.float16
'LLM' 카테고리의 다른 글
[vLLM] vLLM API Server 구동 방법 (0) | 2025.05.07 |
---|---|
[instruction tuning] instruction label masking (1) | 2024.10.16 |
[LangChain/LangSmith] Agent 작동 방식 (0) | 2024.09.02 |
[RAG] Retrieval 평가지표 (1) | 2024.06.18 |