1
+ import wandb
2
+ import torch
3
+ from torch import nn
4
+ from torch .utils .data import DataLoader
5
+ from sklearn .model_selection import train_test_split
6
+
7
+ import json
8
+ import glob
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
13
+ torch .backends .cudnn .benchmark = True
14
+
15
+ with open ('configs/config.json' , 'r' ) as f :
16
+ config = json .load (f )
17
+
18
+ wandb .init (config = config , project = "kws-dlaudio" )
19
+ config = wandb .config
20
+
21
+ from utils import set_seed
22
+ from train import train_distill , evaluate
23
+ from inference import inference
24
+ from dataset import SpeechCommands
25
+ from model import KWSNet
26
+
27
+ set_seed (config .random_seed )
28
+
29
+ print (device )
30
+
31
+ paths = []
32
+ labels = []
33
+
34
+ for path in glob .glob ('speech_commands/*/*.wav' ):
35
+ _ , label , _ = path .split ('/' )
36
+ paths .append (path )
37
+ labels .append (int (label == config .target_class ))
38
+
39
+ df = pd .DataFrame ({'path' : paths , 'label' : labels })
40
+
41
+ X_train , X_test , y_train , y_test = train_test_split (np .array (df ['path' ]),
42
+ np .array (df ['label' ]),
43
+ test_size = 0.1 ,
44
+ stratify = np .array (df ['label' ]),
45
+ random_state = config .random_seed )
46
+
47
+ train_dataset = SpeechCommands (config , X_train , y_train )
48
+ test_dataset = SpeechCommands (config , X_test , y_test )
49
+
50
+ train_loader = DataLoader (train_dataset , batch_size = config .batch_size , shuffle = True , num_workers = config .dataloader_num_workers , pin_memory = True )
51
+ val_loader = DataLoader (test_dataset , batch_size = config .batch_size , shuffle = False , num_workers = config .dataloader_num_workers , pin_memory = True )
52
+
53
+ student_model = KWSNet (config .enc_hidden_size // 2 , config .conv_out_channels , config .conv_kernel_size )
54
+ student_model = student_model .to (device )
55
+
56
+ error = nn .CrossEntropyLoss ()
57
+ optimizer = torch .optim .Adam (student_model .parameters (), lr = config .learning_rate , weight_decay = config .weight_decay )
58
+
59
+ lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = config .lr_scheduler_step_size , gamma = config .lr_scheduler_gamma )
60
+
61
+ teacher_model = KWSNet (config .enc_hidden_size , config .conv_out_channels , config .conv_kernel_size )
62
+ teacher_model .load_state_dict (torch .load ('checkpoints/teacher_model.pth' ))
63
+ teacher_model = teacher_model .to (device )
64
+
65
+ alpha = config .teacher_alpha
66
+
67
+ for epoch in range (config .num_epochs ):
68
+ train_distill (epoch , teacher_model , student_model , alpha , optimizer , error , train_loader , device )
69
+ evaluate (student_model , optimizer , error , val_loader , device )
70
+ lr_scheduler .step ()
71
+
72
+ negative_val = []
73
+ positive_val = []
74
+
75
+ for path , label in zip (X_test , y_test ):
76
+ if label == 1 :
77
+ positive_val .append (path )
78
+ else :
79
+ negative_val .append (path )
80
+
81
+
82
+ path = positive_val [1 ]
83
+ inference ('results/student_positive_example.png' , student_model , path , noise = True , device = device )
84
+
85
+ path = negative_val [1 ]
86
+ inference ('results/student_negative_example.png' , student_model , path , noise = True , device = device )
87
+
88
+ torch .save (student_model .state_dict (), 'checkpoints/student_model.pth' )
89
+ wandb .save ('checkpoints/student_model.pth' )
0 commit comments