diff --git a/Multimodal/Fine_tune_Video_LLaVa.ipynb b/Multimodal/Fine_tune_Video_LLaVa.ipynb new file mode 100644 index 0000000..8edcd74 --- /dev/null +++ b/Multimodal/Fine_tune_Video_LLaVa.ipynb @@ -0,0 +1,2372 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "\"Open" + ], + "metadata": { + "id": "ltgr0WOdmMRF" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SqMsvBANlvDv" + }, + "source": [ + "## Prerequisites\n", + "Before we start, make sure you have the following:\n", + "\n", + "- Access to a GPU (preferably A100 since videos require high sequence lengths).\n", + "- Familiarity with Hugging Face’s Transformers library.\n", + "- Pre-install necessary packages by running the below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-T8dupbwlvDx" + }, + "outputs": [], + "source": [ + "!pip install -U -q transformers accelerate bitsandbytes peft dataset\n", + "!pip install -q av\n", + "!pip install -q lightning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xTG2AJZT89qk" + }, + "source": [ + "## Fine-tune VIdeo-LLaVa on MMBench dataset\n", + "\n", + "In this notebook, we are going to fine-tune the [Video-LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava) model on [MMBench](https://huggingface.co/datasets/OpenGVLab/MVBench) dataset which is comprised of various video-related tasks. Note that MMBench is quite small and is not made for tuning. Make sure to choose a bigger dataset for your own use-case.\n", + "\n", + "Video-LLaVa is an open-source multimodal model that can accept both, images and videos as input in an interleaved manner. The model architecture is pretty much similar to [LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/llava). However Video-LLaVa leverages a new universal visial encoder to seemlessly handle both visual modes. As we'll see, fine-tuning these various models is pretty similar as their API is mostly the same.\n", + "\n", + "The goal for the model in this notebook is to answer given multiple choice questions basedd on the video. The questions can be realetd to temporal aspects of the video, pose prediction and so on.\n", + "Sources:\n", + "\n", + "* Video-LLaVa [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava)\n", + "* Video-LLaVa [checkpoint on the hub](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf)\n", + "\n", + "**Note: this notebook is a direct adaptation of Niels' [LLaVa notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Fine_tune_LLaVa_on_a_custom_dataset_(with_PyTorch_Lightning).ipynb).**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Szf17AKL89qm" + }, + "source": [ + "## Define variables\n", + "\n", + "We'll first set some variables useful througout this notebook and doo all the necessary imports." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LJtnWc3b89qn", + "outputId": "8306b1f9-be6b-4083-f1c7-429a21984f96" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-05-22 18:06:06.577404: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-05-22 18:06:07.308550: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + } + ], + "source": [ + "import os\n", + "import av\n", + "import re\n", + "import bisect\n", + "import shutil\n", + "import numpy as np\n", + "from nltk import edit_distance\n", + "\n", + "from transformers import AutoProcessor\n", + "from transformers import BitsAndBytesConfig, VideoLlavaForConditionalGeneration\n", + "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n", + "\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from huggingface_hub import snapshot_download\n", + "from datasets import load_dataset, concatenate_datasets\n", + "\n", + "import lightning as L\n", + "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n", + "\n", + "\n", + "MAX_LENGTH = 256\n", + "MODEL_ID = \"LanguageBind/Video-LLaVA-7B-hf\"\n", + "REPO_ID = \"RaushanTurganbay/VideoLLava-demo\" # Change to your hf-hub repo\n", + "\n", + "USE_LORA = False\n", + "USE_QLORA = True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WNZQ3imilvDz" + }, + "source": [ + "We will start vy downloading and processing the dataset. Even though MMBench is a small dataset, it still requires around 100GB to store the videos, so make sure you have enough free space.\n", + "\n", + "First, we will use this mapping to get the datasets because each one is a separate subset in its own folder. Then we need a few helper functions to download videos and process them to fit the model's format (8 frames each video) ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PZGzYh7KlvDz" + }, + "outputs": [], + "source": [ + "config2path = {\n", + " \"object_interaction\": \"star/Charades_v1_480\",\n", + " \"action_sequence\": \"star/Charades_v1_480\",\n", + " \"action_prediction\": \"star/Charades_v1_480\",\n", + " \"moving_count\": \"clevrer/video_validation\",\n", + " \"moving_attribute\": \"clevrer/video_validation\",\n", + " \"object_existence\": \"clevrer/video_validation\",\n", + " \"moving_direction\": \"clevrer/video_validation\",\n", + " \"counterfactual_inference\": \"clevrer/video_validation\",\n", + " \"unexpected_action\": \"FunQA_test/test\",\n", + " \"episodic_reasoning\": \"tvqa/frames_fps3_hq\",\n", + " \"action_antonym\": \"ssv2_video\",\n", + " \"scene_transition\": \"scene_qa/video\",\n", + " \"fine_grained_pose\": \"nturgbd\",\n", + " \"object_shuffle\": \"perception/videos\",\n", + " \"state_change\": \"perception/videos\",\n", + " \"character_order\": \"perception/videos\",\n", + " \"action_localization\": \"sta/sta_video\",\n", + " \"fine_grained_action\": \"Moments_in_Time_Raw/vi\",\n", + " \"egocentric_navigation\": \"vlnqa\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yVdfSXMGlvDz" + }, + "outputs": [], + "source": [ + "def read_video_pyav(video_path, start, end):\n", + " \"\"\"Reads a video for given start-end timestamps interval and uniformly samples 8 frames of it\"\"\"\n", + " container = av.open(video_path)\n", + " video = container.streams.get(0)[0]\n", + "\n", + " av_timestamps = [\n", + " int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None\n", + " ]\n", + "\n", + " av_timestamps.sort()\n", + " start_id = bisect.bisect_left(av_timestamps, start)\n", + " end_id = bisect.bisect_left(av_timestamps, end)\n", + "\n", + " # in case it is a very short video, lets take a longer duration and sample\n", + " if end_id - start_id < 10:\n", + " end_id += 10\n", + " start_id -= 10\n", + "\n", + " end_id = min(len(av_timestamps) - 1, end_id)\n", + " start_id = max(1, start_id)\n", + " indices = np.linspace(start_id, end_id, 8).astype(int)\n", + "\n", + " frames = []\n", + " container.seek(0)\n", + " for i, frame in enumerate(container.decode(video=0)):\n", + " if i > end_id:\n", + " break\n", + " if i >= start_id and i in indices:\n", + " frames.append(frame)\n", + " assert len(frames) == 8, f\"Got {len(frames)} frames but should be 8. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames.\"\n", + " return np.stack([x.to_ndarray(format=\"rgb24\") for x in frames])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X9-8O207lvD0" + }, + "outputs": [], + "source": [ + "def collate_read_video(example, path):\n", + " # Some datasets have a start-end interval, so we try to get it if exists. Otherwise just set a very large end timestamp\n", + " clip = read_video_pyav(f'{path}/{example[\"video\"]}', example.get(\"start\", 1), example.get(\"end\", 1e+10))\n", + " example[\"clip\"] = clip\n", + " return example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "321b49438e534f3a881e10f7de9528bb" + ] + }, + "id": "BMNQND-ilvD0", + "outputId": "b6ebfc90-1a0e-4a4b-c494-2f9aa0a56a2d" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "321b49438e534f3a881e10f7de9528bb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 40 files: 0%| | 0/40 [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "processor = AutoProcessor.from_pretrained(MODEL_ID)\n", + "processor.tokenizer.padding_side = \"right\" # during training, one always uses padding on the right" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rtxYp7h2lvD0" + }, + "source": [ + "## Custom Dataset Class\n", + "\n", + "In the next step, we'll define a custom dataset class and the necessary functions to prepare our data for fine-tuning the Video-LLaVA model. The VideoLlavaDataset class extends the [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class to facilitate loading and processing \"MMBench\". This class will handle the conversion of dataset samples into the format required for training and evaluation by preparing a prompt and making array from videos.\n", + "\n", + "NOTE: Video-LLaVa accepts videos in one of the following formats:\n", + "- an array or tensor of shape: (batch-size, frames, channel, height, width) where batch-size is an optional dimension\n", + "- a list of arrays of shape: (frames, channel, height, width)\n", + "- a nested list of video frames, where each frame is an image\n", + "\n", + "\n", + "Next, we define collate functions to handle the batching of data during training and evaluation. These functions ensure that the input data is properly formatted and padded.\n", + "\n", + "It's only here that we're going to use the processor to turn the (video, target token sequence) into the format that the model expects (which is pixel_values, input_ids etc.). The reason we do that here is because it allows for dynamic padding of the batches: each batch contains ground truth sequences of varying lengths. By only using the processor here, we will pad the input_ids up to the largest sequence in the batch.\n", + "\n", + "We also decide to limit the length of the text tokens (input_ids) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).\n", + "\n", + "The formatting of the input_ids is super important: we need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating). As of now, Video-LLaVa does not yet support chat templates, so we manually write down the prompt in the correct format (which starts with USER and ends with ASSISTANT).You could also omit this and just train the model on (video, instruction) pairs without text prompt.\n", + "\n", + "Labels are created for the model by simply copying the inputs to the LLM (input_ids), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).\n", + "\n", + "Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen here.\n", + "\n", + "The collate function for evaluation is different, since there we only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "un_H3-pPlvD0" + }, + "outputs": [], + "source": [ + "class VideoLlavaDataset(Dataset):\n", + " \"\"\"\n", + " PyTorch Dataset for VideoLlavaDataset. This class takes a HuggingFace Dataset as input.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " dataset: str,\n", + " ):\n", + " super().__init__()\n", + " self.dataset = dataset\n", + " self.id2choice = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, idx: int):\n", + " sample = self.dataset[idx]\n", + " # cast to np because ds.map() casted everything to list, and the processor does not work list format\n", + " clip = np.array(sample[\"clip\"])\n", + "\n", + " question, candidates = sample[\"question\"], sample[\"candidates\"]\n", + " answer = candidates.index(sample[\"answer\"])\n", + " answer = self.id2choice[answer]\n", + "\n", + " mult_choice = \"\"\n", + " for i, choice in enumerate(candidates):\n", + " mult_choice += f\"{self.id2choice[i]}. {choice}; \"\n", + "\n", + " # Prepare a prompt template, can be changed depeding on the dataset and use-cases\n", + " prompt = f\"USER: