|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "# Decision Trees" |
8 |
| - ] |
9 |
| - }, |
10 |
| - { |
11 |
| - "cell_type": "markdown", |
12 |
| - "metadata": {}, |
13 |
| - "source": [ |
14 |
| - "<img src = \"imgs/DT0.png\">" |
15 |
| - ] |
16 |
| - }, |
17 |
| - { |
18 |
| - "cell_type": "markdown", |
19 |
| - "metadata": {}, |
20 |
| - "source": [ |
21 |
| - "<img src = \"imgs/DT1.PNG\">" |
22 |
| - ] |
23 |
| - }, |
24 |
| - { |
25 |
| - "cell_type": "markdown", |
26 |
| - "metadata": {}, |
27 |
| - "source": [ |
28 |
| - "<img src = \"imgs/DT2.PNG\">" |
29 |
| - ] |
30 |
| - }, |
31 |
| - { |
32 |
| - "cell_type": "markdown", |
33 |
| - "metadata": {}, |
34 |
| - "source": [ |
35 |
| - "<img src = \"imgs/DT3.PNG\">" |
| 7 | + "# **Decision Trees**\n", |
| 8 | + "<img style=\"float: right;\" width=\"400\" height=\"400\" src = \"imgs/DT0.png\">\n", |
| 9 | + "Decision Trees are a popular and versatile machine learning algorithm used for both classification and regression tasks. A Decision Tree is a tree-like model where each internal node represents a decision based on the value of a particular feature, and each leaf node represents the outcome of a decision or a target variable.\n", |
| 10 | + "\n", |
| 11 | + "### **Algorithms**\n", |
| 12 | + "> 1. ID3\n", |
| 13 | + "> 2. C4.5\n", |
| 14 | + "> 3. C5.0\n", |
| 15 | + "> 4. CART (implemented in sklearn for decision trees)\n", |
| 16 | + "\n", |
| 17 | + "### **Working of CART Algorithm** \n", |
| 18 | + "> 1. The algorithm selects the best feature to split the dataset based on a certain criterion (e.g., Gini impurity for classification, mean squared error for regression).\n", |
| 19 | + "> 2. It recursively divides the dataset into subsets based on the chosen feature.\n", |
| 20 | + "> 3. The process continues until a stopping condition is met (e.g., a predefined depth is reached, or a node contains a minimum number of data points).\n", |
| 21 | + "> 4. Each leaf node corresponds to a class label in classification or a predicted value in regression.\n", |
| 22 | + "\n", |
| 23 | + "### **Key Concept** \n", |
| 24 | + "<img style=\"float: right;\" width=\"500\" height=\"400\" src = \"imgs/DT1.PNG\">\n", |
| 25 | + "\n", |
| 26 | + "1. **Information Gain** \n", |
| 27 | + "In the context of decision tree algorithms, the **Information Gain** is a metric **used to determine the effectiveness of a feature in partitioning the data**. The Information Gain is calculated by measuring the reduction in entropy (or Gini impurity) after a dataset is split based on a particular feature. The formula for Information Gain is often expressed as the weighted sum of entropies of child nodes.\n", |
| 28 | + "\n", |
| 29 | + "> The weighted entropy is beneficial in situations where the classes in the child nodes are not balanced, meaning one child node might have more instances than the other. This is particularly important in datasets where the class distribution is skewed. \n", |
| 30 | + "> The formula for Information Gain with weighted entropy is as follows:\n", |
| 31 | + "\n", |
| 32 | + "$$\\text{Information Gain} = \\text{Entropy(parent)} - \\sum_{i=1}^{p} \\frac{|D_i|}{|D|} \\cdot \\text{Entropy(child}_i) $$\n", |
| 33 | + "\n", |
| 34 | + "<img style=\"float: right;\" width=\"500\" height=\"400\" src = \"imgs/DT3.PNG\">\n", |
| 35 | + "\n", |
| 36 | + "> Here:\n", |
| 37 | + "> - $p$ is the number of child nodes.\n", |
| 38 | + "> - $|D_i|$ is the number of instances in the \\(i\\)-th child node.\n", |
| 39 | + "> - $|D|$ is the total number of instances in the parent node.\n", |
| 40 | + "\n", |
| 41 | + "2. **Entropy (Classification Trees)** `AKA Shannon Entropy` \n", |
| 42 | + "Entropy is a measure used for splitting nodes in Decision Trees, particularly in information theory. It is commonly applied in classification tasks. Entropy is a measure of node impurity. The lower the entropy, the more homogeneous the node is in terms of class labels.\n", |
| 43 | + "The formula for entropy $H(t)$ for a given node is defined as:\n", |
| 44 | + "\n", |
| 45 | + "$$ H(t) = - \\sum_{i=1}^{C} p(y_i) \\log_2(p(y_i)) $$\n", |
| 46 | + "\n", |
| 47 | + "> Here:\n", |
| 48 | + "> - $t$ is the current node.\n", |
| 49 | + "> - $C$ is the number of classes.\n", |
| 50 | + "> - $p(y_i)$ is the probability of class in the $t$ node.\n", |
| 51 | + "> The entropy ranges from 0 to 1\n", |
| 52 | + "\n", |
| 53 | + "<img style=\"float: right;\" width=\"500\" height=\"400\" src = \"imgs/DT2.PNG\">\n", |
| 54 | + "\n", |
| 55 | + "3. **Gini Impurity (Classification Trees)**\n", |
| 56 | + "For a given node $t$, the Gini impurity $G(t)$ is calculated as:\n", |
| 57 | + "\n", |
| 58 | + "$$ G(t) = 1 - \\sum_{i=1}^{C} p(y_i)^2 $$\n", |
| 59 | + "\n", |
| 60 | + "> where $C$ is the number of classes and $p(y_i)$ is the probability of class in the $t$ node. The lower the Gini impurity, the more homogeneous the node is in terms of class labels.\n", |
| 61 | + "\n", |
| 62 | + "4. **Mean Squared Error (Regression Trees)** \n", |
| 63 | + "For regression tasks, the mean squared error (MSE) is commonly used as the criterion for splitting nodes. Given a node $t$ with $|D_i|$ data points and their target values $y_i$, the MSE $MSE(t)$ is calculated as:\n", |
| 64 | + "\n", |
| 65 | + "$$ MSE(t) = \\frac{1}{|D_i|} \\sum_{i \\in t} (y_i - \\bar{y}_t)^2 $$\n", |
| 66 | + "\n", |
| 67 | + "> where $\\bar{y}_t$ is the mean target value in node $t$. The idea is to minimize the variance of the target values within each node.\n", |
| 68 | + "\n", |
| 69 | + "> Similar to classification trees, when constructing a regression tree, the algorithm evaluates the MSE for each feature and its potential split points and selects the feature and split point that result in the lowest overall MSE across the child nodes.\n", |
| 70 | + "\n", |
| 71 | + "\n", |
| 72 | + "### **Advantages** \n", |
| 73 | + "> 1. Requires very little data preparation - No need to rescale numerical columns. For categorical columns apply LabelEncoding\n", |
| 74 | + "> 2. Very fast during prediction time\n", |
| 75 | + "> 3. Interpretability\n", |
| 76 | + "> 4. Feature Importance\n", |
| 77 | + "> 5. Non Linearity\n", |
| 78 | + "> 6. Handle Missing Values\n", |
| 79 | + "> 7. Handle Multi-class Classification Problem\n", |
| 80 | + "\n", |
| 81 | + "### **Disadvantages** \n", |
| 82 | + "> 1. Overfitting\n", |
| 83 | + "> 2. Instable\n", |
| 84 | + "> 3. Sensitive to Outliers\n", |
| 85 | + "> 4. Sensitive to Data Imbalance - Biased towards class that dominates\n", |
| 86 | + "> 5. Very high training time. It is worse if there are numerical features in input data.\n", |
| 87 | + "> 6. Predictions are piecewise constant approximation" |
36 | 88 | ]
|
37 | 89 | },
|
38 | 90 | {
|
39 | 91 | "cell_type": "markdown",
|
40 | 92 | "metadata": {},
|
41 | 93 | "source": [
|
42 |
| - "### Loading the Data" |
| 94 | + "## **Decision Tree on Iris Data - Loading the Data**" |
43 | 95 | ]
|
44 | 96 | },
|
45 | 97 | {
|
|
206 | 258 | "cell_type": "markdown",
|
207 | 259 | "metadata": {},
|
208 | 260 | "source": [
|
209 |
| - "## Identifying Input and Output" |
| 261 | + "## **Identifying Input and Output**" |
210 | 262 | ]
|
211 | 263 | },
|
212 | 264 | {
|
|
223 | 275 | "cell_type": "markdown",
|
224 | 276 | "metadata": {},
|
225 | 277 | "source": [
|
226 |
| - "### Test Train Split" |
| 278 | + "## **Test Train Split**" |
227 | 279 | ]
|
228 | 280 | },
|
229 | 281 | {
|
|
246 | 298 | "cell_type": "markdown",
|
247 | 299 | "metadata": {},
|
248 | 300 | "source": [
|
249 |
| - "### Training" |
| 301 | + "## **Training**" |
250 | 302 | ]
|
251 | 303 | },
|
252 | 304 | {
|
|
283 | 335 | "cell_type": "markdown",
|
284 | 336 | "metadata": {},
|
285 | 337 | "source": [
|
286 |
| - "### Visualizing the Model" |
| 338 | + "## **Visualizing the Model**" |
287 | 339 | ]
|
288 | 340 | },
|
289 | 341 | {
|
|
342 | 394 | "cell_type": "markdown",
|
343 | 395 | "metadata": {},
|
344 | 396 | "source": [
|
345 |
| - "### Feature Importance" |
| 397 | + "## **Feature Importance**" |
346 | 398 | ]
|
347 | 399 | },
|
348 | 400 | {
|
|
393 | 445 | "cell_type": "markdown",
|
394 | 446 | "metadata": {},
|
395 | 447 | "source": [
|
396 |
| - "## **Wine Data**" |
| 448 | + "## **Decision Tree on Wine Data**" |
397 | 449 | ]
|
398 | 450 | },
|
399 | 451 | {
|
|
556 | 608 | }
|
557 | 609 | ],
|
558 | 610 | "source": [
|
| 611 | + "# Load the data\n", |
559 | 612 | "df = pd.read_csv('data/wine_data.csv')\n",
|
560 | 613 | "\n",
|
561 | 614 | "print(\"Shape:\", df.shape)\n",
|
|
630 | 683 | "metadata": {},
|
631 | 684 | "outputs": [],
|
632 | 685 | "source": [
|
| 686 | + "# Identify input and output\n", |
633 | 687 | "y = df['quality']\n",
|
634 | 688 | "\n",
|
635 | 689 | "X = df.drop('quality', axis=1)"
|
|
641 | 695 | "metadata": {},
|
642 | 696 | "outputs": [],
|
643 | 697 | "source": [
|
| 698 | + "# Split data into train and tests\n", |
644 | 699 | "from sklearn.model_selection import train_test_split\n",
|
645 | 700 | "\n",
|
646 | 701 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state = 0)"
|
|
660 | 715 | }
|
661 | 716 | ],
|
662 | 717 | "source": [
|
| 718 | + "# Training a model and evaluation of learning\n", |
663 | 719 | "from sklearn.tree import DecisionTreeClassifier\n",
|
664 | 720 | "from sklearn import metrics\n",
|
665 | 721 | "\n",
|
|
688 | 744 | }
|
689 | 745 | ],
|
690 | 746 | "source": [
|
| 747 | + "# Visuallization of DT Model\n", |
691 | 748 | "from sklearn.tree import plot_tree\n",
|
692 | 749 | "\n",
|
693 | 750 | "plt.figure(figsize=(25,10))\n",
|
|
722 | 779 | }
|
723 | 780 | ],
|
724 | 781 | "source": [
|
| 782 | + "# Feature Importance\n", |
725 | 783 | "classifier.feature_importances_"
|
726 | 784 | ]
|
727 | 785 | },
|
|
0 commit comments