-
Notifications
You must be signed in to change notification settings - Fork 572
/
Copy pathtrain.py
47 lines (34 loc) · 1.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from joblib import dump
from pathlib import Path
import numpy as np
import pandas as pd
from skimage.io import imread_collection
from skimage.transform import resize
from sklearn.linear_model import SGDClassifier
def load_images(data_frame, column_name):
filelist = data_frame[column_name].to_list()
image_list = imread_collection(filelist)
return image_list
def load_labels(data_frame, column_name):
label_list = data_frame[column_name].to_list()
return label_list
def preprocess(image):
resized = resize(image, (100, 100, 3))
reshaped = resized.reshape((1, 30000))
return reshaped
def load_data(data_path):
df = pd.read_csv(data_path)
labels = load_labels(data_frame=df, column_name="label")
raw_images = load_images(data_frame=df, column_name="filename")
processed_images = [preprocess(image) for image in raw_images]
data = np.concatenate(processed_images, axis=0)
return data, labels
def main(repo_path):
train_csv_path = repo_path / "data/prepared/train.csv"
train_data, labels = load_data(train_csv_path)
sgd = SGDClassifier(max_iter=100)
trained_model = sgd.fit(train_data, labels)
dump(trained_model, repo_path / "model/model.joblib")
if __name__ == "__main__":
repo_path = Path(__file__).parent.parent
main(repo_path)