-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheeg_schizophrenia.py
2883 lines (1918 loc) · 80.7 KB
/
eeg_schizophrenia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
_________EEG SCHIZOPHRENIA CLASSIFICATION_________
Oguzhan Memis January 2025
1-DATASET DESCRIPTION:
-------------------------------------------------------------------------------
EEG Dataset which contains 2 classes of EEG signals captured from adolescents.
-Classes: Normal (39 people) and Schizophrenia (45 people).
-Properties of the EEG data:
*16 channels * 128 sample-per-second * 60 seconds of measurement for each person.
*Voltages are captured in units of microvolts (µV) 10^-6
*So the amplitudes of the signals varies from -2000 to +2000
-Orientation of the data in files:
*Signals are vertically placed into text files, ordered by channel number (1 to 16).
*Length of 1 signal is = 128*60 = 7680 samples.
*So each text file contains 16*7680 = 122880 samples , vertically.
SOURCE OF THE DATASET:
http://brain.bio.msu.ru/eeg_schizophrenia.htm
Original article of dataset: https://doi.org/10.1007/s10747-005-0042-z Physiology (Q4)
A recent article which uses this dataset: https://doi.org/10.1007/s11571-024-10121-0 Cognitive Neuroscience (Q2)
2-CODE ORGANIZATION:
-------------------------------------------------------------------------------
The codes are divided into separate cells by putting #%%,
RUN EACH CELL ONE BY ONE CONSECUTIVELY.
The cells are as follows:
1) Importing the data
2) Filtering stage (includes time and frequency plots)
3:
3.1) Visualization of all the healthy EEG channels together
3.2) Visualization of all the patient EEG channels together
4) Feature Examinations (including many statistical features on the signals)
5) Further explorations: Correlation matrix, and Recurrence plot
6) Multi-level Decomposition by DWT (examination)
7:
7.1) DWT Feature Extraction and Data Transformation
7.2) SVM Grid-search
7.3) SVM cross-validation
7.4) MLP model
7.5) Optional part: save the best model
7.6) MLP k-fold cross-validation
7.7) Leave One Out CV on the MLP
8:
8.1) STFT-Feature extraction method
8.2) STFT-MLP
8.3) STFT-SVM (Grid-search)
9:
9.1) STFT Data Transformation
9.2) STFT - CNN
10:
10.1) CWT Data Transformation
10.2) CWT - CNN
10.3) CNN k-fold cross-validation
10.4) Leave One Out CV on the CNN
3-CONSIDERATIONS:
-------------------------------------------------------------------------------
*Before running the classification models, consider related data transformation/feature extraction methods
and the input size (for the Deep Learning models).
*The DWT-Feature extraction method gives an output dataset in size of (84,16,25)
then the data of every subject are flattened into 16*25=400
*Use different wavelets for SVM and the MLP models. Such as 'bior2.8' and 'bior3.3' for the SVM
*The first STFT-Feature extraction method gives an output dataset in size of (84,16,325)
It uses a downsampled and flattened STFT.
Then the data of every subject are flattened into 16*325=5200
*In the second STFT method, Spectrograms of the signals are not flattened, and
dataset in size of (84, 16, 513, 21) is obtained.
The CNN model takes the input as 16 channel 513*21 matrices.
*In the last CWT method, Scalograms (downsampled in one axis) of the signals are captured
into the resultant dataset which has a size of (84, 16, 60, 1920).
The CNN model takes the input as 16 channel 60*1920 matrices.
*All the MLP models are built by using Keras,
and all the CNN models are built by using PyTorch (uses GPU)
"""
#%% 1) Importing the data
import os
import numpy as np
import matplotlib.pyplot as plt
# Read the files in corresponding folders
path1 = "./normal"
path2 = "./schizophren"
normals = np.zeros((39,122880)) # Preparing an empty matrix for collecting all the data of normal category.
i=0
# Iterate through each file in the directory
for filename in os.listdir(path1):
file_path = os.path.join(path1, filename)
normals[i,:] = np.loadtxt(file_path)
i +=1
reshaped_normal = np.zeros((39, 16, 7680))
for person in range(39):
for channel in range(16):
start_index = channel * 7680
end_index = (channel + 1) * 7680 # Will help to carefully separate the channels in correct order.
reshaped_normal[person, channel, :] = normals[person, start_index:end_index]
# __________________________________For the patient (Schizophrenia) category
patients = np.zeros((45,122880))
j=0
for filename2 in os.listdir(path2):
file_path2 = os.path.join(path2, filename2)
patients[j,:] = np.loadtxt(file_path2)
j +=1
reshaped_patient = np.zeros((45, 16, 7680))
for person2 in range(45):
for channel2 in range(16):
start_index2 = channel2 * 7680
end_index2 = (channel2 + 1) * 7680
reshaped_patient[person2, channel2, :] = patients[person2, start_index2:end_index2]
#%% 2) Filtering stage
# For noise removal, the pre-processing stage should be done by signal processing.
# Initial investigations on frequency content of the signals
import scipy
import pywt
from scipy.signal import welch, butter, filtfilt, iirnotch
from scipy.fft import fft, fftfreq
from scipy.signal import stft
# Example signal is the F4 channel of the 31th healthy person.
signal = reshaped_normal[30,2,:]
fs = 128 # sampling frequency
frequency_resolution = fs/len(signal)
print(f"Frequency Resolution: {frequency_resolution:.2f}%")
time = np.linspace(0, 60, len(signal))
plt.figure(figsize=(50,8))
plt.title("1 channel of EEG")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude in uV")
plt.xlim(0,60)
plt.ylim(-1500, 1500)
plt.plot(time, signal)
# 1.1) FFT of the signal
fft_val = np.abs(fft(signal))
freqs = fftfreq(len(signal), d=1/fs) # Full frequency range (bilateral)
freqs = freqs[:len(freqs)//2] # Take only positive side
fft_val = fft_val[:len(freqs)] # Positive side of the bilateral FFT values
power_vals = fft_val**2 # linear power spectrum
power_vals = 10 * np.log10( power_vals + 1e-12) # for logarithmic power spectrum
plt.figure(figsize=(13,8))
plt.title("Linear-scale FFT")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.plot(freqs, fft_val)
plt.figure(figsize=(13,8))
plt.title("Log-scale Power Spectrum")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitudes in dB")
plt.plot(freqs, power_vals)
#2.1) Periodogram of the signal
f, psd = welch(signal, fs=fs, nperseg=512, noverlap=128, window='hamming')
plt.figure(figsize=(13, 8))
plt.plot(f, psd)
plt.xlabel('Frequency (Hz)'), plt.ylabel('Power (signal^2 /Hz)')
plt.title('Periodogram (PSD) by Welch method')
#3.1) Spectrogram of the signal
f_stft, t_stft, z_m = stft(signal, fs=fs, window='hann', nperseg=512, noverlap=128, nfft=1024)
plt.figure(figsize=(13, 8))
plt.pcolormesh(t_stft, f_stft, np.abs(z_m), shading='gouraud')
plt.title('Magnitude Spectrogram')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.colorbar(label='Magnitude')
plt.tight_layout()
plt.show()
#4.1) Scalogram of the signal
wavelist = pywt.wavelist(kind='continuous') # The built-in CWT can only use limited number of wavelets.
scales = np.geomspace(1, 60, 60) # change according to verification of scale frequenices
scale_frequencies = pywt.scale2frequency("cmor1.5-10",scales)*fs # always verify it to represent accurate frequency resolution
# Difference than spectrogram, comes from these uneven distribution of wavelet frequencies.
coefficients, frequencies = pywt.cwt(signal, scales, 'cmor1.5-10', 1/fs)
log_coeffs = np.log10(abs(coefficients) + 1) # prevent log(0)
plt.figure(figsize=(13, 8))
f_min = 0
f_max = fs/2
plt.imshow(log_coeffs, extent=[0, len(signal)/fs, f_min, f_max], # just arranges the axis numbers
aspect='auto', cmap='viridis', interpolation='bilinear' ) # matrix to image
cbar = plt.colorbar()
cbar.set_label('Log10(Magnitude + 1)')
y_axis_labels = list(map(int, np.flip(frequencies))) # assign scale frequencies as y-axis numbers
step = 10 # step size to downsample the indexes
plt.yticks(
ticks=np.arange(0, len(y_axis_labels), step),
labels=y_axis_labels[::step] # Select corresponding labels
)
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.title('Log Magnitude Scalogram')
plt.show()
#Filter 1: IIR Butterworth 4th Order Low-pass 60Hz
nyquist = fs/2
lowpass_cutoff = 60 / nyquist # Most of these built-in functions use normalized frequencies rather than absolute (real) frequencies
b1, a1 = butter(4, lowpass_cutoff, btype='low')
lpf_signal = filtfilt(b1, a1, signal)
#Filter 2: Notch filter 50Hz (According to Russian electric frequency where dataset is provided)
notch_freq = 50 / nyquist
quality_factor = 30 # Quality factor for sharpness
b2, a2 = iirnotch(notch_freq, quality_factor)
notch_signal = filtfilt(b2, a2, lpf_signal)
#Filter 3: IIR Butterworth 8th Order High-pass 0.5 Hz
highpass_cutoff = 0.5 / nyquist
b3, a3 = butter(8, highpass_cutoff, btype='high')
filtered_signal = filtfilt(b3, a3, notch_signal)
# Comparison of examples. Same plotting steps repeated.
plt.figure(figsize=(50,8))
plt.title("Filtered channel of EEG")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude in uV")
plt.xlim(0,60)
plt.ylim(-1500, 1500)
plt.plot(time, filtered_signal)
# 1.2) FFT
fft_val = np.abs(fft(filtered_signal))
freqs = fftfreq(len(filtered_signal), d=1/fs)
freqs = freqs[:len(freqs)//2]
fft_val = fft_val[:len(freqs)]
power_vals = fft_val**2
power_vals = 10 * np.log10( power_vals + 1e-12)
plt.figure(figsize=(13,8))
plt.title("Linear-scale FFT")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.plot(freqs, fft_val)
plt.figure(figsize=(13,8))
plt.title("Log-scale Power Spectrum")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitudes in dB")
plt.plot(freqs, power_vals)
#2.2) Periodogram
f, psd = welch(filtered_signal, fs=fs, nperseg=512, noverlap=128, window='hamming')
plt.figure(figsize=(13, 8))
plt.plot(f, psd)
plt.xlabel('Frequency (Hz)'), plt.ylabel('Power (signal^2 /Hz)')
plt.title('Periodogram (PSD) by Welch method')
#3.2) Spectrogram
f_stft, t_stft, z_m = stft(filtered_signal, fs=fs, window='hann', nperseg=512, noverlap=128, nfft=1024)
plt.figure(figsize=(13, 8))
plt.pcolormesh(t_stft, f_stft, np.abs(z_m), shading='gouraud')
plt.title('Magnitude Spectrogram')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.colorbar(label='Magnitude')
plt.tight_layout()
plt.show()
#4.2) Scalogram
scales2 = np.geomspace(1, 60, 60)
scale_frequencies2 = pywt.scale2frequency("cmor1.5-10",scales2)*fs
coefficients2, frequencies2 = pywt.cwt(signal, scales2, 'cmor1.5-10', 1/fs)
log_coeffs2 = np.log10(abs(coefficients2) + 1) # prevent log(0)
plt.figure(figsize=(13, 8))
f_min = 0
f_max = fs/2
plt.imshow(log_coeffs2, extent=[0, len(signal)/fs, f_min, f_max],
aspect='auto', cmap='viridis', interpolation='bilinear' )
cbar = plt.colorbar()
cbar.set_label('Log10(Magnitude + 1)')
y_axis_labels2 = list(map(int, np.flip(frequencies2)))
step = 10
plt.yticks(
ticks=np.arange(0, len(y_axis_labels2), step),
labels=y_axis_labels2[::step]
)
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.title('Log Magnitude Scalogram')
plt.show()
# No dramatic changes observed, probably the cause is initial pre-processing of the provided dataset.
# __________________________Now, apply these filters to all channels____________________________________________
filtered_normal = np.zeros((39, 16, 7680))
filtered_patient = np.zeros((45, 16, 7680))
# a) Normals
for aa1 in range(39):
for cc1 in range(16):
filtered_normal[aa1, cc1, :] = filtfilt(b1, a1, reshaped_normal[aa1, cc1, :] ) #LPF
for aa2 in range(39):
for cc2 in range(16):
filtered_normal[aa2, cc2, :] = filtfilt(b2, a2, filtered_normal[aa2, cc2, :] ) #NOTCH
for aa3 in range(39):
for cc3 in range(16):
filtered_normal[aa3, cc3, :] = filtfilt(b3, a3, filtered_normal[aa3, cc3, :] ) #HPF
# b) Patients
for aa1b in range(45):
for cc1b in range(16):
filtered_patient[aa1b, cc1b, :] = filtfilt(b1, a1, reshaped_patient[aa1b, cc1b, :] ) #LPF
for aa2b in range(45):
for cc2b in range(16):
filtered_patient[aa2b, cc2b, :] = filtfilt(b2, a2, filtered_patient[aa2b, cc2b, :] ) #NOTCH
for aa3b in range(45):
for cc3b in range(16):
filtered_patient[aa3b, cc3b, :] = filtfilt(b3, a3, filtered_patient[aa3b, cc3b, :] ) #HPF
#%% 3.1) Visualization of all the healthy EEG
normal_person_number = 38
plt.figure(figsize=(75, 38))
for i in range(16):
plt.subplot(16, 1, i+1)
plt.xlim(0,60)
plt.ylim(-1500, 1500)
plt.plot(time,filtered_normal[normal_person_number, i, :])
plt.xticks(np.arange(0, 61, 10))
plt.ylabel(f'Channel {i+1}')
plt.show()
#%% 3.2) Visualization of all the pateint EEG
patient_person_number = 44
plt.figure(figsize=(75, 38))
for j in range(16):
plt.subplot(16, 1, j+1)
plt.xlim(0,60)
plt.ylim(-1500, 1500)
plt.plot(time,filtered_patient[patient_person_number, j, :])
plt.xticks(np.arange(0, 61, 10))
plt.ylabel(f'Channel {j+1}')
plt.show()
#%% 4) Feature examinations
"""
Effective representation of data differences, plays an important role to reach more optimal results in Machine Learning.
Thus, Feature Engineering of these signals is a crucial step to obtain greater insights.
Here, various statistical and engineering measurements will be evaluated.
"""
from scipy.stats import skew, kurtosis, entropy
from scipy.signal import find_peaks
# Definition of the custom function can be changed due to choice of features.
def extract_features(signal, fs=128):
# Number of samples
N = len(signal)
# Time-domain features
T1 = np.max(signal)
T2 = np.min(signal)
T3 = np.mean(signal)
#T4 = np.var(signal)
T5 = np.std(signal)
T6 = np.mean(np.abs(signal - T3))
T7 = np.sqrt(np.mean(signal**2))
T8 = np.mean(np.abs(np.diff(signal)))
#T9 = np.sum(signal**2)
T10 = np.ptp(signal)
#T11 = np.sum(np.abs(np.diff(signal)))
hist, bin_edges = np.histogram(signal, bins='auto', density=True)
probabilities = hist / np.sum(hist)
T12 = entropy(probabilities, base=2)
T13 = np.trapz(signal)
T14 = np.corrcoef(signal[:-1], signal[1:])[0, 1]
T15 = np.sum(np.arange(N) * signal**2) / np.sum(signal**2)
peaks, _ = find_peaks(signal)
T16 = len(peaks)
#T17 = np.sum(np.sqrt(1 + np.diff(signal)**2))
#T18 = np.sum(signal**2) / N
T19 = np.sum(signal[:-1] * signal[1:] < 0) / N
T20 = skew(signal)
T21 = kurtosis(signal)
T22 = np.sum((signal[:-2] < signal[1:-1]) & (signal[1:-1] > signal[2:]))
T23 = np.sum((signal[:-2] > signal[1:-1]) & (signal[1:-1] < signal[2:]))
T24 = np.max(signal) - np.min(signal)
#T25 = np.sum(np.abs(np.diff(signal)))
#T26 = np.sqrt(np.var(np.diff(signal)) / T4)
#T27 = np.sqrt(np.var(np.diff(np.diff(signal))) / np.var(np.diff(signal))) / T26
fft_val = np.abs(fft(signal))
freqs = fftfreq(len(signal), d=1/fs)
freqs = freqs[:len(freqs)//2]
fft_val = fft_val[:len(freqs)]
power_vals = fft_val**2
power_vals = 10 * np.log10(power_vals + 1e-12)
# Frequency-domain features
F1 = np.max(power_vals)
F2 = freqs[np.argmax(power_vals)]
cumulative_power = np.cumsum(power_vals)
total_power = cumulative_power[-1]
F3 = freqs[np.where(cumulative_power >= total_power / 2)[0][0]]
F4 = np.sum(freqs * power_vals) / np.sum(power_vals)
threshold = 0.5 * np.max(power_vals)
low_cutoff = np.min(freqs[np.where(power_vals >= threshold)])
high_cutoff = np.max(freqs[np.where(power_vals >= threshold)])
F5 = high_cutoff - low_cutoff
F6 = np.sum((power_vals - power_vals)**2)
epsilon = 1e-10
norm_power_vals = (power_vals + epsilon) / (np.sum(power_vals) + epsilon)
F7 = -np.sum(norm_power_vals * np.log2(norm_power_vals))
F8 = np.sqrt(np.sum((freqs - F4)**2 * power_vals) / np.sum(power_vals))
F9 = np.sum((freqs - F4)**3 * power_vals) / (F8**3 * np.sum(power_vals))
F10 = np.sum((freqs - F4)**4 * power_vals) / (F8**4 * np.sum(power_vals))
F11 = freqs[np.where(cumulative_power >= 0.95 * total_power)[0][0]]
#F12 = freqs[np.where(cumulative_power >= 1.15 * total_power)[0][0]] if np.any(cumulative_power >= 1.15 * total_power) else None
F13 = freqs[np.argmax(power_vals)]
F14 = np.sum(power_vals[(freqs >= 0.6) & (freqs <= 2.5)]) / np.sum(power_vals)
H_high = np.max(power_vals[freqs >= 0.5 * fs / 2])
H_low = np.min(power_vals[freqs >= 0.5 * fs / 2])
K = 1
F15 = K * (H_high / H_low)
F16 = 1 - np.sum(freqs * power_vals * power_vals) / (
np.sqrt(np.sum(freqs * power_vals) * np.sum(freqs * power_vals))
)
delta_power = np.sum(power_vals[(freqs >= 0.5) & (freqs <= 4)])
theta_power = np.sum(power_vals[(freqs > 4) & (freqs <= 8)])
alpha_power = np.sum(power_vals[(freqs > 8) & (freqs <= 13)])
beta_power = np.sum(power_vals[(freqs > 13) & (freqs <= 30)])
total_power = np.sum(power_vals)
F17 = delta_power / total_power
F18 = theta_power / total_power
F19 = alpha_power / total_power
F20 = beta_power / total_power
# Time-frequency domain features
W1 = pywt.wavedec(signal, 'db4', level=5)
W2 = [np.sum(np.square(c)) for c in W1]
#TF1 = np.sum(W2)
W3 = np.sum(W2)
W4 = entropy(W2 / W3, base=2)
TF2 = W4
TF3 = [W2[i] / W2[i + 1] for i in range(len(W2) - 1)]
#TF4 = [np.var(d) for d in W1]
# Combine all features into a single 1D array
features = np.array([
T1, T2, T3, T5, T6, T7, T8, T10, T12, T13, T14, T15, T16, T19,
T20, T21, T22, T23, T24, F1, F2, F3, F4, F5, F6, F7,
F8, F9, F10, F11, F13, F14, F15, F16, F17, F18, F19, F20, TF2, TF3[0], TF3[1],
TF3[2], TF3[3], TF3[4]
])
return features
healthy_features = np.zeros((39, 16, 44)) # consider the size of the feature vector here
patient_features = np.zeros((45, 16, 44))
for u in range(39):
for p in range(16):
healthy_features[u, p, :] = extract_features(filtered_normal[u, p, :])
for m in range(45):
for n in range(16):
patient_features[m, n, :] = extract_features(filtered_patient[m, n, :])
feature_set = np.concatenate((healthy_features,patient_features))
# Choose example channels for examination
feature1 = feature_set[30,2,:]
feature2 = feature_set[74,2,:]
plt.figure(figsize=(10,10))
plt.ylabel("Feature Values")
plt.xlabel("Feature Number")
plt.title("Feature Vector")
plt.plot(feature1)
plt.show()
plt.figure(figsize=(10,10))
plt.ylabel("Feature Values")
plt.xlabel("Feature Number")
plt.title("Feature Vector")
plt.plot(feature2)
plt.show()
#%% 5) Further explorations: Connectivity related plots
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
# 1) Correlation matrix between the channels, for a specific individual
eeg1 = filtered_normal[30,:,:] # 31th healthy person channels
eeg2 = filtered_patient[39,:,:] # 40th patient channels
correlation_matrix1 = np.corrcoef(eeg1)
correlation_matrix2 = np.corrcoef(eeg2)
# Heatmap
plt.figure(figsize=(15, 10))
sns.heatmap(correlation_matrix1,
cmap='coolwarm',
center=0,
annot=True,
fmt='.2f',
square=True)
plt.title('EEG Channel Correlation Matrix - Healthy')
plt.show()
plt.figure(figsize=(15, 10))
sns.heatmap(correlation_matrix2,
cmap='coolwarm',
center=0,
annot=True,
fmt='.2f',
square=True)
plt.title('EEG Channel Correlation Matrix - Schizophrenia')
plt.show()
# 2) Recurrence Plot
"""
It is used for nonlinear time series analysis.
The plot shows how much the signal revisits it's previous states.
"""
signal1 = filtered_normal[30, 2, :] # F4 channel from previous plots.
signal2 = filtered_patient[39, 2, :]
signal_reshaped1 = signal1.reshape(-1, 1)
distance_matrix1 = squareform(pdist(signal_reshaped1)) # The distance matrix
threshold1 = 0.1 * np.max(distance_matrix1) # Threshold 10% of the maximum distance
recurrence_matrix1 = distance_matrix1 < threshold1 # recurrence matrix
plt.figure(figsize=(15, 10))
plt.imshow(recurrence_matrix1, cmap='binary', interpolation='nearest')
plt.colorbar(label='Recurrence')
plt.xlabel('Time Index')
plt.ylabel('Time Index')
plt.title('Recurrence Plot of Channel F4 - Healthy')
plt.tight_layout()
plt.show()
signal_reshaped2 = signal2.reshape(-1, 1)
distance_matrix2 = squareform(pdist(signal_reshaped2))
threshold2 = 0.1 * np.max(distance_matrix2)
recurrence_matrix2 = distance_matrix2 < threshold2
plt.figure(figsize=(15, 10))
plt.imshow(recurrence_matrix2, cmap='binary', interpolation='nearest')
plt.colorbar(label='Recurrence')
plt.xlabel('Time Index')
plt.ylabel('Time Index')
plt.title('Recurrence Plot of Channel F4 - Schizophrenia')
plt.tight_layout()
plt.show()
#%% 6) Multi-level Decomposition by DWT
A5, D5, D4, D3, D2, D1 = pywt.wavedec(signal1, 'db19', level=5) # Perform wavelet decomposition
"""
________Multi-level Decomposition by DWT____________
Original Signal [0-128Hz]
___|________
Frequency content | |
halves in each step | |
A1[0-64Hz] D1[64-128Hz]
___|_______
Also the | |
downsampling | |
occurs in A2[0-32Hz] D2[32-64Hz]
each step ___|________
| |
| |
A3[0-16Hz] D3[16-32Hz]
___|_________
| |
| |
A4[0-8Hz] D4[8-16Hz]
___|___________
| |
| |
A5[0-4Hz] D5[4-8Hz]
Here A5 can be thought as delta, D5 is theta,
D4 is alpha, D3 is beta, and D2 is gamma waves.
"""
# Visualizations of decomposed bands
plt.figure(figsize=(15, 15))
plt.subplot(3, 2, 1)
plt.plot(D1)
plt.ylim(-2000, 2000)
plt.title("D1 [64-128Hz]")
plt.subplot(3, 2, 2)
plt.plot(D2)
plt.ylim(-2000, 2000)
plt.title("D2 [32-64Hz]")
plt.subplot(3, 2, 3)
plt.plot(D3)
plt.ylim(-2000, 2000)
plt.title("D3 [16-32Hz]")
plt.subplot(3, 2, 4)
plt.plot(D4)
plt.ylim(-2000, 2000)
plt.title("D4 [8-16Hz]")
plt.subplot(3, 2, 5)
plt.plot(D5)
plt.ylim(-2000, 2000)
plt.title("D5 [4-8Hz]")
plt.subplot(3, 2, 6)
plt.plot(A5)
plt.ylim(-2000, 2000)
plt.title("A5 [0-4Hz]")
plt.show()
#%% 7.1) DWT Feature Extraction and Data Transformation
"""
After all the pre-processing and the observations, now the data should be prepared for the model input.
Many Feature Extraction, Data Transformation or Dimensionality Reduction methods can be applied such as:
Multi-level Decomposition by DWT, Time and Frequency statistics, EMD, PCA, Spectrogram, Recurrence Plot or similar.
"""
from scipy import stats, integrate
dataset = np.concatenate((filtered_normal, filtered_patient))
#1) Extract statistical features from DWT decomposition coefficients
def extract_dwt_features(x):
dwt = pywt.wavedec(x, 'db19', level=5)
A5, D5, D4, D3, D2, D1 = dwt
energies = [np.sum(np.square(ii)) for ii in dwt] # energy for wavelet bands
total = np.sum(energies) # Total energy
# Relative energy between consecutive wavelet bands
bandratio = [energies[i] / energies[i + 1] for i in range(len(energies) - 1)]
features = []
# Process for each coefficient band (A5, D5, D4, D3, D2)
for band in [A5, D5, D4, D3, D2]:
avg_integrate = (integrate.simps(np.abs(A5))+integrate.simps(np.abs(D5))+
integrate.simps(np.abs(D4))+integrate.simps(np.abs(D3))+
integrate.simps(np.abs(D2))
)/5
freqs, psd = welch(band, fs=128)
# abs(band) /np.mean(abs(band))
features.extend([
(np.sum(freqs * psd) / np.sum(psd))/bandratio[4],
kurtosis(band),
(freqs[np.cumsum(psd) >= 0.5 * np.sum(psd)][0])/bandratio[4],
integrate.simps(np.abs(band))/avg_integrate
])
features.extend([
bandratio[0], # Band energy ratios
bandratio[1],
bandratio[2],
bandratio[3],
bandratio[4]
])
return np.array(features)
example1 = extract_dwt_features(signal1)
example2 = extract_dwt_features(signal2)
plt.figure(figsize=(10,10))
plt.ylabel("Feature Values")
plt.xlabel("Feature Number")
plt.title("Feature Vector")
plt.ylim(-1, 16)
plt.plot(example1)
plt.show()
plt.figure(figsize=(10,10))
plt.ylabel("Feature Values")
plt.xlabel("Feature Number")
plt.title("Feature Vector")
plt.ylim(-1, 16)
plt.plot(example2)
plt.show()
#2) Apply the transformation to whole data
dwt_data = np.zeros((84, 16, 25)) # consider the size of the feature vector
for ii in range(84):
for jj in range(16):
dwt_data[ii, jj, :] = extract_dwt_features(dataset[ii, jj, :])
#%% 7.2) SVM Grid Search
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import minmax_scale , StandardScaler
from sklearn.metrics import accuracy_score , f1_score , recall_score, precision_score , confusion_matrix
from sklearn.model_selection import GridSearchCV
# Matrices are flattened to a 1D vector
x = dwt_data.reshape(84, -1)
y = np.array([0] * 39 + [1] * 45) # Binary labels
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42, stratify=y)
mymodel = svm.SVC(probability=True) # Classifier model
mymodel.get_params() # Look for which parameters can be choosen
'''
Grid search is about simply iterating all sets of values defined, to obtain best performed model.
It also uses Cross-Validation. And parameter grid need to be defined.
'''
# Define the parameter grid
param_grid = {
'C': [0.1, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9],
'kernel': ['linear', 'rbf', 'sigmoid', 'poly'],
'gamma': ['scale', 'auto'],