diff --git a/Multimodal/Fine_tune_Video_LLaVa.ipynb b/Multimodal/Fine_tune_Video_LLaVa.ipynb
new file mode 100644
index 0000000..8edcd74
--- /dev/null
+++ b/Multimodal/Fine_tune_Video_LLaVa.ipynb
@@ -0,0 +1,2372 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ " "
+ ],
+ "metadata": {
+ "id": "ltgr0WOdmMRF"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SqMsvBANlvDv"
+ },
+ "source": [
+ "## Prerequisites\n",
+ "Before we start, make sure you have the following:\n",
+ "\n",
+ "- Access to a GPU (preferably A100 since videos require high sequence lengths).\n",
+ "- Familiarity with Hugging Face’s Transformers library.\n",
+ "- Pre-install necessary packages by running the below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-T8dupbwlvDx"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -U -q transformers accelerate bitsandbytes peft dataset\n",
+ "!pip install -q av\n",
+ "!pip install -q lightning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xTG2AJZT89qk"
+ },
+ "source": [
+ "## Fine-tune VIdeo-LLaVa on MMBench dataset\n",
+ "\n",
+ "In this notebook, we are going to fine-tune the [Video-LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava) model on [MMBench](https://huggingface.co/datasets/OpenGVLab/MVBench) dataset which is comprised of various video-related tasks. Note that MMBench is quite small and is not made for tuning. Make sure to choose a bigger dataset for your own use-case.\n",
+ "\n",
+ "Video-LLaVa is an open-source multimodal model that can accept both, images and videos as input in an interleaved manner. The model architecture is pretty much similar to [LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/llava). However Video-LLaVa leverages a new universal visial encoder to seemlessly handle both visual modes. As we'll see, fine-tuning these various models is pretty similar as their API is mostly the same.\n",
+ "\n",
+ "The goal for the model in this notebook is to answer given multiple choice questions basedd on the video. The questions can be realetd to temporal aspects of the video, pose prediction and so on.\n",
+ "Sources:\n",
+ "\n",
+ "* Video-LLaVa [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava)\n",
+ "* Video-LLaVa [checkpoint on the hub](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf)\n",
+ "\n",
+ "**Note: this notebook is a direct adaptation of Niels' [LLaVa notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Fine_tune_LLaVa_on_a_custom_dataset_(with_PyTorch_Lightning).ipynb).**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Szf17AKL89qm"
+ },
+ "source": [
+ "## Define variables\n",
+ "\n",
+ "We'll first set some variables useful througout this notebook and doo all the necessary imports."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LJtnWc3b89qn",
+ "outputId": "8306b1f9-be6b-4083-f1c7-429a21984f96"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-05-22 18:06:06.577404: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-05-22 18:06:07.308550: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import av\n",
+ "import re\n",
+ "import bisect\n",
+ "import shutil\n",
+ "import numpy as np\n",
+ "from nltk import edit_distance\n",
+ "\n",
+ "from transformers import AutoProcessor\n",
+ "from transformers import BitsAndBytesConfig, VideoLlavaForConditionalGeneration\n",
+ "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n",
+ "\n",
+ "import torch\n",
+ "from torch.utils.data import Dataset\n",
+ "from torch.utils.data import DataLoader\n",
+ "from huggingface_hub import snapshot_download\n",
+ "from datasets import load_dataset, concatenate_datasets\n",
+ "\n",
+ "import lightning as L\n",
+ "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
+ "\n",
+ "\n",
+ "MAX_LENGTH = 256\n",
+ "MODEL_ID = \"LanguageBind/Video-LLaVA-7B-hf\"\n",
+ "REPO_ID = \"RaushanTurganbay/VideoLLava-demo\" # Change to your hf-hub repo\n",
+ "\n",
+ "USE_LORA = False\n",
+ "USE_QLORA = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WNZQ3imilvDz"
+ },
+ "source": [
+ "We will start vy downloading and processing the dataset. Even though MMBench is a small dataset, it still requires around 100GB to store the videos, so make sure you have enough free space.\n",
+ "\n",
+ "First, we will use this mapping to get the datasets because each one is a separate subset in its own folder. Then we need a few helper functions to download videos and process them to fit the model's format (8 frames each video) ."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PZGzYh7KlvDz"
+ },
+ "outputs": [],
+ "source": [
+ "config2path = {\n",
+ " \"object_interaction\": \"star/Charades_v1_480\",\n",
+ " \"action_sequence\": \"star/Charades_v1_480\",\n",
+ " \"action_prediction\": \"star/Charades_v1_480\",\n",
+ " \"moving_count\": \"clevrer/video_validation\",\n",
+ " \"moving_attribute\": \"clevrer/video_validation\",\n",
+ " \"object_existence\": \"clevrer/video_validation\",\n",
+ " \"moving_direction\": \"clevrer/video_validation\",\n",
+ " \"counterfactual_inference\": \"clevrer/video_validation\",\n",
+ " \"unexpected_action\": \"FunQA_test/test\",\n",
+ " \"episodic_reasoning\": \"tvqa/frames_fps3_hq\",\n",
+ " \"action_antonym\": \"ssv2_video\",\n",
+ " \"scene_transition\": \"scene_qa/video\",\n",
+ " \"fine_grained_pose\": \"nturgbd\",\n",
+ " \"object_shuffle\": \"perception/videos\",\n",
+ " \"state_change\": \"perception/videos\",\n",
+ " \"character_order\": \"perception/videos\",\n",
+ " \"action_localization\": \"sta/sta_video\",\n",
+ " \"fine_grained_action\": \"Moments_in_Time_Raw/vi\",\n",
+ " \"egocentric_navigation\": \"vlnqa\",\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yVdfSXMGlvDz"
+ },
+ "outputs": [],
+ "source": [
+ "def read_video_pyav(video_path, start, end):\n",
+ " \"\"\"Reads a video for given start-end timestamps interval and uniformly samples 8 frames of it\"\"\"\n",
+ " container = av.open(video_path)\n",
+ " video = container.streams.get(0)[0]\n",
+ "\n",
+ " av_timestamps = [\n",
+ " int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None\n",
+ " ]\n",
+ "\n",
+ " av_timestamps.sort()\n",
+ " start_id = bisect.bisect_left(av_timestamps, start)\n",
+ " end_id = bisect.bisect_left(av_timestamps, end)\n",
+ "\n",
+ " # in case it is a very short video, lets take a longer duration and sample\n",
+ " if end_id - start_id < 10:\n",
+ " end_id += 10\n",
+ " start_id -= 10\n",
+ "\n",
+ " end_id = min(len(av_timestamps) - 1, end_id)\n",
+ " start_id = max(1, start_id)\n",
+ " indices = np.linspace(start_id, end_id, 8).astype(int)\n",
+ "\n",
+ " frames = []\n",
+ " container.seek(0)\n",
+ " for i, frame in enumerate(container.decode(video=0)):\n",
+ " if i > end_id:\n",
+ " break\n",
+ " if i >= start_id and i in indices:\n",
+ " frames.append(frame)\n",
+ " assert len(frames) == 8, f\"Got {len(frames)} frames but should be 8. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames.\"\n",
+ " return np.stack([x.to_ndarray(format=\"rgb24\") for x in frames])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "X9-8O207lvD0"
+ },
+ "outputs": [],
+ "source": [
+ "def collate_read_video(example, path):\n",
+ " # Some datasets have a start-end interval, so we try to get it if exists. Otherwise just set a very large end timestamp\n",
+ " clip = read_video_pyav(f'{path}/{example[\"video\"]}', example.get(\"start\", 1), example.get(\"end\", 1e+10))\n",
+ " example[\"clip\"] = clip\n",
+ " return example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "referenced_widgets": [
+ "321b49438e534f3a881e10f7de9528bb"
+ ]
+ },
+ "id": "BMNQND-ilvD0",
+ "outputId": "b6ebfc90-1a0e-4a4b-c494-2f9aa0a56a2d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "321b49438e534f3a881e10f7de9528bb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fetching 40 files: 0%| | 0/40 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Download the videos from datasets repo and unzip. Make sure you have enough free space before downloading and unzipping\n",
+ "videos = snapshot_download(repo_id=\"OpenGVLab/MVBench\", allow_patterns=\"*\", repo_type=\"dataset\")\n",
+ "for zip_file in os.listdir(f\"{videos}/video\"):\n",
+ " if zip_file.endswith(\".zip\"):\n",
+ " shutil.unpack_archive(f\"{videos}/video/{zip_file}\", f\"{videos}/videos_unzipped/\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true,
+ "id": "yD9StuiIlvD0"
+ },
+ "outputs": [],
+ "source": [
+ "# Load each config and save in a mapping\n",
+ "config2ds = {}\n",
+ "for config, path in config2path.items():\n",
+ " ds = load_dataset(\"OpenGVLab/MVBench\", config, split=\"train\")\n",
+ " ds = ds.map(collate_read_video, batched=False, fn_kwargs={\"path\": f\"{videos}/videos_unzipped/{path}\"})\n",
+ " config2ds[config] = ds"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GRTrOYp6lvD0",
+ "outputId": "7c90aeec-491a-4f00-de39-43aa0d8d2546"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n",
+ "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ }
+ ],
+ "source": [
+ "processor = AutoProcessor.from_pretrained(MODEL_ID)\n",
+ "processor.tokenizer.padding_side = \"right\" # during training, one always uses padding on the right"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rtxYp7h2lvD0"
+ },
+ "source": [
+ "## Custom Dataset Class\n",
+ "\n",
+ "In the next step, we'll define a custom dataset class and the necessary functions to prepare our data for fine-tuning the Video-LLaVA model. The VideoLlavaDataset class extends the [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class to facilitate loading and processing \"MMBench\". This class will handle the conversion of dataset samples into the format required for training and evaluation by preparing a prompt and making array from videos.\n",
+ "\n",
+ "NOTE: Video-LLaVa accepts videos in one of the following formats:\n",
+ "- an array or tensor of shape: (batch-size, frames, channel, height, width) where batch-size is an optional dimension\n",
+ "- a list of arrays of shape: (frames, channel, height, width)\n",
+ "- a nested list of video frames, where each frame is an image\n",
+ "\n",
+ "\n",
+ "Next, we define collate functions to handle the batching of data during training and evaluation. These functions ensure that the input data is properly formatted and padded.\n",
+ "\n",
+ "It's only here that we're going to use the processor to turn the (video, target token sequence) into the format that the model expects (which is pixel_values, input_ids etc.). The reason we do that here is because it allows for dynamic padding of the batches: each batch contains ground truth sequences of varying lengths. By only using the processor here, we will pad the input_ids up to the largest sequence in the batch.\n",
+ "\n",
+ "We also decide to limit the length of the text tokens (input_ids) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).\n",
+ "\n",
+ "The formatting of the input_ids is super important: we need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating). As of now, Video-LLaVa does not yet support chat templates, so we manually write down the prompt in the correct format (which starts with USER and ends with ASSISTANT).You could also omit this and just train the model on (video, instruction) pairs without text prompt.\n",
+ "\n",
+ "Labels are created for the model by simply copying the inputs to the LLM (input_ids), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).\n",
+ "\n",
+ "Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen here.\n",
+ "\n",
+ "The collate function for evaluation is different, since there we only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "un_H3-pPlvD0"
+ },
+ "outputs": [],
+ "source": [
+ "class VideoLlavaDataset(Dataset):\n",
+ " \"\"\"\n",
+ " PyTorch Dataset for VideoLlavaDataset. This class takes a HuggingFace Dataset as input.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " dataset: str,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.dataset = dataset\n",
+ " self.id2choice = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}\n",
+ "\n",
+ " def __len__(self) -> int:\n",
+ " return len(self.dataset)\n",
+ "\n",
+ " def __getitem__(self, idx: int):\n",
+ " sample = self.dataset[idx]\n",
+ " # cast to np because ds.map() casted everything to list, and the processor does not work list format\n",
+ " clip = np.array(sample[\"clip\"])\n",
+ "\n",
+ " question, candidates = sample[\"question\"], sample[\"candidates\"]\n",
+ " answer = candidates.index(sample[\"answer\"])\n",
+ " answer = self.id2choice[answer]\n",
+ "\n",
+ " mult_choice = \"\"\n",
+ " for i, choice in enumerate(candidates):\n",
+ " mult_choice += f\"{self.id2choice[i]}. {choice}; \"\n",
+ "\n",
+ " # Prepare a prompt template, can be changed depeding on the dataset and use-cases\n",
+ " prompt = f\"USER: \\nAnswer the following multiple choice question based on the video. \" \\\n",
+ " f\"Question: {question}\\n {mult_choice}\\n ASSISTANT: Answer: {answer}\"\n",
+ "\n",
+ " return prompt, clip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AQ6rlxudlvD1"
+ },
+ "outputs": [],
+ "source": [
+ "def train_collate_fn(examples):\n",
+ " videos = []\n",
+ " texts = []\n",
+ " texts, videos = list(zip(*examples))\n",
+ "\n",
+ " batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors=\"pt\")\n",
+ "\n",
+ " labels = batch[\"input_ids\"].clone()\n",
+ " labels[labels == processor.tokenizer.pad_token_id] = -100\n",
+ " batch[\"labels\"] = labels\n",
+ "\n",
+ " input_ids = batch[\"input_ids\"]\n",
+ " attention_mask = batch[\"attention_mask\"]\n",
+ " pixel_values_videos = batch[\"pixel_values_videos\"]\n",
+ " labels = batch[\"labels\"]\n",
+ "\n",
+ " return input_ids, attention_mask, pixel_values_videos, labels\n",
+ "\n",
+ "\n",
+ "def eval_collate_fn(examples):\n",
+ " # We only feed the prompt to the model\n",
+ " videos = []\n",
+ " texts = []\n",
+ " texts, videos = list(zip(*examples))\n",
+ " texts = [text[:-2] for text in texts]\n",
+ "\n",
+ " batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors=\"pt\")\n",
+ "\n",
+ " input_ids = batch[\"input_ids\"]\n",
+ " attention_mask = batch[\"attention_mask\"]\n",
+ " pixel_values_videos = batch[\"pixel_values_videos\"]\n",
+ " answer_choice = [texts[-1] for text in texts]\n",
+ "\n",
+ " return input_ids, attention_mask, pixel_values_videos, answer_choice"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0sH2oWArlvD1"
+ },
+ "source": [
+ "## Combining and Splitting the Dataset\n",
+ "We'll combine multiple datasets, shuffle them, and then split them into training and test sets. This ensures that our model is trained on a diverse and representative sample of the data.\n",
+ "\n",
+ "As oftentimes, we get a DatasetDict which is a dictionary containing 3 splits, one for training, validation and testing. Each split has 2 features, an image and a corresponding ground truth.\n",
+ "\n",
+ "Let's check the some training example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DenEK2IqlvD1"
+ },
+ "outputs": [],
+ "source": [
+ "datasets_combined = concatenate_datasets(list(config2ds.values()))\n",
+ "datasets_combined = datasets_combined.shuffle(seed=42)\n",
+ "dataset = datasets_combined.train_test_split(test_size=0.2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qZb9QggAlvD1",
+ "outputId": "7db0044a-aa9e-4d42-f3ff-10e995131831"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['video', 'question', 'end', 'candidates', 'accurate_end', 'accurate_start', 'start', 'answer', 'clip'],\n",
+ " num_rows: 1280\n",
+ " })\n",
+ " test: Dataset({\n",
+ " features: ['video', 'question', 'end', 'candidates', 'accurate_end', 'accurate_start', 'start', 'answer', 'clip'],\n",
+ " num_rows: 320\n",
+ " })\n",
+ "})"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "X_DDgpOulvD1"
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "\n",
+ "from matplotlib import pyplot as plt\n",
+ "from matplotlib import animation\n",
+ "from IPython.display import HTML\n",
+ "\n",
+ "\n",
+ "example = dataset['train'][0]\n",
+ "clip = example[\"clip\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "aLtCWizwlvD1",
+ "outputId": "cefb5908-d399-4d7a-de9b-6d323f4a847a"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " Your browser does not support the video tag.\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# np array with shape (frames, height, width, channels)\n",
+ "video = np.array(clip)\n",
+ "\n",
+ "fig = plt.figure()\n",
+ "im = plt.imshow(video[0,:,:,:])\n",
+ "\n",
+ "plt.close() # this is required to not display the generated image\n",
+ "\n",
+ "def init():\n",
+ " im.set_data(video[0,:,:,:])\n",
+ "\n",
+ "def animate(i):\n",
+ " im.set_data(video[i,:,:,:])\n",
+ " return im\n",
+ "\n",
+ "anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],\n",
+ " interval=100)\n",
+ "HTML(anim.to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2rdNZDG5lvD1",
+ "outputId": "38039271-d891-432e-d221-5a36996e6df7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "('Which object was eaten by the person?',\n",
+ " ['The refrigerator.', 'The medicine.', 'The picture.', 'The sandwich.'],\n",
+ " 'The medicine.')"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "example[\"question\"], example[\"candidates\"], example[\"answer\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lsc4WadilvD1"
+ },
+ "source": [
+ "And now we wrap it in the Pytorch Datasets class and print one example as sanity check."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZXPeOA34lvD1"
+ },
+ "outputs": [],
+ "source": [
+ "train_dataset = VideoLlavaDataset(dataset[\"train\"])\n",
+ "eval_dataset = VideoLlavaDataset(dataset[\"test\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AeI7LaF4lvD1"
+ },
+ "outputs": [],
+ "source": [
+ "prompt, clip = train_dataset[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fzcocRe1lvD1",
+ "outputId": "b3948347-d146-432e-c8da-7a3ca7e83398"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'USER: \\nAnswer the following multiple choice question based on the video. Question: Which object was eaten by the person?\\n A. The refrigerator.; B. The medicine.; C. The picture.; D. The sandwich.; \\n ASSISTANT: Answer: B'"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "prompt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pd4mrynzlvD2"
+ },
+ "source": [
+ "## Load model\n",
+ "Next, we're going to load the Video-LLaVa model from the hub. This is a model with about 7 billion trainable parameters (as it combines a LLaMa-7B language model with a relatively low-parameter vision encoder). Do note that we load a model here which already has undergone supervised fine-tuning (SFT) on VideoChat instruction dataset. We can benefit from the fine-tuning that the model already has undergone.\n",
+ "\n",
+ "## Full fine-tuning, LoRa and Q-LoRa\n",
+ "As this model has 7 billion trainable parameters, that's going to have quite an impact on the amount of memory used. For reference, fine-tuning a model using the AdamW optimizer (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x7 billion bytes = 126 GB of GPU RAM if we want to update all the parameters of the model!! That's huge right? And for most people infeasible.\n",
+ "\n",
+ "Luckily, some clever people came up with the LoRa method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods (that's where the name PEFT comes from).\n",
+ "\n",
+ "Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.\n",
+ "\n",
+ "This means that we're going to shrink the size of the base Idefics2-8b model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.\n",
+ "\n",
+ "Of course, if you have the memory available, feel free to use full fine-tuning or LoRa without quantization! In case of full fine-tuning, the code snippet below instantiates the model with Flash Attention which considerably speeds up computations.\n",
+ "\n",
+ "There exist many forms of quantization, here we leverage the [BitsAndBytes integration](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "referenced_widgets": [
+ "0b428184283c4dcd8f2fa10ff3adbf04"
+ ]
+ },
+ "id": "DQ0nTqbVlvD2",
+ "outputId": "044885b1-6d44-46bb-e2c5-16038a34fdb2"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0b428184283c4dcd8f2fa10ff3adbf04",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "## Load model\n",
+ "# Three options for training, from the lowest precision training to the highest precision training:\n",
+ "# QLoRA: model uses 4-bit quantization, which helps in reducing memory usage while maintaining performance.\n",
+ "# Standard LoRA: model is loaded with standard LoRA adaptations.\n",
+ "# Full Fine-Tuning: no memory optimization are done. In that case Flash Attention is used to speed up training, if hardware supports it.\n",
+ "\n",
+ "if USE_QLORA or USE_LORA:\n",
+ " if USE_QLORA:\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = VideoLlavaForConditionalGeneration.from_pretrained(\n",
+ " MODEL_ID,\n",
+ " torch_dtype=torch.float16,\n",
+ " quantization_config=bnb_config,\n",
+ " device_map=\"auto\",\n",
+ " )\n",
+ "else:\n",
+ " # for full fine-tuning, we can speed up the model using Flash Attention\n",
+ " # only available on certain devices, see https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features\n",
+ " model = VideoLlavaForConditionalGeneration.from_pretrained(\n",
+ " MODEL_ID,\n",
+ " torch_dtype=torch.float16,\n",
+ " _attn_implementation=\"flash_attention_2\",\n",
+ " device_map=\"auto\",\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aNtOGpvplvD2"
+ },
+ "source": [
+ "## Apply PEFT\n",
+ "After loading the base model, we're going to add LoRa adapter layers. We're going to only train these adapter layers (the base model is kept frozen).\n",
+ "\n",
+ "The difference here with other models are the layers at which we're going to add adapters (in PEFT this is called target_modules). This typically depends a bit on the model.\n",
+ "\n",
+ "We defined a function to find all linear layers in the model, excluding any layers related to multimodal projections and vision models. This function will help us identify which layers should have LoRA applied. We're going to add adapters to all linear layers of the model (nn.Linear), except for the ones present in the vision encoder and multimodal projector. This means that we're mostly going to adapt the language model part of Video-LLaVa for our use case."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MYDW50LslvD2"
+ },
+ "outputs": [],
+ "source": [
+ "def find_all_linear_names(model):\n",
+ " cls = torch.nn.Linear\n",
+ " lora_module_names = set()\n",
+ " multimodal_keywords = ['multi_modal_projector', 'vision_model']\n",
+ " for name, module in model.named_modules():\n",
+ " if any(mm_keyword in name for mm_keyword in multimodal_keywords):\n",
+ " continue\n",
+ " if isinstance(module, cls):\n",
+ " names = name.split('.')\n",
+ " lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n",
+ "\n",
+ " if 'lm_head' in lora_module_names: # needed for 16-bit\n",
+ " lora_module_names.remove('lm_head')\n",
+ " return list(lora_module_names)\n",
+ "\n",
+ "\n",
+ "lora_config = LoraConfig(\n",
+ " r=8,\n",
+ " lora_alpha=8,\n",
+ " lora_dropout=0.1,\n",
+ " target_modules=find_all_linear_names(model),\n",
+ " init_lora_weights=\"gaussian\",\n",
+ ")\n",
+ "\n",
+ "model = prepare_model_for_kbit_training(model)\n",
+ "model = get_peft_model(model, lora_config)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true,
+ "id": "QPNYtC1QlvD2",
+ "outputId": "014f17a2-ccf2-41d1-c126-cc68cbf034e9"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PeftModel(\n",
+ " (base_model): LoraModel(\n",
+ " (model): VideoLlavaForConditionalGeneration(\n",
+ " (video_tower): CLIPVisionModel(\n",
+ " (vision_model): CLIPVisionTransformer(\n",
+ " (embeddings): CLIPVisionEmbeddings(\n",
+ " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
+ " (position_embedding): Embedding(257, 1024)\n",
+ " )\n",
+ " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (encoder): CLIPEncoder(\n",
+ " (layers): ModuleList(\n",
+ " (0-23): 24 x CLIPEncoderLayer(\n",
+ " (self_attn): CLIPAttention(\n",
+ " (k_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=1024, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (v_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=1024, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (q_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=1024, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (out_proj): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " )\n",
+ " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (mlp): CLIPMLP(\n",
+ " (activation_fn): QuickGELUActivation()\n",
+ " (fc1): Linear4bit(in_features=1024, out_features=4096, bias=True)\n",
+ " (fc2): Linear4bit(in_features=4096, out_features=1024, bias=True)\n",
+ " )\n",
+ " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " (image_tower): CLIPVisionModel(\n",
+ " (vision_model): CLIPVisionTransformer(\n",
+ " (embeddings): CLIPVisionEmbeddings(\n",
+ " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
+ " (position_embedding): Embedding(257, 1024)\n",
+ " )\n",
+ " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (encoder): CLIPEncoder(\n",
+ " (layers): ModuleList(\n",
+ " (0-23): 24 x CLIPEncoderLayer(\n",
+ " (self_attn): CLIPAttention(\n",
+ " (k_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=1024, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (v_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=1024, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (q_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=1024, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (out_proj): Linear4bit(in_features=1024, out_features=1024, bias=True)\n",
+ " )\n",
+ " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (mlp): CLIPMLP(\n",
+ " (activation_fn): QuickGELUActivation()\n",
+ " (fc1): Linear4bit(in_features=1024, out_features=4096, bias=True)\n",
+ " (fc2): Linear4bit(in_features=4096, out_features=1024, bias=True)\n",
+ " )\n",
+ " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " (multi_modal_projector): VideoLlavaMultiModalProjector(\n",
+ " (linear_1): Linear4bit(in_features=1024, out_features=4096, bias=True)\n",
+ " (act): GELUActivation()\n",
+ " (linear_2): Linear4bit(in_features=4096, out_features=4096, bias=True)\n",
+ " )\n",
+ " (language_model): LlamaForCausalLM(\n",
+ " (model): LlamaModel(\n",
+ " (embed_tokens): Embedding(32064, 4096, padding_idx=0)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x LlamaDecoderLayer(\n",
+ " (self_attn): LlamaSdpaAttention(\n",
+ " (q_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (k_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (v_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (o_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
+ " )\n",
+ " (mlp): LlamaMLP(\n",
+ " (gate_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=11008, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=11008, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (up_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=11008, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=11008, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (down_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=11008, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=11008, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (act_fn): SiLU()\n",
+ " )\n",
+ " (input_layernorm): LlamaRMSNorm()\n",
+ " (post_attention_layernorm): LlamaRMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): LlamaRMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=32064, bias=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zuv04h-DlvD2"
+ },
+ "source": [
+ "## Define PyTorch Lightning Module for Video-LLaVA\n",
+ "To streamline the training and evaluation of the Video-LLaVA model, we use [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which abstracts away much of the boilerplate code and provides a structured framework for model training. In this section, we define the VideoLlavaModelPLModule, a custom PyTorch Lightning module that encapsulates the model, training loop, validation loop, and optimizer configuration.\n",
+ "\n",
+ "### VideoLlavaModelPLModule Class\n",
+ "\n",
+ "The VideoLlavaModelPLModule class inherits from LightningModule and includes methods for training, validation, and optimizer configuration. This setup ensures a clean and efficient training process.\n",
+ "\n",
+ "Basically, PyTorch Lightning will take care of all device placements (.to(device)) for us, as well as the backward pass, putting the model in training mode, etc.\n",
+ "\n",
+ "Notice the difference between a training step and an evaluation step:\n",
+ "\n",
+ "- a training step only consists of a forward pass, in which we compute the cross-entropy loss between the model's next token predictions and the ground truth (in parallel for all tokens, this technique is known as \"teacher forcing\"). The backward pass is handled by PyTorch Lightning.\n",
+ "- an evaluation step consists of making the model autoregressively complete the prompt using the generate() method. After that, we compute an evaluation metric between the predicted sequences and the ground truth ones. This allows us to see how the model is improving over the course of training. The metric we use here is accuracy of answering the question.\n",
+ "\n",
+ "Besides that, we define the optimizer to use (AdamW is a good default choice) and the data loaders, which use the collate functions defined above to batch together items of the PyTorch datasets. Do note that AdamW is a pretty heavy optimizer in terms of memory requirements, but as we're training with QLoRa we only need to store optimizer states for the adapter layers. For full fine-tuning, one could take a look at more memory friendly optimizers such as 8-bit Adam."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WvNdvJ1ylvD2"
+ },
+ "outputs": [],
+ "source": [
+ "class VideoLlavaModelPLModule(L.LightningModule):\n",
+ " def __init__(self, config, processor, model):\n",
+ " super().__init__()\n",
+ " self.config = config\n",
+ " self.processor = processor\n",
+ " self.model = model\n",
+ "\n",
+ " self.batch_size = config.get(\"batch_size\")\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ "\n",
+ " input_ids, attention_mask, pixel_values_videos, labels = batch\n",
+ "\n",
+ " outputs = self.model(\n",
+ " input_ids=input_ids,\n",
+ " attention_mask=attention_mask,\n",
+ " pixel_values_videos=pixel_values_videos,\n",
+ " labels=labels\n",
+ " )\n",
+ " loss = outputs.loss\n",
+ "\n",
+ " self.log(\"train_loss\", loss)\n",
+ "\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx, dataset_idx=0):\n",
+ "\n",
+ " input_ids, attention_mask, pixel_values_videos, answers = batch\n",
+ "\n",
+ " # autoregressively generate token IDs\n",
+ " generated_ids = self.model.generate(\n",
+ " input_ids=input_ids,\n",
+ " attention_mask=attention_mask,\n",
+ " pixel_values_videos=pixel_values_videos,\n",
+ " max_new_tokens=MAX_LENGTH,\n",
+ " do_sample=False,\n",
+ " )\n",
+ " # turn them back into text, chopping of the prompt\n",
+ " predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)\n",
+ "\n",
+ " correct = 0\n",
+ " for pred, answer in zip(predictions, answers):\n",
+ " correct += (pred.strip().lower() == answer.lower())\n",
+ " self.log(\"val_accuracy\", correct / len(answers))\n",
+ "\n",
+ " return correct\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " # you could also add a learning rate scheduler if you want\n",
+ " optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get(\"lr\"))\n",
+ "\n",
+ " return optimizer\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(eval_dataset, collate_fn=eval_collate_fn, batch_size=self.batch_size, shuffle=False, num_workers=4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "u6iWFbnQlvD2"
+ },
+ "source": [
+ "Let's instantiate it (based on a config dictionary which defines all hyperparameters for training).\n",
+ "\n",
+ "The batch size was determined based on the compute available.\n",
+ "\n",
+ "Do note that one can play around with the hyperparameters, I just use good defaults here: 10 epochs, a learning rate of 1e-4 which I found in the original Idefics2 notebook (linked at the top of this notebook), use mixed precision for training (more memory friendly). One could extend this with things like gradient accumulation and gradient checkpointing.\n",
+ "\n",
+ "I recommend [this guide](https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one) which goes over all tips and tricks regarding maximizing fine-tuning performance on consumer hardware."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "K2I7vJaDlvD2"
+ },
+ "outputs": [],
+ "source": [
+ "config = {\"max_epochs\": 2,\n",
+ " # \"val_check_interval\": 0.2, # how many times we want to validate during an epoch\n",
+ " \"check_val_every_n_epoch\": 1,\n",
+ " \"gradient_clip_val\": 1.0,\n",
+ " \"accumulate_grad_batches\": 8,\n",
+ " \"lr\": 1e-4,\n",
+ " \"batch_size\": 1,\n",
+ " \"num_nodes\": 1,\n",
+ " \"warmup_steps\": 50,\n",
+ "}\n",
+ "\n",
+ "model_module = VideoLlavaModelPLModule(config, processor, model)\n",
+ "early_stop_callback = EarlyStopping(monitor=\"val_accuracy\", patience=3, verbose=False, mode=\"min\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I-EncDAllvD2"
+ },
+ "source": [
+ "## Define callbacks\n",
+ "Optionally, Lightning allows to define so-called [callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), which are arbitrary pieces of code that can be executed during training.\n",
+ "\n",
+ "Here I'm adding a PushToHubCallback which will push the model to the hub at the end of every epoch as well as at the end of training. Do note that you could of course also pass the private=True flag when pushing to the hub, if you wish to keep your model private. Hugging Face also offers the Enterprise Hub so that you can easily share models with your colleagues privately in a secure way.\n",
+ "\n",
+ "We'll also use the EarlyStopping callback of Lightning, which will automatically stop training once the evaluation metric (edit distance in our case) doesn't improve after 3 epochs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "p2KRo-rclvD2"
+ },
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import HfApi\n",
+ "\n",
+ "api = HfApi()\n",
+ "\n",
+ "class PushToHubCallback(Callback):\n",
+ " def on_train_epoch_end(self, trainer, pl_module):\n",
+ " print(f\"Pushing model to the hub, epoch {trainer.current_epoch}\")\n",
+ " pl_module.model.push_to_hub(REPO_ID,\n",
+ " commit_message=f\"Training in progress, epoch {trainer.current_epoch}\")\n",
+ "\n",
+ " def on_train_end(self, trainer, pl_module):\n",
+ " print(f\"Pushing model to the hub after training\")\n",
+ " pl_module.processor.push_to_hub(REPO_ID,\n",
+ " commit_message=f\"Training done\")\n",
+ " pl_module.model.push_to_hub(REPO_ID,\n",
+ " commit_message=f\"Training done\")\n",
+ "\n",
+ "early_stop_callback = EarlyStopping(monitor=\"val_edit_distance\", patience=3, verbose=False, mode=\"min\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "S-ylKSrMlvD2"
+ },
+ "source": [
+ "## Train!\n",
+ "Alright, we're set to start training!\n",
+ "\n",
+ "Do note that this Trainer class supports many more flags! See the [docs](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "liZhW7NhlvD3",
+ "outputId": "16ab8773-d40f-48f7-e485-b3960c2bea47"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using 16bit Automatic Mixed Precision (AMP)\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainer = L.Trainer(\n",
+ " default_root_dir=\"/raid/.cache/huggingface/video_llava_demo\",\n",
+ " accelerator=\"gpu\",\n",
+ " devices=[0],\n",
+ " max_epochs=config.get(\"max_epochs\"),\n",
+ " accumulate_grad_batches=config.get(\"accumulate_grad_batches\"),\n",
+ " check_val_every_n_epoch=config.get(\"check_val_every_n_epoch\"),\n",
+ " gradient_clip_val=config.get(\"gradient_clip_val\"),\n",
+ " precision=\"16-mixed\",\n",
+ " limit_val_batches=5,\n",
+ " num_sanity_val_steps=1,\n",
+ " callbacks=[early_stop_callback, PushToHubCallback()],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true,
+ "colab": {
+ "referenced_widgets": [
+ "",
+ "9c99e8a134844fc6bc734868514d92fc"
+ ]
+ },
+ "id": "X1XGM6QflvD3",
+ "outputId": "113a2fc3-d7b3-4140-b614-1b46e14b411e"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
+ "Missing logger folder: /raid/.cache/huggingface/video_llava_demo/lightning_logs\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "------------------------------------\n",
+ "0 | model | PeftModel | 3.8 B \n",
+ "------------------------------------\n",
+ "22.3 M Trainable params\n",
+ "3.8 B Non-trainable params\n",
+ "3.8 B Total params\n",
+ "15,352.594Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: | …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/env0/lib/python3.8/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9c99e8a134844fc6bc734868514d92fc",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: | …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "/home/raushan/transformers/src/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.)\n",
+ " return torch.tensor(value)\n",
+ "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainer.fit(model_module)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "z6RpsZBClvD6"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hQPiR5IV89qs"
+ },
+ "source": [
+ "## Inference\n",
+ "\n",
+ "Let's see if the model has learned something. We'll load the model from the hub first. Notice that, as we only trained adapters on top of the base model, the repository on the hub to which we pushed only contains the weights and configuration of the adapters. This is a very lightweight file smaller than 100 MB.\n",
+ "\n",
+ "Thanks to the PEFT integration in Transformers, the `from_pretrained` method will automatically load the weights of the base model as well as the adapter weights.\n",
+ "\n",
+ "To reduce inference costs, we'll again load the model in 4 bits by passing a `quantization_config`, in order to reduce memory usage."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jc0JHzIY89qs",
+ "outputId": "5845f9fb-0296-47b3-ad21-9dec3fb1093e",
+ "colab": {
+ "referenced_widgets": [
+ "245f0d12878c4efc992829d9b3c78288",
+ "13b98dae2e714ac3ac1df6963ebca7a2",
+ "b4b5698952c2401dabc5924e1cfa1d6d"
+ ]
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "245f0d12878c4efc992829d9b3c78288",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "adapter_config.json: 0%| | 0.00/840 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "13b98dae2e714ac3ac1df6963ebca7a2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b4b5698952c2401dabc5924e1cfa1d6d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "adapter_model.safetensors: 0%| | 0.00/89.5M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from transformers import AutoProcessor, BitsAndBytesConfig, VideoLlavaForConditionalGeneration\n",
+ "import torch\n",
+ "\n",
+ "processor = AutoProcessor.from_pretrained(MODEL_ID)\n",
+ "\n",
+ "# Define quantization config\n",
+ "quantization_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ ")\n",
+ "\n",
+ "# Load the base model with adapters on top\n",
+ "model = VideoLlavaForConditionalGeneration.from_pretrained(\n",
+ " REPO_ID,\n",
+ " torch_dtype=torch.float16,\n",
+ " quantization_config=quantization_config,\n",
+ " device_map=\"auto\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5Lqxte3U89qs"
+ },
+ "source": [
+ "Now we're ready to perform inference. We'll take a one example from the validation set here and plot 8 frames to see what is happening in the video."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BNw1jE0n89qw",
+ "outputId": "f93b5f96-8565-4c78-b802-74195ae63d9c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from matplotlib import pyplot as plt\n",
+ "from PIL import Image\n",
+ "\n",
+ "prompt, clip = eval_dataset[1]\n",
+ "fig, axarr = plt.subplots(2, 4, figsize = (10, 10))\n",
+ "fig.tight_layout()\n",
+ "\n",
+ "for i in range(2):\n",
+ " for j in range(4):\n",
+ " curr_frame = Image.fromarray(np.uint8(clip[i + j]))\n",
+ " axarr[i, j].imshow(curr_frame)\n",
+ " axarr[i, j].get_xaxis().set_visible(False)\n",
+ " axarr[i, j].get_yaxis().set_visible(False)\n",
+ " axarr[i, j].set_aspect('equal')\n",
+ "\n",
+ "plt.subplots_adjust(wspace=None, hspace=None)\n",
+ "plt.axis('off')\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xWWHU1xB89qx"
+ },
+ "source": [
+ "Next we need to prepare the video for the model, along with the text prompt we used during training. We need to apply the chat template to make sure the format is respected.\n",
+ "\n",
+ "Notice that this is exactly the same as what we did for the evaluation data collate function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BBWCEdiT89qx",
+ "outputId": "3d36269e-c9b7-43e2-85ad-1500c62f146d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pixel_values_videos torch.Size([1, 8, 3, 224, 224])\n",
+ "input_ids torch.Size([1, 94])\n",
+ "attention_mask torch.Size([1, 94])\n"
+ ]
+ }
+ ],
+ "source": [
+ "answer = prompt[-1]\n",
+ "prompt = prompt[:-2]\n",
+ "\n",
+ "inputs = processor(text=prompt, videos=clip, return_tensors=\"pt\").to(model.device)\n",
+ "for k,v in inputs.items():\n",
+ " print(k,v.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "w2Brzvy689qx"
+ },
+ "source": [
+ "Next we let the model autoregressively generate tokens using the [generate()](https://huggingface.co/docs/transformers/v4.40.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) method, which is recommended for use at inference time. This method feeds each predicted token back into the model as conditioning for each next time step.\n",
+ "\n",
+ "Do note that there are various ways of decoding text, here we use greedy decoding which is the default. There are various fancier methods such as beam search and top-k sampling. Refer to [this amazing blog post](https://huggingface.co/blog/how-to-generate) for all details."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xFSVXwN889qx"
+ },
+ "outputs": [],
+ "source": [
+ "# Generate token IDs\n",
+ "generated_ids = model.generate(**inputs, max_new_tokens=MAX_LENGTH)\n",
+ "\n",
+ "# Decode back into text\n",
+ "generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Bn5yop6589qx",
+ "outputId": "537adbf4-56b6-4513-dbb5-58ba0adaa898"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['USER: \\nAnswer the following multiple choice question based on the video. Question: What will happen if the sphere is removed?\\n A. The blue cylinder collides with the cyan cylinder; B. The blue cylinder collides with the yellow cylinder; C. The green object collides with the yellow object; D. The blue cylinder collides with the green object; \\n ASSISTANT: Answer: A']\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(generated_texts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ELwUILkk89qx",
+ "outputId": "e457c0df-f6d6-4193-e2bc-b91410c6179a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "A\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(answer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bKS_QcmX89qx"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file