Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
"id": "fcbb5d61-8a0b-47d9-a7c5-0c041c82b8bf",
"metadata": {},
"source": [
"# 🚀 Deploy `deepseek-ai/DeepSeek-R1-Distill-Llama-8B` on Amazon SageMaker"
"# 🚀 Deploy `Qwen/Qwen3-4B-Instruct-2507` on Amazon SageMaker"
]
},
{
"cell_type": "markdown",
"id": "dd210e90-21e1-4f03-a08e-c3fba9aa6979",
"metadata": {},
"source": [
"## Prerequisites\n",
"\n",
"To start off, let's install some packages to help us through the notebooks. **Restart the kernel after packages have been installed.**"
]
},
Expand Down Expand Up @@ -57,6 +59,14 @@
"get_ipython().kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"id": "a947367a-bea3-498a-9548-d6e6e08f0d10",
"metadata": {},
"source": [
"***"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -66,9 +76,9 @@
"source": [
"import os\n",
"import sagemaker\n",
"from sagemaker.djl_inference import DJLModel\n",
"from ipywidgets import Dropdown\n",
"\n",
"import boto3\n",
"import shutil\n",
"from sagemaker.config import load_sagemaker_config\n",
"import sys\n",
"sys.path.append(os.path.dirname(os.getcwd()))\n",
"\n",
Expand All @@ -78,31 +88,20 @@
" print_dialog,\n",
" format_messages,\n",
" write_eula\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b53f21c-3a65-44fc-b547-712d971cd652",
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"import shutil\n",
"import sagemaker\n",
"from sagemaker.config import load_sagemaker_config\n",
")\n",
"\n",
"sagemaker_session = sagemaker.Session()\n",
"s3_client = boto3.client('s3')\n",
"\n",
"region = sagemaker_session.boto_session.region_name\n",
"bucket_name = sagemaker_session.default_bucket()\n",
"default_prefix = sagemaker_session.default_bucket_prefix\n",
"configs = load_sagemaker_config()\n",
"\n",
"session = sagemaker.Session()\n",
"role = sagemaker.get_execution_role()\n",
"\n",
"\n",
"print(f\"Execution Role: {role}\")\n",
"print(f\"Default S3 Bucket: {bucket_name}\")"
]
Expand Down Expand Up @@ -130,11 +129,14 @@
"metadata": {},
"outputs": [],
"source": [
"inference_image_uri = sagemaker.image_uris.retrieve(\n",
" framework=\"djl-lmi\", \n",
" region=session.boto_session.region_name, \n",
" version=\"0.29.0\"\n",
")\n",
"# commenting until LMI 0.33.0 available via SageMaker SDK\n",
"# inference_image_uri = sagemaker.image_uris.retrieve(\n",
"# framework=\"djl-lmi\", \n",
"# region=session.boto_session.region_name, \n",
"# version=\"0.33.0\"\n",
"# )\n",
"\n",
"inference_image_uri = f\"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.33.0-lmi15.0.0-cu128\"\n",
"pretty_print_html(f\"using image to host: {inference_image_uri}\")"
]
},
Expand All @@ -153,7 +155,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_id = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
"model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
"model_id_filesafe = model_id.replace(\"/\",\"_\")\n",
"\n",
"use_local_model = True #set to false for the training job to download from HF, otherwise True will download locally"
Expand Down Expand Up @@ -225,7 +227,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_name = \"DeepSeek-R1-Distill-Llama-8B\"\n",
"model_name = \"Qwen3-4B-Instruct-2507\"\n",
"\n",
"lmi_model = sagemaker.Model(\n",
" image_uri=inference_image_uri,\n",
Expand All @@ -242,12 +244,15 @@
"metadata": {},
"outputs": [],
"source": [
"base_endpoint_name = f\"{model_name}-endpoint\"\n",
"from sagemaker.utils import name_from_base\n",
"\n",
"endpoint_name = f\"{model_name}-endpoint\"\n",
"BASE_ENDPOINT_NAME = name_from_base(endpoint_name)\n",
"\n",
"predictor = lmi_model.deploy(\n",
" initial_instance_count=1, \n",
" instance_type=\"ml.g5.2xlarge\",\n",
" endpoint_name=base_endpoint_name\n",
" endpoint_name=BASE_ENDPOINT_NAME\n",
")"
]
},
Expand All @@ -258,30 +263,19 @@
"metadata": {},
"outputs": [],
"source": [
"base_prompt = f\"\"\"\n",
"<|begin_of_text|>\n",
"<|start_header_id|>system<|end_header_id|>\n",
"You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n",
"SYSTEM_PROMPT = f\"\"\"You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n",
"Below is an instruction that describes a task, paired with an input that provides further context. \n",
"Write a response that appropriately completes the request.\n",
"Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"{{question}}<|eot_id|>\n",
"<|start_header_id|>assistant<|end_header_id|>\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b37e7f1-730c-4b31-aa3b-55e2009f8f04",
"metadata": {},
"outputs": [],
"source": [
"prompt = base_prompt.format(\n",
" question=\"A 3-week-old child has been diagnosed with late onset perinatal meningitis, and the CSF culture shows gram-positive bacilli. What characteristic of this bacterium can specifically differentiate it from other bacterial agents?\"\n",
")\n",
"Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\"\"\"\n",
"\n",
"USER_PROMPT = \"A 3-week-old child has been diagnosed with late onset perinatal meningitis, and the CSF culture shows gram-positive bacilli. What characteristic of this bacterium can specifically differentiate it from other bacterial agents?\"\n",
"\n",
"print(prompt)"
"messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": USER_PROMPT},\n",
"]\n",
"\n",
"messages"
]
},
{
Expand All @@ -292,35 +286,44 @@
"outputs": [],
"source": [
"predictor = sagemaker.Predictor(\n",
" endpoint_name=base_endpoint_name,\n",
" endpoint_name=BASE_ENDPOINT_NAME,\n",
" sagemaker_session=sagemaker_session,\n",
" serializer=sagemaker.serializers.JSONSerializer(),\n",
" deserializer=sagemaker.deserializers.JSONDeserializer(),\n",
")\n",
"\n",
"response = predictor.predict({\n",
"\t\"inputs\": prompt,\n",
"\t\"messages\": messages,\n",
" \"parameters\": {\n",
" \"temperature\": 0.2,\n",
" \"top_p\": 0.9,\n",
" \"return_full_text\": False,\n",
" \"max_new_tokens\": 1024,\n",
" \"stop\": ['<|eot_id|>']\n",
" \"max_new_tokens\": 1024\n",
" }\n",
"})\n",
"\n",
"response = response[\"generated_text\"].split(\"<|eot_id|>\")[0]\n",
"response[\"choices\"][0][\"message\"][\"content\"]"
]
},
{
"cell_type": "markdown",
"id": "165c8660-ee18-411f-9d8a-8032c6171d77",
"metadata": {},
"source": [
"### Store variables\n",
"\n",
"response"
"Save the endpoint name for use later"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbfc37bb-dc1f-4ba7-9948-6e482c1c86b0",
"id": "0ed6ca9e-705c-4d01-9118-110b86caaef6",
"metadata": {},
"outputs": [],
"source": []
"source": [
"%store BASE_ENDPOINT_NAME"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
transformers==4.48.2
transformers==4.53.0
peft==0.14.0
accelerate==1.3.0
bitsandbytes==0.45.1
datasets==3.2.0
evaluate==0.4.3
huggingface_hub[hf_transfer]==0.33.4
mlflow
mlflow==2.22.2
safetensors>=0.4.5
sagemaker==2.239.0
sagemaker==2.252.0
sagemaker-mlflow==0.1.0
sentencepiece==0.2.0
scikit-learn==1.6.1
tokenizers>=0.21.0
trl==0.9.6
py7zr
py7zr==1.0.0
Loading