From a46a813a811824d1430cc8e92efb6d9092bd605a Mon Sep 17 00:00:00 2001 From: dhingratul Date: Fri, 14 Jul 2023 15:49:18 -0700 Subject: [PATCH] PeFT training --- notebooks/L2_PeFT.ipynb | 2864 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 2864 insertions(+) create mode 100644 notebooks/L2_PeFT.ipynb diff --git a/notebooks/L2_PeFT.ipynb b/notebooks/L2_PeFT.ipynb new file mode 100644 index 0000000..7b65e87 --- /dev/null +++ b/notebooks/L2_PeFT.ipynb @@ -0,0 +1,2864 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# Fine-Tune a Generative AI Model for Dialogue Summarization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, you will fine-tune an existing LLM from Hugging Face for enhanced dialogue summarization. You will use the [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model, which provides a high quality instruction tuned model and can summarize text out of the box. To improve the inferences, you will explore a full fine-tuning approach and evaluate the results with ROUGE metrics. Then you will perform Parameter Efficient Fine-Tuning (PEFT), evaluate the resulting model and see that the benefits of PEFT outweigh the slightly-lower performance metrics." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Table of Contents" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "- [ 1 - Set up Kernel, Load Required Dependencies, Dataset and LLM](#1)\n", + " - [ 1.1 - Set up Kernel and Required Dependencies](#1.1)\n", + " - [ 1.2 - Load Dataset and LLM](#1.2)\n", + " - [ 1.3 - Test the Model with Zero Shot Inferencing](#1.3)\n", + "- [ 2 - Perform Full Fine-Tuning](#2)\n", + " - [ 2.1 - Preprocess the Dialog-Summary Dataset](#2.1)\n", + " - [ 2.2 - Fine-Tune the Model with the Preprocessed Dataset](#2.2)\n", + " - [ 2.3 - Evaluate the Model Qualitatively (Human Evaluation)](#2.3)\n", + " - [ 2.4 - Evaluate the Model Quantitatively (with ROUGE Metric)](#2.4)\n", + "- [ 3 - Perform Parameter Efficient Fine-Tuning (PEFT)](#3)\n", + " - [ 3.1 - Setup the PEFT/LoRA model for Fine-Tuning](#3.1)\n", + " - [ 3.2 - Train PEFT Adapter](#3.2)\n", + " - [ 3.3 - Evaluate the Model Qualitatively (Human Evaluation)](#3.3)\n", + " - [ 3.4 - Evaluate the Model Quantitatively (with ROUGE Metric)](#3.4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## 1 - Set up Kernel, Load Required Dependencies, Dataset and LLM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### 1.1 - Set up Kernel and Required Dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "To begin with, check that the kernel is selected correctly.\n", + "\n", + "\n", + "\n", + "If you click on that (top right of the screen), you'll be able to see and check the details of the image, kernel, and instance type.\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Now install the required packages for the LLM and datasets.\n", + "\n", + "\"Time" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pip in /opt/conda/lib/python3.7/site-packages (23.1.2)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed.\n", + "pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed.\n", + "spyder 4.0.1 requires pyqt5<5.13; python_version >= \"3\", which is not installed.\n", + "spyder 4.0.1 requires pyqtwebengine<5.13; python_version >= \"3\", which is not installed.\n", + "sagemaker 2.165.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 6.6.0 which is incompatible.\n", + "sparkmagic 0.20.4 requires nest-asyncio==1.5.5, but you have nest-asyncio 1.5.6 which is incompatible.\n", + "spyder 4.0.1 requires jedi==0.14.1, but you have jedi 0.18.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --upgrade pip\n", + "%pip install --disable-pip-version-check \\\n", + " torch==1.13.1 \\\n", + " torchdata==0.5.1 --quiet\n", + "\n", + "%pip install \\\n", + " transformers==4.27.2 \\\n", + " datasets==2.11.0 \\\n", + " evaluate==0.4.0 \\\n", + " rouge_score==0.1.2 \\\n", + " loralib==0.1.1 \\\n", + " peft==0.3.0 --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\"Time" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Import the necessary components. Some of them are new for this week, they will be discussed later in the notebook. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer\n", + "import torch\n", + "import time\n", + "import evaluate\n", + "import pandas as pd\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "### 1.2 - Load Dataset and LLM\n", + "\n", + "You are going to continue experimenting with the [DialogSum](https://huggingface.co/datasets/knkarthick/dialogsum) Hugging Face dataset. It contains 10,000+ dialogues with the corresponding manually labeled summaries and topics. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2a05a19a9bdf4e989ca08528157e3110", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading readme: 0%| | 0.00/4.56k [00:00\n", + "### 1.3 - Test the Model with Zero Shot Inferencing\n", + "\n", + "Test the model with the zero shot inferencing. You can see that the model struggles to summarize the dialogue compared to the baseline summary, but it does pull out some important information from the text which indicates the model can be fine-tuned to the task at hand." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------------------------------------------------------------------------------------------------\n", + "INPUT PROMPT:\n", + "\n", + "Summarize the following conversation.\n", + "\n", + "#Person1#: Have you considered upgrading your system?\n", + "#Person2#: Yes, but I'm not sure what exactly I would need.\n", + "#Person1#: You could consider adding a painting program to your software. It would allow you to make up your own flyers and banners for advertising.\n", + "#Person2#: That would be a definite bonus.\n", + "#Person1#: You might also want to upgrade your hardware because it is pretty outdated now.\n", + "#Person2#: How can we do that?\n", + "#Person1#: You'd probably need a faster processor, to begin with. And you also need a more powerful hard disc, more memory and a faster modem. Do you have a CD-ROM drive?\n", + "#Person2#: No.\n", + "#Person1#: Then you might want to add a CD-ROM drive too, because most new software programs are coming out on Cds.\n", + "#Person2#: That sounds great. Thanks.\n", + "\n", + "Summary:\n", + "\n", + "---------------------------------------------------------------------------------------------------\n", + "BASELINE HUMAN SUMMARY:\n", + "#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.\n", + "\n", + "---------------------------------------------------------------------------------------------------\n", + "MODEL GENERATION - ZERO SHOT:\n", + "#Person1#: I'm thinking of upgrading my computer.\n" + ] + } + ], + "source": [ + "index = 200\n", + "\n", + "dialogue = dataset['test'][index]['dialogue']\n", + "summary = dataset['test'][index]['summary']\n", + "\n", + "prompt = f\"\"\"\n", + "Summarize the following conversation.\n", + "\n", + "{dialogue}\n", + "\n", + "Summary:\n", + "\"\"\"\n", + "\n", + "inputs = tokenizer(prompt, return_tensors='pt')\n", + "output = tokenizer.decode(\n", + " original_model.generate(\n", + " inputs[\"input_ids\"], \n", + " max_new_tokens=200,\n", + " )[0], \n", + " skip_special_tokens=True\n", + ")\n", + "\n", + "dash_line = '-'.join('' for x in range(100))\n", + "print(dash_line)\n", + "print(f'INPUT PROMPT:\\n{prompt}')\n", + "print(dash_line)\n", + "print(f'BASELINE HUMAN SUMMARY:\\n{summary}\\n')\n", + "print(dash_line)\n", + "print(f'MODEL GENERATION - ZERO SHOT:\\n{output}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "## 2 - Perform Full Fine-Tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "### 2.1 - Preprocess the Dialog-Summary Dataset\n", + "\n", + "You need to convert the dialog-summary (prompt-response) pairs into explicit instructions for the LLM. Prepend an instruction to the start of the dialog with `Summarize the following conversation` and to the start of the summary with `Summary` as follows:\n", + "\n", + "Training prompt (dialogue):\n", + "```\n", + "Summarize the following conversation.\n", + "\n", + " Chris: This is his part of the conversation.\n", + " Antje: This is her part of the conversation.\n", + " \n", + "Summary: \n", + "```\n", + "\n", + "Training response (summary):\n", + "```\n", + "Both Chris and Antje participated in the conversation.\n", + "```\n", + "\n", + "Then preprocess the prompt-response dataset into tokens and pull out their `input_ids` (1 per token)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/12460 [00:00\n", + "### 2.2 - Fine-Tune the Model with the Preprocessed Dataset\n", + "\n", + "Now utilize the built-in Hugging Face `Trainer` class (see the documentation [here](https://huggingface.co/docs/transformers/main_classes/trainer)). Pass the preprocessed dataset with reference to the original model. Other training parameters are found experimentally and there is no need to go into details about those at the moment." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "output_dir = f'./dialogue-summary-training-{str(int(time.time()))}'\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir=output_dir,\n", + " learning_rate=1e-5,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " logging_steps=1,\n", + " max_steps=1\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " model=original_model,\n", + " args=training_args,\n", + " train_dataset=tokenized_datasets['train'],\n", + " eval_dataset=tokenized_datasets['validation']\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Start training process...\n", + "\n", + "\"Time" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " FutureWarning,\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [1/1 00:00, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
149.250000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=1, training_loss=49.25, metrics={'train_runtime': 73.1666, 'train_samples_per_second': 0.109, 'train_steps_per_second': 0.014, 'total_flos': 5478058819584.0, 'train_loss': 49.25, 'epoch': 0.06})" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\"Time" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Training a fully fine-tuned version of the model would take a few hours on a GPU. To save time, download a checkpoint of the fully fine-tuned model to use in the rest of this notebook. This fully fine-tuned model will also be referred to as the **instruct model** in this lab." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/generation_config.json to flan-dialogue-summary-checkpoint/generation_config.json\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/scheduler.pt to flan-dialogue-summary-checkpoint/scheduler.pt\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/config.json to flan-dialogue-summary-checkpoint/config.json\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/rng_state.pth to flan-dialogue-summary-checkpoint/rng_state.pth\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/trainer_state.json to flan-dialogue-summary-checkpoint/trainer_state.json\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/training_args.bin to flan-dialogue-summary-checkpoint/training_args.bin\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/pytorch_model.bin to flan-dialogue-summary-checkpoint/pytorch_model.bin\n", + "download: s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/optimizer.pt to flan-dialogue-summary-checkpoint/optimizer.pt\n" + ] + } + ], + "source": [ + "!aws s3 cp --recursive s3://dlai-generative-ai/models/flan-dialogue-summary-checkpoint/ ./flan-dialogue-summary-checkpoint/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "The size of the downloaded instruct model is approximately 1GB." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-rw-r--r-- 1 root root 945M May 15 10:25 ./flan-dialogue-summary-checkpoint/pytorch_model.bin\n" + ] + } + ], + "source": [ + "!ls -alh ./flan-dialogue-summary-checkpoint/pytorch_model.bin" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Create an instance of the `AutoModelForSeq2SeqLM` class for the instruct model:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "instruct_model = AutoModelForSeq2SeqLM.from_pretrained(\"./flan-dialogue-summary-checkpoint\", torch_dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### 2.3 - Evaluate the Model Qualitatively (Human Evaluation)\n", + "\n", + "As with many GenAI applications, a qualitative approach where you ask yourself the question \"Is my model behaving the way it is supposed to?\" is usually a good starting point. In the example below (the same one we started this notebook with), you can see how the fine-tuned model is able to create a reasonable summary of the dialogue compared to the original inability to understand what is being asked of the model." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------------------------------------------------------------------------------------------------\n", + "BASELINE HUMAN SUMMARY:\n", + "#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.\n", + "---------------------------------------------------------------------------------------------------\n", + "ORIGINAL MODEL:\n", + "#Person1#: You'd like to upgrade your computer. #Person2: You'd like to upgrade your computer.\n", + "---------------------------------------------------------------------------------------------------\n", + "INSTRUCT MODEL:\n", + "#Person1# suggests #Person2# upgrading #Person2#'s system, hardware, and CD-ROM drive. #Person2# thinks it's great.\n" + ] + } + ], + "source": [ + "index = 200\n", + "dialogue = dataset['test'][index]['dialogue']\n", + "human_baseline_summary = dataset['test'][index]['summary']\n", + "\n", + "prompt = f\"\"\"\n", + "Summarize the following conversation.\n", + "\n", + "{dialogue}\n", + "\n", + "Summary:\n", + "\"\"\"\n", + "\n", + "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + "\n", + "original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", + "original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", + "\n", + "instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", + "instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)\n", + "\n", + "print(dash_line)\n", + "print(f'BASELINE HUMAN SUMMARY:\\n{human_baseline_summary}')\n", + "print(dash_line)\n", + "print(f'ORIGINAL MODEL:\\n{original_model_text_output}')\n", + "print(dash_line)\n", + "print(f'INSTRUCT MODEL:\\n{instruct_model_text_output}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### 2.4 - Evaluate the Model Quantitatively (with ROUGE Metric)\n", + "\n", + "The [ROUGE metric](https://en.wikipedia.org/wiki/ROUGE_(metric)) helps quantify the validity of summarizations produced by models. It compares summarizations to a \"baseline\" summary which is usually created by a human. While not perfect, it does indicate the overall increase in summarization effectiveness that we have accomplished by fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "64753dda710b47b98c1a5c9e94caf476", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading builder script: 0%| | 0.00/6.27k [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
human_baseline_summariesoriginal_model_summariesinstruct_model_summaries
0Ms. Dawson helps #Person1# to write a memo to ...#Person1#: Thank you for your time.#Person1# asks Ms. Dawson to take a dictation ...
1In order to prevent employees from wasting tim...This memo should go out as an intra-office mem...#Person1# asks Ms. Dawson to take a dictation ...
2Ms. Dawson takes a dictation for #Person1# abo...Employees who use the Instant Messaging progra...#Person1# asks Ms. Dawson to take a dictation ...
3#Person2# arrives late because of traffic jam....#Person1: I'm sorry you're stuck in traffic. #...#Person2# got stuck in traffic again. #Person1...
4#Person2# decides to follow #Person1#'s sugges...#Person1#: I'm finally here. I've got a traffi...#Person2# got stuck in traffic again. #Person1...
5#Person2# complains to #Person1# about the tra...The driver of the car is stuck in a traffic jam.#Person2# got stuck in traffic again. #Person1...
6#Person1# tells Kate that Masha and Hero get d...Masha and Hero are getting divorced.Masha and Hero are getting divorced. Kate can'...
7#Person1# tells Kate that Masha and Hero are g...Masha and Hero are getting married.Masha and Hero are getting divorced. Kate can'...
8#Person1# and Kate talk about the divorce betw...Masha and Hero are getting divorced.Masha and Hero are getting divorced. Kate can'...
9#Person1# and Brian are at the birthday party ...#Person1#: Happy birthday, Brian. #Person2#: H...Brian's birthday is coming. #Person1# invites ...
\n", + "" + ], + "text/plain": [ + " human_baseline_summaries \\\n", + "0 Ms. Dawson helps #Person1# to write a memo to ... \n", + "1 In order to prevent employees from wasting tim... \n", + "2 Ms. Dawson takes a dictation for #Person1# abo... \n", + "3 #Person2# arrives late because of traffic jam.... \n", + "4 #Person2# decides to follow #Person1#'s sugges... \n", + "5 #Person2# complains to #Person1# about the tra... \n", + "6 #Person1# tells Kate that Masha and Hero get d... \n", + "7 #Person1# tells Kate that Masha and Hero are g... \n", + "8 #Person1# and Kate talk about the divorce betw... \n", + "9 #Person1# and Brian are at the birthday party ... \n", + "\n", + " original_model_summaries \\\n", + "0 #Person1#: Thank you for your time. \n", + "1 This memo should go out as an intra-office mem... \n", + "2 Employees who use the Instant Messaging progra... \n", + "3 #Person1: I'm sorry you're stuck in traffic. #... \n", + "4 #Person1#: I'm finally here. I've got a traffi... \n", + "5 The driver of the car is stuck in a traffic jam. \n", + "6 Masha and Hero are getting divorced. \n", + "7 Masha and Hero are getting married. \n", + "8 Masha and Hero are getting divorced. \n", + "9 #Person1#: Happy birthday, Brian. #Person2#: H... \n", + "\n", + " instruct_model_summaries \n", + "0 #Person1# asks Ms. Dawson to take a dictation ... \n", + "1 #Person1# asks Ms. Dawson to take a dictation ... \n", + "2 #Person1# asks Ms. Dawson to take a dictation ... \n", + "3 #Person2# got stuck in traffic again. #Person1... \n", + "4 #Person2# got stuck in traffic again. #Person1... \n", + "5 #Person2# got stuck in traffic again. #Person1... \n", + "6 Masha and Hero are getting divorced. Kate can'... \n", + "7 Masha and Hero are getting divorced. Kate can'... \n", + "8 Masha and Hero are getting divorced. Kate can'... \n", + "9 Brian's birthday is coming. #Person1# invites ... " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dialogues = dataset['test'][0:10]['dialogue']\n", + "human_baseline_summaries = dataset['test'][0:10]['summary']\n", + "\n", + "original_model_summaries = []\n", + "instruct_model_summaries = []\n", + "\n", + "for _, dialogue in enumerate(dialogues):\n", + " prompt = f\"\"\"\n", + "Summarize the following conversation.\n", + "\n", + "{dialogue}\n", + "\n", + "Summary: \"\"\"\n", + " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + "\n", + " original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", + " original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", + " original_model_summaries.append(original_model_text_output)\n", + "\n", + " instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", + " instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)\n", + " instruct_model_summaries.append(instruct_model_text_output)\n", + " \n", + "zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries))\n", + " \n", + "df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries'])\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Evaluate the models computing ROUGE metrics. Notice the improvement in the results!" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ORIGINAL MODEL:\n", + "{'rouge1': 0.24223171760013867, 'rouge2': 0.10614243734192583, 'rougeL': 0.21380459196706333, 'rougeLsum': 0.21740921541379205}\n", + "INSTRUCT MODEL:\n", + "{'rouge1': 0.41026607717457186, 'rouge2': 0.17840645241958838, 'rougeL': 0.2977022096267017, 'rougeLsum': 0.2987374187518165}\n" + ] + } + ], + "source": [ + "original_model_results = rouge.compute(\n", + " predictions=original_model_summaries,\n", + " references=human_baseline_summaries[0:len(original_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "instruct_model_results = rouge.compute(\n", + " predictions=instruct_model_summaries,\n", + " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "print('ORIGINAL MODEL:')\n", + "print(original_model_results)\n", + "print('INSTRUCT MODEL:')\n", + "print(instruct_model_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The file `data/dialogue-summary-training-results.csv` contains a pre-populated list of all model results which you can use to evaluate on a larger section of data. Let's do that for each of the models:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ORIGINAL MODEL:\n", + "{'rouge1': 0.2334158581572823, 'rouge2': 0.07603964187010573, 'rougeL': 0.20145520923859048, 'rougeLsum': 0.20145899339006135}\n", + "INSTRUCT MODEL:\n", + "{'rouge1': 0.42161291557556113, 'rouge2': 0.18035380596301792, 'rougeL': 0.3384439349963909, 'rougeLsum': 0.33835653595561666}\n" + ] + } + ], + "source": [ + "results = pd.read_csv(\"data/dialogue-summary-training-results.csv\")\n", + "\n", + "human_baseline_summaries = results['human_baseline_summaries'].values\n", + "original_model_summaries = results['original_model_summaries'].values\n", + "instruct_model_summaries = results['instruct_model_summaries'].values\n", + "\n", + "original_model_results = rouge.compute(\n", + " predictions=original_model_summaries,\n", + " references=human_baseline_summaries[0:len(original_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "instruct_model_results = rouge.compute(\n", + " predictions=instruct_model_summaries,\n", + " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "print('ORIGINAL MODEL:')\n", + "print(original_model_results)\n", + "print('INSTRUCT MODEL:')\n", + "print(instruct_model_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "The results show substantial improvement in all ROUGE metrics:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Absolute percentage improvement of INSTRUCT MODEL over HUMAN BASELINE\n", + "rouge1: 18.82%\n", + "rouge2: 10.43%\n", + "rougeL: 13.70%\n", + "rougeLsum: 13.69%\n" + ] + } + ], + "source": [ + "print(\"Absolute percentage improvement of INSTRUCT MODEL over HUMAN BASELINE\")\n", + "\n", + "improvement = (np.array(list(instruct_model_results.values())) - np.array(list(original_model_results.values())))\n", + "for key, value in zip(instruct_model_results.keys(), improvement):\n", + " print(f'{key}: {value*100:.2f}%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## 3 - Perform Parameter Efficient Fine-Tuning (PEFT)\n", + "\n", + "Now, let's perform **Parameter Efficient Fine-Tuning (PEFT)** fine-tuning as opposed to \"full fine-tuning\" as you did above. PEFT is a form of instruction fine-tuning that is much more efficient than full fine-tuning - with comparable evaluation results as you will see soon. \n", + "\n", + "PEFT is a generic term that includes **Low-Rank Adaptation (LoRA)** and prompt tuning (which is NOT THE SAME as prompt engineering!). In most cases, when someone says PEFT, they typically mean LoRA. LoRA, at a very high level, allows the user to fine-tune their model using fewer compute resources (in some cases, a single GPU). After fine-tuning for a specific task, use case, or tenant with LoRA, the result is that the original LLM remains unchanged and a newly-trained “LoRA adapter” emerges. This LoRA adapter is much, much smaller than the original LLM - on the order of a single-digit % of the original LLM size (MBs vs GBs). \n", + "\n", + "That said, at inference time, the LoRA adapter needs to be reunited and combined with its original LLM to serve the inference request. The benefit, however, is that many LoRA adapters can re-use the original LLM which reduces overall memory requirements when serving multiple tasks and use cases." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### 3.1 - Setup the PEFT/LoRA model for Fine-Tuning\n", + "\n", + "You need to set up the PEFT/LoRA model for fine-tuning with a new layer/parameter adapter. Using PEFT/LoRA, you are freezing the underlying LLM and only training the adapter. Have a look at the LoRA configuration below. Note the rank (`r`) hyper-parameter, which defines the rank/dimension of the adapter to be trained." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from peft import LoraConfig, get_peft_model, TaskType\n", + "\n", + "lora_config = LoraConfig(\n", + " r=32, # Rank\n", + " lora_alpha=32,\n", + " target_modules=[\"q\", \"v\"],\n", + " lora_dropout=0.05,\n", + " bias=\"none\",\n", + " task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Add LoRA adapter layers/parameters to the original LLM to be trained." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable model parameters: 3538944\n", + "all model parameters: 251116800\n", + "percentage of trainable model parameters: 1.41%\n" + ] + } + ], + "source": [ + "peft_model = get_peft_model(original_model, \n", + " lora_config)\n", + "print(print_number_of_trainable_model_parameters(peft_model))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "### 3.2 - Train PEFT Adapter\n", + "\n", + "Define training arguments and create `Trainer` instance." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "output_dir = f'./peft-dialogue-summary-training-{str(int(time.time()))}'\n", + "\n", + "peft_training_args = TrainingArguments(\n", + " output_dir=output_dir,\n", + " auto_find_batch_size=True,\n", + " learning_rate=1e-3, # Higher learning rate than full fine-tuning.\n", + " num_train_epochs=1,\n", + " logging_steps=1,\n", + " max_steps=1 \n", + ")\n", + " \n", + "peft_trainer = Trainer(\n", + " model=peft_model,\n", + " args=peft_training_args,\n", + " train_dataset=tokenized_datasets[\"train\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now everything is ready to train the PEFT adapter and save the model.\n", + "\n", + "\"Time" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " FutureWarning,\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1/1 00:00, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
151.000000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "('./peft-dialogue-summary-checkpoint-local/tokenizer_config.json',\n", + " './peft-dialogue-summary-checkpoint-local/special_tokens_map.json',\n", + " './peft-dialogue-summary-checkpoint-local/tokenizer.json')" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "peft_trainer.train()\n", + "\n", + "peft_model_path=\"./peft-dialogue-summary-checkpoint-local\"\n", + "\n", + "peft_trainer.model.save_pretrained(peft_model_path)\n", + "tokenizer.save_pretrained(peft_model_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "\"Time" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "That training was performed on a subset of data. To load a fully trained PEFT model, read a checkpoint of a PEFT model from S3." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/special_tokens_map.json to peft-dialogue-summary-checkpoint-from-s3/special_tokens_map.json\n", + "download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_config.json to peft-dialogue-summary-checkpoint-from-s3/adapter_config.json\n", + "download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer_config.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer_config.json\n", + "download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer.json\n", + "download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_model.bin to peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin\n" + ] + } + ], + "source": [ + "!aws s3 cp --recursive s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/ ./peft-dialogue-summary-checkpoint-from-s3/ " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Check that the size of this model is much less than the original LLM:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-rw-r--r-- 1 root root 14208525 May 15 11:18 ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin\n" + ] + } + ], + "source": [ + "!ls -al ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Prepare this model by adding an adapter to the original FLAN-T5 model. You are setting `is_trainable=False` because the plan is only to perform inference with this PEFT model. If you were preparing the model for further training, you would set `is_trainable=True`." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from peft import PeftModel, PeftConfig\n", + "\n", + "peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(\"google/flan-t5-base\", torch_dtype=torch.bfloat16)\n", + "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n", + "\n", + "peft_model = PeftModel.from_pretrained(peft_model_base, \n", + " './peft-dialogue-summary-checkpoint-from-s3/', \n", + " torch_dtype=torch.bfloat16,\n", + " is_trainable=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "The number of trainable parameters will be `0` due to `is_trainable=False` setting:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable model parameters: 0\n", + "all model parameters: 251116800\n", + "percentage of trainable model parameters: 0.00%\n" + ] + } + ], + "source": [ + "print(print_number_of_trainable_model_parameters(peft_model))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### 3.3 - Evaluate the Model Qualitatively (Human Evaluation)\n", + "\n", + "Make inferences for the same example as in sections [1.3](#1.3) and [2.3](#2.3), with the original model, fully fine-tuned and PEFT model." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------------------------------------------------------------------------------------------------\n", + "BASELINE HUMAN SUMMARY:\n", + "#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.\n", + "---------------------------------------------------------------------------------------------------\n", + "ORIGINAL MODEL:\n", + "#Pork1: Have you considered upgrading your system? #Person1: Yes, but I'd like to make some improvements. #Pork1: I'd like to make a painting program. #Person1: I'd like to make a flyer. #Pork2: I'd like to make banners. #Person1: I'd like to make a computer graphics program. #Person2: I'd like to make a computer graphics program. #Person1: I'd like to make a computer graphics program. #Person2: Is there anything else you'd like to do? #Person1: I'd like to make a computer graphics program. #Person2: Is there anything else you need? #Person1: I'd like to make a computer graphics program. #Person2: I'\n", + "---------------------------------------------------------------------------------------------------\n", + "INSTRUCT MODEL:\n", + "#Person1# suggests #Person2# upgrading #Person2#'s system, hardware, and CD-ROM drive. #Person2# thinks it's great.\n", + "---------------------------------------------------------------------------------------------------\n", + "PEFT MODEL: #Person1# recommends adding a painting program to #Person2#'s software and upgrading hardware. #Person2# also wants to upgrade the hardware because it's outdated now.\n" + ] + } + ], + "source": [ + "index = 200\n", + "dialogue = dataset['test'][index]['dialogue']\n", + "baseline_human_summary = dataset['test'][index]['summary']\n", + "\n", + "prompt = f\"\"\"\n", + "Summarize the following conversation.\n", + "\n", + "{dialogue}\n", + "\n", + "Summary: \"\"\"\n", + "\n", + "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + "\n", + "original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", + "original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", + "\n", + "instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", + "instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)\n", + "\n", + "peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", + "peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)\n", + "\n", + "print(dash_line)\n", + "print(f'BASELINE HUMAN SUMMARY:\\n{human_baseline_summary}')\n", + "print(dash_line)\n", + "print(f'ORIGINAL MODEL:\\n{original_model_text_output}')\n", + "print(dash_line)\n", + "print(f'INSTRUCT MODEL:\\n{instruct_model_text_output}')\n", + "print(dash_line)\n", + "print(f'PEFT MODEL: {peft_model_text_output}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### 3.4 - Evaluate the Model Quantitatively (with ROUGE Metric)\n", + "Perform inferences for the sample of the test dataset (only 10 dialogues and summaries to save time). " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
human_baseline_summariesoriginal_model_summariesinstruct_model_summariespeft_model_summaries
0Ms. Dawson helps #Person1# to write a memo to ...The new intra-office policy will apply to all ...#Person1# asks Ms. Dawson to take a dictation ...#Person1# asks Ms. Dawson to take a dictation ...
1In order to prevent employees from wasting tim...Ms. Dawson will send an intra-office memo to a...#Person1# asks Ms. Dawson to take a dictation ...#Person1# asks Ms. Dawson to take a dictation ...
2Ms. Dawson takes a dictation for #Person1# abo...The memo should go out today.#Person1# asks Ms. Dawson to take a dictation ...#Person1# asks Ms. Dawson to take a dictation ...
3#Person2# arrives late because of traffic jam....#Person1#: I'm here. #Person2#: I'm here. #Per...#Person2# got stuck in traffic again. #Person1...#Person2# got stuck in traffic and #Person1# s...
4#Person2# decides to follow #Person1#'s sugges...The traffic jam is causing a lot of congestion...#Person2# got stuck in traffic again. #Person1...#Person2# got stuck in traffic and #Person1# s...
5#Person2# complains to #Person1# about the tra...I'm driving home from work.#Person2# got stuck in traffic again. #Person1...#Person2# got stuck in traffic and #Person1# s...
6#Person1# tells Kate that Masha and Hero get d...Masha and Hero are divorced for 2 months.Masha and Hero are getting divorced. Kate can'...Kate tells #Person2# Masha and Hero are gettin...
7#Person1# tells Kate that Masha and Hero are g...Masha and Hero are getting divorced.Masha and Hero are getting divorced. Kate can'...Kate tells #Person2# Masha and Hero are gettin...
8#Person1# and Kate talk about the divorce betw...#Person1#: Masha and Hero are getting divorced...Masha and Hero are getting divorced. Kate can'...Kate tells #Person2# Masha and Hero are gettin...
9#Person1# and Brian are at the birthday party ...#Person1#: Happy birthday, Brian. #Person2#: T...Brian's birthday is coming. #Person1# invites ...Brian remembers his birthday and invites #Pers...
\n", + "
" + ], + "text/plain": [ + " human_baseline_summaries \\\n", + "0 Ms. Dawson helps #Person1# to write a memo to ... \n", + "1 In order to prevent employees from wasting tim... \n", + "2 Ms. Dawson takes a dictation for #Person1# abo... \n", + "3 #Person2# arrives late because of traffic jam.... \n", + "4 #Person2# decides to follow #Person1#'s sugges... \n", + "5 #Person2# complains to #Person1# about the tra... \n", + "6 #Person1# tells Kate that Masha and Hero get d... \n", + "7 #Person1# tells Kate that Masha and Hero are g... \n", + "8 #Person1# and Kate talk about the divorce betw... \n", + "9 #Person1# and Brian are at the birthday party ... \n", + "\n", + " original_model_summaries \\\n", + "0 The new intra-office policy will apply to all ... \n", + "1 Ms. Dawson will send an intra-office memo to a... \n", + "2 The memo should go out today. \n", + "3 #Person1#: I'm here. #Person2#: I'm here. #Per... \n", + "4 The traffic jam is causing a lot of congestion... \n", + "5 I'm driving home from work. \n", + "6 Masha and Hero are divorced for 2 months. \n", + "7 Masha and Hero are getting divorced. \n", + "8 #Person1#: Masha and Hero are getting divorced... \n", + "9 #Person1#: Happy birthday, Brian. #Person2#: T... \n", + "\n", + " instruct_model_summaries \\\n", + "0 #Person1# asks Ms. Dawson to take a dictation ... \n", + "1 #Person1# asks Ms. Dawson to take a dictation ... \n", + "2 #Person1# asks Ms. Dawson to take a dictation ... \n", + "3 #Person2# got stuck in traffic again. #Person1... \n", + "4 #Person2# got stuck in traffic again. #Person1... \n", + "5 #Person2# got stuck in traffic again. #Person1... \n", + "6 Masha and Hero are getting divorced. Kate can'... \n", + "7 Masha and Hero are getting divorced. Kate can'... \n", + "8 Masha and Hero are getting divorced. Kate can'... \n", + "9 Brian's birthday is coming. #Person1# invites ... \n", + "\n", + " peft_model_summaries \n", + "0 #Person1# asks Ms. Dawson to take a dictation ... \n", + "1 #Person1# asks Ms. Dawson to take a dictation ... \n", + "2 #Person1# asks Ms. Dawson to take a dictation ... \n", + "3 #Person2# got stuck in traffic and #Person1# s... \n", + "4 #Person2# got stuck in traffic and #Person1# s... \n", + "5 #Person2# got stuck in traffic and #Person1# s... \n", + "6 Kate tells #Person2# Masha and Hero are gettin... \n", + "7 Kate tells #Person2# Masha and Hero are gettin... \n", + "8 Kate tells #Person2# Masha and Hero are gettin... \n", + "9 Brian remembers his birthday and invites #Pers... " + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dialogues = dataset['test'][0:10]['dialogue']\n", + "human_baseline_summaries = dataset['test'][0:10]['summary']\n", + "\n", + "original_model_summaries = []\n", + "instruct_model_summaries = []\n", + "peft_model_summaries = []\n", + "\n", + "for idx, dialogue in enumerate(dialogues):\n", + " prompt = f\"\"\"\n", + "Summarize the following conversation.\n", + "\n", + "{dialogue}\n", + "\n", + "Summary: \"\"\"\n", + " \n", + " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + "\n", + " human_baseline_text_output = human_baseline_summaries[idx]\n", + " \n", + " original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", + " original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", + "\n", + " instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", + " instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)\n", + "\n", + " peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", + " peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)\n", + "\n", + " original_model_summaries.append(original_model_text_output)\n", + " instruct_model_summaries.append(instruct_model_text_output)\n", + " peft_model_summaries.append(peft_model_text_output)\n", + "\n", + "zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries, peft_model_summaries))\n", + " \n", + "df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries', 'peft_model_summaries'])\n", + "df" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Compute ROUGE score for this subset of the data. " + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ORIGINAL MODEL:\n", + "{'rouge1': 0.2127769756385947, 'rouge2': 0.07849999999999999, 'rougeL': 0.1803101433337705, 'rougeLsum': 0.1872151390166362}\n", + "INSTRUCT MODEL:\n", + "{'rouge1': 0.41026607717457186, 'rouge2': 0.17840645241958838, 'rougeL': 0.2977022096267017, 'rougeLsum': 0.2987374187518165}\n", + "PEFT MODEL:\n", + "{'rouge1': 0.3725351062275605, 'rouge2': 0.12138811933618107, 'rougeL': 0.27620639623170606, 'rougeLsum': 0.2758134870822362}\n" + ] + } + ], + "source": [ + "rouge = evaluate.load('rouge')\n", + "\n", + "original_model_results = rouge.compute(\n", + " predictions=original_model_summaries,\n", + " references=human_baseline_summaries[0:len(original_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "instruct_model_results = rouge.compute(\n", + " predictions=instruct_model_summaries,\n", + " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "peft_model_results = rouge.compute(\n", + " predictions=peft_model_summaries,\n", + " references=human_baseline_summaries[0:len(peft_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "print('ORIGINAL MODEL:')\n", + "print(original_model_results)\n", + "print('INSTRUCT MODEL:')\n", + "print(instruct_model_results)\n", + "print('PEFT MODEL:')\n", + "print(peft_model_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice, that PEFT model results are not too bad, while the training process was much easier!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You already computed ROUGE score on the full dataset, after loading the results from the `data/dialogue-summary-training-results.csv` file. Load the values for the PEFT model now and check its performance compared to other models." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ORIGINAL MODEL:\n", + "{'rouge1': 0.2334158581572823, 'rouge2': 0.07603964187010573, 'rougeL': 0.20145520923859048, 'rougeLsum': 0.20145899339006135}\n", + "INSTRUCT MODEL:\n", + "{'rouge1': 0.42161291557556113, 'rouge2': 0.18035380596301792, 'rougeL': 0.3384439349963909, 'rougeLsum': 0.33835653595561666}\n", + "PEFT MODEL:\n", + "{'rouge1': 0.40810631575616746, 'rouge2': 0.1633255794568712, 'rougeL': 0.32507074586565354, 'rougeLsum': 0.3248950182867091}\n" + ] + } + ], + "source": [ + "human_baseline_summaries = results['human_baseline_summaries'].values\n", + "original_model_summaries = results['original_model_summaries'].values\n", + "instruct_model_summaries = results['instruct_model_summaries'].values\n", + "peft_model_summaries = results['peft_model_summaries'].values\n", + "\n", + "original_model_results = rouge.compute(\n", + " predictions=original_model_summaries,\n", + " references=human_baseline_summaries[0:len(original_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "instruct_model_results = rouge.compute(\n", + " predictions=instruct_model_summaries,\n", + " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "peft_model_results = rouge.compute(\n", + " predictions=peft_model_summaries,\n", + " references=human_baseline_summaries[0:len(peft_model_summaries)],\n", + " use_aggregator=True,\n", + " use_stemmer=True,\n", + ")\n", + "\n", + "print('ORIGINAL MODEL:')\n", + "print(original_model_results)\n", + "print('INSTRUCT MODEL:')\n", + "print(instruct_model_results)\n", + "print('PEFT MODEL:')\n", + "print(peft_model_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The results show less of an improvement over full fine-tuning, but the benefits of PEFT typically outweigh the slightly-lower performance metrics.\n", + "\n", + "Calculate the improvement of PEFT over the original model:" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Absolute percentage improvement of PEFT MODEL over HUMAN BASELINE\n", + "rouge1: 17.47%\n", + "rouge2: 8.73%\n", + "rougeL: 12.36%\n", + "rougeLsum: 12.34%\n" + ] + } + ], + "source": [ + "print(\"Absolute percentage improvement of PEFT MODEL over HUMAN BASELINE\")\n", + "\n", + "improvement = (np.array(list(peft_model_results.values())) - np.array(list(original_model_results.values())))\n", + "for key, value in zip(peft_model_results.keys(), improvement):\n", + " print(f'{key}: {value*100:.2f}%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now calculate the improvement of PEFT over a full fine-tuned model:" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Absolute percentage improvement of PEFT MODEL over INSTRUCT MODEL\n", + "rouge1: -1.35%\n", + "rouge2: -1.70%\n", + "rougeL: -1.34%\n", + "rougeLsum: -1.35%\n" + ] + } + ], + "source": [ + "print(\"Absolute percentage improvement of PEFT MODEL over INSTRUCT MODEL\")\n", + "\n", + "improvement = (np.array(list(peft_model_results.values())) - np.array(list(instruct_model_results.values())))\n", + "for key, value in zip(peft_model_results.keys(), improvement):\n", + " print(f'{key}: {value*100:.2f}%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here you see a small percentage decrease in the ROUGE metrics vs. full fine-tuned. However, the training requires much less computing and memory resources (often just a single GPU)." + ] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + }, + { + "_defaultOrder": 55, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 56, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4de.24xlarge", + "vcpuNum": 96 + } + ], + "colab": { + "name": "Fine-tune a language model", + "provenance": [] + }, + "instance_type": "ml.m5.2xlarge", + "kernelspec": { + "display_name": "Python 3 (Data Science)", + "language": "python", + "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}