Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ results/
*.slurm
*.arrow
/shards/*
*.png
*.png

env/
28 changes: 27 additions & 1 deletion models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,30 @@ class TrainConfig:
use_lmms_eval: bool = True # Use lmms-eval for evaluation
lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench')
lmms_eval_limit: int = 2000
lmms_eval_batch_size: int = 128
lmms_eval_batch_size: int = 128

# LoRA Configuration
use_lora: bool = True
lora_rank: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.1
lora_target_modules: list[str] = field(default_factory=lambda: [
# Vision Transformer
"vision_encoder.blocks.*.attn.qkv_proj",
"vision_encoder.blocks.*.attn.out_proj",
"vision_encoder.blocks.*.mlp.fc1",
"vision_encoder.blocks.*.mlp.fc2",

# Language Model
"decoder.blocks.*.attn.q_proj",
"decoder.blocks.*.attn.k_proj",
"decoder.blocks.*.attn.v_proj",
"decoder.blocks.*.attn.out_proj",
"decoder.blocks.*.mlp.gate_proj",
"decoder.blocks.*.mlp.up_proj",
"decoder.blocks.*.mlp.down_proj",
"decoder.head",

# Modality Projector
"MP.proj",
])
54 changes: 53 additions & 1 deletion models/vision_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,39 @@ def __init__(self, cfg: VLMConfig, load_backbone=True):
self.load_backbone = load_backbone
self.tokenizer = get_tokenizer(cfg.lm_tokenizer, cfg.vlm_extra_tokens, cfg.lm_chat_template)

# Add config attribute for PEFT compatibility
self.config = self._create_hf_compatible_config()

def _create_hf_compatible_config(self):
"""Create a minimal HuggingFace-compatible config for PEFT"""

class HFCompatibleConfig:
"""A config class that behaves like both an object and a dictionary for PEFT compatibility"""
def __init__(self, cfg):
self.model_type = "vision_language_model" # Custom model type
self.vocab_size = cfg.lm_vocab_size
self.hidden_size = cfg.lm_hidden_dim
self.num_hidden_layers = cfg.lm_n_blocks
self.num_attention_heads = cfg.lm_n_heads
self.intermediate_size = cfg.lm_inter_dim
self.max_position_embeddings = cfg.lm_max_position_embeddings
self.rms_norm_eps = cfg.lm_rms_eps
self.tie_word_embeddings = cfg.lm_tie_weights

def get(self, key, default=None):
"""Dictionary-like get method for PEFT compatibility"""
return getattr(self, key, default)

def __getitem__(self, key):
"""Dictionary-like access for PEFT compatibility"""
return getattr(self, key)

def __contains__(self, key):
"""Dictionary-like 'in' operator for PEFT compatibility"""
return hasattr(self, key)

return HFCompatibleConfig(self.cfg)

def _replace_img_tokens_with_embd(self, input_ids, token_embd, image_embd):
"""
Replace every image-token placeholder in `input_ids` with the corresponding slice
Expand All @@ -48,7 +81,15 @@ def _replace_img_tokens_with_embd(self, input_ids, token_embd, image_embd):

return updated_token_embd

def forward(self, input_ids, images, attention_mask=None, targets=None):

def forward(self, input_ids, images=None, attention_mask=None, labels=None, **kwargs):
# Handle different argument names - PEFT might pass 'labels' instead of 'targets'
targets = labels if labels is not None else kwargs.get('targets', None)

# Handle cases where images might be passed through kwargs
if images is None:
images = kwargs.get('images', [])

if isinstance(images, list):
if not images: # Handle cases with no images
images = torch.empty(0, self.cfg.vit_channels, self.cfg.vit_image_size, self.cfg.vit_image_size, device=input_ids.device)
Expand Down Expand Up @@ -76,6 +117,17 @@ def forward(self, input_ids, images, attention_mask=None, targets=None):
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)

return logits, loss

def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""
Prepare inputs for generation. Required by PEFT for CAUSAL_LM task type.
This is a minimal implementation that just returns the basic inputs.
"""

return {
"input_ids": input_ids,
**kwargs
}

@torch.inference_mode()
def generate(self, input_ids, images, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
Expand Down
23 changes: 21 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from peft import LoraConfig, get_peft_model, TaskType

torch.manual_seed(0)
if torch.cuda.is_available():
Expand Down Expand Up @@ -210,6 +211,24 @@ def train(train_cfg, vlm_cfg):
else:
model = VisionLanguageModel(vlm_cfg, load_backbone=vlm_cfg.vlm_load_backbone_weights)

# Apply LoRA if enabled
if train_cfg.use_lora:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # Now works with the added prepare_inputs_for_generation method
r=train_cfg.lora_rank,
lora_alpha=train_cfg.lora_alpha,
lora_dropout=train_cfg.lora_dropout,
target_modules=train_cfg.lora_target_modules,
bias="none",
)

model = get_peft_model(model, lora_config)

if is_master():
print("\n=== LoRA Configuration ===")
model.print_trainable_parameters()
print("===========================\n")

if is_master():
print(f"nanoVLM initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Training summary{' (global)' if is_dist() else ''}: {len(train_loader.dataset)} samples, {int(len(train_loader)*get_world_size())} batches/epoch, batch size {int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}{', training on ' + str(get_world_size()) + ' GPUs' if is_dist() else ''}")
Expand Down Expand Up @@ -299,7 +318,7 @@ def train(train_cfg, vlm_cfg):
)
with autocast_context:
with context:
_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)
_, loss = model(input_ids=input_ids, images=images, attention_mask=attention_mask, labels=labels)

if train_cfg.gradient_accumulation_steps > 1:
loss = loss / train_cfg.gradient_accumulation_steps
Expand Down Expand Up @@ -357,7 +376,7 @@ def train(train_cfg, vlm_cfg):
attention_mask = batch["attention_mask"].to(device)

with autocast_context:
_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)
_, loss = model(input_ids=input_ids, images=images, attention_mask=attention_mask, labels=labels)

total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
Expand Down