|
6 | 6 | from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
|
7 | 7 | import seaborn as sns
|
8 | 8 | import matplotlib.pyplot as plt
|
9 |
| - |
| 9 | +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix |
10 | 10 |
|
11 | 11 | '''
|
12 | 12 | This portion uses code from a previous project from this [notebook](https://github.com/DerikVo/DSI_project_4_plant_disease/blob/main/notebooks/01_Potato_PlantVillageEDA.ipynb).
|
@@ -184,3 +184,18 @@ def plot_confusion_matrix(confusion_matrix, class_paths, title):
|
184 | 184 | plt.savefig(f'../Created_images/{title} confusion matrix.png')
|
185 | 185 | #displays the image
|
186 | 186 | plt.show()
|
| 187 | + |
| 188 | + |
| 189 | + |
| 190 | +def model_metrics(true_classes, predicted_classes, title): |
| 191 | + ''' |
| 192 | + Calculate accuracy, precision, recall, and F1 score. |
| 193 | + Also passes a title argument that titles the index for the model being used |
| 194 | + ''' |
| 195 | + accuracy = accuracy_score(true_classes, predicted_classes) |
| 196 | + precision = precision_score(true_classes, predicted_classes, average='weighted') |
| 197 | + recall = recall_score(true_classes, predicted_classes, average='weighted') |
| 198 | + f1 = f1_score(true_classes, predicted_classes, average='weighted') |
| 199 | + data = {'Accuracy': [accuracy], 'Precision': [precision], 'Recall': [recall], 'F1 Score': [f1]} |
| 200 | + df = pd.DataFrame(data, index=[f'{title}']) |
| 201 | + return df |
0 commit comments