Skip to content

Commit

Permalink
trca dll release
Browse files Browse the repository at this point in the history
  • Loading branch information
zikai committed Nov 12, 2024
1 parent d929598 commit a549a80
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 40 deletions.
84 changes: 84 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
"files.associations": {
"memory": "cpp",
"algorithm": "cpp",
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"cctype": "cpp",
"charconv": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"complex": "cpp",
"concepts": "cpp",
"condition_variable": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"deque": "cpp",
"exception": "cpp",
"format": "cpp",
"forward_list": "cpp",
"fstream": "cpp",
"functional": "cpp",
"future": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"ios": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"iterator": "cpp",
"limits": "cpp",
"list": "cpp",
"locale": "cpp",
"map": "cpp",
"mutex": "cpp",
"new": "cpp",
"numeric": "cpp",
"optional": "cpp",
"ostream": "cpp",
"queue": "cpp",
"random": "cpp",
"ranges": "cpp",
"ratio": "cpp",
"set": "cpp",
"span": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"stop_token": "cpp",
"streambuf": "cpp",
"string": "cpp",
"system_error": "cpp",
"thread": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"typeinfo": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"utility": "cpp",
"valarray": "cpp",
"vector": "cpp",
"xfacet": "cpp",
"xhash": "cpp",
"xiosbase": "cpp",
"xlocale": "cpp",
"xlocbuf": "cpp",
"xlocinfo": "cpp",
"xlocmes": "cpp",
"xlocmon": "cpp",
"xlocnum": "cpp",
"xloctime": "cpp",
"xmemory": "cpp",
"xstring": "cpp",
"xtr1common": "cpp",
"xtree": "cpp",
"xutility": "cpp"
}
}
2 changes: 1 addition & 1 deletion Preprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Preprocess::Preprocess(int s_rate, int subbands, int electrodes, int num_samples
subbands_ = subbands;
electrodes_ = electrodes;
num_samples_ = num_samples;
// 写死参数,可以改为config.ini配置.
// 写死参数,可以改为config.ini配置.
// Lowpass, Highpass, bandPass, bandStop
bsf_ = std::make_unique<Cheby1Filter>(4, 2, 47, 53, s_rate_, 's');
bpf_ = std::make_unique<Cheby1Filter[]>(subbands);
Expand Down
48 changes: 40 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
### TRCA-cpp
- Cheby1Filter类: 带通、带阻滤波器设计
- Preprocess类: notch滤波、filterBank、detrend、均值计算、标准差计算
- TRCA类: TRCA算法实现
- utils: 包含数据函数、调试函数,其中一部分数据函数应该放到别的类里面
## TRCA-cpp-src
- Cheby1Filter类: 带通、带阻滤波器设计
- Preprocess类: notch滤波、filterBank、detrend、均值计算、标准差计算
- TRCA类: TRCA算法实现
- utils: 包含数据函数、调试函数,其中一部分数据函数应该放到别的类里面
- vstudio默认gb2312编码, vs code默认utf8编码, 大部分含中文注释代码的文件已经转换为utf8编码

### 测试
- dllvalid.py: 使用[SSVEP-AnaTool](https://github.com/pikipity/SSVEP-Analysis-Toolbox)测试, 需确认lib版本, 也可使用本repo中提供的lib
- 使用Wearable-SSVEP(wet)数据集测试, 其中使用SSVEPAnalysisToolbox库测试时, 需要确认库get_data方法截取的数据是否正确, 需比照lib代码和数据集说明
## TRCA-dll
- 具体实现参照[dll.cpp](./dll.cpp), python调用参照[dllvalid.py](./dllvalid.py)
### TrcaTrain
- 输入
- darray: double指针(行优先的4D数组, [训练轮数, 刺激数, 电极数, 信号点数])
- pTemplate: double指针(4D数组, [刺激数, filterBank数, 电极数, 信号点数])
- pU: double指针(4D数组, [filterBank数量, 刺激数, 电极数, 1])
- s_rate: int(采样率)
- subbands: int(filterBank数量)
- train_len: int(训练轮数)
- stimulus: int(刺激数)
- electrodes: int(电极数)
- num_samples: int(信号点数)
- 输出
- 计算得到的template和U通过memcpy的方式copy到pTemplate和pU地址上
- 返回错误码(还没做)
### TrcaTest
- 输入
- darray: double指针(行优先的3D数组, [测试次数, 电极数, 信号点数])
- pTemplate: train得到的指针
- pU: train得到的指针
- pPred: int指针(1D数组, [测试次数])
- s_rate: 同上
- subbands: 同上
- stimulus: 同上
- electrodes: 同上
- num_samples: 同上
- 输出
- 计算得到的标签通过memcpy方式拷贝到pPred
- 返回错误码(还没做)

## 测试
- dllvalid.py: 使用[SSVEP-AnaTool](https://github.com/pikipity/SSVEP-Analysis-Toolbox)测试, 需确认lib版本, 也可使用本repo中提供的lib
- 使用Wearable-SSVEP(wet)数据集测试, 其中使用SSVEPAnalysisToolbox库测试时, 需要确认库get_data方法截取的数据是否正确, 需比照lib代码和数据集说明
Binary file modified TRCA.dll
Binary file not shown.
3 changes: 0 additions & 3 deletions TRCA.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@
<ClInclude Include="TRCA.h" />
<ClInclude Include="utils.h" />
</ItemGroup>
<ItemGroup>
<None Include="Readme.md" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
Expand Down
3 changes: 0 additions & 3 deletions TRCA.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,4 @@
<Filter>头文件</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="Readme.md" />
</ItemGroup>
</Project>
24 changes: 12 additions & 12 deletions dll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ extern "C" __declspec(dllexport) int TrcaTrain(double* darray, double* pTemplate
Eigen::Tensor<double, 4> templates = Eigen::Tensor<double, 4>(stimulus, subbands, electrodes, num_samples);
Eigen::Tensor<double, 4> U_trca_ = Eigen::Tensor<double, 4>(subbands, stimulus, electrodes, 1);

//行优先->列优先:转置
//列优先->行优先:reshape,其中double需要cast到float上面才能reshape,
// reshape之后要赋值给tensor float变量之后才能cast到double,原因是reshape传回参数不能被cast解析
//行优先->列优先:转置
//列优先->行优先:reshape,其中double需要cast到float上面才能reshape,
// reshape之后要赋值给tensor float变量之后才能cast到double,原因是reshape传回参数不能被cast解析
Eigen::Tensor<double, 4, Eigen::RowMajor> dtensor = Eigen::TensorMap<Eigen::Tensor<double, 4, Eigen::RowMajor>>(
darray, train_len, stimulus, electrodes, num_samples);
Eigen::Tensor<double, 4> input = dtensor.swap_layout().shuffle(Eigen::array<int, 4>{3,2,1,0});

//@zikai 4维:(训练次数,目标数,电极通道数,单通道数据)
//@zikai train_trials init,优化trials分割
//@zikai 4维:(训练次数,目标数,电极通道数,单通道数据)
//@zikai train_trials init,优化trials分割
for (int block = 0; block < train_len; block++) {
for (int i = 0; i < stimulus; i++) {
Eigen::Tensor<double, 2> single_trial = pe->notch(input.chip(block, 0).chip(i, 0));
Expand All @@ -35,23 +35,23 @@ extern "C" __declspec(dllexport) int TrcaTrain(double* darray, double* pTemplate
}

extern "C" __declspec(dllexport) int TrcaTest(double* darray, double* pTemplate, double* pU, int* pPred,
int s_rate, int subbands, int stimulus, int electrodes, int num_samples)
int s_rate, int subbands, int test_len, int stimulus, int electrodes, int num_samples)
{
std::unique_ptr<Preprocess> pe = std::make_unique<Preprocess>(s_rate, subbands, electrodes, num_samples);
std::unique_ptr<Trca> te = std::make_unique<Trca>(subbands, stimulus, electrodes, num_samples);
Eigen::Tensor<double, 4> test_trial = Eigen::Tensor<double, 4>(stimulus, subbands, electrodes, num_samples);
Eigen::Tensor<double, 4> test_trial = Eigen::Tensor<double, 4>(test_len, subbands, electrodes, num_samples);

Eigen::Tensor<double, 4, Eigen::RowMajor> dtensor = Eigen::TensorMap<Eigen::Tensor<double, 4, Eigen::RowMajor>>(
darray, 1, stimulus, electrodes, num_samples);
Eigen::Tensor<double, 4> input = dtensor.swap_layout().shuffle(Eigen::array<int, 4>{3, 2, 1, 0});
Eigen::Tensor<double, 3, Eigen::RowMajor> dtensor = Eigen::TensorMap<Eigen::Tensor<double, 3, Eigen::RowMajor>>(
darray, test_len, electrodes, num_samples);
Eigen::Tensor<double, 3> input = dtensor.swap_layout().shuffle(Eigen::array<int, 3>{2, 1, 0});
Eigen::Tensor<double, 4> templates = Eigen::TensorMap<Eigen::Tensor<double, 4>>(
pTemplate, subbands, stimulus, electrodes, num_samples);
Eigen::Tensor<double, 4> U_trca = Eigen::TensorMap<Eigen::Tensor<double, 4>>(
pU, subbands, stimulus, electrodes, 1);

//@zikai test_trials init
for (int i = 0; i < stimulus; i++) {
Eigen::Tensor<double, 2> single_trial = pe->notch(input.chip(0, 0).chip(i, 0));
for (int i = 0; i < test_len; i++) {
Eigen::Tensor<double, 2> single_trial = pe->notch(input.chip(i, 0));
test_trial.chip<0>(i) = pe->filterBank(single_trial);
}

Expand Down
2 changes: 1 addition & 1 deletion dll.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
extern "C" __declspec(dllexport) int TrcaTrain(double* darray, double* pTemplate, double* pU,
int s_rate, int subbands, int train_len, int stimulus, int electrodes, int num_samples);
extern "C" __declspec(dllexport) int TrcaTest(double* darray, double* pTemplate, double* pU, int* pPred,
int s_rate, int subbands, int stimulus, int electrodes, int num_samples);
int s_rate, int subbands, int test_len, int stimulus, int electrodes, int num_samples);

#endif // DLL_H
24 changes: 14 additions & 10 deletions dllvalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,24 @@ def dll():
channels=ch_used,
sig_len=tw)
arr = np.array(X_test)
X_test = arr.reshape((1, 12, 8, 500))
Pred = np.empty((12), dtype=int)
arr = arr.reshape((1, 12, 8, 500)).squeeze()
ans = []
for i in range(0, 12):
X_test = arr[i, :, :]
Pred = np.empty((1), dtype=int)

pX_test = X_test.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
dPred = Pred.ctypes.data_as(ctypes.POINTER(ctypes.c_int))
dll.TrcaTest(pX_test, dTemplate, dU, dPred, 250, 5, 12, 8, 500)
Pred = np.ctypeslib.as_array(ctypes.cast(dPred, ctypes.POINTER(ctypes.c_int)), Pred.shape)
acc = cal_acc(Y_true=Y_test, Y_pred=Pred)
pX_test = X_test.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
dPred = Pred.ctypes.data_as(ctypes.POINTER(ctypes.c_int))
dll.TrcaTest(pX_test, dTemplate, dU, dPred, 250, 5, 1, 12, 8, 500)
Pred = np.ctypeslib.as_array(ctypes.cast(dPred, ctypes.POINTER(ctypes.c_int)), Pred.shape)
ans.append(Pred[0])
acc = cal_acc(Y_true=Y_test, Y_pred=ans)

return Pred, acc
return ans, acc


RUN_TEST = 1
RUN_ORI = 1
RUN_ORI = 0
RUN_BOTH = 0

if __name__ == '__main__':
Expand All @@ -151,7 +155,7 @@ def dll():
dllAcc=0
oriTime=0
oriAcc=0
for sub_idx in range(54, 55):
for sub_idx in range(1, 101):
print(sub_idx)

if RUN_TEST or RUN_BOTH:
Expand Down
4 changes: 2 additions & 2 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
/*void trcaTest(int subject, std::string path);*/

int main() {
//@zikai 23.12.2 内存释放问题
//@zikai 23.12.2 内存释放问题
/*getchar();
for (int subject = 20; subject < 30; subject++) {
std::cout << "S0" << subject << std::endl;
Expand All @@ -25,7 +25,7 @@ void trcaTest(int subject, std::string path) {
pe = new PreprocessEngine(data);
te = new TrcaEngine(data);
//@zikai 4维:(训练次数,目标数,电极通道数,单通道数据)
//@zikai 4维:(训练次数,目标数,电极通道数,单通道数据)
auto start = std::chrono::high_resolution_clock::now();
for (int block = 0; block < data->train_len_; block++) {
for (int stimulus = 0; stimulus < data->stimulus_; stimulus++) {
Expand Down

0 comments on commit a549a80

Please sign in to comment.