Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions finetune-distilbert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"## Introduction\n",
"\n",
"In this demo, you will use the Hugging Faces `transformers` and `datasets` library with Amazon SageMaker to fine-tune a pre-trained transformer on binary text classification. In particular, you will use the pre-trained DistilBERT model with the Amazon Reviews Polarity dataset.\n",
"In this demo, you will use the Hugging Faces `transformers` and `datasets` library with Amazon SageMaker to fine-tune a pre-trained transformer on binary text classification. In particular, you will use the pre-trained DistilBERT model with the IMDB dataset.\n",
"You will then deploy the resulting model for inference using SageMaker Endpoint.\n",
"\n",
"### The model\n",
Expand All @@ -22,7 +22,7 @@
"\n",
"### The data\n",
"\n",
"The [Amazon Reviews Polarity dataset](https://github.com/dsk78/Text-Classification---Amazon-Reviews-Polarity) consists of reviews from Amazon. The data span a period of 18 years, including ~35 million reviews up to March 2013. Reviews include product and user information, ratings, and a plaintext review. It's avalaible under the [`amazon_polarity`](https://huggingface.co/datasets/amazon_polarity) dataset on [Hugging Face](https://huggingface.co/)."
"The [IMDB dataset](https://github.com/huggingface/datasets/tree/master/datasets/imdb) consists of movie reviews that can be use for binary sentiment classification containing substantially more data than previous benchmark datasets. It provides a set of 25,000 highly polar movie reviews for training, and 25,000 for testing. It's avalaible under the [`imdb`](https://huggingface.co/datasets/imdb) dataset on [Hugging Face](https://huggingface.co/)."
]
},
{
Expand Down Expand Up @@ -148,7 +148,7 @@
"# Data preparation\n",
"\n",
"The data preparation is straightforward as you're using the `datasets` library to download and preprocess the `\n",
"amazon_polarity` dataset directly from Hugging face. After preprocessing, the dataset will be uploaded to our `sagemaker_session_bucket` to be used within our training job."
"imdb` dataset directly from Hugging face. After preprocessing, the dataset will be uploaded to our `sagemaker_session_bucket` to be used within our training job."
]
},
{
Expand All @@ -157,11 +157,11 @@
"metadata": {},
"outputs": [],
"source": [
"dataset_name = 'amazon_polarity'\n",
"dataset_name = 'imdb'\n",
"\n",
"train_dataset, test_dataset = load_dataset(dataset_name, split=['train', 'test'])\n",
"train_dataset = train_dataset.shuffle().select(range(5000)) # limiting the dataset size to speed up the training during the demo\n",
"test_dataset = test_dataset.shuffle().select(range(1000))"
"test_dataset = test_dataset.shuffle().select(range(500))"
]
},
{
Expand Down Expand Up @@ -246,17 +246,18 @@
"source": [
"# Helper function to get the content to tokenize\n",
"def tokenize(batch):\n",
" return tokenizer(batch['content'], padding='max_length', truncation=True)\n",
" return tokenizer(batch['text'], padding='max_length', truncation=True)\n",
"\n",
"# Tokenize\n",
"train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))\n",
"test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))\n",
"\n",
"# Set the format to PyTorch\n",
"train_dataset = train_dataset.rename_column(\"label\", \"labels\")\n",
"train_dataset = train_dataset.rename_column(\"text\", \"content\")\n",
"train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])\n",
"test_dataset = test_dataset.rename_column(\"label\", \"labels\")\n",
"test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])"
"test_dataset = test_dataset.rename_column(\"text\", \"content\")"
]
},
{
Expand Down Expand Up @@ -624,7 +625,7 @@
"kernelspec": {
"display_name": "Python 3 (Data Science)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/datascience-1.0"
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
},
"language_info": {
"codemirror_mode": {
Expand Down