Skip to content

Commit c2bae34

Browse files
Created using Colab
1 parent cff45ff commit c2bae34

1 file changed

+346
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": [],
7+
"authorship_tag": "ABX9TyPvuy00BRdhVo2HI2/iEvWM",
8+
"include_colab_link": true
9+
},
10+
"kernelspec": {
11+
"name": "python3",
12+
"display_name": "Python 3"
13+
},
14+
"language_info": {
15+
"name": "python"
16+
}
17+
},
18+
"cells": [
19+
{
20+
"cell_type": "markdown",
21+
"metadata": {
22+
"id": "view-in-github",
23+
"colab_type": "text"
24+
},
25+
"source": [
26+
"<a href=\"https://colab.research.google.com/github/masoudshahrian/DeepLearning-Code/blob/main/rsna_2024_lumbar_spine_degenerative_classification_with_tensorflow.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"source": [
32+
"import os\n",
33+
"import numpy as np\n",
34+
"import pandas as pd\n",
35+
"import pydicom\n",
36+
"import tensorflow as tf\n",
37+
"from tensorflow.keras import layers, models, optimizers, callbacks\n",
38+
"from sklearn.model_selection import train_test_split\n",
39+
"from tqdm import tqdm\n",
40+
"\n",
41+
"# ------------------------------\n",
42+
"# 1. بارگذاری فایل‌های CSV و تعریف مسیرها\n",
43+
"# ------------------------------\n",
44+
"train_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/'\n",
45+
"\n",
46+
"train_df = pd.read_csv(os.path.join(train_path, 'train.csv'))\n",
47+
"label_df = pd.read_csv(os.path.join(train_path, 'train_label_coordinates.csv'))\n",
48+
"train_desc_df = pd.read_csv(os.path.join(train_path, 'train_series_descriptions.csv'))\n",
49+
"test_desc_df = pd.read_csv(os.path.join(train_path, 'test_series_descriptions.csv'))\n",
50+
"sub = pd.read_csv(os.path.join(train_path, 'sample_submission.csv'))\n",
51+
"\n",
52+
"# ------------------------------\n",
53+
"# 2. تابع برای تولید مسیرهای تصاویر\n",
54+
"# ------------------------------\n",
55+
"def generate_image_paths(df, data_dir):\n",
56+
" image_paths = []\n",
57+
" for study_id, series_id in zip(df['study_id'], df['series_id']):\n",
58+
" study_dir = os.path.join(data_dir, str(study_id))\n",
59+
" series_dir = os.path.join(study_dir, str(series_id))\n",
60+
" if os.path.exists(series_dir):\n",
61+
" images = os.listdir(series_dir)\n",
62+
" image_paths.extend([os.path.join(series_dir, img) for img in images])\n",
63+
" return image_paths\n",
64+
"\n",
65+
"train_image_paths = generate_image_paths(train_desc_df, os.path.join(train_path, 'train_images'))\n",
66+
"test_image_paths = generate_image_paths(test_desc_df, os.path.join(train_path, 'test_images'))\n",
67+
"\n",
68+
"print(\"نمونه مسیر از تصاویر train:\", train_image_paths[2])\n",
69+
"print(\"تعداد ردیف‌های train_desc:\", len(train_desc_df))\n",
70+
"print(\"تعداد تصاویر train:\", len(train_image_paths))\n",
71+
"\n",
72+
"# ------------------------------\n",
73+
"# 3. تغییر شکل داده‌های train و ادغام دیتا فریم‌ها\n",
74+
"# ------------------------------\n",
75+
"def reshape_row(row):\n",
76+
" data = {'study_id': [], 'condition': [], 'level': [], 'severity': []}\n",
77+
" for column, value in row.items():\n",
78+
" if column not in ['study_id', 'series_id', 'instance_number', 'x', 'y', 'series_description']:\n",
79+
" parts = column.split('_')\n",
80+
" condition = ' '.join([word.capitalize() for word in parts[:-2]])\n",
81+
" level = parts[-2].capitalize() + '/' + parts[-1].capitalize()\n",
82+
" data['study_id'].append(row['study_id'])\n",
83+
" data['condition'].append(condition)\n",
84+
" data['level'].append(level)\n",
85+
" data['severity'].append(value)\n",
86+
" return pd.DataFrame(data)\n",
87+
"\n",
88+
"new_train_df = pd.concat([reshape_row(row) for _, row in train_df.iterrows()], ignore_index=True)\n",
89+
"\n",
90+
"merged_df = pd.merge(new_train_df, label_df, on=['study_id', 'condition', 'level'], how='inner')\n",
91+
"final_merged_df = pd.merge(merged_df, train_desc_df, on=['series_id','study_id'], how='inner')\n",
92+
"\n",
93+
"final_merged_df['row_id'] = (final_merged_df['study_id'].astype(str) + '_' +\n",
94+
" final_merged_df['condition'].str.lower().str.replace(' ', '_') + '_' +\n",
95+
" final_merged_df['level'].str.lower().str.replace('/', '_'))\n",
96+
"\n",
97+
"final_merged_df['image_path'] = (os.path.join(train_path, 'train_images') + '/' +\n",
98+
" final_merged_df['study_id'].astype(str) + '/' +\n",
99+
" final_merged_df['series_id'].astype(str) + '/' +\n",
100+
" final_merged_df['instance_number'].astype(str) + '.dcm')\n",
101+
"\n",
102+
"# تغییر برچسب severity به حروف کوچک\n",
103+
"final_merged_df['severity'] = final_merged_df['severity'].map({\n",
104+
" 'Normal/Mild': 'normal_mild',\n",
105+
" 'Moderate': 'moderate',\n",
106+
" 'Severe': 'severe'\n",
107+
"})\n",
108+
"\n",
109+
"# فیلتر کردن ردیف‌هایی که مسیر تصویر موجود است\n",
110+
"def check_exists(path):\n",
111+
" return os.path.exists(path)\n",
112+
"final_merged_df = final_merged_df[final_merged_df['image_path'].apply(check_exists)]\n",
113+
"\n",
114+
"# نگاشت برچسب‌ها به اعداد صحیح\n",
115+
"severity_map = {'normal_mild': 0, 'moderate': 1, 'severe': 2}\n",
116+
"final_merged_df['severity'] = final_merged_df['severity'].map(severity_map)\n",
117+
"\n",
118+
"# استفاده از final_merged_df به عنوان داده‌های آموزشی\n",
119+
"train_data = final_merged_df.copy()\n",
120+
"\n",
121+
"# ------------------------------\n",
122+
"# 4. توابع بارگذاری و پیش‌پردازش تصاویر DICOM\n",
123+
"# ------------------------------\n",
124+
"def load_dicom_image(path):\n",
125+
" \"\"\"\n",
126+
" تابعی برای بارگذاری تصویر DICOM\n",
127+
" \"\"\"\n",
128+
" # تبدیل EagerTensor به آرایه NumPy و سپس به رشته\n",
129+
" path = path.numpy().decode('utf-8')\n",
130+
" ds = pydicom.dcmread(path)\n",
131+
" data = ds.pixel_array.astype(np.float32)\n",
132+
" data = data - np.min(data)\n",
133+
" if np.max(data) != 0:\n",
134+
" data = data / np.max(data)\n",
135+
" data = (data * 255).astype(np.uint8)\n",
136+
" return data\n",
137+
"\n",
138+
"def load_and_preprocess(path, label=None):\n",
139+
" \"\"\"\n",
140+
" - بارگذاری تصویر DICOM با استفاده از tf.py_function\n",
141+
" - افزودن بعد کانال (برای تصاویر خاکستری)\n",
142+
" - تغییر اندازه به 224x224، تبدیل از grayscale به RGB و نرمال‌سازی به [0, 1]\n",
143+
" \"\"\"\n",
144+
" image = tf.py_function(func=lambda p: load_dicom_image(p), inp=[path], Tout=tf.uint8)\n",
145+
" image.set_shape([None, None])\n",
146+
" # افزودن بعد کانال (از (ارتفاع, عرض) به (ارتفاع, عرض, 1))\n",
147+
" image = tf.expand_dims(image, axis=-1)\n",
148+
" image = tf.image.resize(image, [224, 224])\n",
149+
" image = tf.image.grayscale_to_rgb(image)\n",
150+
" image = tf.cast(image, tf.float32) / 255.0\n",
151+
" if label is None:\n",
152+
" return image\n",
153+
" else:\n",
154+
" return image, label\n",
155+
"\n",
156+
"# ------------------------------\n",
157+
"# 5. ایجاد دیتاست‌های TensorFlow برای هر سری توضیحی\n",
158+
"# ------------------------------\n",
159+
"def create_datasets(df, series_description, batch_size=8):\n",
160+
" filtered_df = df[df['series_description'] == series_description]\n",
161+
" if filtered_df.empty:\n",
162+
" raise ValueError(f\"داده‌ای برای سری توضیحی: {series_description} پیدا نشد.\")\n",
163+
" train_df_part, val_df_part = train_test_split(filtered_df, test_size=0.2, random_state=42)\n",
164+
"\n",
165+
" train_paths = train_df_part['image_path'].values\n",
166+
" train_labels = train_df_part['severity'].values\n",
167+
" val_paths = val_df_part['image_path'].values\n",
168+
" val_labels = val_df_part['severity'].values\n",
169+
"\n",
170+
" train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))\n",
171+
" train_ds = train_ds.map(lambda p, l: load_and_preprocess(p, l),\n",
172+
" num_parallel_calls=tf.data.AUTOTUNE)\n",
173+
" train_ds = train_ds.shuffle(buffer_size=len(train_df_part)).batch(batch_size).prefetch(tf.data.AUTOTUNE)\n",
174+
"\n",
175+
" val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))\n",
176+
" val_ds = val_ds.map(lambda p, l: load_and_preprocess(p, l),\n",
177+
" num_parallel_calls=tf.data.AUTOTUNE)\n",
178+
" val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)\n",
179+
"\n",
180+
" return train_ds, val_ds, len(train_df_part), len(val_df_part)\n",
181+
"\n",
182+
"# ایجاد دیتاست‌ها برای سه سری توضیحی\n",
183+
"train_ds_t1, val_ds_t1, len_train_t1, len_val_t1 = create_datasets(train_data, 'Sagittal T1', batch_size=8)\n",
184+
"train_ds_t2, val_ds_t2, len_train_t2, len_val_t2 = create_datasets(train_data, 'Axial T2', batch_size=8)\n",
185+
"train_ds_t2stir, val_ds_t2stir, len_train_t2stir, len_val_t2stir = create_datasets(train_data, 'Sagittal T2/STIR', batch_size=8)\n",
186+
"\n",
187+
"# ------------------------------\n",
188+
"# 6. تعریف مدل VGG19 با استفاده از Keras (TensorFlow)\n",
189+
"# ------------------------------\n",
190+
"def create_vgg19_model(num_classes=3):\n",
191+
" base_model = tf.keras.applications.VGG19(include_top=False,\n",
192+
" input_shape=(224, 224, 3),\n",
193+
" weights='imagenet')\n",
194+
" base_model.trainable = False\n",
195+
" x = layers.Flatten()(base_model.output)\n",
196+
" x = layers.Dense(4096, activation='relu')(x)\n",
197+
" x = layers.Dense(4096, activation='relu')(x)\n",
198+
" outputs = layers.Dense(num_classes, activation='softmax')(x)\n",
199+
" model = tf.keras.Model(inputs=base_model.input, outputs=outputs)\n",
200+
" return model\n",
201+
"\n",
202+
"# ایجاد سه مدل مجزا برای سری‌های مختلف\n",
203+
"model_t1 = create_vgg19_model(num_classes=3)\n",
204+
"model_t2 = create_vgg19_model(num_classes=3)\n",
205+
"model_t2stir = create_vgg19_model(num_classes=3)\n",
206+
"\n",
207+
"model_t1.compile(optimizer=optimizers.Adam(learning_rate=0.001),\n",
208+
" loss='sparse_categorical_crossentropy',\n",
209+
" metrics=['accuracy'])\n",
210+
"model_t2.compile(optimizer=optimizers.Adam(learning_rate=0.001),\n",
211+
" loss='sparse_categorical_crossentropy',\n",
212+
" metrics=['accuracy'])\n",
213+
"model_t2stir.compile(optimizer=optimizers.Adam(learning_rate=0.001),\n",
214+
" loss='sparse_categorical_crossentropy',\n",
215+
" metrics=['accuracy'])\n",
216+
"\n",
217+
"# ------------------------------\n",
218+
"# 7. آموزش مدل‌ها\n",
219+
"# ------------------------------\n",
220+
"# تغییر پسوند فایل‌های checkpoint به .keras برای سازگاری با فرمت Keras\n",
221+
"es_callback_t1 = callbacks.EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)\n",
222+
"ckpt_callback_t1 = callbacks.ModelCheckpoint('best_model_t1.keras', monitor='val_accuracy', save_best_only=True)\n",
223+
"\n",
224+
"es_callback_t2 = callbacks.EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)\n",
225+
"ckpt_callback_t2 = callbacks.ModelCheckpoint('best_model_t2.keras', monitor='val_accuracy', save_best_only=True)\n",
226+
"\n",
227+
"es_callback_t2stir = callbacks.EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)\n",
228+
"ckpt_callback_t2stir = callbacks.ModelCheckpoint('best_model_t2stir.keras', monitor='val_accuracy', save_best_only=True)\n",
229+
"\n",
230+
"print(\"آموزش مدل Sagittal T1\")\n",
231+
"history_t1 = model_t1.fit(train_ds_t1, epochs=10, validation_data=val_ds_t1,\n",
232+
" callbacks=[es_callback_t1, ckpt_callback_t1])\n",
233+
"\n",
234+
"print(\"آموزش مدل Axial T2\")\n",
235+
"history_t2 = model_t2.fit(train_ds_t2, epochs=10, validation_data=val_ds_t2,\n",
236+
" callbacks=[es_callback_t2, ckpt_callback_t2])\n",
237+
"\n",
238+
"print(\"آموزش مدل Sagittal T2/STIR\")\n",
239+
"history_t2stir = model_t2stir.fit(train_ds_t2stir, epochs=10, validation_data=val_ds_t2stir,\n",
240+
" callbacks=[es_callback_t2stir, ckpt_callback_t2stir])\n",
241+
"\n",
242+
"# ------------------------------\n",
243+
"# 8. پیش‌بینی روی داده‌های Test و ایجاد سابمیشن\n",
244+
"# ------------------------------\n",
245+
"condition_mapping = {\n",
246+
" 'Sagittal T1': {'left': 'left_neural_foraminal_narrowing', 'right': 'right_neural_foraminal_narrowing'},\n",
247+
" 'Axial T2': {'left': 'left_subarticular_stenosis', 'right': 'right_subarticular_stenosis'},\n",
248+
" 'Sagittal T2/STIR': 'spinal_canal_stenosis'\n",
249+
"}\n",
250+
"\n",
251+
"expanded_rows = []\n",
252+
"for index, row in test_desc_df.iterrows():\n",
253+
" study_id = row['study_id']\n",
254+
" series_id = row['series_id']\n",
255+
" series_description = row['series_description']\n",
256+
" series_path = os.path.join(train_path, 'test_images', str(study_id), str(series_id))\n",
257+
" if os.path.exists(series_path):\n",
258+
" image_files = [os.path.join(series_path, f) for f in os.listdir(series_path)\n",
259+
" if os.path.isfile(os.path.join(series_path, f))]\n",
260+
" conditions = condition_mapping.get(series_description, {})\n",
261+
" if isinstance(conditions, str):\n",
262+
" conditions = {'left': conditions, 'right': conditions}\n",
263+
" for side, condition in conditions.items():\n",
264+
" for image_path in image_files:\n",
265+
" expanded_rows.append({\n",
266+
" 'study_id': study_id,\n",
267+
" 'series_id': series_id,\n",
268+
" 'series_description': series_description,\n",
269+
" 'image_path': image_path,\n",
270+
" 'condition': condition,\n",
271+
" 'row_id': f\"{study_id}_{condition}\"\n",
272+
" })\n",
273+
"\n",
274+
"expanded_test_desc = pd.DataFrame(expanded_rows)\n",
275+
"\n",
276+
"# به‌روزرسانی row_id با اضافه کردن سطح (level)\n",
277+
"levels = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']\n",
278+
"def update_row_id(row, levels):\n",
279+
" level = levels[row.name % len(levels)]\n",
280+
" return f\"{row['study_id']}_{row['condition']}_{level}\"\n",
281+
"\n",
282+
"expanded_test_desc['row_id'] = expanded_test_desc.apply(lambda row: update_row_id(row, levels), axis=1)\n",
283+
"\n",
284+
"# ایجاد دیتاست test (بدون برچسب)\n",
285+
"test_paths = expanded_test_desc['image_path'].values\n",
286+
"test_ds = tf.data.Dataset.from_tensor_slices(test_paths)\n",
287+
"test_ds = test_ds.map(lambda p: load_and_preprocess(p),\n",
288+
" num_parallel_calls=tf.data.AUTOTUNE)\n",
289+
"test_ds = test_ds.batch(1)\n",
290+
"\n",
291+
"# دیکشنری مدل‌ها بر اساس سری توضیحی\n",
292+
"models_dict = {\n",
293+
" 'Sagittal T1': model_t1,\n",
294+
" 'Axial T2': model_t2,\n",
295+
" 'Sagittal T2/STIR': model_t2stir\n",
296+
"}\n",
297+
"\n",
298+
"normal_mild_probs = []\n",
299+
"moderate_probs = []\n",
300+
"severe_probs = []\n",
301+
"predictions_list = []\n",
302+
"\n",
303+
"for i, batch in enumerate(tqdm(test_ds)):\n",
304+
" series_description = expanded_test_desc.iloc[i]['series_description']\n",
305+
" model_used = models_dict.get(series_description, None)\n",
306+
" if model_used is None:\n",
307+
" normal_mild_probs.append(None)\n",
308+
" moderate_probs.append(None)\n",
309+
" severe_probs.append(None)\n",
310+
" predictions_list.append(None)\n",
311+
" else:\n",
312+
" preds = model_used.predict(batch)\n",
313+
" preds = preds[0]\n",
314+
" normal_mild_probs.append(preds[0])\n",
315+
" moderate_probs.append(preds[1])\n",
316+
" severe_probs.append(preds[2])\n",
317+
" predictions_list.append(preds)\n",
318+
"\n",
319+
"expanded_test_desc['normal_mild'] = normal_mild_probs\n",
320+
"expanded_test_desc['moderate'] = moderate_probs\n",
321+
"expanded_test_desc['severe'] = severe_probs\n",
322+
"\n",
323+
"submission_df = expanded_test_desc[[\"row_id\", \"normal_mild\", \"moderate\", \"severe\"]]\n",
324+
"grouped_submission = submission_df.groupby('row_id').max().reset_index()\n",
325+
"\n",
326+
"sub[['normal_mild', 'moderate', 'severe']] = grouped_submission[['normal_mild', 'moderate', 'severe']]\n",
327+
"sub.to_csv(\"/kaggle/working/submission.csv\", index=False)\n",
328+
"\n",
329+
"print(\"نمونه سابمیشن:\")\n",
330+
"print(sub.head())\n",
331+
"\n",
332+
"# ------------------------------\n",
333+
"# 9. ذخیره مدل‌ها\n",
334+
"# ------------------------------\n",
335+
"model_t1.save(\"Vgg19_t1.keras\")\n",
336+
"model_t2.save(\"Vgg19_t2.keras\")\n",
337+
"model_t2stir.save(\"Vgg19_t2stir.keras\")\n"
338+
],
339+
"metadata": {
340+
"id": "-tseZ4MG7bF9"
341+
},
342+
"execution_count": null,
343+
"outputs": []
344+
}
345+
]
346+
}

0 commit comments

Comments
 (0)