-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.c
More file actions
143 lines (113 loc) · 4.22 KB
/
main.c
File metadata and controls
143 lines (113 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "interface/tensor.h"
#include "interface/model.h"
#include "interface/loss.h"
#include "interface/train.h"
#include "interface/image_loader.h"
#include "interface/dataset.h"
void run_train(const char* dataset_path, const char* model_path) {
printf("Loading dataset from %s...\n", dataset_path);
Dataset* d = dataset_load(dataset_path);
if (!d) return;
// Build Model (Simple CNN for Cats vs Dogs)
// Input: 1x28x28
Model* model = model_create();
// Conv1: 1->8, 3x3, s=1, p=1 => 8x28x28
model_add_conv(model, 1, 8, 3, 1, 1);
model_add_relu(model);
// Conv2: 8->16, 3x3, s=2, p=1 => 16x14x14
model_add_conv(model, 8, 16, 3, 2, 1);
model_add_relu(model);
// Conv3: 16->32, 3x3, s=2, p=1 => 32x7x7
model_add_conv(model, 16, 32, 3, 2, 1);
model_add_relu(model);
// Global Average Pooling or Flatten?
// For simplicity, we'll just project to 2 classes with a 1x1 conv (simulating FC)
// Input is 32x7x7. We want output 2x1x1.
// This is a bit hacky for a "from scratch" demo without a Flatten/FC layer,
// but we can use a kernel size of 7 to reduce 7x7 to 1x1.
// Conv4 (FC equivalent): 32->2, 7x7, s=1, p=0 => 2x1x1
model_add_conv(model, 32, 2, 7, 1, 0);
// No ReLU here, straight to Softmax
model_add_softmax(model);
printf("Model built. Starting training...\n");
// Training Loop
float lr = 0.01f;
int epochs = 50; // Small number for demo
for (int epoch = 0; epoch < epochs; epoch++) {
float total_loss = 0;
int correct = 0;
for (int i = 0; i < d->size; i++) {
Tensor* input = d->images[i];
int target[1] = {d->labels[i]};
model_forward(model, input);
Tensor* output = model->layers[model->num_layers - 1].output_cache;
total_loss += cross_entropy_loss(output, target);
// Calc accuracy
int pred = output->data[0] > output->data[1] ? 0 : 1;
if (pred == target[0]) correct++;
cross_entropy_backward(output, target);
model_backward(model, input);
model_step(model, lr);
}
if (epoch % 1 == 0) {
printf("Epoch %d: Loss = %.4f, Acc = %.2f%%\n", epoch, total_loss / d->size, (float)correct / d->size * 100.0f);
}
}
if (model_path) {
model_save(model, model_path);
}
dataset_free(d);
model_free(model);
}
void run_predict(const char* model_path, const char* image_path) {
printf("Loading model from %s...\n", model_path);
Model* model = model_load(model_path);
if (!model) return;
printf("Loading image from %s...\n", image_path);
Tensor* input = tensor_load_image(image_path);
if (!input) {
model_free(model);
return;
}
model_forward(model, input);
Tensor* output = model->layers[model->num_layers - 1].output_cache;
float p_cat = output->data[0];
float p_dog = output->data[1];
printf("\nPrediction:\n");
printf("Cat: %.2f%%\n", p_cat * 100.0f);
printf("Dog: %.2f%%\n", p_dog * 100.0f);
if (p_cat > p_dog) printf("Result: CAT\n");
else printf("Result: DOG\n");
tensor_free(input);
model_free(model);
}
int main(int argc, char** argv) {
srand(time(NULL));
if (argc < 2) {
printf("Usage:\n");
printf(" Train: %s train <dataset_file> <model_save_path>\n", argv[0]);
printf(" Predict: %s predict <model_path> <image_file>\n", argv[0]);
return 1;
}
if (strcmp(argv[1], "train") == 0) {
if (argc < 4) {
printf("Usage: %s train <dataset_file> <model_save_path>\n", argv[0]);
return 1;
}
run_train(argv[2], argv[3]);
} else if (strcmp(argv[1], "predict") == 0) {
if (argc < 4) {
printf("Usage: %s predict <model_path> <image_file>\n", argv[0]);
return 1;
}
run_predict(argv[2], argv[3]);
} else {
printf("Unknown command: %s\n", argv[1]);
return 1;
}
return 0;
}