Skip to content

Commit

Permalink
dll bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zikai committed Nov 14, 2024
1 parent a549a80 commit 3bfb55a
Show file tree
Hide file tree
Showing 19 changed files with 716 additions and 141 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore

Wearable/
*.csv
*.mat
*.txt

# User-specific files
*.rsuser
Expand Down
2 changes: 1 addition & 1 deletion Cheby1Filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Cheby1Filter::Cheby1Filter(){}

// type: Lowpass, Highpass, bandPass, bandStop
Cheby1Filter::Cheby1Filter(int order, int ripple, double wn1, double wn2, double srate, char type) {
Cheby1Filter::Cheby1Filter(int order, double ripple, double wn1, double wn2, double srate, char type) {
order_ = order;
ripple_ = ripple;
srate_ = srate;
Expand Down
4 changes: 2 additions & 2 deletions Cheby1Filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
#define M_PI 3.1415926

class Cheby1Filter {
public:
public:
Eigen::Tensor<double, 1> b_;
Eigen::Tensor<double, 1> a_;
Cheby1Filter();
Cheby1Filter(int order, int ripple, double wn1, double wn2, double srate, char type);
Cheby1Filter(int order, double ripple, double wn1, double wn2, double srate, char type);

private:
int order_;
Expand Down
4 changes: 3 additions & 1 deletion Preprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ Preprocess::Preprocess(int s_rate, int subbands, int electrodes, int num_samples
// Lowpass, Highpass, bandPass, bandStop
bsf_ = std::make_unique<Cheby1Filter>(4, 2, 47, 53, s_rate_, 's');
bpf_ = std::make_unique<Cheby1Filter[]>(subbands);
std::vector<int> passband = { 6, 14, 22, 30, 38, 46, 54, 62, 70, 78 };

for (int i = 0; i < subbands; i++) {
bpf_[i] = Cheby1Filter(4, 1, 9 * (i + 1), 90, s_rate_, 'p');
//bpf_[i] = Cheby1Filter(4, 1, 9 * (i + 1), 90, s_rate_, 'p');
bpf_[i] = Cheby1Filter(6, 0.5, passband[i], 90, s_rate_, 'p');
}
}

Expand Down
271 changes: 241 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,252 @@
## TRCA-cpp-src
- Cheby1Filter类: 带通、带阻滤波器设计

- Preprocess类: notch滤波、filterBank、detrend、均值计算、标准差计算

- TRCA类: TRCA算法实现
- utils: 包含数据函数、调试函数,其中一部分数据函数应该放到别的类里面
- vstudio默认gb2312编码, vs code默认utf8编码, 大部分含中文注释代码的文件已经转换为utf8编码

## TRCA-dll
- 具体实现参照[dll.cpp](./dll.cpp), python调用参照[dllvalid.py](./dllvalid.py)
- utils: 包含数据函数、调试函数
- TODO: 其中一部分数据函数应该放到别的类里面

- vstudio默认gb2312编码,vs code默认utf8编码,大部分含中文注释代码的文件已经转换为utf8编码

## TRCA.dll
- 具体实现参照[dll.cpp](./dll.cpp),python调用参照[dllValid.py](./dllValid.py)[onlineValid.py](./onlineValid.py)

### FilterBank
完成notch滤波、filterBank滤波(最多10组6阶带通,[{6, 14, 22, 30, 38, 46, 54, 62, 70, 78},90])、去直流、标准化操作,滤波器设计为Cheby1

- 输入(2指针、7int)
- darray: double指针(行优先的4D数组,[轮数,刺激数,电极数,信号点数]),一般情况下,轮数=1,刺激数=trials数即可

- dout: double指针(行优先的4D数组,[轮数*刺激数,filterBank数,电极数,信号点数]

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- len: int(轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1进入调试,将filterBank数据覆盖写入路径下csv文件

- 输出
- dout: 指针拷贝数据

- 返回错误码(还没做,下同)

### TrcaTrain
- 输入
- darray: double指针(行优先的4D数组, [训练轮数, 刺激数, 电极数, 信号点数])
- pTemplate: double指针(4D数组, [刺激数, filterBank数, 电极数, 信号点数])
- pU: double指针(4D数组, [filterBank数量, 刺激数, 电极数, 1])
- s_rate: int(采样率)
- subbands: int(filterBank数量)
- train_len: int(训练轮数)
- stimulus: int(刺激数)
- electrodes: int(电极数)
- num_samples: int(信号点数)
完整的TRCA训练流程,filterBank+TrcaTrainOnly

- 输入(3指针、7int)
- darray: double指针(行优先的4D数组,[训练轮数,刺激数,电极数,信号点数]

- pTemplate: double指针(4D数组,[刺激数,filterBank数,电极数,信号点数]),一般不考虑该数组的存储主序,因为会直接输入给Test使用,当前输出的是列优先的数据,**下同**

- pU: double指针(4D数组,[filterBank数量,刺激数,电极数,1]),一般不考虑该数组的存储主序,因为会直接输入给Test使用,当前输出的是列优先的数据,**下同**

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- train_len: int(训练轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1时将templates和u覆盖写入路径下csv文件,传入2时将input、filterBank数据、templates和u覆盖写入路径下csv文件

- 输出
- 计算得到的template和U通过memcpy的方式copy到pTemplate和pU地址上
- 返回错误码(还没做)
- pTemplate: 指针拷贝数据

- pU: 指针拷贝数据

- 返回错误码


### TrcaTrainOnly
只执行Trca训练功能

- 输入(3指针、7int)
- darray: double指针(行优先的4D数组,[训练轮数*刺激数,filterBank数,电极数,信号点数]),即filterBank完成后的数据

- pTemplate: double指针(4D数组,[刺激数,filterBank数,电极数,信号点数]

- pU: double指针(4D数组,[filterBank数量,刺激数,电极数,1]

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- train_len: int(训练轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1时将templates和u覆盖写入路径下csv文件,传入2时将input、filterBank数据、templates和u覆盖写入路径下csv文件

- 输出
- pTemplate: 指针拷贝数据

- pU: 指针拷贝数据

- 返回错误码

### TrcaTest
- 输入
- darray: double指针(行优先的3D数组, [测试次数, 电极数, 信号点数])
- pTemplate: train得到的指针
- pU: train得到的指针
- pPred: int指针(1D数组, [测试次数])
- s_rate: 同上
- subbands: 同上
- stimulus: 同上
- electrodes: 同上
- num_samples: 同上
完整的TRCA测试流程,filterBank+TrcaTestOnly

- 输入(5指针、7int)
- darray: double指针(行优先的4D数组,[1,测试次数,电极数,信号点数]

- pTemplate: double指针,train得到的指针

- pU: double指针,train得到的指针

- pcoeff: double指针(1D数组,[测试次数*stimulus]

- pPred: int指针(1D数组,[测试次数]

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- test_len: int(训练轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1时进入调试,将templates、u和filterBank数据覆盖写入路径下csv文件

- 输出
- pPred: 指针拷贝数据

- pcoeff: 指针拷贝数据

- 返回错误码

### TrcaTestOnly
只执行Trca测试功能

- 输入(5指针、7int)
- darray: double指针(行优先的4D数组,[1*测试次数,filterBank数,电极数,信号点数]

- pTemplate: double指针,train得到的指针

- pU: double指针,train得到的指针

- pcoeff: double指针(1D数组,[测试次数*stimulus]

- pPred: int指针(1D数组,[测试次数]

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- test_len: int(训练轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1时进入调试,将templates、u和输入数据覆盖写入路径下csv文件

- 输出
- pPred: 指针拷贝数据

- pcoeff: 指针拷贝数据

- 返回错误码

### TrcaTestCsv
完整的TRCA测试流程,filterBank+TrcaTestOnly,使用csv文件输入template和u

- 输入(5指针、7int)
- darray: double指针(行优先的4D数组,[1,测试次数,电极数,信号点数]

- pTemplate: char指针,存放templates的csv文件路径

- pU: char指针,存放u的csv文件路径

- pcoeff: double指针(1D数组,[测试次数*stimulus]

- pPred: int指针(1D数组,[测试次数]

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- test_len: int(训练轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1时进入调试,将templates、u和filterBank数据覆盖写入路径下csv文件

- 输出
- 计算得到的标签通过memcpy方式拷贝到pPred
- 返回错误码(还没做)
- pPred: 指针拷贝数据

- pcoeff: 指针拷贝数据

- 返回错误码

### TrcaTestOnlyCsv
只执行Trca测试功能,使用csv文件输入template和u

- 输入(5指针、7int)
- darray: double指针(行优先的4D数组,[1,测试次数,电极数,信号点数]

- pTemplate: char指针,存放templates的csv文件路径

- pU: char指针,存放u的csv文件路径

- pcoeff: double指针(1D数组,[测试次数*stimulus]

- pPred: int指针(1D数组,[测试次数]

- s_rate: int(采样率)

- subbands: int(filterBank数量)

- test_len: int(训练轮数)

- stimulus: int(刺激数)

- electrodes: int(电极数)

- num_samples: int(信号点数)

- debug: int,传入1时进入调试,将templates、u和filterBank数据覆盖写入路径下csv文件

- 输出
- pPred: 指针拷贝数据

- pcoeff: 指针拷贝数据

- 返回错误码


## 测试
- dllvalid.py: 使用[SSVEP-AnaTool](https://github.com/pikipity/SSVEP-Analysis-Toolbox)测试, 需确认lib版本, 也可使用本repo中提供的lib
- 使用Wearable-SSVEP(wet)数据集测试, 其中使用SSVEPAnalysisToolbox库测试时, 需要确认库get_data方法截取的数据是否正确, 需比照lib代码和数据集说明
- dllvalid.py: 使用[SSVEP-AnaTool]https://github.com/pikipity/SSVEP-Analysis-Toolbox)测试,需确认toolbox版本,也可使用本repo中提供的toolbox

- 使用Wearable-SSVEP(wet)数据集测试,其中使用SSVEPAnalysisToolbox库测试时,需要确认库get_data方法截取的数据是否正确,需比照toolbox代码和数据集说明

- 测试数据重排请使用可控的循环实现,避免调用transpose等api
6 changes: 4 additions & 2 deletions TRCA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Trca::Trca(int subbands, int stimulus, int electrodes, int num_samples, int trai
}
else {
for (int i = 0; i < subbands_; i++) {
filter_banks_weights(i) = pow(i+1, -1.25) + 0.15;
filter_banks_weights(i) = pow(i+1, -1.25) + 0.25;
}
}
filter_banks_weights_ = filter_banks_weights;
Expand Down Expand Up @@ -86,9 +86,10 @@ Eigen::Tensor<double, 2> Trca::trcaU(const Eigen::Tensor<double, 3>& trials) con
}

Eigen::Tensor<int, 1> Trca::predict(const Eigen::Tensor<double, 4>& trials, const Eigen::Tensor<double, 4>& templates,
const Eigen::Tensor<double, 4>& U, const Eigen::Tensor<double, 4>& V) const {
const Eigen::Tensor<double, 4>& U, const Eigen::Tensor<double, 4>& V, std::vector<double> &coeff) const {
Eigen::Tensor<int, 1> pred_labels(trials.dimension(0));
Eigen::array<Eigen::IndexPair<int>, 1> product_dims = { Eigen::IndexPair<int>(1, 0) };

for (int i = 0; i < trials.dimension(0); i++) {
Eigen::Tensor<double, 2> r = tensor1to2(filter_banks_weights_).
contract(canoncorrWithUV(trials.chip<0>(i), templates, U, V), product_dims);
Expand All @@ -97,6 +98,7 @@ Eigen::Tensor<int, 1> Trca::predict(const Eigen::Tensor<double, 4>& trials, cons
if (r(j) == max_coeff(0)) {
pred_labels(i) = j;
}
coeff.push_back(r(j));
}
}
return pred_labels;
Expand Down
Binary file removed TRCA.dll
Binary file not shown.
4 changes: 2 additions & 2 deletions TRCA.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
class Trca {
public:
~Trca();
Trca(int subbands, int stimulus, int electrodes, int num_samples, int train_len=1, int fb_weights_type=1);
Trca(int subbands, int stimulus, int electrodes, int num_samples, int train_len=1, int fb_weights_type=0);
Eigen::Tensor<double, 4> fit(const Eigen::Tensor<double, 4>& trials, const Eigen::Tensor<double, 4>& templates);
Eigen::Tensor<int, 1> predict(const Eigen::Tensor<double, 4>& trials, const Eigen::Tensor<double, 4>& templates,
const Eigen::Tensor<double, 4>& U, const Eigen::Tensor<double, 4>& V) const;
const Eigen::Tensor<double, 4>& U, const Eigen::Tensor<double, 4>& V, std::vector<double>& coeff) const;

private:
int subbands_;
Expand Down
10 changes: 10 additions & 0 deletions TRCA.sln
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ VisualStudioVersion = 17.7.34024.191
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TRCA", "TRCA.vcxproj", "{95A805F6-EFC1-4F9F-ACD5-C88D3F110DA2}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "test", "..\test\test.vcxproj", "{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
Expand All @@ -21,6 +23,14 @@ Global
{95A805F6-EFC1-4F9F-ACD5-C88D3F110DA2}.Release|x64.Build.0 = Release|x64
{95A805F6-EFC1-4F9F-ACD5-C88D3F110DA2}.Release|x86.ActiveCfg = Release|Win32
{95A805F6-EFC1-4F9F-ACD5-C88D3F110DA2}.Release|x86.Build.0 = Release|Win32
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Debug|x64.ActiveCfg = Debug|x64
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Debug|x64.Build.0 = Debug|x64
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Debug|x86.ActiveCfg = Debug|Win32
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Debug|x86.Build.0 = Debug|Win32
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Release|x64.ActiveCfg = Release|x64
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Release|x64.Build.0 = Release|x64
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Release|x86.ActiveCfg = Release|Win32
{FB92ABE4-E904-46F7-9BAB-7B92CA3352A3}.Release|x86.Build.0 = Release|Win32
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
Loading

0 comments on commit 3bfb55a

Please sign in to comment.