Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tests/assets/tokenizer/tokenizer.json
Original file line number Diff line number Diff line change
Expand Up @@ -2029,11 +2029,7 @@
"land": 1994,
"?\n": 1995,
" respect": 1996,
"ances": 1997,
"<|image|>": 1998,
"<|begin_of_image|>": 1999,
"<|end_of_image|>": 2000,
"<|pad|>": 2001
"ances": 1997
},
"merges": [
]
Expand Down
35 changes: 0 additions & 35 deletions tests/assets/tokenizer/tokenizer_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,11 @@
"rstrip": false,
"single_word": false,
"special": true
},
"1998": {
"content": "<|image|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1999": {
"content": "<|begin_of_image|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2000": {
"content": "<|end_of_image|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2001": {
"content": "<|pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|begin_of_text|>",
"clean_up_tokenization_spaces": true,
"eos_token": "<|end_of_text|>",
"img_token": "<|image|>",
"boi_token": "<|begin_of_image|>",
"eoi_token": "<|end_of_image|>",
"pad_token": "<|pad|>",
"model_input_names": [
"input_ids",
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/components/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[s
return token

def _process_special_token(
self, token_str: str, token_config: dict, token_id: Optional[int] = None
self,
token_str: str,
token_config: dict | None = None,
token_id: int | None = None,
) -> AddedToken:
"""
Process a special token and update BOS/EOS attributes if applicable.
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/experiments/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict, replace
from dataclasses import asdict

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.experiments.vlm.tokenizer import build_vlm_tokenizer
from torchtitan.models.llama3 import llama3_configs
from torchtitan.protocols.train_spec import TrainSpec

Expand All @@ -29,7 +29,7 @@

llama3_siglip2_configs = {
"debugmodel": Llama3Siglip2ModelArgs(
**asdict(replace(llama3_configs["debugmodel"], vocab_size=2048)),
**asdict(llama3_configs["debugmodel"]),
encoder=Siglip2ModelArgs(
dim=128,
ffn_dim=256,
Expand All @@ -50,7 +50,7 @@ def get_train_spec() -> TrainSpec:
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_mm_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_tokenizer_fn=build_vlm_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
)
21 changes: 11 additions & 10 deletions torchtitan/experiments/vlm/datasets/mm_collator_nld.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@

from torchtitan.tools.logging import logger

from ..model.args import SpecialTokens

from ..tokenizer import VLMTokenizer
from .utils.image import (
convert_to_patches,
pad_empty_images_to_target_batch_size,
pad_patches,
)
from .utils.text import pad_input_ids_and_labels_to_target_batch_size, pad_text_batch

IGNORE_INDEX = -100


@dataclass
class MultiModalCollatorNLD:
Expand Down Expand Up @@ -85,7 +86,7 @@ class MultiModalCollatorNLD:
max_images_per_batch: int # Vision Encoder's batch size
max_patches_per_image: int # Vision Encoder's sequence length

special_tokens: SpecialTokens
tokenizer: VLMTokenizer

def collate_images(
self, all_images: list[torch.Tensor]
Expand Down Expand Up @@ -145,28 +146,28 @@ def collate_text(
input_ids = pad_sequence(
[s["input_ids"] for s in batch],
batch_first=True,
padding_value=self.special_tokens.pad_id,
padding_value=self.tokenizer.pad_id,
)
labels = pad_sequence(
[s["labels"] for s in batch],
batch_first=True,
padding_value=self.special_tokens.pad_id,
padding_value=self.tokenizer.pad_id,
)

# Handle sequence length
input_ids, labels = pad_text_batch(
input_ids,
labels,
self.seq_len + 1, # Extra token for label shifting
padding_idx=self.special_tokens.pad_id,
ignore_idx=self.special_tokens.ignore_id,
padding_idx=self.tokenizer.pad_id,
ignore_idx=IGNORE_INDEX,
)
input_ids, labels = pad_input_ids_and_labels_to_target_batch_size(
input_ids,
labels,
self.batch_size,
padding_idx=self.special_tokens.pad_id,
ignore_idx=self.special_tokens.ignore_id,
padding_idx=self.tokenizer.pad_id,
ignore_idx=IGNORE_INDEX,
)

return input_ids[:, :-1], labels[:, 1:] # Shift for next token prediction
Expand Down Expand Up @@ -221,7 +222,7 @@ def __call__(
"input": input_ids,
"pixel_values": patches,
"grid_thw": grids,
"special_tokens": self.special_tokens,
"img_token_id": self.tokenizer.img_id,
}

return input_dict, labels
36 changes: 14 additions & 22 deletions torchtitan/experiments/vlm/datasets/mm_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@
from torch.utils.data import IterableDataset

from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer
from torchtitan.config import JobConfig
from torchtitan.datasets import DatasetConfig
from torchtitan.tools.logging import logger

from ..model.args import SpecialTokens
from ..tokenizer import VLMTokenizer
from .mm_collator_nld import MultiModalCollatorNLD
from .utils.image import calculate_image_tokens, process_image
from .utils.packing import SamplePacker
from .utils.text import process_text_with_images


IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy


def _process_mm_sample(
texts: list[str] | str,
images: list[bytes] | bytes,
tokenizer: BaseTokenizer,
tokenizer: VLMTokenizer,
patch_size: int,
max_patch_per_image: int,
spatial_merge_size: int,
special_tokens: SpecialTokens,
) -> dict[str, Any] | None:
"""Common processing logic for multimodal samples.

Expand Down Expand Up @@ -98,7 +99,7 @@ def _process_mm_sample(
processed_images.append(processed_img)
image_dimensions.append((num_tokens, width, height))
# Replace None with image token
texts_list[idx] = special_tokens.img_token
texts_list[idx] = tokenizer.img_token
else:
# Replace None with empty string if processing failed
texts_list[idx] = ""
Expand All @@ -109,7 +110,7 @@ def _process_mm_sample(

# Process all image tokens at once
processed_text = process_text_with_images(
texts_list, image_dimensions, tokenizer, special_tokens, add_eos=True
texts_list, image_dimensions, tokenizer, add_eos=True
)

tokens = tokenizer.encode(processed_text)
Expand All @@ -120,10 +121,10 @@ def _process_mm_sample(

# Mask special tokens in labels
special_token_ids = torch.tensor(
[special_tokens.boi_id, special_tokens.eoi_id, special_tokens.img_id]
[tokenizer.boi_id, tokenizer.eoi_id, tokenizer.img_id]
)
labels = torch.where(
torch.isin(labels, special_token_ids), special_tokens.ignore_id, labels
torch.isin(labels, special_token_ids), IGNORE_INDEX, labels
)

return {
Expand All @@ -139,11 +140,10 @@ def _process_mm_sample(

def _process_obelics_sample(
sample: dict[str, Any],
tokenizer: HuggingFaceTokenizer,
tokenizer: VLMTokenizer,
patch_size: int,
spatial_merge_size: int,
max_patch_per_image: int,
special_tokens: SpecialTokens,
) -> dict[str, Any] | None:
"""Process a sample from the OBELICS dataset."""
return _process_mm_sample(
Expand All @@ -153,17 +153,15 @@ def _process_obelics_sample(
patch_size=patch_size,
spatial_merge_size=spatial_merge_size,
max_patch_per_image=max_patch_per_image,
special_tokens=special_tokens,
)


def _process_cc12_wd_sample(
sample: dict[str, Any],
tokenizer: BaseTokenizer,
tokenizer: VLMTokenizer,
patch_size: int,
spatial_merge_size: int,
max_patch_per_image: int,
special_tokens: SpecialTokens,
) -> dict[str, Any] | None:
"""Process a sample from the CC12-WD dataset.
Transforms CC12-WD format to match Interleaved format:
Expand All @@ -184,7 +182,6 @@ def _process_cc12_wd_sample(
patch_size=patch_size,
spatial_merge_size=spatial_merge_size,
max_patch_per_image=max_patch_per_image,
special_tokens=special_tokens,
)


Expand Down Expand Up @@ -225,15 +222,14 @@ def __init__(
self,
dataset_name: str,
dataset_path: str | None,
tokenizer: BaseTokenizer,
tokenizer: VLMTokenizer,
batch_size: int,
seq_len: int,
patch_size: int,
spatial_merge_size: int,
max_patches_per_image: int,
max_images_per_batch: int,
packing_buffer_size: int,
special_tokens: SpecialTokens,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
Expand All @@ -254,7 +250,6 @@ def __init__(
self.spatial_merge_size = spatial_merge_size
self.max_patches_per_image = max_patches_per_image
self.max_images_per_batch = max_images_per_batch
self.special_tokens = special_tokens
self.enable_packing = packing_buffer_size > 0
if self.enable_packing:
self.packer = SamplePacker(
Expand All @@ -277,7 +272,6 @@ def __iter__(self):
patch_size=self.patch_size,
spatial_merge_size=self.spatial_merge_size,
max_patch_per_image=self.max_patches_per_image,
special_tokens=self.special_tokens,
)
if processed is None:
continue
Expand Down Expand Up @@ -366,7 +360,7 @@ def state_dict(self):
def build_mm_dataloader(
dp_world_size: int,
dp_rank: int,
tokenizer: HuggingFaceTokenizer,
tokenizer: VLMTokenizer,
job_config: JobConfig,
infinite: bool = True,
) -> ParallelAwareDataloader:
Expand All @@ -393,7 +387,6 @@ def build_mm_dataloader(
patch_size = job_config.data.patch_size
spatial_merge_size = job_config.data.spatial_merge_size
packing_buffer_size = job_config.data.packing_buffer_size
special_tokens = SpecialTokens.from_tokenizer(tokenizer)

dataset = MultiModalDataset(
dataset_name=job_config.training.dataset,
Expand All @@ -406,7 +399,6 @@ def build_mm_dataloader(
max_patches_per_image=max_patches_per_image,
max_images_per_batch=max_images_per_batch,
packing_buffer_size=packing_buffer_size,
special_tokens=special_tokens,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
Expand All @@ -418,7 +410,7 @@ def build_mm_dataloader(
patch_size=patch_size,
max_images_per_batch=max_images_per_batch,
max_patches_per_image=max_patches_per_image,
special_tokens=special_tokens,
tokenizer=tokenizer,
)

base_dataloader = ParallelAwareDataloader(
Expand Down
13 changes: 7 additions & 6 deletions torchtitan/experiments/vlm/datasets/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from ...tokenizer import VLMTokenizer


def pad_text_batch(
input_ids: torch.Tensor,
Expand Down Expand Up @@ -97,8 +99,7 @@ def pad_input_ids_and_labels_to_target_batch_size(
def process_text_with_images(
text: list[str],
image_tokens: list[tuple[int, int, int]], # [(total, width, height), ...]
tokenizer,
special_tokens,
tokenizer: VLMTokenizer,
add_eos: bool = True,
) -> str:
"""Process text by interleaving image tokens efficiently.
Expand All @@ -122,14 +123,14 @@ def process_text_with_images(
image_idx = 0

for part in text:
if part == special_tokens.img_token and image_idx < len(image_tokens):
if part == tokenizer.img_token and image_idx < len(image_tokens):
num_image_tokens, _, _ = image_tokens[image_idx]

parts.extend(
[
special_tokens.boi_token,
*([special_tokens.img_token] * num_image_tokens),
special_tokens.eoi_token,
tokenizer.boi_token,
*([tokenizer.img_token] * num_image_tokens),
tokenizer.eoi_token,
]
)
image_idx += 1
Expand Down
Loading