diff --git a/include/AddTaskDispatcher.h b/include/AddTaskDispatcher.h deleted file mode 100644 index 197602a..0000000 --- a/include/AddTaskDispatcher.h +++ /dev/null @@ -1,28 +0,0 @@ -/************************************************************************* - > File Name: ADD_TaskDispatcher.h - > Author: - > Mail: - > Created Time: Mon 13 Apr 2015 03:23:50 AM EDT - ************************************************************************/ - -#ifndef GET_ADD_TASKDISPATCHER_H -#define GET_ADD_TASKDISPATCHER_H -#include "TaskDispatcher.h" - -template -class AddTaskDispatcher : public TaskDispatcher -{ - public: - ~AddTaskDispatcher() {}; - explicit AddTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) - { - add_param_ = task_param.add_param(); - } - - void PreTaskDispatch(); - void TaskDispatch(); - - protected: - GET::AddParam add_param_; -}; -#endif diff --git a/include/BaseTask.h b/include/BaseTask.h index 91a68f3..0968cae 100644 --- a/include/BaseTask.h +++ b/include/BaseTask.h @@ -11,7 +11,8 @@ using namespace std; #include"BaseDevice.h" #include"DataBlob.h" #include - +#include +#include"proto/GET.pb.h" template class BaseTask @@ -29,12 +30,6 @@ class BaseTask virtual void PostCompute() {}; //virtual void SetParams(TaskParam param); - inline void TaskOn() - { - PreCompute(); - Compute(); - PostCompute(); - } protected: vector > datas_; @@ -43,4 +38,232 @@ class BaseTask vector device_buffers_; }; + +template +static void * +BaseTaskOn(void *args) +{ + BaseTask* task = (BaseTask*)args; + task->PreCompute(); + task->Compute(); + task->PostCompute(); + pthread_exit(NULL); +} + +template +class AddTask : public BaseTask +{ + public: + ~AddTask() {}; + //explicit AddTask(BaseDevice* device) {device_ = device;} + AddTask(BaseDevice* device): + BaseTask(device) + { + + } + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::AddParam param) {}; + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = (cl_int)height; + width_ = (cl_int)width; + } + + protected: + cl_int height_, width_; + int channels_; +}; + +template +class SubTask : public BaseTask +{ + public: + ~SubTask() {}; + //explicit AddTask(BaseDevice* device) {device_ = device;} + SubTask(BaseDevice* device): + BaseTask(device) + { + + } + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::SubParam param) {}; + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = (cl_int)height; + width_ = (cl_int)width; + } + + protected: + cl_int height_, width_; + int channels_; +}; + +template +class MulTask : public BaseTask +{ + public: + ~MulTask() {}; + //explicit AddTask(BaseDevice* device) {device_ = device;} + MulTask(BaseDevice* device): + BaseTask(device) + { + + } + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::MulParam param) {}; + void SetParams(int channels, int M, int K, int N) + { + channels_ = channels; + M_ = (cl_uint)M; + K_ = (cl_uint)K; + N_ = (cl_uint)N; + } + + protected: + cl_uint M_, K_, N_; + int channels_; +}; + + +template +class ConvTask : public BaseTask +{ + public: + ~ConvTask() {}; + ConvTask(BaseDevice* device): + BaseTask(device) + { + + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::ConvParam param); + void SetParams(int channels, int data_h, int data_w, int filter_h, int filter_w, int stride_h, int stride_w, int pad_h, int pad_w) + { + channels_ = (cl_uint)channels; + data_h_ = (cl_uint)data_h; + data_w_ = (cl_uint)data_w; + filter_h_ = (cl_uint)filter_h; + filter_w_ = (cl_uint)filter_w; + stride_h_ = (cl_uint)stride_h; + stride_w_ = (cl_uint)stride_w; + pad_h_ = (cl_uint)pad_h; + pad_w_ = (cl_uint)pad_w; + output_h_ = (data_h_ - filter_h_) / stride_h_ + 1; + output_w_ = (data_w_ - filter_w_) / stride_w_ + 1; + } + + protected: + cl_uint data_h_, data_w_; + cl_uint filter_h_, filter_w_; + cl_uint stride_h_, stride_w_; + cl_uint pad_h_, pad_w_; + cl_uint output_h_, output_w_; + cl_uint channels_; + +}; + +template +class PoolTask : public BaseTask +{ + public: + ~PoolTask() {}; + PoolTask(BaseDevice* device): + BaseTask(device) + { + + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::PoolParam param); + void SetParams(int channels, int data_h, int data_w, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w) + { + channels_ = (cl_uint)channels; + data_h_ = (cl_uint)data_h; + data_w_ = (cl_uint)data_w; + kernel_h_ = (cl_uint)kernel_h; + kernel_w_ = (cl_uint)kernel_w; + stride_h_ = (cl_uint)stride_h; + stride_w_ = (cl_uint)stride_w; + pad_h_ = (cl_uint)pad_h; + pad_w_ = (cl_uint)pad_w; + output_h_ = (data_h_ - kernel_h_) / stride_h_ + 1; + output_w_ = (data_w_ - kernel_w_) / stride_w_ + 1; + } + + protected: + cl_uint data_h_, data_w_; + cl_uint kernel_h_, kernel_w_; + cl_uint stride_h_, stride_w_; + cl_uint pad_h_, pad_w_; + cl_uint output_h_, output_w_; + cl_uint channels_; + +}; + +template +class ReLUTask : public BaseTask +{ + public: + ~ReLUTask() {}; + //explicit AddTask(BaseDevice* device) {device_ = device;} + ReLUTask(BaseDevice* device): + BaseTask(device) + { + + } + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::ReLUParam param) {}; + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = (cl_int)height; + width_ = (cl_int)width; + } + + protected: + cl_int height_, width_; + int channels_; +}; + +template +class SigmoidTask : public BaseTask +{ + public: + ~SigmoidTask() {}; + //explicit AddTask(BaseDevice* device) {device_ = device;} + SigmoidTask(BaseDevice* device): + BaseTask(device) + { + + } + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(GET::ReLUParam param) {}; + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = (cl_int)height; + width_ = (cl_int)width; + } + + protected: + cl_int height_, width_; + int channels_; +}; #endif diff --git a/include/CommonTask.h b/include/CommonTask.h deleted file mode 100644 index 4549873..0000000 --- a/include/CommonTask.h +++ /dev/null @@ -1,13 +0,0 @@ -/************************************************************************* - > File Name: CommonTask.h - > Author: - > Mail: - > Created Time: Sun 19 Apr 2015 09:58:51 PM EDT - ************************************************************************/ - -#ifndef GET_COMMONTASK_H -#define GET_COMMONTASK_H -#include"AddTask.h" -#include"SubTask.h" -#include"AddStreamTask.h" -#endif diff --git a/include/StreamTask.h b/include/StreamTask.h index 64dedc0..817b697 100644 --- a/include/StreamTask.h +++ b/include/StreamTask.h @@ -10,6 +10,7 @@ using namespace std; #include"BaseDevice.h" #include"DataBlob.h" +#include"proto/GET.pb.h" #include @@ -49,5 +50,212 @@ class StreamTask long data_per_block_; int blocks_num_; +}; + +template +class AddStreamTask: public StreamTask +{ + public: + ~AddStreamTask(); + explicit AddStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = height; + width_ = width; + } + protected: + cl_int height_, width_; + int channels_; + +}; + +template +class SubStreamTask: public StreamTask +{ + public: + ~SubStreamTask(); + explicit SubStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = height; + width_ = width; + } + protected: + cl_int height_, width_; + int channels_; + +}; + +template +class MulStreamTask: public StreamTask +{ + public: + ~MulStreamTask(); + explicit MulStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int M, int K, int N) + { + channels_ = channels; + M_ = (cl_uint)M; + K_ = (cl_uint)K; + N_ = (cl_uint)N; + } + protected: + cl_uint M_, K_, N_; + int channels_; + +}; + +template +class ConvStreamTask: public StreamTask +{ + public: + ~ConvStreamTask(); + explicit ConvStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int data_h, int data_w, int filter_h, int filter_w, int stride_h, int stride_w, int pad_h, int pad_w) + { + channels_ = (cl_uint)channels; + data_h_ = (cl_uint)data_h; + data_w_ = (cl_uint)data_w; + filter_h_ = (cl_uint)filter_h; + filter_w_ = (cl_uint)filter_w; + stride_h_ = (cl_uint)stride_h; + stride_w_ = (cl_uint)stride_w; + pad_h_ = (cl_uint)pad_h; + pad_w_ = (cl_uint)pad_w; + output_h_ = (data_h_ - filter_h_) / stride_h_ + 1; + output_w_ = (data_w_ - filter_w_) / stride_w_ + 1; + } + protected: + cl_uint data_h_, data_w_; + cl_uint filter_h_, filter_w_; + cl_uint stride_h_, stride_w_; + cl_uint pad_h_, pad_w_; + cl_uint output_h_, output_w_; + cl_uint channels_; + +}; + +template +class PoolStreamTask: public StreamTask +{ + public: + ~PoolStreamTask(); + explicit PoolStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int data_h, int data_w, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w) + { + channels_ = (cl_uint)channels; + data_h_ = (cl_uint)data_h; + data_w_ = (cl_uint)data_w; + kernel_h_ = (cl_uint)kernel_h; + kernel_w_ = (cl_uint)kernel_w; + stride_h_ = (cl_uint)stride_h; + stride_w_ = (cl_uint)stride_w; + pad_h_ = (cl_uint)pad_h; + pad_w_ = (cl_uint)pad_w; + output_h_ = (data_h_ - kernel_h_) / stride_h_ + 1; + output_w_ = (data_w_ - kernel_w_) / stride_w_ + 1; + } + protected: + cl_uint data_h_, data_w_; + cl_uint kernel_h_, kernel_w_; + cl_uint stride_h_, stride_w_; + cl_uint pad_h_, pad_w_; + cl_uint output_h_, output_w_; + cl_uint channels_; + +}; + +template +class ReLUStreamTask: public StreamTask +{ + public: + ~ReLUStreamTask(); + explicit ReLUStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = height; + width_ = width; + } + protected: + cl_int height_, width_; + int channels_; + +}; + +template +class SigmoidStreamTask: public StreamTask +{ + public: + ~SigmoidStreamTask(); + explicit SigmoidStreamTask(BaseDevice* device): + StreamTask(device) + { + this->data_per_block_ =device->GlobalMemory() / 6; + } + + void PreCompute(); + void Compute(); + void PostCompute(); + void SetParams(int channels, int height, int width) + { + channels_ = channels; + height_ = height; + width_ = width; + } + protected: + cl_int height_, width_; + int channels_; + }; #endif diff --git a/include/TaskDispatcher.h b/include/TaskDispatcher.h index 10bc111..53fd9e5 100644 --- a/include/TaskDispatcher.h +++ b/include/TaskDispatcher.h @@ -12,7 +12,7 @@ #include"BaseTask.h" #include"DeviceManager.h" #include"BaseDevice.h" -#include"CommonTask.h" +#include"StreamTask.h" #include enum TaskProcessType {UNINITIAL, ORDINARY, STREAM, HYBRID}; @@ -53,4 +53,123 @@ class TaskDispatcher TaskProcessType process_type_; }; + +template +class AddTaskDispatcher : public TaskDispatcher +{ + public: + ~AddTaskDispatcher() {}; + explicit AddTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + add_param_ = task_param.add_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::AddParam add_param_; +}; + +template +class SubTaskDispatcher : public TaskDispatcher +{ + public: + ~SubTaskDispatcher() {}; + explicit SubTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + sub_param_ = task_param.sub_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::SubParam sub_param_; +}; + +template +class MulTaskDispatcher : public TaskDispatcher +{ + public: + ~MulTaskDispatcher() {}; + explicit MulTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + mul_param_ = task_param.mul_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::MulParam mul_param_; +}; + +template +class ConvTaskDispatcher : public TaskDispatcher +{ + public: + ~ConvTaskDispatcher() {}; + explicit ConvTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + conv_param_ = task_param.conv_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::ConvParam conv_param_; +}; + +template +class PoolTaskDispatcher : public TaskDispatcher +{ + public: + ~PoolTaskDispatcher() {}; + explicit PoolTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + pool_param_ = task_param.pool_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::PoolParam pool_param_; +}; + +template +class ReLUTaskDispatcher : public TaskDispatcher +{ + public: + ~ReLUTaskDispatcher() {}; + explicit ReLUTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + relu_param_ = task_param.relu_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::ReLUParam relu_param_; +}; + +template +class SigmoidTaskDispatcher : public TaskDispatcher +{ + public: + ~SigmoidTaskDispatcher() {}; + explicit SigmoidTaskDispatcher(GET::TaskParam task_param, DeviceManager* device_manager):TaskDispatcher(task_param, device_manager) + { + sig_param_ = task_param.sigmoid_param(); + } + + void PreTaskDispatch(); + void TaskDispatch(); + + protected: + GET::SigmoidParam sig_param_; +}; #endif diff --git a/include/TaskManager.h b/include/TaskManager.h index 6d0e969..e6fb90e 100644 --- a/include/TaskManager.h +++ b/include/TaskManager.h @@ -26,11 +26,12 @@ class TaskManager int TaskOn(); void TaskRequestRemote(); void Init(); + int GetTaskDispatcher(GET::TaskParam param); typedef enum {TASKWAIT,TASKON,TASKFINISHED} TaskStatus; protected: - map > tasks_; + map* > tasks_; //map tasks_params_; map tasks_status_; DeviceManager device_manager_; diff --git a/include/AddStreamTask.h b/include/backup/AddStreamTask.h similarity index 100% rename from include/AddStreamTask.h rename to include/backup/AddStreamTask.h diff --git a/include/AddTask.h b/include/backup/AddTask.h.backup similarity index 100% rename from include/AddTask.h rename to include/backup/AddTask.h.backup diff --git a/include/SubTaskDispatcher.h b/include/backup/AddTaskDispatcher.h similarity index 100% rename from include/SubTaskDispatcher.h rename to include/backup/AddTaskDispatcher.h diff --git a/include/ConvTask.h b/include/backup/ConvTask.h similarity index 100% rename from include/ConvTask.h rename to include/backup/ConvTask.h diff --git a/include/SubTask.h b/include/backup/SubTask.h similarity index 100% rename from include/SubTask.h rename to include/backup/SubTask.h diff --git a/include/get.h b/include/get.h index ac9460e..49f2760 100644 --- a/include/get.h +++ b/include/get.h @@ -10,7 +10,8 @@ #include"TaskManager.h" #include"DeviceManager.h" #include"io.h" -#include"CommonTask.h" +#include"BaseTask.h" +#include"StreamTask.h" #include using namespace std; diff --git a/include/opencl_z.h b/include/opencl_z.h index a05f1af..39bbbcd 100644 --- a/include/opencl_z.h +++ b/include/opencl_z.h @@ -17,6 +17,8 @@ cl_pStatus(cl_int status, string funcName); char* cl_readSource(char* srcPath); +unsigned int +RoundUp(unsigned int value, unsigned int multiple); //float* //cl_readImagef_gray_cv(char* imgPath, int* img_width, int* img_height, void* Array); #endif diff --git a/include/proto/GET.pb.cc b/include/proto/GET.pb.cc index 0b62922..23dd232 100644 --- a/include/proto/GET.pb.cc +++ b/include/proto/GET.pb.cc @@ -34,6 +34,21 @@ const ::google::protobuf::internal::GeneratedMessageReflection* const ::google::protobuf::Descriptor* MulParam_descriptor_ = NULL; const ::google::protobuf::internal::GeneratedMessageReflection* MulParam_reflection_ = NULL; +const ::google::protobuf::Descriptor* ConvParam_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + ConvParam_reflection_ = NULL; +const ::google::protobuf::Descriptor* PoolParam_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + PoolParam_reflection_ = NULL; +const ::google::protobuf::Descriptor* LRNParam_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + LRNParam_reflection_ = NULL; +const ::google::protobuf::Descriptor* ReLUParam_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + ReLUParam_reflection_ = NULL; +const ::google::protobuf::Descriptor* SigmoidParam_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + SigmoidParam_reflection_ = NULL; } // namespace @@ -45,7 +60,7 @@ void protobuf_AssignDesc_GET_2eproto() { "GET.proto"); GOOGLE_CHECK(file != NULL); TaskParam_descriptor_ = file->message_type(0); - static const int TaskParam_offsets_[10] = { + static const int TaskParam_offsets_[15] = { GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, source_pos_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, sourcef_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, sourcem_), @@ -56,6 +71,11 @@ void protobuf_AssignDesc_GET_2eproto() { GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, add_param_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, sub_param_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, mul_param_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, conv_param_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, pool_param_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, lrn_param_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, relu_param_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(TaskParam, sigmoid_param_), }; TaskParam_reflection_ = new ::google::protobuf::internal::GeneratedMessageReflection( @@ -105,11 +125,10 @@ void protobuf_AssignDesc_GET_2eproto() { ::google::protobuf::MessageFactory::generated_factory(), sizeof(SubParam)); MulParam_descriptor_ = file->message_type(3); - static const int MulParam_offsets_[5] = { - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, height_a_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, width_a_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, height_b_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, width_b_), + static const int MulParam_offsets_[4] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, m_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, k_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, n_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(MulParam, channels_), }; MulParam_reflection_ = @@ -123,6 +142,101 @@ void protobuf_AssignDesc_GET_2eproto() { ::google::protobuf::DescriptorPool::generated_pool(), ::google::protobuf::MessageFactory::generated_factory(), sizeof(MulParam)); + ConvParam_descriptor_ = file->message_type(4); + static const int ConvParam_offsets_[9] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, data_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, data_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, filter_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, filter_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, stride_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, stride_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, pad_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, pad_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, channels_), + }; + ConvParam_reflection_ = + new ::google::protobuf::internal::GeneratedMessageReflection( + ConvParam_descriptor_, + ConvParam::default_instance_, + ConvParam_offsets_, + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, _has_bits_[0]), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvParam, _unknown_fields_), + -1, + ::google::protobuf::DescriptorPool::generated_pool(), + ::google::protobuf::MessageFactory::generated_factory(), + sizeof(ConvParam)); + PoolParam_descriptor_ = file->message_type(5); + static const int PoolParam_offsets_[9] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, data_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, data_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, kernel_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, kernel_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, stride_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, stride_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, pad_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, pad_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, channels_), + }; + PoolParam_reflection_ = + new ::google::protobuf::internal::GeneratedMessageReflection( + PoolParam_descriptor_, + PoolParam::default_instance_, + PoolParam_offsets_, + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, _has_bits_[0]), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(PoolParam, _unknown_fields_), + -1, + ::google::protobuf::DescriptorPool::generated_pool(), + ::google::protobuf::MessageFactory::generated_factory(), + sizeof(PoolParam)); + LRNParam_descriptor_ = file->message_type(6); + static const int LRNParam_offsets_[1] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(LRNParam, channels_), + }; + LRNParam_reflection_ = + new ::google::protobuf::internal::GeneratedMessageReflection( + LRNParam_descriptor_, + LRNParam::default_instance_, + LRNParam_offsets_, + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(LRNParam, _has_bits_[0]), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(LRNParam, _unknown_fields_), + -1, + ::google::protobuf::DescriptorPool::generated_pool(), + ::google::protobuf::MessageFactory::generated_factory(), + sizeof(LRNParam)); + ReLUParam_descriptor_ = file->message_type(7); + static const int ReLUParam_offsets_[3] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ReLUParam, height_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ReLUParam, width_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ReLUParam, channels_), + }; + ReLUParam_reflection_ = + new ::google::protobuf::internal::GeneratedMessageReflection( + ReLUParam_descriptor_, + ReLUParam::default_instance_, + ReLUParam_offsets_, + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ReLUParam, _has_bits_[0]), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ReLUParam, _unknown_fields_), + -1, + ::google::protobuf::DescriptorPool::generated_pool(), + ::google::protobuf::MessageFactory::generated_factory(), + sizeof(ReLUParam)); + SigmoidParam_descriptor_ = file->message_type(8); + static const int SigmoidParam_offsets_[3] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(SigmoidParam, height_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(SigmoidParam, width_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(SigmoidParam, channels_), + }; + SigmoidParam_reflection_ = + new ::google::protobuf::internal::GeneratedMessageReflection( + SigmoidParam_descriptor_, + SigmoidParam::default_instance_, + SigmoidParam_offsets_, + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(SigmoidParam, _has_bits_[0]), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(SigmoidParam, _unknown_fields_), + -1, + ::google::protobuf::DescriptorPool::generated_pool(), + ::google::protobuf::MessageFactory::generated_factory(), + sizeof(SigmoidParam)); } namespace { @@ -143,6 +257,16 @@ void protobuf_RegisterTypes(const ::std::string&) { SubParam_descriptor_, &SubParam::default_instance()); ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( MulParam_descriptor_, &MulParam::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + ConvParam_descriptor_, &ConvParam::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + PoolParam_descriptor_, &PoolParam::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + LRNParam_descriptor_, &LRNParam::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + ReLUParam_descriptor_, &ReLUParam::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + SigmoidParam_descriptor_, &SigmoidParam::default_instance()); } } // namespace @@ -156,6 +280,16 @@ void protobuf_ShutdownFile_GET_2eproto() { delete SubParam_reflection_; delete MulParam::default_instance_; delete MulParam_reflection_; + delete ConvParam::default_instance_; + delete ConvParam_reflection_; + delete PoolParam::default_instance_; + delete PoolParam_reflection_; + delete LRNParam::default_instance_; + delete LRNParam_reflection_; + delete ReLUParam::default_instance_; + delete ReLUParam_reflection_; + delete SigmoidParam::default_instance_; + delete SigmoidParam_reflection_; } void protobuf_AddDesc_GET_2eproto() { @@ -165,7 +299,7 @@ void protobuf_AddDesc_GET_2eproto() { GOOGLE_PROTOBUF_VERIFY_VERSION; ::google::protobuf::DescriptorPool::InternalAddGeneratedFile( - "\n\tGET.proto\022\003GET\"\276\003\n\tTaskParam\0228\n\nsource" + "\n\tGET.proto\022\003GET\"\226\005\n\tTaskParam\0228\n\nsource" "_pos\030\001 \001(\0162\033.GET.TaskParam.DataPosition:" "\007HOSTMEM\022\017\n\007sourcef\030\002 \003(\t\022\017\n\007sourcem\030\003 \003" "(\004\0228\n\nresult_pos\030\004 \001(\0162\033.GET.TaskParam.D" @@ -173,26 +307,54 @@ void protobuf_AddDesc_GET_2eproto() { "\007resultm\030\006 \003(\004\022%\n\004type\030\007 \001(\0162\027.GET.TaskP" "aram.TaskType\022 \n\tadd_param\030\010 \001(\0132\r.GET.A" "ddParam\022 \n\tsub_param\030\t \001(\0132\r.GET.SubPara" - "m\022 \n\tmul_param\030\n \001(\0132\r.GET.MulParam\"%\n\014D" - "ataPosition\022\010\n\004FILE\020\000\022\013\n\007HOSTMEM\020\001\"E\n\010Ta" - "skType\022\007\n\003ADD\020\000\022\007\n\003SUB\020\001\022\t\n\005MULTI\020\002\022\017\n\013C" - "ONVOLUTION\020\003\022\013\n\007POOLING\020\004\">\n\010AddParam\022\016\n" + "m\022 \n\tmul_param\030\n \001(\0132\r.GET.MulParam\022\"\n\nc" + "onv_param\030\013 \001(\0132\016.GET.ConvParam\022\"\n\npool_" + "param\030\014 \001(\0132\016.GET.PoolParam\022 \n\tlrn_param" + "\030\r \001(\0132\r.GET.LRNParam\022\"\n\nrelu_param\030\016 \001(" + "\0132\016.GET.ReLUParam\022(\n\rsigmoid_param\030\017 \001(\013" + "2\021.GET.SigmoidParam\"%\n\014DataPosition\022\010\n\004F" + "ILE\020\000\022\013\n\007HOSTMEM\020\001\"e\n\010TaskType\022\007\n\003ADD\020\000\022" + "\007\n\003SUB\020\001\022\t\n\005MULTI\020\002\022\017\n\013CONVOLUTION\020\003\022\013\n\007" + "POOLING\020\004\022\007\n\003LRN\020\005\022\010\n\004RELU\020\006\022\013\n\007Sigmoid\020" + "\007\">\n\010AddParam\022\016\n\006height\030\001 \001(\005\022\r\n\005width\030\002" + " \001(\005\022\023\n\010channels\030\003 \001(\005:\0011\">\n\010SubParam\022\016\n" "\006height\030\001 \001(\005\022\r\n\005width\030\002 \001(\005\022\023\n\010channels" - "\030\003 \001(\005:\0011\">\n\010SubParam\022\016\n\006height\030\001 \001(\005\022\r\n" - "\005width\030\002 \001(\005\022\023\n\010channels\030\003 \001(\005:\0011\"e\n\010Mul" - "Param\022\020\n\010height_A\030\001 \001(\005\022\017\n\007width_A\030\002 \001(\005" - "\022\020\n\010height_B\030\003 \001(\005\022\017\n\007width_B\030\004 \001(\005\022\023\n\010c" - "hannels\030\005 \001(\005:\0011", 696); + "\030\003 \001(\005:\0011\"@\n\010MulParam\022\t\n\001M\030\001 \001(\005\022\t\n\001K\030\002 " + "\001(\005\022\t\n\001N\030\003 \001(\005\022\023\n\010channels\030\004 \001(\005:\0011\"\262\001\n\t" + "ConvParam\022\016\n\006data_h\030\001 \001(\005\022\016\n\006data_w\030\002 \001(" + "\005\022\020\n\010filter_h\030\003 \001(\005\022\020\n\010filter_w\030\004 \001(\005\022\023\n" + "\010stride_h\030\005 \001(\005:\0011\022\023\n\010stride_w\030\006 \001(\005:\0011\022" + "\020\n\005pad_h\030\007 \001(\005:\0010\022\020\n\005pad_w\030\010 \001(\005:\0010\022\023\n\010c" + "hannels\030\t \001(\005:\0011\"\262\001\n\tPoolParam\022\016\n\006data_h" + "\030\001 \001(\005\022\016\n\006data_w\030\002 \001(\005\022\020\n\010kernel_h\030\003 \001(\005" + "\022\020\n\010kernel_w\030\004 \001(\005\022\023\n\010stride_h\030\005 \001(\005:\0011\022" + "\023\n\010stride_w\030\006 \001(\005:\0011\022\020\n\005pad_h\030\007 \001(\005:\0010\022\020" + "\n\005pad_w\030\010 \001(\005:\0010\022\023\n\010channels\030\t \001(\005:\0011\"\037\n" + "\010LRNParam\022\023\n\010channels\030\001 \001(\005:\0010\"B\n\tReLUPa" + "ram\022\021\n\006height\030\001 \001(\005:\0011\022\r\n\005width\030\002 \001(\005\022\023\n" + "\010channels\030\003 \001(\005:\0011\"E\n\014SigmoidParam\022\021\n\006he" + "ight\030\001 \001(\005:\0011\022\r\n\005width\030\002 \001(\005\022\023\n\010channels" + "\030\003 \001(\005:\0011", 1409); ::google::protobuf::MessageFactory::InternalRegisterGeneratedFile( "GET.proto", &protobuf_RegisterTypes); TaskParam::default_instance_ = new TaskParam(); AddParam::default_instance_ = new AddParam(); SubParam::default_instance_ = new SubParam(); MulParam::default_instance_ = new MulParam(); + ConvParam::default_instance_ = new ConvParam(); + PoolParam::default_instance_ = new PoolParam(); + LRNParam::default_instance_ = new LRNParam(); + ReLUParam::default_instance_ = new ReLUParam(); + SigmoidParam::default_instance_ = new SigmoidParam(); TaskParam::default_instance_->InitAsDefaultInstance(); AddParam::default_instance_->InitAsDefaultInstance(); SubParam::default_instance_->InitAsDefaultInstance(); MulParam::default_instance_->InitAsDefaultInstance(); + ConvParam::default_instance_->InitAsDefaultInstance(); + PoolParam::default_instance_->InitAsDefaultInstance(); + LRNParam::default_instance_->InitAsDefaultInstance(); + ReLUParam::default_instance_->InitAsDefaultInstance(); + SigmoidParam::default_instance_->InitAsDefaultInstance(); ::google::protobuf::internal::OnShutdown(&protobuf_ShutdownFile_GET_2eproto); } @@ -237,6 +399,9 @@ bool TaskParam_TaskType_IsValid(int value) { case 2: case 3: case 4: + case 5: + case 6: + case 7: return true; default: return false; @@ -249,6 +414,9 @@ const TaskParam_TaskType TaskParam::SUB; const TaskParam_TaskType TaskParam::MULTI; const TaskParam_TaskType TaskParam::CONVOLUTION; const TaskParam_TaskType TaskParam::POOLING; +const TaskParam_TaskType TaskParam::LRN; +const TaskParam_TaskType TaskParam::RELU; +const TaskParam_TaskType TaskParam::Sigmoid; const TaskParam_TaskType TaskParam::TaskType_MIN; const TaskParam_TaskType TaskParam::TaskType_MAX; const int TaskParam::TaskType_ARRAYSIZE; @@ -264,6 +432,11 @@ const int TaskParam::kTypeFieldNumber; const int TaskParam::kAddParamFieldNumber; const int TaskParam::kSubParamFieldNumber; const int TaskParam::kMulParamFieldNumber; +const int TaskParam::kConvParamFieldNumber; +const int TaskParam::kPoolParamFieldNumber; +const int TaskParam::kLrnParamFieldNumber; +const int TaskParam::kReluParamFieldNumber; +const int TaskParam::kSigmoidParamFieldNumber; #endif // !_MSC_VER TaskParam::TaskParam() @@ -275,6 +448,11 @@ void TaskParam::InitAsDefaultInstance() { add_param_ = const_cast< ::GET::AddParam*>(&::GET::AddParam::default_instance()); sub_param_ = const_cast< ::GET::SubParam*>(&::GET::SubParam::default_instance()); mul_param_ = const_cast< ::GET::MulParam*>(&::GET::MulParam::default_instance()); + conv_param_ = const_cast< ::GET::ConvParam*>(&::GET::ConvParam::default_instance()); + pool_param_ = const_cast< ::GET::PoolParam*>(&::GET::PoolParam::default_instance()); + lrn_param_ = const_cast< ::GET::LRNParam*>(&::GET::LRNParam::default_instance()); + relu_param_ = const_cast< ::GET::ReLUParam*>(&::GET::ReLUParam::default_instance()); + sigmoid_param_ = const_cast< ::GET::SigmoidParam*>(&::GET::SigmoidParam::default_instance()); } TaskParam::TaskParam(const TaskParam& from) @@ -291,6 +469,11 @@ void TaskParam::SharedCtor() { add_param_ = NULL; sub_param_ = NULL; mul_param_ = NULL; + conv_param_ = NULL; + pool_param_ = NULL; + lrn_param_ = NULL; + relu_param_ = NULL; + sigmoid_param_ = NULL; ::memset(_has_bits_, 0, sizeof(_has_bits_)); } @@ -303,6 +486,11 @@ void TaskParam::SharedDtor() { delete add_param_; delete sub_param_; delete mul_param_; + delete conv_param_; + delete pool_param_; + delete lrn_param_; + delete relu_param_; + delete sigmoid_param_; } } @@ -343,6 +531,21 @@ void TaskParam::Clear() { if (has_mul_param()) { if (mul_param_ != NULL) mul_param_->::GET::MulParam::Clear(); } + if (has_conv_param()) { + if (conv_param_ != NULL) conv_param_->::GET::ConvParam::Clear(); + } + if (has_pool_param()) { + if (pool_param_ != NULL) pool_param_->::GET::PoolParam::Clear(); + } + if (has_lrn_param()) { + if (lrn_param_ != NULL) lrn_param_->::GET::LRNParam::Clear(); + } + if (has_relu_param()) { + if (relu_param_ != NULL) relu_param_->::GET::ReLUParam::Clear(); + } + if (has_sigmoid_param()) { + if (sigmoid_param_ != NULL) sigmoid_param_->::GET::SigmoidParam::Clear(); + } } sourcef_.Clear(); sourcem_.Clear(); @@ -540,6 +743,76 @@ bool TaskParam::MergePartialFromCodedStream( } else { goto handle_uninterpreted; } + if (input->ExpectTag(90)) goto parse_conv_param; + break; + } + + // optional .GET.ConvParam conv_param = 11; + case 11: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + parse_conv_param: + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_conv_param())); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(98)) goto parse_pool_param; + break; + } + + // optional .GET.PoolParam pool_param = 12; + case 12: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + parse_pool_param: + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_pool_param())); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(106)) goto parse_lrn_param; + break; + } + + // optional .GET.LRNParam lrn_param = 13; + case 13: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + parse_lrn_param: + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_lrn_param())); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(114)) goto parse_relu_param; + break; + } + + // optional .GET.ReLUParam relu_param = 14; + case 14: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + parse_relu_param: + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_relu_param())); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(122)) goto parse_sigmoid_param; + break; + } + + // optional .GET.SigmoidParam sigmoid_param = 15; + case 15: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + parse_sigmoid_param: + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_sigmoid_param())); + } else { + goto handle_uninterpreted; + } if (input->ExpectAtEnd()) return true; break; } @@ -628,6 +901,36 @@ void TaskParam::SerializeWithCachedSizes( 10, this->mul_param(), output); } + // optional .GET.ConvParam conv_param = 11; + if (has_conv_param()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 11, this->conv_param(), output); + } + + // optional .GET.PoolParam pool_param = 12; + if (has_pool_param()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 12, this->pool_param(), output); + } + + // optional .GET.LRNParam lrn_param = 13; + if (has_lrn_param()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 13, this->lrn_param(), output); + } + + // optional .GET.ReLUParam relu_param = 14; + if (has_relu_param()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 14, this->relu_param(), output); + } + + // optional .GET.SigmoidParam sigmoid_param = 15; + if (has_sigmoid_param()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 15, this->sigmoid_param(), output); + } + if (!unknown_fields().empty()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( unknown_fields(), output); @@ -705,6 +1008,41 @@ ::google::protobuf::uint8* TaskParam::SerializeWithCachedSizesToArray( 10, this->mul_param(), target); } + // optional .GET.ConvParam conv_param = 11; + if (has_conv_param()) { + target = ::google::protobuf::internal::WireFormatLite:: + WriteMessageNoVirtualToArray( + 11, this->conv_param(), target); + } + + // optional .GET.PoolParam pool_param = 12; + if (has_pool_param()) { + target = ::google::protobuf::internal::WireFormatLite:: + WriteMessageNoVirtualToArray( + 12, this->pool_param(), target); + } + + // optional .GET.LRNParam lrn_param = 13; + if (has_lrn_param()) { + target = ::google::protobuf::internal::WireFormatLite:: + WriteMessageNoVirtualToArray( + 13, this->lrn_param(), target); + } + + // optional .GET.ReLUParam relu_param = 14; + if (has_relu_param()) { + target = ::google::protobuf::internal::WireFormatLite:: + WriteMessageNoVirtualToArray( + 14, this->relu_param(), target); + } + + // optional .GET.SigmoidParam sigmoid_param = 15; + if (has_sigmoid_param()) { + target = ::google::protobuf::internal::WireFormatLite:: + WriteMessageNoVirtualToArray( + 15, this->sigmoid_param(), target); + } + if (!unknown_fields().empty()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( unknown_fields(), target); @@ -757,6 +1095,41 @@ int TaskParam::ByteSize() const { this->mul_param()); } + // optional .GET.ConvParam conv_param = 11; + if (has_conv_param()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->conv_param()); + } + + // optional .GET.PoolParam pool_param = 12; + if (has_pool_param()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->pool_param()); + } + + // optional .GET.LRNParam lrn_param = 13; + if (has_lrn_param()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->lrn_param()); + } + + // optional .GET.ReLUParam relu_param = 14; + if (has_relu_param()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->relu_param()); + } + + // optional .GET.SigmoidParam sigmoid_param = 15; + if (has_sigmoid_param()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->sigmoid_param()); + } + } // repeated string sourcef = 2; total_size += 1 * this->sourcef_size(); @@ -842,6 +1215,21 @@ void TaskParam::MergeFrom(const TaskParam& from) { if (from.has_mul_param()) { mutable_mul_param()->::GET::MulParam::MergeFrom(from.mul_param()); } + if (from.has_conv_param()) { + mutable_conv_param()->::GET::ConvParam::MergeFrom(from.conv_param()); + } + if (from.has_pool_param()) { + mutable_pool_param()->::GET::PoolParam::MergeFrom(from.pool_param()); + } + if (from.has_lrn_param()) { + mutable_lrn_param()->::GET::LRNParam::MergeFrom(from.lrn_param()); + } + if (from.has_relu_param()) { + mutable_relu_param()->::GET::ReLUParam::MergeFrom(from.relu_param()); + } + if (from.has_sigmoid_param()) { + mutable_sigmoid_param()->::GET::SigmoidParam::MergeFrom(from.sigmoid_param()); + } } mutable_unknown_fields()->MergeFrom(from.unknown_fields()); } @@ -875,6 +1263,11 @@ void TaskParam::Swap(TaskParam* other) { std::swap(add_param_, other->add_param_); std::swap(sub_param_, other->sub_param_); std::swap(mul_param_, other->mul_param_); + std::swap(conv_param_, other->conv_param_); + std::swap(pool_param_, other->pool_param_); + std::swap(lrn_param_, other->lrn_param_); + std::swap(relu_param_, other->relu_param_); + std::swap(sigmoid_param_, other->sigmoid_param_); std::swap(_has_bits_[0], other->_has_bits_[0]); _unknown_fields_.Swap(&other->_unknown_fields_); std::swap(_cached_size_, other->_cached_size_); @@ -1469,10 +1862,9 @@ ::google::protobuf::Metadata SubParam::GetMetadata() const { // =================================================================== #ifndef _MSC_VER -const int MulParam::kHeightAFieldNumber; -const int MulParam::kWidthAFieldNumber; -const int MulParam::kHeightBFieldNumber; -const int MulParam::kWidthBFieldNumber; +const int MulParam::kMFieldNumber; +const int MulParam::kKFieldNumber; +const int MulParam::kNFieldNumber; const int MulParam::kChannelsFieldNumber; #endif // !_MSC_VER @@ -1492,10 +1884,9 @@ MulParam::MulParam(const MulParam& from) void MulParam::SharedCtor() { _cached_size_ = 0; - height_a_ = 0; - width_a_ = 0; - height_b_ = 0; - width_b_ = 0; + m_ = 0; + k_ = 0; + n_ = 0; channels_ = 1; ::memset(_has_bits_, 0, sizeof(_has_bits_)); } @@ -1532,10 +1923,9 @@ MulParam* MulParam::New() const { void MulParam::Clear() { if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { - height_a_ = 0; - width_a_ = 0; - height_b_ = 0; - width_b_ = 0; + m_ = 0; + k_ = 0; + n_ = 0; channels_ = 1; } ::memset(_has_bits_, 0, sizeof(_has_bits_)); @@ -1548,71 +1938,55 @@ bool MulParam::MergePartialFromCodedStream( ::google::protobuf::uint32 tag; while ((tag = input->ReadTag()) != 0) { switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { - // optional int32 height_A = 1; + // optional int32 M = 1; case 1: { if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &height_a_))); - set_has_height_a(); + input, &m_))); + set_has_m(); } else { goto handle_uninterpreted; } - if (input->ExpectTag(16)) goto parse_width_A; + if (input->ExpectTag(16)) goto parse_K; break; } - // optional int32 width_A = 2; + // optional int32 K = 2; case 2: { if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { - parse_width_A: + parse_K: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &width_a_))); - set_has_width_a(); + input, &k_))); + set_has_k(); } else { goto handle_uninterpreted; } - if (input->ExpectTag(24)) goto parse_height_B; + if (input->ExpectTag(24)) goto parse_N; break; } - // optional int32 height_B = 3; + // optional int32 N = 3; case 3: { if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { - parse_height_B: + parse_N: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &height_b_))); - set_has_height_b(); + input, &n_))); + set_has_n(); } else { goto handle_uninterpreted; } - if (input->ExpectTag(32)) goto parse_width_B; + if (input->ExpectTag(32)) goto parse_channels; break; } - // optional int32 width_B = 4; + // optional int32 channels = 4 [default = 1]; case 4: { - if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == - ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { - parse_width_B: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &width_b_))); - set_has_width_b(); - } else { - goto handle_uninterpreted; - } - if (input->ExpectTag(40)) goto parse_channels; - break; - } - - // optional int32 channels = 5 [default = 1]; - case 5: { if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { parse_channels: @@ -1645,29 +2019,24 @@ bool MulParam::MergePartialFromCodedStream( void MulParam::SerializeWithCachedSizes( ::google::protobuf::io::CodedOutputStream* output) const { - // optional int32 height_A = 1; - if (has_height_a()) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->height_a(), output); + // optional int32 M = 1; + if (has_m()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->m(), output); } - // optional int32 width_A = 2; - if (has_width_a()) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->width_a(), output); + // optional int32 K = 2; + if (has_k()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->k(), output); } - // optional int32 height_B = 3; - if (has_height_b()) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->height_b(), output); + // optional int32 N = 3; + if (has_n()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->n(), output); } - // optional int32 width_B = 4; - if (has_width_b()) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->width_b(), output); - } - - // optional int32 channels = 5 [default = 1]; + // optional int32 channels = 4 [default = 1]; if (has_channels()) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->channels(), output); + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->channels(), output); } if (!unknown_fields().empty()) { @@ -1678,29 +2047,24 @@ void MulParam::SerializeWithCachedSizes( ::google::protobuf::uint8* MulParam::SerializeWithCachedSizesToArray( ::google::protobuf::uint8* target) const { - // optional int32 height_A = 1; - if (has_height_a()) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->height_a(), target); - } - - // optional int32 width_A = 2; - if (has_width_a()) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->width_a(), target); + // optional int32 M = 1; + if (has_m()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->m(), target); } - // optional int32 height_B = 3; - if (has_height_b()) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->height_b(), target); + // optional int32 K = 2; + if (has_k()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->k(), target); } - // optional int32 width_B = 4; - if (has_width_b()) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->width_b(), target); + // optional int32 N = 3; + if (has_n()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->n(), target); } - // optional int32 channels = 5 [default = 1]; + // optional int32 channels = 4 [default = 1]; if (has_channels()) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->channels(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->channels(), target); } if (!unknown_fields().empty()) { @@ -1714,35 +2078,28 @@ int MulParam::ByteSize() const { int total_size = 0; if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { - // optional int32 height_A = 1; - if (has_height_a()) { + // optional int32 M = 1; + if (has_m()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->height_a()); + this->m()); } - // optional int32 width_A = 2; - if (has_width_a()) { + // optional int32 K = 2; + if (has_k()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->width_a()); + this->k()); } - // optional int32 height_B = 3; - if (has_height_b()) { + // optional int32 N = 3; + if (has_n()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->height_b()); + this->n()); } - // optional int32 width_B = 4; - if (has_width_b()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::Int32Size( - this->width_b()); - } - - // optional int32 channels = 5 [default = 1]; + // optional int32 channels = 4 [default = 1]; if (has_channels()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( @@ -1776,17 +2133,14 @@ void MulParam::MergeFrom(const ::google::protobuf::Message& from) { void MulParam::MergeFrom(const MulParam& from) { GOOGLE_CHECK_NE(&from, this); if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { - if (from.has_height_a()) { - set_height_a(from.height_a()); + if (from.has_m()) { + set_m(from.m()); } - if (from.has_width_a()) { - set_width_a(from.width_a()); + if (from.has_k()) { + set_k(from.k()); } - if (from.has_height_b()) { - set_height_b(from.height_b()); - } - if (from.has_width_b()) { - set_width_b(from.width_b()); + if (from.has_n()) { + set_n(from.n()); } if (from.has_channels()) { set_channels(from.channels()); @@ -1814,10 +2168,9 @@ bool MulParam::IsInitialized() const { void MulParam::Swap(MulParam* other) { if (other != this) { - std::swap(height_a_, other->height_a_); - std::swap(width_a_, other->width_a_); - std::swap(height_b_, other->height_b_); - std::swap(width_b_, other->width_b_); + std::swap(m_, other->m_); + std::swap(k_, other->k_); + std::swap(n_, other->n_); std::swap(channels_, other->channels_); std::swap(_has_bits_[0], other->_has_bits_[0]); _unknown_fields_.Swap(&other->_unknown_fields_); @@ -1834,6 +2187,1858 @@ ::google::protobuf::Metadata MulParam::GetMetadata() const { } +// =================================================================== + +#ifndef _MSC_VER +const int ConvParam::kDataHFieldNumber; +const int ConvParam::kDataWFieldNumber; +const int ConvParam::kFilterHFieldNumber; +const int ConvParam::kFilterWFieldNumber; +const int ConvParam::kStrideHFieldNumber; +const int ConvParam::kStrideWFieldNumber; +const int ConvParam::kPadHFieldNumber; +const int ConvParam::kPadWFieldNumber; +const int ConvParam::kChannelsFieldNumber; +#endif // !_MSC_VER + +ConvParam::ConvParam() + : ::google::protobuf::Message() { + SharedCtor(); +} + +void ConvParam::InitAsDefaultInstance() { +} + +ConvParam::ConvParam(const ConvParam& from) + : ::google::protobuf::Message() { + SharedCtor(); + MergeFrom(from); +} + +void ConvParam::SharedCtor() { + _cached_size_ = 0; + data_h_ = 0; + data_w_ = 0; + filter_h_ = 0; + filter_w_ = 0; + stride_h_ = 1; + stride_w_ = 1; + pad_h_ = 0; + pad_w_ = 0; + channels_ = 1; + ::memset(_has_bits_, 0, sizeof(_has_bits_)); +} + +ConvParam::~ConvParam() { + SharedDtor(); +} + +void ConvParam::SharedDtor() { + if (this != default_instance_) { + } +} + +void ConvParam::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* ConvParam::descriptor() { + protobuf_AssignDescriptorsOnce(); + return ConvParam_descriptor_; +} + +const ConvParam& ConvParam::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_GET_2eproto(); + return *default_instance_; +} + +ConvParam* ConvParam::default_instance_ = NULL; + +ConvParam* ConvParam::New() const { + return new ConvParam; +} + +void ConvParam::Clear() { + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + data_h_ = 0; + data_w_ = 0; + filter_h_ = 0; + filter_w_ = 0; + stride_h_ = 1; + stride_w_ = 1; + pad_h_ = 0; + pad_w_ = 0; + } + if (_has_bits_[8 / 32] & (0xffu << (8 % 32))) { + channels_ = 1; + } + ::memset(_has_bits_, 0, sizeof(_has_bits_)); + mutable_unknown_fields()->Clear(); +} + +bool ConvParam::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) return false + ::google::protobuf::uint32 tag; + while ((tag = input->ReadTag()) != 0) { + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional int32 data_h = 1; + case 1: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &data_h_))); + set_has_data_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(16)) goto parse_data_w; + break; + } + + // optional int32 data_w = 2; + case 2: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_data_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &data_w_))); + set_has_data_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(24)) goto parse_filter_h; + break; + } + + // optional int32 filter_h = 3; + case 3: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_filter_h: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &filter_h_))); + set_has_filter_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(32)) goto parse_filter_w; + break; + } + + // optional int32 filter_w = 4; + case 4: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_filter_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &filter_w_))); + set_has_filter_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(40)) goto parse_stride_h; + break; + } + + // optional int32 stride_h = 5 [default = 1]; + case 5: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_stride_h: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &stride_h_))); + set_has_stride_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(48)) goto parse_stride_w; + break; + } + + // optional int32 stride_w = 6 [default = 1]; + case 6: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_stride_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &stride_w_))); + set_has_stride_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(56)) goto parse_pad_h; + break; + } + + // optional int32 pad_h = 7 [default = 0]; + case 7: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_pad_h: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &pad_h_))); + set_has_pad_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(64)) goto parse_pad_w; + break; + } + + // optional int32 pad_w = 8 [default = 0]; + case 8: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_pad_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &pad_w_))); + set_has_pad_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(72)) goto parse_channels; + break; + } + + // optional int32 channels = 9 [default = 1]; + case 9: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_channels: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &channels_))); + set_has_channels(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectAtEnd()) return true; + break; + } + + default: { + handle_uninterpreted: + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + return true; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, mutable_unknown_fields())); + break; + } + } + } + return true; +#undef DO_ +} + +void ConvParam::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // optional int32 data_h = 1; + if (has_data_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->data_h(), output); + } + + // optional int32 data_w = 2; + if (has_data_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->data_w(), output); + } + + // optional int32 filter_h = 3; + if (has_filter_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->filter_h(), output); + } + + // optional int32 filter_w = 4; + if (has_filter_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->filter_w(), output); + } + + // optional int32 stride_h = 5 [default = 1]; + if (has_stride_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->stride_h(), output); + } + + // optional int32 stride_w = 6 [default = 1]; + if (has_stride_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(6, this->stride_w(), output); + } + + // optional int32 pad_h = 7 [default = 0]; + if (has_pad_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(7, this->pad_h(), output); + } + + // optional int32 pad_w = 8 [default = 0]; + if (has_pad_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(8, this->pad_w(), output); + } + + // optional int32 channels = 9 [default = 1]; + if (has_channels()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(9, this->channels(), output); + } + + if (!unknown_fields().empty()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + unknown_fields(), output); + } +} + +::google::protobuf::uint8* ConvParam::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // optional int32 data_h = 1; + if (has_data_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->data_h(), target); + } + + // optional int32 data_w = 2; + if (has_data_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->data_w(), target); + } + + // optional int32 filter_h = 3; + if (has_filter_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->filter_h(), target); + } + + // optional int32 filter_w = 4; + if (has_filter_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->filter_w(), target); + } + + // optional int32 stride_h = 5 [default = 1]; + if (has_stride_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->stride_h(), target); + } + + // optional int32 stride_w = 6 [default = 1]; + if (has_stride_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(6, this->stride_w(), target); + } + + // optional int32 pad_h = 7 [default = 0]; + if (has_pad_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(7, this->pad_h(), target); + } + + // optional int32 pad_w = 8 [default = 0]; + if (has_pad_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(8, this->pad_w(), target); + } + + // optional int32 channels = 9 [default = 1]; + if (has_channels()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(9, this->channels(), target); + } + + if (!unknown_fields().empty()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + unknown_fields(), target); + } + return target; +} + +int ConvParam::ByteSize() const { + int total_size = 0; + + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + // optional int32 data_h = 1; + if (has_data_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->data_h()); + } + + // optional int32 data_w = 2; + if (has_data_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->data_w()); + } + + // optional int32 filter_h = 3; + if (has_filter_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->filter_h()); + } + + // optional int32 filter_w = 4; + if (has_filter_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->filter_w()); + } + + // optional int32 stride_h = 5 [default = 1]; + if (has_stride_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->stride_h()); + } + + // optional int32 stride_w = 6 [default = 1]; + if (has_stride_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->stride_w()); + } + + // optional int32 pad_h = 7 [default = 0]; + if (has_pad_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->pad_h()); + } + + // optional int32 pad_w = 8 [default = 0]; + if (has_pad_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->pad_w()); + } + + } + if (_has_bits_[8 / 32] & (0xffu << (8 % 32))) { + // optional int32 channels = 9 [default = 1]; + if (has_channels()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->channels()); + } + + } + if (!unknown_fields().empty()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + unknown_fields()); + } + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void ConvParam::MergeFrom(const ::google::protobuf::Message& from) { + GOOGLE_CHECK_NE(&from, this); + const ConvParam* source = + ::google::protobuf::internal::dynamic_cast_if_available( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void ConvParam::MergeFrom(const ConvParam& from) { + GOOGLE_CHECK_NE(&from, this); + if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { + if (from.has_data_h()) { + set_data_h(from.data_h()); + } + if (from.has_data_w()) { + set_data_w(from.data_w()); + } + if (from.has_filter_h()) { + set_filter_h(from.filter_h()); + } + if (from.has_filter_w()) { + set_filter_w(from.filter_w()); + } + if (from.has_stride_h()) { + set_stride_h(from.stride_h()); + } + if (from.has_stride_w()) { + set_stride_w(from.stride_w()); + } + if (from.has_pad_h()) { + set_pad_h(from.pad_h()); + } + if (from.has_pad_w()) { + set_pad_w(from.pad_w()); + } + } + if (from._has_bits_[8 / 32] & (0xffu << (8 % 32))) { + if (from.has_channels()) { + set_channels(from.channels()); + } + } + mutable_unknown_fields()->MergeFrom(from.unknown_fields()); +} + +void ConvParam::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ConvParam::CopyFrom(const ConvParam& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ConvParam::IsInitialized() const { + + return true; +} + +void ConvParam::Swap(ConvParam* other) { + if (other != this) { + std::swap(data_h_, other->data_h_); + std::swap(data_w_, other->data_w_); + std::swap(filter_h_, other->filter_h_); + std::swap(filter_w_, other->filter_w_); + std::swap(stride_h_, other->stride_h_); + std::swap(stride_w_, other->stride_w_); + std::swap(pad_h_, other->pad_h_); + std::swap(pad_w_, other->pad_w_); + std::swap(channels_, other->channels_); + std::swap(_has_bits_[0], other->_has_bits_[0]); + _unknown_fields_.Swap(&other->_unknown_fields_); + std::swap(_cached_size_, other->_cached_size_); + } +} + +::google::protobuf::Metadata ConvParam::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = ConvParam_descriptor_; + metadata.reflection = ConvParam_reflection_; + return metadata; +} + + +// =================================================================== + +#ifndef _MSC_VER +const int PoolParam::kDataHFieldNumber; +const int PoolParam::kDataWFieldNumber; +const int PoolParam::kKernelHFieldNumber; +const int PoolParam::kKernelWFieldNumber; +const int PoolParam::kStrideHFieldNumber; +const int PoolParam::kStrideWFieldNumber; +const int PoolParam::kPadHFieldNumber; +const int PoolParam::kPadWFieldNumber; +const int PoolParam::kChannelsFieldNumber; +#endif // !_MSC_VER + +PoolParam::PoolParam() + : ::google::protobuf::Message() { + SharedCtor(); +} + +void PoolParam::InitAsDefaultInstance() { +} + +PoolParam::PoolParam(const PoolParam& from) + : ::google::protobuf::Message() { + SharedCtor(); + MergeFrom(from); +} + +void PoolParam::SharedCtor() { + _cached_size_ = 0; + data_h_ = 0; + data_w_ = 0; + kernel_h_ = 0; + kernel_w_ = 0; + stride_h_ = 1; + stride_w_ = 1; + pad_h_ = 0; + pad_w_ = 0; + channels_ = 1; + ::memset(_has_bits_, 0, sizeof(_has_bits_)); +} + +PoolParam::~PoolParam() { + SharedDtor(); +} + +void PoolParam::SharedDtor() { + if (this != default_instance_) { + } +} + +void PoolParam::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* PoolParam::descriptor() { + protobuf_AssignDescriptorsOnce(); + return PoolParam_descriptor_; +} + +const PoolParam& PoolParam::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_GET_2eproto(); + return *default_instance_; +} + +PoolParam* PoolParam::default_instance_ = NULL; + +PoolParam* PoolParam::New() const { + return new PoolParam; +} + +void PoolParam::Clear() { + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + data_h_ = 0; + data_w_ = 0; + kernel_h_ = 0; + kernel_w_ = 0; + stride_h_ = 1; + stride_w_ = 1; + pad_h_ = 0; + pad_w_ = 0; + } + if (_has_bits_[8 / 32] & (0xffu << (8 % 32))) { + channels_ = 1; + } + ::memset(_has_bits_, 0, sizeof(_has_bits_)); + mutable_unknown_fields()->Clear(); +} + +bool PoolParam::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) return false + ::google::protobuf::uint32 tag; + while ((tag = input->ReadTag()) != 0) { + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional int32 data_h = 1; + case 1: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &data_h_))); + set_has_data_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(16)) goto parse_data_w; + break; + } + + // optional int32 data_w = 2; + case 2: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_data_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &data_w_))); + set_has_data_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(24)) goto parse_kernel_h; + break; + } + + // optional int32 kernel_h = 3; + case 3: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_kernel_h: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &kernel_h_))); + set_has_kernel_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(32)) goto parse_kernel_w; + break; + } + + // optional int32 kernel_w = 4; + case 4: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_kernel_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &kernel_w_))); + set_has_kernel_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(40)) goto parse_stride_h; + break; + } + + // optional int32 stride_h = 5 [default = 1]; + case 5: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_stride_h: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &stride_h_))); + set_has_stride_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(48)) goto parse_stride_w; + break; + } + + // optional int32 stride_w = 6 [default = 1]; + case 6: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_stride_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &stride_w_))); + set_has_stride_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(56)) goto parse_pad_h; + break; + } + + // optional int32 pad_h = 7 [default = 0]; + case 7: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_pad_h: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &pad_h_))); + set_has_pad_h(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(64)) goto parse_pad_w; + break; + } + + // optional int32 pad_w = 8 [default = 0]; + case 8: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_pad_w: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &pad_w_))); + set_has_pad_w(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(72)) goto parse_channels; + break; + } + + // optional int32 channels = 9 [default = 1]; + case 9: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_channels: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &channels_))); + set_has_channels(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectAtEnd()) return true; + break; + } + + default: { + handle_uninterpreted: + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + return true; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, mutable_unknown_fields())); + break; + } + } + } + return true; +#undef DO_ +} + +void PoolParam::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // optional int32 data_h = 1; + if (has_data_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->data_h(), output); + } + + // optional int32 data_w = 2; + if (has_data_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->data_w(), output); + } + + // optional int32 kernel_h = 3; + if (has_kernel_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->kernel_h(), output); + } + + // optional int32 kernel_w = 4; + if (has_kernel_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->kernel_w(), output); + } + + // optional int32 stride_h = 5 [default = 1]; + if (has_stride_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->stride_h(), output); + } + + // optional int32 stride_w = 6 [default = 1]; + if (has_stride_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(6, this->stride_w(), output); + } + + // optional int32 pad_h = 7 [default = 0]; + if (has_pad_h()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(7, this->pad_h(), output); + } + + // optional int32 pad_w = 8 [default = 0]; + if (has_pad_w()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(8, this->pad_w(), output); + } + + // optional int32 channels = 9 [default = 1]; + if (has_channels()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(9, this->channels(), output); + } + + if (!unknown_fields().empty()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + unknown_fields(), output); + } +} + +::google::protobuf::uint8* PoolParam::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // optional int32 data_h = 1; + if (has_data_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->data_h(), target); + } + + // optional int32 data_w = 2; + if (has_data_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->data_w(), target); + } + + // optional int32 kernel_h = 3; + if (has_kernel_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->kernel_h(), target); + } + + // optional int32 kernel_w = 4; + if (has_kernel_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->kernel_w(), target); + } + + // optional int32 stride_h = 5 [default = 1]; + if (has_stride_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->stride_h(), target); + } + + // optional int32 stride_w = 6 [default = 1]; + if (has_stride_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(6, this->stride_w(), target); + } + + // optional int32 pad_h = 7 [default = 0]; + if (has_pad_h()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(7, this->pad_h(), target); + } + + // optional int32 pad_w = 8 [default = 0]; + if (has_pad_w()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(8, this->pad_w(), target); + } + + // optional int32 channels = 9 [default = 1]; + if (has_channels()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(9, this->channels(), target); + } + + if (!unknown_fields().empty()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + unknown_fields(), target); + } + return target; +} + +int PoolParam::ByteSize() const { + int total_size = 0; + + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + // optional int32 data_h = 1; + if (has_data_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->data_h()); + } + + // optional int32 data_w = 2; + if (has_data_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->data_w()); + } + + // optional int32 kernel_h = 3; + if (has_kernel_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->kernel_h()); + } + + // optional int32 kernel_w = 4; + if (has_kernel_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->kernel_w()); + } + + // optional int32 stride_h = 5 [default = 1]; + if (has_stride_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->stride_h()); + } + + // optional int32 stride_w = 6 [default = 1]; + if (has_stride_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->stride_w()); + } + + // optional int32 pad_h = 7 [default = 0]; + if (has_pad_h()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->pad_h()); + } + + // optional int32 pad_w = 8 [default = 0]; + if (has_pad_w()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->pad_w()); + } + + } + if (_has_bits_[8 / 32] & (0xffu << (8 % 32))) { + // optional int32 channels = 9 [default = 1]; + if (has_channels()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->channels()); + } + + } + if (!unknown_fields().empty()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + unknown_fields()); + } + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void PoolParam::MergeFrom(const ::google::protobuf::Message& from) { + GOOGLE_CHECK_NE(&from, this); + const PoolParam* source = + ::google::protobuf::internal::dynamic_cast_if_available( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void PoolParam::MergeFrom(const PoolParam& from) { + GOOGLE_CHECK_NE(&from, this); + if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { + if (from.has_data_h()) { + set_data_h(from.data_h()); + } + if (from.has_data_w()) { + set_data_w(from.data_w()); + } + if (from.has_kernel_h()) { + set_kernel_h(from.kernel_h()); + } + if (from.has_kernel_w()) { + set_kernel_w(from.kernel_w()); + } + if (from.has_stride_h()) { + set_stride_h(from.stride_h()); + } + if (from.has_stride_w()) { + set_stride_w(from.stride_w()); + } + if (from.has_pad_h()) { + set_pad_h(from.pad_h()); + } + if (from.has_pad_w()) { + set_pad_w(from.pad_w()); + } + } + if (from._has_bits_[8 / 32] & (0xffu << (8 % 32))) { + if (from.has_channels()) { + set_channels(from.channels()); + } + } + mutable_unknown_fields()->MergeFrom(from.unknown_fields()); +} + +void PoolParam::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void PoolParam::CopyFrom(const PoolParam& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool PoolParam::IsInitialized() const { + + return true; +} + +void PoolParam::Swap(PoolParam* other) { + if (other != this) { + std::swap(data_h_, other->data_h_); + std::swap(data_w_, other->data_w_); + std::swap(kernel_h_, other->kernel_h_); + std::swap(kernel_w_, other->kernel_w_); + std::swap(stride_h_, other->stride_h_); + std::swap(stride_w_, other->stride_w_); + std::swap(pad_h_, other->pad_h_); + std::swap(pad_w_, other->pad_w_); + std::swap(channels_, other->channels_); + std::swap(_has_bits_[0], other->_has_bits_[0]); + _unknown_fields_.Swap(&other->_unknown_fields_); + std::swap(_cached_size_, other->_cached_size_); + } +} + +::google::protobuf::Metadata PoolParam::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = PoolParam_descriptor_; + metadata.reflection = PoolParam_reflection_; + return metadata; +} + + +// =================================================================== + +#ifndef _MSC_VER +const int LRNParam::kChannelsFieldNumber; +#endif // !_MSC_VER + +LRNParam::LRNParam() + : ::google::protobuf::Message() { + SharedCtor(); +} + +void LRNParam::InitAsDefaultInstance() { +} + +LRNParam::LRNParam(const LRNParam& from) + : ::google::protobuf::Message() { + SharedCtor(); + MergeFrom(from); +} + +void LRNParam::SharedCtor() { + _cached_size_ = 0; + channels_ = 0; + ::memset(_has_bits_, 0, sizeof(_has_bits_)); +} + +LRNParam::~LRNParam() { + SharedDtor(); +} + +void LRNParam::SharedDtor() { + if (this != default_instance_) { + } +} + +void LRNParam::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* LRNParam::descriptor() { + protobuf_AssignDescriptorsOnce(); + return LRNParam_descriptor_; +} + +const LRNParam& LRNParam::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_GET_2eproto(); + return *default_instance_; +} + +LRNParam* LRNParam::default_instance_ = NULL; + +LRNParam* LRNParam::New() const { + return new LRNParam; +} + +void LRNParam::Clear() { + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + channels_ = 0; + } + ::memset(_has_bits_, 0, sizeof(_has_bits_)); + mutable_unknown_fields()->Clear(); +} + +bool LRNParam::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) return false + ::google::protobuf::uint32 tag; + while ((tag = input->ReadTag()) != 0) { + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional int32 channels = 1 [default = 0]; + case 1: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &channels_))); + set_has_channels(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectAtEnd()) return true; + break; + } + + default: { + handle_uninterpreted: + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + return true; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, mutable_unknown_fields())); + break; + } + } + } + return true; +#undef DO_ +} + +void LRNParam::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // optional int32 channels = 1 [default = 0]; + if (has_channels()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->channels(), output); + } + + if (!unknown_fields().empty()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + unknown_fields(), output); + } +} + +::google::protobuf::uint8* LRNParam::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // optional int32 channels = 1 [default = 0]; + if (has_channels()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->channels(), target); + } + + if (!unknown_fields().empty()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + unknown_fields(), target); + } + return target; +} + +int LRNParam::ByteSize() const { + int total_size = 0; + + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + // optional int32 channels = 1 [default = 0]; + if (has_channels()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->channels()); + } + + } + if (!unknown_fields().empty()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + unknown_fields()); + } + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void LRNParam::MergeFrom(const ::google::protobuf::Message& from) { + GOOGLE_CHECK_NE(&from, this); + const LRNParam* source = + ::google::protobuf::internal::dynamic_cast_if_available( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void LRNParam::MergeFrom(const LRNParam& from) { + GOOGLE_CHECK_NE(&from, this); + if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { + if (from.has_channels()) { + set_channels(from.channels()); + } + } + mutable_unknown_fields()->MergeFrom(from.unknown_fields()); +} + +void LRNParam::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void LRNParam::CopyFrom(const LRNParam& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool LRNParam::IsInitialized() const { + + return true; +} + +void LRNParam::Swap(LRNParam* other) { + if (other != this) { + std::swap(channels_, other->channels_); + std::swap(_has_bits_[0], other->_has_bits_[0]); + _unknown_fields_.Swap(&other->_unknown_fields_); + std::swap(_cached_size_, other->_cached_size_); + } +} + +::google::protobuf::Metadata LRNParam::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = LRNParam_descriptor_; + metadata.reflection = LRNParam_reflection_; + return metadata; +} + + +// =================================================================== + +#ifndef _MSC_VER +const int ReLUParam::kHeightFieldNumber; +const int ReLUParam::kWidthFieldNumber; +const int ReLUParam::kChannelsFieldNumber; +#endif // !_MSC_VER + +ReLUParam::ReLUParam() + : ::google::protobuf::Message() { + SharedCtor(); +} + +void ReLUParam::InitAsDefaultInstance() { +} + +ReLUParam::ReLUParam(const ReLUParam& from) + : ::google::protobuf::Message() { + SharedCtor(); + MergeFrom(from); +} + +void ReLUParam::SharedCtor() { + _cached_size_ = 0; + height_ = 1; + width_ = 0; + channels_ = 1; + ::memset(_has_bits_, 0, sizeof(_has_bits_)); +} + +ReLUParam::~ReLUParam() { + SharedDtor(); +} + +void ReLUParam::SharedDtor() { + if (this != default_instance_) { + } +} + +void ReLUParam::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* ReLUParam::descriptor() { + protobuf_AssignDescriptorsOnce(); + return ReLUParam_descriptor_; +} + +const ReLUParam& ReLUParam::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_GET_2eproto(); + return *default_instance_; +} + +ReLUParam* ReLUParam::default_instance_ = NULL; + +ReLUParam* ReLUParam::New() const { + return new ReLUParam; +} + +void ReLUParam::Clear() { + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + height_ = 1; + width_ = 0; + channels_ = 1; + } + ::memset(_has_bits_, 0, sizeof(_has_bits_)); + mutable_unknown_fields()->Clear(); +} + +bool ReLUParam::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) return false + ::google::protobuf::uint32 tag; + while ((tag = input->ReadTag()) != 0) { + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional int32 height = 1 [default = 1]; + case 1: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &height_))); + set_has_height(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(16)) goto parse_width; + break; + } + + // optional int32 width = 2; + case 2: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_width: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &width_))); + set_has_width(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(24)) goto parse_channels; + break; + } + + // optional int32 channels = 3 [default = 1]; + case 3: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_channels: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &channels_))); + set_has_channels(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectAtEnd()) return true; + break; + } + + default: { + handle_uninterpreted: + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + return true; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, mutable_unknown_fields())); + break; + } + } + } + return true; +#undef DO_ +} + +void ReLUParam::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // optional int32 height = 1 [default = 1]; + if (has_height()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->height(), output); + } + + // optional int32 width = 2; + if (has_width()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->width(), output); + } + + // optional int32 channels = 3 [default = 1]; + if (has_channels()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->channels(), output); + } + + if (!unknown_fields().empty()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + unknown_fields(), output); + } +} + +::google::protobuf::uint8* ReLUParam::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // optional int32 height = 1 [default = 1]; + if (has_height()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->height(), target); + } + + // optional int32 width = 2; + if (has_width()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->width(), target); + } + + // optional int32 channels = 3 [default = 1]; + if (has_channels()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->channels(), target); + } + + if (!unknown_fields().empty()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + unknown_fields(), target); + } + return target; +} + +int ReLUParam::ByteSize() const { + int total_size = 0; + + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + // optional int32 height = 1 [default = 1]; + if (has_height()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->height()); + } + + // optional int32 width = 2; + if (has_width()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->width()); + } + + // optional int32 channels = 3 [default = 1]; + if (has_channels()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->channels()); + } + + } + if (!unknown_fields().empty()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + unknown_fields()); + } + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void ReLUParam::MergeFrom(const ::google::protobuf::Message& from) { + GOOGLE_CHECK_NE(&from, this); + const ReLUParam* source = + ::google::protobuf::internal::dynamic_cast_if_available( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void ReLUParam::MergeFrom(const ReLUParam& from) { + GOOGLE_CHECK_NE(&from, this); + if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { + if (from.has_height()) { + set_height(from.height()); + } + if (from.has_width()) { + set_width(from.width()); + } + if (from.has_channels()) { + set_channels(from.channels()); + } + } + mutable_unknown_fields()->MergeFrom(from.unknown_fields()); +} + +void ReLUParam::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ReLUParam::CopyFrom(const ReLUParam& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ReLUParam::IsInitialized() const { + + return true; +} + +void ReLUParam::Swap(ReLUParam* other) { + if (other != this) { + std::swap(height_, other->height_); + std::swap(width_, other->width_); + std::swap(channels_, other->channels_); + std::swap(_has_bits_[0], other->_has_bits_[0]); + _unknown_fields_.Swap(&other->_unknown_fields_); + std::swap(_cached_size_, other->_cached_size_); + } +} + +::google::protobuf::Metadata ReLUParam::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = ReLUParam_descriptor_; + metadata.reflection = ReLUParam_reflection_; + return metadata; +} + + +// =================================================================== + +#ifndef _MSC_VER +const int SigmoidParam::kHeightFieldNumber; +const int SigmoidParam::kWidthFieldNumber; +const int SigmoidParam::kChannelsFieldNumber; +#endif // !_MSC_VER + +SigmoidParam::SigmoidParam() + : ::google::protobuf::Message() { + SharedCtor(); +} + +void SigmoidParam::InitAsDefaultInstance() { +} + +SigmoidParam::SigmoidParam(const SigmoidParam& from) + : ::google::protobuf::Message() { + SharedCtor(); + MergeFrom(from); +} + +void SigmoidParam::SharedCtor() { + _cached_size_ = 0; + height_ = 1; + width_ = 0; + channels_ = 1; + ::memset(_has_bits_, 0, sizeof(_has_bits_)); +} + +SigmoidParam::~SigmoidParam() { + SharedDtor(); +} + +void SigmoidParam::SharedDtor() { + if (this != default_instance_) { + } +} + +void SigmoidParam::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* SigmoidParam::descriptor() { + protobuf_AssignDescriptorsOnce(); + return SigmoidParam_descriptor_; +} + +const SigmoidParam& SigmoidParam::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_GET_2eproto(); + return *default_instance_; +} + +SigmoidParam* SigmoidParam::default_instance_ = NULL; + +SigmoidParam* SigmoidParam::New() const { + return new SigmoidParam; +} + +void SigmoidParam::Clear() { + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + height_ = 1; + width_ = 0; + channels_ = 1; + } + ::memset(_has_bits_, 0, sizeof(_has_bits_)); + mutable_unknown_fields()->Clear(); +} + +bool SigmoidParam::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) return false + ::google::protobuf::uint32 tag; + while ((tag = input->ReadTag()) != 0) { + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional int32 height = 1 [default = 1]; + case 1: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &height_))); + set_has_height(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(16)) goto parse_width; + break; + } + + // optional int32 width = 2; + case 2: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_width: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &width_))); + set_has_width(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectTag(24)) goto parse_channels; + break; + } + + // optional int32 channels = 3 [default = 1]; + case 3: { + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_VARINT) { + parse_channels: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &channels_))); + set_has_channels(); + } else { + goto handle_uninterpreted; + } + if (input->ExpectAtEnd()) return true; + break; + } + + default: { + handle_uninterpreted: + if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + return true; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, mutable_unknown_fields())); + break; + } + } + } + return true; +#undef DO_ +} + +void SigmoidParam::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // optional int32 height = 1 [default = 1]; + if (has_height()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->height(), output); + } + + // optional int32 width = 2; + if (has_width()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->width(), output); + } + + // optional int32 channels = 3 [default = 1]; + if (has_channels()) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->channels(), output); + } + + if (!unknown_fields().empty()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + unknown_fields(), output); + } +} + +::google::protobuf::uint8* SigmoidParam::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // optional int32 height = 1 [default = 1]; + if (has_height()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->height(), target); + } + + // optional int32 width = 2; + if (has_width()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->width(), target); + } + + // optional int32 channels = 3 [default = 1]; + if (has_channels()) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->channels(), target); + } + + if (!unknown_fields().empty()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + unknown_fields(), target); + } + return target; +} + +int SigmoidParam::ByteSize() const { + int total_size = 0; + + if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { + // optional int32 height = 1 [default = 1]; + if (has_height()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->height()); + } + + // optional int32 width = 2; + if (has_width()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->width()); + } + + // optional int32 channels = 3 [default = 1]; + if (has_channels()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->channels()); + } + + } + if (!unknown_fields().empty()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + unknown_fields()); + } + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void SigmoidParam::MergeFrom(const ::google::protobuf::Message& from) { + GOOGLE_CHECK_NE(&from, this); + const SigmoidParam* source = + ::google::protobuf::internal::dynamic_cast_if_available( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void SigmoidParam::MergeFrom(const SigmoidParam& from) { + GOOGLE_CHECK_NE(&from, this); + if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { + if (from.has_height()) { + set_height(from.height()); + } + if (from.has_width()) { + set_width(from.width()); + } + if (from.has_channels()) { + set_channels(from.channels()); + } + } + mutable_unknown_fields()->MergeFrom(from.unknown_fields()); +} + +void SigmoidParam::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SigmoidParam::CopyFrom(const SigmoidParam& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SigmoidParam::IsInitialized() const { + + return true; +} + +void SigmoidParam::Swap(SigmoidParam* other) { + if (other != this) { + std::swap(height_, other->height_); + std::swap(width_, other->width_); + std::swap(channels_, other->channels_); + std::swap(_has_bits_[0], other->_has_bits_[0]); + _unknown_fields_.Swap(&other->_unknown_fields_); + std::swap(_cached_size_, other->_cached_size_); + } +} + +::google::protobuf::Metadata SigmoidParam::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = SigmoidParam_descriptor_; + metadata.reflection = SigmoidParam_reflection_; + return metadata; +} + + // @@protoc_insertion_point(namespace_scope) } // namespace GET diff --git a/include/proto/GET.pb.h b/include/proto/GET.pb.h index a0fb8ac..7e02b22 100644 --- a/include/proto/GET.pb.h +++ b/include/proto/GET.pb.h @@ -38,6 +38,11 @@ class TaskParam; class AddParam; class SubParam; class MulParam; +class ConvParam; +class PoolParam; +class LRNParam; +class ReLUParam; +class SigmoidParam; enum TaskParam_DataPosition { TaskParam_DataPosition_FILE = 0, @@ -63,11 +68,14 @@ enum TaskParam_TaskType { TaskParam_TaskType_SUB = 1, TaskParam_TaskType_MULTI = 2, TaskParam_TaskType_CONVOLUTION = 3, - TaskParam_TaskType_POOLING = 4 + TaskParam_TaskType_POOLING = 4, + TaskParam_TaskType_LRN = 5, + TaskParam_TaskType_RELU = 6, + TaskParam_TaskType_Sigmoid = 7 }; bool TaskParam_TaskType_IsValid(int value); const TaskParam_TaskType TaskParam_TaskType_TaskType_MIN = TaskParam_TaskType_ADD; -const TaskParam_TaskType TaskParam_TaskType_TaskType_MAX = TaskParam_TaskType_POOLING; +const TaskParam_TaskType TaskParam_TaskType_TaskType_MAX = TaskParam_TaskType_Sigmoid; const int TaskParam_TaskType_TaskType_ARRAYSIZE = TaskParam_TaskType_TaskType_MAX + 1; const ::google::protobuf::EnumDescriptor* TaskParam_TaskType_descriptor(); @@ -164,6 +172,9 @@ class TaskParam : public ::google::protobuf::Message { static const TaskType MULTI = TaskParam_TaskType_MULTI; static const TaskType CONVOLUTION = TaskParam_TaskType_CONVOLUTION; static const TaskType POOLING = TaskParam_TaskType_POOLING; + static const TaskType LRN = TaskParam_TaskType_LRN; + static const TaskType RELU = TaskParam_TaskType_RELU; + static const TaskType Sigmoid = TaskParam_TaskType_Sigmoid; static inline bool TaskType_IsValid(int value) { return TaskParam_TaskType_IsValid(value); } @@ -291,6 +302,51 @@ class TaskParam : public ::google::protobuf::Message { inline ::GET::MulParam* release_mul_param(); inline void set_allocated_mul_param(::GET::MulParam* mul_param); + // optional .GET.ConvParam conv_param = 11; + inline bool has_conv_param() const; + inline void clear_conv_param(); + static const int kConvParamFieldNumber = 11; + inline const ::GET::ConvParam& conv_param() const; + inline ::GET::ConvParam* mutable_conv_param(); + inline ::GET::ConvParam* release_conv_param(); + inline void set_allocated_conv_param(::GET::ConvParam* conv_param); + + // optional .GET.PoolParam pool_param = 12; + inline bool has_pool_param() const; + inline void clear_pool_param(); + static const int kPoolParamFieldNumber = 12; + inline const ::GET::PoolParam& pool_param() const; + inline ::GET::PoolParam* mutable_pool_param(); + inline ::GET::PoolParam* release_pool_param(); + inline void set_allocated_pool_param(::GET::PoolParam* pool_param); + + // optional .GET.LRNParam lrn_param = 13; + inline bool has_lrn_param() const; + inline void clear_lrn_param(); + static const int kLrnParamFieldNumber = 13; + inline const ::GET::LRNParam& lrn_param() const; + inline ::GET::LRNParam* mutable_lrn_param(); + inline ::GET::LRNParam* release_lrn_param(); + inline void set_allocated_lrn_param(::GET::LRNParam* lrn_param); + + // optional .GET.ReLUParam relu_param = 14; + inline bool has_relu_param() const; + inline void clear_relu_param(); + static const int kReluParamFieldNumber = 14; + inline const ::GET::ReLUParam& relu_param() const; + inline ::GET::ReLUParam* mutable_relu_param(); + inline ::GET::ReLUParam* release_relu_param(); + inline void set_allocated_relu_param(::GET::ReLUParam* relu_param); + + // optional .GET.SigmoidParam sigmoid_param = 15; + inline bool has_sigmoid_param() const; + inline void clear_sigmoid_param(); + static const int kSigmoidParamFieldNumber = 15; + inline const ::GET::SigmoidParam& sigmoid_param() const; + inline ::GET::SigmoidParam* mutable_sigmoid_param(); + inline ::GET::SigmoidParam* release_sigmoid_param(); + inline void set_allocated_sigmoid_param(::GET::SigmoidParam* sigmoid_param); + // @@protoc_insertion_point(class_scope:GET.TaskParam) private: inline void set_has_source_pos(); @@ -305,6 +361,16 @@ class TaskParam : public ::google::protobuf::Message { inline void clear_has_sub_param(); inline void set_has_mul_param(); inline void clear_has_mul_param(); + inline void set_has_conv_param(); + inline void clear_has_conv_param(); + inline void set_has_pool_param(); + inline void clear_has_pool_param(); + inline void set_has_lrn_param(); + inline void clear_has_lrn_param(); + inline void set_has_relu_param(); + inline void clear_has_relu_param(); + inline void set_has_sigmoid_param(); + inline void clear_has_sigmoid_param(); ::google::protobuf::UnknownFieldSet _unknown_fields_; @@ -317,10 +383,15 @@ class TaskParam : public ::google::protobuf::Message { ::GET::AddParam* add_param_; ::GET::SubParam* sub_param_; ::GET::MulParam* mul_param_; + ::GET::ConvParam* conv_param_; + ::GET::PoolParam* pool_param_; + ::GET::LRNParam* lrn_param_; + ::GET::ReLUParam* relu_param_; + ::GET::SigmoidParam* sigmoid_param_; int type_; mutable int _cached_size_; - ::google::protobuf::uint32 _has_bits_[(10 + 31) / 32]; + ::google::protobuf::uint32 _has_bits_[(15 + 31) / 32]; friend void protobuf_AddDesc_GET_2eproto(); friend void protobuf_AssignDesc_GET_2eproto(); @@ -589,64 +660,54 @@ class MulParam : public ::google::protobuf::Message { // accessors ------------------------------------------------------- - // optional int32 height_A = 1; - inline bool has_height_a() const; - inline void clear_height_a(); - static const int kHeightAFieldNumber = 1; - inline ::google::protobuf::int32 height_a() const; - inline void set_height_a(::google::protobuf::int32 value); - - // optional int32 width_A = 2; - inline bool has_width_a() const; - inline void clear_width_a(); - static const int kWidthAFieldNumber = 2; - inline ::google::protobuf::int32 width_a() const; - inline void set_width_a(::google::protobuf::int32 value); - - // optional int32 height_B = 3; - inline bool has_height_b() const; - inline void clear_height_b(); - static const int kHeightBFieldNumber = 3; - inline ::google::protobuf::int32 height_b() const; - inline void set_height_b(::google::protobuf::int32 value); - - // optional int32 width_B = 4; - inline bool has_width_b() const; - inline void clear_width_b(); - static const int kWidthBFieldNumber = 4; - inline ::google::protobuf::int32 width_b() const; - inline void set_width_b(::google::protobuf::int32 value); - - // optional int32 channels = 5 [default = 1]; + // optional int32 M = 1; + inline bool has_m() const; + inline void clear_m(); + static const int kMFieldNumber = 1; + inline ::google::protobuf::int32 m() const; + inline void set_m(::google::protobuf::int32 value); + + // optional int32 K = 2; + inline bool has_k() const; + inline void clear_k(); + static const int kKFieldNumber = 2; + inline ::google::protobuf::int32 k() const; + inline void set_k(::google::protobuf::int32 value); + + // optional int32 N = 3; + inline bool has_n() const; + inline void clear_n(); + static const int kNFieldNumber = 3; + inline ::google::protobuf::int32 n() const; + inline void set_n(::google::protobuf::int32 value); + + // optional int32 channels = 4 [default = 1]; inline bool has_channels() const; inline void clear_channels(); - static const int kChannelsFieldNumber = 5; + static const int kChannelsFieldNumber = 4; inline ::google::protobuf::int32 channels() const; inline void set_channels(::google::protobuf::int32 value); // @@protoc_insertion_point(class_scope:GET.MulParam) private: - inline void set_has_height_a(); - inline void clear_has_height_a(); - inline void set_has_width_a(); - inline void clear_has_width_a(); - inline void set_has_height_b(); - inline void clear_has_height_b(); - inline void set_has_width_b(); - inline void clear_has_width_b(); + inline void set_has_m(); + inline void clear_has_m(); + inline void set_has_k(); + inline void clear_has_k(); + inline void set_has_n(); + inline void clear_has_n(); inline void set_has_channels(); inline void clear_has_channels(); ::google::protobuf::UnknownFieldSet _unknown_fields_; - ::google::protobuf::int32 height_a_; - ::google::protobuf::int32 width_a_; - ::google::protobuf::int32 height_b_; - ::google::protobuf::int32 width_b_; + ::google::protobuf::int32 m_; + ::google::protobuf::int32 k_; + ::google::protobuf::int32 n_; ::google::protobuf::int32 channels_; mutable int _cached_size_; - ::google::protobuf::uint32 _has_bits_[(5 + 31) / 32]; + ::google::protobuf::uint32 _has_bits_[(4 + 31) / 32]; friend void protobuf_AddDesc_GET_2eproto(); friend void protobuf_AssignDesc_GET_2eproto(); @@ -655,584 +716,1932 @@ class MulParam : public ::google::protobuf::Message { void InitAsDefaultInstance(); static MulParam* default_instance_; }; -// =================================================================== +// ------------------------------------------------------------------- +class ConvParam : public ::google::protobuf::Message { + public: + ConvParam(); + virtual ~ConvParam(); -// =================================================================== + ConvParam(const ConvParam& from); -// TaskParam + inline ConvParam& operator=(const ConvParam& from) { + CopyFrom(from); + return *this; + } -// optional .GET.TaskParam.DataPosition source_pos = 1 [default = HOSTMEM]; -inline bool TaskParam::has_source_pos() const { - return (_has_bits_[0] & 0x00000001u) != 0; -} -inline void TaskParam::set_has_source_pos() { - _has_bits_[0] |= 0x00000001u; -} -inline void TaskParam::clear_has_source_pos() { - _has_bits_[0] &= ~0x00000001u; -} -inline void TaskParam::clear_source_pos() { - source_pos_ = 1; - clear_has_source_pos(); -} -inline ::GET::TaskParam_DataPosition TaskParam::source_pos() const { - return static_cast< ::GET::TaskParam_DataPosition >(source_pos_); -} -inline void TaskParam::set_source_pos(::GET::TaskParam_DataPosition value) { - assert(::GET::TaskParam_DataPosition_IsValid(value)); - set_has_source_pos(); - source_pos_ = value; -} + inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { + return _unknown_fields_; + } -// repeated string sourcef = 2; -inline int TaskParam::sourcef_size() const { - return sourcef_.size(); -} -inline void TaskParam::clear_sourcef() { - sourcef_.Clear(); -} -inline const ::std::string& TaskParam::sourcef(int index) const { - return sourcef_.Get(index); -} -inline ::std::string* TaskParam::mutable_sourcef(int index) { - return sourcef_.Mutable(index); -} -inline void TaskParam::set_sourcef(int index, const ::std::string& value) { - sourcef_.Mutable(index)->assign(value); -} -inline void TaskParam::set_sourcef(int index, const char* value) { - sourcef_.Mutable(index)->assign(value); -} -inline void TaskParam::set_sourcef(int index, const char* value, size_t size) { - sourcef_.Mutable(index)->assign( - reinterpret_cast(value), size); -} -inline ::std::string* TaskParam::add_sourcef() { - return sourcef_.Add(); -} -inline void TaskParam::add_sourcef(const ::std::string& value) { - sourcef_.Add()->assign(value); -} -inline void TaskParam::add_sourcef(const char* value) { - sourcef_.Add()->assign(value); -} -inline void TaskParam::add_sourcef(const char* value, size_t size) { - sourcef_.Add()->assign(reinterpret_cast(value), size); -} -inline const ::google::protobuf::RepeatedPtrField< ::std::string>& -TaskParam::sourcef() const { - return sourcef_; -} -inline ::google::protobuf::RepeatedPtrField< ::std::string>* -TaskParam::mutable_sourcef() { - return &sourcef_; -} + inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { + return &_unknown_fields_; + } -// repeated uint64 sourcem = 3; -inline int TaskParam::sourcem_size() const { - return sourcem_.size(); -} -inline void TaskParam::clear_sourcem() { - sourcem_.Clear(); -} -inline ::google::protobuf::uint64 TaskParam::sourcem(int index) const { - return sourcem_.Get(index); -} -inline void TaskParam::set_sourcem(int index, ::google::protobuf::uint64 value) { - sourcem_.Set(index, value); -} -inline void TaskParam::add_sourcem(::google::protobuf::uint64 value) { - sourcem_.Add(value); -} -inline const ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >& -TaskParam::sourcem() const { - return sourcem_; -} -inline ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >* -TaskParam::mutable_sourcem() { - return &sourcem_; -} + static const ::google::protobuf::Descriptor* descriptor(); + static const ConvParam& default_instance(); -// optional .GET.TaskParam.DataPosition result_pos = 4 [default = HOSTMEM]; -inline bool TaskParam::has_result_pos() const { - return (_has_bits_[0] & 0x00000008u) != 0; -} -inline void TaskParam::set_has_result_pos() { - _has_bits_[0] |= 0x00000008u; -} -inline void TaskParam::clear_has_result_pos() { - _has_bits_[0] &= ~0x00000008u; -} -inline void TaskParam::clear_result_pos() { - result_pos_ = 1; - clear_has_result_pos(); -} -inline ::GET::TaskParam_DataPosition TaskParam::result_pos() const { - return static_cast< ::GET::TaskParam_DataPosition >(result_pos_); -} -inline void TaskParam::set_result_pos(::GET::TaskParam_DataPosition value) { - assert(::GET::TaskParam_DataPosition_IsValid(value)); - set_has_result_pos(); - result_pos_ = value; -} + void Swap(ConvParam* other); -// repeated string resultf = 5; -inline int TaskParam::resultf_size() const { - return resultf_.size(); -} -inline void TaskParam::clear_resultf() { - resultf_.Clear(); -} -inline const ::std::string& TaskParam::resultf(int index) const { - return resultf_.Get(index); -} -inline ::std::string* TaskParam::mutable_resultf(int index) { - return resultf_.Mutable(index); -} -inline void TaskParam::set_resultf(int index, const ::std::string& value) { - resultf_.Mutable(index)->assign(value); -} -inline void TaskParam::set_resultf(int index, const char* value) { - resultf_.Mutable(index)->assign(value); -} -inline void TaskParam::set_resultf(int index, const char* value, size_t size) { - resultf_.Mutable(index)->assign( - reinterpret_cast(value), size); -} -inline ::std::string* TaskParam::add_resultf() { - return resultf_.Add(); -} -inline void TaskParam::add_resultf(const ::std::string& value) { - resultf_.Add()->assign(value); -} -inline void TaskParam::add_resultf(const char* value) { - resultf_.Add()->assign(value); -} -inline void TaskParam::add_resultf(const char* value, size_t size) { - resultf_.Add()->assign(reinterpret_cast(value), size); -} -inline const ::google::protobuf::RepeatedPtrField< ::std::string>& -TaskParam::resultf() const { - return resultf_; -} -inline ::google::protobuf::RepeatedPtrField< ::std::string>* -TaskParam::mutable_resultf() { - return &resultf_; -} + // implements Message ---------------------------------------------- -// repeated uint64 resultm = 6; -inline int TaskParam::resultm_size() const { - return resultm_.size(); -} -inline void TaskParam::clear_resultm() { - resultm_.Clear(); -} -inline ::google::protobuf::uint64 TaskParam::resultm(int index) const { - return resultm_.Get(index); -} -inline void TaskParam::set_resultm(int index, ::google::protobuf::uint64 value) { - resultm_.Set(index, value); -} -inline void TaskParam::add_resultm(::google::protobuf::uint64 value) { - resultm_.Add(value); -} -inline const ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >& -TaskParam::resultm() const { - return resultm_; -} -inline ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >* -TaskParam::mutable_resultm() { - return &resultm_; -} + ConvParam* New() const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const ConvParam& from); + void MergeFrom(const ConvParam& from); + void Clear(); + bool IsInitialized() const; -// optional .GET.TaskParam.TaskType type = 7; -inline bool TaskParam::has_type() const { - return (_has_bits_[0] & 0x00000040u) != 0; -} -inline void TaskParam::set_has_type() { - _has_bits_[0] |= 0x00000040u; -} -inline void TaskParam::clear_has_type() { - _has_bits_[0] &= ~0x00000040u; -} -inline void TaskParam::clear_type() { - type_ = 0; - clear_has_type(); -} -inline ::GET::TaskParam_TaskType TaskParam::type() const { - return static_cast< ::GET::TaskParam_TaskType >(type_); -} -inline void TaskParam::set_type(::GET::TaskParam_TaskType value) { - assert(::GET::TaskParam_TaskType_IsValid(value)); - set_has_type(); - type_ = value; -} + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + public: -// optional .GET.AddParam add_param = 8; -inline bool TaskParam::has_add_param() const { - return (_has_bits_[0] & 0x00000080u) != 0; -} -inline void TaskParam::set_has_add_param() { - _has_bits_[0] |= 0x00000080u; -} -inline void TaskParam::clear_has_add_param() { - _has_bits_[0] &= ~0x00000080u; -} -inline void TaskParam::clear_add_param() { - if (add_param_ != NULL) add_param_->::GET::AddParam::Clear(); - clear_has_add_param(); -} -inline const ::GET::AddParam& TaskParam::add_param() const { - return add_param_ != NULL ? *add_param_ : *default_instance_->add_param_; -} -inline ::GET::AddParam* TaskParam::mutable_add_param() { - set_has_add_param(); - if (add_param_ == NULL) add_param_ = new ::GET::AddParam; - return add_param_; -} -inline ::GET::AddParam* TaskParam::release_add_param() { + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional int32 data_h = 1; + inline bool has_data_h() const; + inline void clear_data_h(); + static const int kDataHFieldNumber = 1; + inline ::google::protobuf::int32 data_h() const; + inline void set_data_h(::google::protobuf::int32 value); + + // optional int32 data_w = 2; + inline bool has_data_w() const; + inline void clear_data_w(); + static const int kDataWFieldNumber = 2; + inline ::google::protobuf::int32 data_w() const; + inline void set_data_w(::google::protobuf::int32 value); + + // optional int32 filter_h = 3; + inline bool has_filter_h() const; + inline void clear_filter_h(); + static const int kFilterHFieldNumber = 3; + inline ::google::protobuf::int32 filter_h() const; + inline void set_filter_h(::google::protobuf::int32 value); + + // optional int32 filter_w = 4; + inline bool has_filter_w() const; + inline void clear_filter_w(); + static const int kFilterWFieldNumber = 4; + inline ::google::protobuf::int32 filter_w() const; + inline void set_filter_w(::google::protobuf::int32 value); + + // optional int32 stride_h = 5 [default = 1]; + inline bool has_stride_h() const; + inline void clear_stride_h(); + static const int kStrideHFieldNumber = 5; + inline ::google::protobuf::int32 stride_h() const; + inline void set_stride_h(::google::protobuf::int32 value); + + // optional int32 stride_w = 6 [default = 1]; + inline bool has_stride_w() const; + inline void clear_stride_w(); + static const int kStrideWFieldNumber = 6; + inline ::google::protobuf::int32 stride_w() const; + inline void set_stride_w(::google::protobuf::int32 value); + + // optional int32 pad_h = 7 [default = 0]; + inline bool has_pad_h() const; + inline void clear_pad_h(); + static const int kPadHFieldNumber = 7; + inline ::google::protobuf::int32 pad_h() const; + inline void set_pad_h(::google::protobuf::int32 value); + + // optional int32 pad_w = 8 [default = 0]; + inline bool has_pad_w() const; + inline void clear_pad_w(); + static const int kPadWFieldNumber = 8; + inline ::google::protobuf::int32 pad_w() const; + inline void set_pad_w(::google::protobuf::int32 value); + + // optional int32 channels = 9 [default = 1]; + inline bool has_channels() const; + inline void clear_channels(); + static const int kChannelsFieldNumber = 9; + inline ::google::protobuf::int32 channels() const; + inline void set_channels(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:GET.ConvParam) + private: + inline void set_has_data_h(); + inline void clear_has_data_h(); + inline void set_has_data_w(); + inline void clear_has_data_w(); + inline void set_has_filter_h(); + inline void clear_has_filter_h(); + inline void set_has_filter_w(); + inline void clear_has_filter_w(); + inline void set_has_stride_h(); + inline void clear_has_stride_h(); + inline void set_has_stride_w(); + inline void clear_has_stride_w(); + inline void set_has_pad_h(); + inline void clear_has_pad_h(); + inline void set_has_pad_w(); + inline void clear_has_pad_w(); + inline void set_has_channels(); + inline void clear_has_channels(); + + ::google::protobuf::UnknownFieldSet _unknown_fields_; + + ::google::protobuf::int32 data_h_; + ::google::protobuf::int32 data_w_; + ::google::protobuf::int32 filter_h_; + ::google::protobuf::int32 filter_w_; + ::google::protobuf::int32 stride_h_; + ::google::protobuf::int32 stride_w_; + ::google::protobuf::int32 pad_h_; + ::google::protobuf::int32 pad_w_; + ::google::protobuf::int32 channels_; + + mutable int _cached_size_; + ::google::protobuf::uint32 _has_bits_[(9 + 31) / 32]; + + friend void protobuf_AddDesc_GET_2eproto(); + friend void protobuf_AssignDesc_GET_2eproto(); + friend void protobuf_ShutdownFile_GET_2eproto(); + + void InitAsDefaultInstance(); + static ConvParam* default_instance_; +}; +// ------------------------------------------------------------------- + +class PoolParam : public ::google::protobuf::Message { + public: + PoolParam(); + virtual ~PoolParam(); + + PoolParam(const PoolParam& from); + + inline PoolParam& operator=(const PoolParam& from) { + CopyFrom(from); + return *this; + } + + inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { + return _unknown_fields_; + } + + inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { + return &_unknown_fields_; + } + + static const ::google::protobuf::Descriptor* descriptor(); + static const PoolParam& default_instance(); + + void Swap(PoolParam* other); + + // implements Message ---------------------------------------------- + + PoolParam* New() const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const PoolParam& from); + void MergeFrom(const PoolParam& from); + void Clear(); + bool IsInitialized() const; + + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + public: + + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional int32 data_h = 1; + inline bool has_data_h() const; + inline void clear_data_h(); + static const int kDataHFieldNumber = 1; + inline ::google::protobuf::int32 data_h() const; + inline void set_data_h(::google::protobuf::int32 value); + + // optional int32 data_w = 2; + inline bool has_data_w() const; + inline void clear_data_w(); + static const int kDataWFieldNumber = 2; + inline ::google::protobuf::int32 data_w() const; + inline void set_data_w(::google::protobuf::int32 value); + + // optional int32 kernel_h = 3; + inline bool has_kernel_h() const; + inline void clear_kernel_h(); + static const int kKernelHFieldNumber = 3; + inline ::google::protobuf::int32 kernel_h() const; + inline void set_kernel_h(::google::protobuf::int32 value); + + // optional int32 kernel_w = 4; + inline bool has_kernel_w() const; + inline void clear_kernel_w(); + static const int kKernelWFieldNumber = 4; + inline ::google::protobuf::int32 kernel_w() const; + inline void set_kernel_w(::google::protobuf::int32 value); + + // optional int32 stride_h = 5 [default = 1]; + inline bool has_stride_h() const; + inline void clear_stride_h(); + static const int kStrideHFieldNumber = 5; + inline ::google::protobuf::int32 stride_h() const; + inline void set_stride_h(::google::protobuf::int32 value); + + // optional int32 stride_w = 6 [default = 1]; + inline bool has_stride_w() const; + inline void clear_stride_w(); + static const int kStrideWFieldNumber = 6; + inline ::google::protobuf::int32 stride_w() const; + inline void set_stride_w(::google::protobuf::int32 value); + + // optional int32 pad_h = 7 [default = 0]; + inline bool has_pad_h() const; + inline void clear_pad_h(); + static const int kPadHFieldNumber = 7; + inline ::google::protobuf::int32 pad_h() const; + inline void set_pad_h(::google::protobuf::int32 value); + + // optional int32 pad_w = 8 [default = 0]; + inline bool has_pad_w() const; + inline void clear_pad_w(); + static const int kPadWFieldNumber = 8; + inline ::google::protobuf::int32 pad_w() const; + inline void set_pad_w(::google::protobuf::int32 value); + + // optional int32 channels = 9 [default = 1]; + inline bool has_channels() const; + inline void clear_channels(); + static const int kChannelsFieldNumber = 9; + inline ::google::protobuf::int32 channels() const; + inline void set_channels(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:GET.PoolParam) + private: + inline void set_has_data_h(); + inline void clear_has_data_h(); + inline void set_has_data_w(); + inline void clear_has_data_w(); + inline void set_has_kernel_h(); + inline void clear_has_kernel_h(); + inline void set_has_kernel_w(); + inline void clear_has_kernel_w(); + inline void set_has_stride_h(); + inline void clear_has_stride_h(); + inline void set_has_stride_w(); + inline void clear_has_stride_w(); + inline void set_has_pad_h(); + inline void clear_has_pad_h(); + inline void set_has_pad_w(); + inline void clear_has_pad_w(); + inline void set_has_channels(); + inline void clear_has_channels(); + + ::google::protobuf::UnknownFieldSet _unknown_fields_; + + ::google::protobuf::int32 data_h_; + ::google::protobuf::int32 data_w_; + ::google::protobuf::int32 kernel_h_; + ::google::protobuf::int32 kernel_w_; + ::google::protobuf::int32 stride_h_; + ::google::protobuf::int32 stride_w_; + ::google::protobuf::int32 pad_h_; + ::google::protobuf::int32 pad_w_; + ::google::protobuf::int32 channels_; + + mutable int _cached_size_; + ::google::protobuf::uint32 _has_bits_[(9 + 31) / 32]; + + friend void protobuf_AddDesc_GET_2eproto(); + friend void protobuf_AssignDesc_GET_2eproto(); + friend void protobuf_ShutdownFile_GET_2eproto(); + + void InitAsDefaultInstance(); + static PoolParam* default_instance_; +}; +// ------------------------------------------------------------------- + +class LRNParam : public ::google::protobuf::Message { + public: + LRNParam(); + virtual ~LRNParam(); + + LRNParam(const LRNParam& from); + + inline LRNParam& operator=(const LRNParam& from) { + CopyFrom(from); + return *this; + } + + inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { + return _unknown_fields_; + } + + inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { + return &_unknown_fields_; + } + + static const ::google::protobuf::Descriptor* descriptor(); + static const LRNParam& default_instance(); + + void Swap(LRNParam* other); + + // implements Message ---------------------------------------------- + + LRNParam* New() const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const LRNParam& from); + void MergeFrom(const LRNParam& from); + void Clear(); + bool IsInitialized() const; + + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + public: + + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional int32 channels = 1 [default = 0]; + inline bool has_channels() const; + inline void clear_channels(); + static const int kChannelsFieldNumber = 1; + inline ::google::protobuf::int32 channels() const; + inline void set_channels(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:GET.LRNParam) + private: + inline void set_has_channels(); + inline void clear_has_channels(); + + ::google::protobuf::UnknownFieldSet _unknown_fields_; + + ::google::protobuf::int32 channels_; + + mutable int _cached_size_; + ::google::protobuf::uint32 _has_bits_[(1 + 31) / 32]; + + friend void protobuf_AddDesc_GET_2eproto(); + friend void protobuf_AssignDesc_GET_2eproto(); + friend void protobuf_ShutdownFile_GET_2eproto(); + + void InitAsDefaultInstance(); + static LRNParam* default_instance_; +}; +// ------------------------------------------------------------------- + +class ReLUParam : public ::google::protobuf::Message { + public: + ReLUParam(); + virtual ~ReLUParam(); + + ReLUParam(const ReLUParam& from); + + inline ReLUParam& operator=(const ReLUParam& from) { + CopyFrom(from); + return *this; + } + + inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { + return _unknown_fields_; + } + + inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { + return &_unknown_fields_; + } + + static const ::google::protobuf::Descriptor* descriptor(); + static const ReLUParam& default_instance(); + + void Swap(ReLUParam* other); + + // implements Message ---------------------------------------------- + + ReLUParam* New() const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const ReLUParam& from); + void MergeFrom(const ReLUParam& from); + void Clear(); + bool IsInitialized() const; + + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + public: + + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional int32 height = 1 [default = 1]; + inline bool has_height() const; + inline void clear_height(); + static const int kHeightFieldNumber = 1; + inline ::google::protobuf::int32 height() const; + inline void set_height(::google::protobuf::int32 value); + + // optional int32 width = 2; + inline bool has_width() const; + inline void clear_width(); + static const int kWidthFieldNumber = 2; + inline ::google::protobuf::int32 width() const; + inline void set_width(::google::protobuf::int32 value); + + // optional int32 channels = 3 [default = 1]; + inline bool has_channels() const; + inline void clear_channels(); + static const int kChannelsFieldNumber = 3; + inline ::google::protobuf::int32 channels() const; + inline void set_channels(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:GET.ReLUParam) + private: + inline void set_has_height(); + inline void clear_has_height(); + inline void set_has_width(); + inline void clear_has_width(); + inline void set_has_channels(); + inline void clear_has_channels(); + + ::google::protobuf::UnknownFieldSet _unknown_fields_; + + ::google::protobuf::int32 height_; + ::google::protobuf::int32 width_; + ::google::protobuf::int32 channels_; + + mutable int _cached_size_; + ::google::protobuf::uint32 _has_bits_[(3 + 31) / 32]; + + friend void protobuf_AddDesc_GET_2eproto(); + friend void protobuf_AssignDesc_GET_2eproto(); + friend void protobuf_ShutdownFile_GET_2eproto(); + + void InitAsDefaultInstance(); + static ReLUParam* default_instance_; +}; +// ------------------------------------------------------------------- + +class SigmoidParam : public ::google::protobuf::Message { + public: + SigmoidParam(); + virtual ~SigmoidParam(); + + SigmoidParam(const SigmoidParam& from); + + inline SigmoidParam& operator=(const SigmoidParam& from) { + CopyFrom(from); + return *this; + } + + inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { + return _unknown_fields_; + } + + inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { + return &_unknown_fields_; + } + + static const ::google::protobuf::Descriptor* descriptor(); + static const SigmoidParam& default_instance(); + + void Swap(SigmoidParam* other); + + // implements Message ---------------------------------------------- + + SigmoidParam* New() const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const SigmoidParam& from); + void MergeFrom(const SigmoidParam& from); + void Clear(); + bool IsInitialized() const; + + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + public: + + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional int32 height = 1 [default = 1]; + inline bool has_height() const; + inline void clear_height(); + static const int kHeightFieldNumber = 1; + inline ::google::protobuf::int32 height() const; + inline void set_height(::google::protobuf::int32 value); + + // optional int32 width = 2; + inline bool has_width() const; + inline void clear_width(); + static const int kWidthFieldNumber = 2; + inline ::google::protobuf::int32 width() const; + inline void set_width(::google::protobuf::int32 value); + + // optional int32 channels = 3 [default = 1]; + inline bool has_channels() const; + inline void clear_channels(); + static const int kChannelsFieldNumber = 3; + inline ::google::protobuf::int32 channels() const; + inline void set_channels(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:GET.SigmoidParam) + private: + inline void set_has_height(); + inline void clear_has_height(); + inline void set_has_width(); + inline void clear_has_width(); + inline void set_has_channels(); + inline void clear_has_channels(); + + ::google::protobuf::UnknownFieldSet _unknown_fields_; + + ::google::protobuf::int32 height_; + ::google::protobuf::int32 width_; + ::google::protobuf::int32 channels_; + + mutable int _cached_size_; + ::google::protobuf::uint32 _has_bits_[(3 + 31) / 32]; + + friend void protobuf_AddDesc_GET_2eproto(); + friend void protobuf_AssignDesc_GET_2eproto(); + friend void protobuf_ShutdownFile_GET_2eproto(); + + void InitAsDefaultInstance(); + static SigmoidParam* default_instance_; +}; +// =================================================================== + + +// =================================================================== + +// TaskParam + +// optional .GET.TaskParam.DataPosition source_pos = 1 [default = HOSTMEM]; +inline bool TaskParam::has_source_pos() const { + return (_has_bits_[0] & 0x00000001u) != 0; +} +inline void TaskParam::set_has_source_pos() { + _has_bits_[0] |= 0x00000001u; +} +inline void TaskParam::clear_has_source_pos() { + _has_bits_[0] &= ~0x00000001u; +} +inline void TaskParam::clear_source_pos() { + source_pos_ = 1; + clear_has_source_pos(); +} +inline ::GET::TaskParam_DataPosition TaskParam::source_pos() const { + return static_cast< ::GET::TaskParam_DataPosition >(source_pos_); +} +inline void TaskParam::set_source_pos(::GET::TaskParam_DataPosition value) { + assert(::GET::TaskParam_DataPosition_IsValid(value)); + set_has_source_pos(); + source_pos_ = value; +} + +// repeated string sourcef = 2; +inline int TaskParam::sourcef_size() const { + return sourcef_.size(); +} +inline void TaskParam::clear_sourcef() { + sourcef_.Clear(); +} +inline const ::std::string& TaskParam::sourcef(int index) const { + return sourcef_.Get(index); +} +inline ::std::string* TaskParam::mutable_sourcef(int index) { + return sourcef_.Mutable(index); +} +inline void TaskParam::set_sourcef(int index, const ::std::string& value) { + sourcef_.Mutable(index)->assign(value); +} +inline void TaskParam::set_sourcef(int index, const char* value) { + sourcef_.Mutable(index)->assign(value); +} +inline void TaskParam::set_sourcef(int index, const char* value, size_t size) { + sourcef_.Mutable(index)->assign( + reinterpret_cast(value), size); +} +inline ::std::string* TaskParam::add_sourcef() { + return sourcef_.Add(); +} +inline void TaskParam::add_sourcef(const ::std::string& value) { + sourcef_.Add()->assign(value); +} +inline void TaskParam::add_sourcef(const char* value) { + sourcef_.Add()->assign(value); +} +inline void TaskParam::add_sourcef(const char* value, size_t size) { + sourcef_.Add()->assign(reinterpret_cast(value), size); +} +inline const ::google::protobuf::RepeatedPtrField< ::std::string>& +TaskParam::sourcef() const { + return sourcef_; +} +inline ::google::protobuf::RepeatedPtrField< ::std::string>* +TaskParam::mutable_sourcef() { + return &sourcef_; +} + +// repeated uint64 sourcem = 3; +inline int TaskParam::sourcem_size() const { + return sourcem_.size(); +} +inline void TaskParam::clear_sourcem() { + sourcem_.Clear(); +} +inline ::google::protobuf::uint64 TaskParam::sourcem(int index) const { + return sourcem_.Get(index); +} +inline void TaskParam::set_sourcem(int index, ::google::protobuf::uint64 value) { + sourcem_.Set(index, value); +} +inline void TaskParam::add_sourcem(::google::protobuf::uint64 value) { + sourcem_.Add(value); +} +inline const ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >& +TaskParam::sourcem() const { + return sourcem_; +} +inline ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >* +TaskParam::mutable_sourcem() { + return &sourcem_; +} + +// optional .GET.TaskParam.DataPosition result_pos = 4 [default = HOSTMEM]; +inline bool TaskParam::has_result_pos() const { + return (_has_bits_[0] & 0x00000008u) != 0; +} +inline void TaskParam::set_has_result_pos() { + _has_bits_[0] |= 0x00000008u; +} +inline void TaskParam::clear_has_result_pos() { + _has_bits_[0] &= ~0x00000008u; +} +inline void TaskParam::clear_result_pos() { + result_pos_ = 1; + clear_has_result_pos(); +} +inline ::GET::TaskParam_DataPosition TaskParam::result_pos() const { + return static_cast< ::GET::TaskParam_DataPosition >(result_pos_); +} +inline void TaskParam::set_result_pos(::GET::TaskParam_DataPosition value) { + assert(::GET::TaskParam_DataPosition_IsValid(value)); + set_has_result_pos(); + result_pos_ = value; +} + +// repeated string resultf = 5; +inline int TaskParam::resultf_size() const { + return resultf_.size(); +} +inline void TaskParam::clear_resultf() { + resultf_.Clear(); +} +inline const ::std::string& TaskParam::resultf(int index) const { + return resultf_.Get(index); +} +inline ::std::string* TaskParam::mutable_resultf(int index) { + return resultf_.Mutable(index); +} +inline void TaskParam::set_resultf(int index, const ::std::string& value) { + resultf_.Mutable(index)->assign(value); +} +inline void TaskParam::set_resultf(int index, const char* value) { + resultf_.Mutable(index)->assign(value); +} +inline void TaskParam::set_resultf(int index, const char* value, size_t size) { + resultf_.Mutable(index)->assign( + reinterpret_cast(value), size); +} +inline ::std::string* TaskParam::add_resultf() { + return resultf_.Add(); +} +inline void TaskParam::add_resultf(const ::std::string& value) { + resultf_.Add()->assign(value); +} +inline void TaskParam::add_resultf(const char* value) { + resultf_.Add()->assign(value); +} +inline void TaskParam::add_resultf(const char* value, size_t size) { + resultf_.Add()->assign(reinterpret_cast(value), size); +} +inline const ::google::protobuf::RepeatedPtrField< ::std::string>& +TaskParam::resultf() const { + return resultf_; +} +inline ::google::protobuf::RepeatedPtrField< ::std::string>* +TaskParam::mutable_resultf() { + return &resultf_; +} + +// repeated uint64 resultm = 6; +inline int TaskParam::resultm_size() const { + return resultm_.size(); +} +inline void TaskParam::clear_resultm() { + resultm_.Clear(); +} +inline ::google::protobuf::uint64 TaskParam::resultm(int index) const { + return resultm_.Get(index); +} +inline void TaskParam::set_resultm(int index, ::google::protobuf::uint64 value) { + resultm_.Set(index, value); +} +inline void TaskParam::add_resultm(::google::protobuf::uint64 value) { + resultm_.Add(value); +} +inline const ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >& +TaskParam::resultm() const { + return resultm_; +} +inline ::google::protobuf::RepeatedField< ::google::protobuf::uint64 >* +TaskParam::mutable_resultm() { + return &resultm_; +} + +// optional .GET.TaskParam.TaskType type = 7; +inline bool TaskParam::has_type() const { + return (_has_bits_[0] & 0x00000040u) != 0; +} +inline void TaskParam::set_has_type() { + _has_bits_[0] |= 0x00000040u; +} +inline void TaskParam::clear_has_type() { + _has_bits_[0] &= ~0x00000040u; +} +inline void TaskParam::clear_type() { + type_ = 0; + clear_has_type(); +} +inline ::GET::TaskParam_TaskType TaskParam::type() const { + return static_cast< ::GET::TaskParam_TaskType >(type_); +} +inline void TaskParam::set_type(::GET::TaskParam_TaskType value) { + assert(::GET::TaskParam_TaskType_IsValid(value)); + set_has_type(); + type_ = value; +} + +// optional .GET.AddParam add_param = 8; +inline bool TaskParam::has_add_param() const { + return (_has_bits_[0] & 0x00000080u) != 0; +} +inline void TaskParam::set_has_add_param() { + _has_bits_[0] |= 0x00000080u; +} +inline void TaskParam::clear_has_add_param() { + _has_bits_[0] &= ~0x00000080u; +} +inline void TaskParam::clear_add_param() { + if (add_param_ != NULL) add_param_->::GET::AddParam::Clear(); + clear_has_add_param(); +} +inline const ::GET::AddParam& TaskParam::add_param() const { + return add_param_ != NULL ? *add_param_ : *default_instance_->add_param_; +} +inline ::GET::AddParam* TaskParam::mutable_add_param() { + set_has_add_param(); + if (add_param_ == NULL) add_param_ = new ::GET::AddParam; + return add_param_; +} +inline ::GET::AddParam* TaskParam::release_add_param() { clear_has_add_param(); ::GET::AddParam* temp = add_param_; add_param_ = NULL; return temp; } -inline void TaskParam::set_allocated_add_param(::GET::AddParam* add_param) { - delete add_param_; - add_param_ = add_param; - if (add_param) { - set_has_add_param(); - } else { - clear_has_add_param(); - } +inline void TaskParam::set_allocated_add_param(::GET::AddParam* add_param) { + delete add_param_; + add_param_ = add_param; + if (add_param) { + set_has_add_param(); + } else { + clear_has_add_param(); + } +} + +// optional .GET.SubParam sub_param = 9; +inline bool TaskParam::has_sub_param() const { + return (_has_bits_[0] & 0x00000100u) != 0; +} +inline void TaskParam::set_has_sub_param() { + _has_bits_[0] |= 0x00000100u; +} +inline void TaskParam::clear_has_sub_param() { + _has_bits_[0] &= ~0x00000100u; +} +inline void TaskParam::clear_sub_param() { + if (sub_param_ != NULL) sub_param_->::GET::SubParam::Clear(); + clear_has_sub_param(); +} +inline const ::GET::SubParam& TaskParam::sub_param() const { + return sub_param_ != NULL ? *sub_param_ : *default_instance_->sub_param_; +} +inline ::GET::SubParam* TaskParam::mutable_sub_param() { + set_has_sub_param(); + if (sub_param_ == NULL) sub_param_ = new ::GET::SubParam; + return sub_param_; +} +inline ::GET::SubParam* TaskParam::release_sub_param() { + clear_has_sub_param(); + ::GET::SubParam* temp = sub_param_; + sub_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_sub_param(::GET::SubParam* sub_param) { + delete sub_param_; + sub_param_ = sub_param; + if (sub_param) { + set_has_sub_param(); + } else { + clear_has_sub_param(); + } +} + +// optional .GET.MulParam mul_param = 10; +inline bool TaskParam::has_mul_param() const { + return (_has_bits_[0] & 0x00000200u) != 0; +} +inline void TaskParam::set_has_mul_param() { + _has_bits_[0] |= 0x00000200u; +} +inline void TaskParam::clear_has_mul_param() { + _has_bits_[0] &= ~0x00000200u; +} +inline void TaskParam::clear_mul_param() { + if (mul_param_ != NULL) mul_param_->::GET::MulParam::Clear(); + clear_has_mul_param(); +} +inline const ::GET::MulParam& TaskParam::mul_param() const { + return mul_param_ != NULL ? *mul_param_ : *default_instance_->mul_param_; +} +inline ::GET::MulParam* TaskParam::mutable_mul_param() { + set_has_mul_param(); + if (mul_param_ == NULL) mul_param_ = new ::GET::MulParam; + return mul_param_; +} +inline ::GET::MulParam* TaskParam::release_mul_param() { + clear_has_mul_param(); + ::GET::MulParam* temp = mul_param_; + mul_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_mul_param(::GET::MulParam* mul_param) { + delete mul_param_; + mul_param_ = mul_param; + if (mul_param) { + set_has_mul_param(); + } else { + clear_has_mul_param(); + } +} + +// optional .GET.ConvParam conv_param = 11; +inline bool TaskParam::has_conv_param() const { + return (_has_bits_[0] & 0x00000400u) != 0; +} +inline void TaskParam::set_has_conv_param() { + _has_bits_[0] |= 0x00000400u; +} +inline void TaskParam::clear_has_conv_param() { + _has_bits_[0] &= ~0x00000400u; +} +inline void TaskParam::clear_conv_param() { + if (conv_param_ != NULL) conv_param_->::GET::ConvParam::Clear(); + clear_has_conv_param(); +} +inline const ::GET::ConvParam& TaskParam::conv_param() const { + return conv_param_ != NULL ? *conv_param_ : *default_instance_->conv_param_; +} +inline ::GET::ConvParam* TaskParam::mutable_conv_param() { + set_has_conv_param(); + if (conv_param_ == NULL) conv_param_ = new ::GET::ConvParam; + return conv_param_; +} +inline ::GET::ConvParam* TaskParam::release_conv_param() { + clear_has_conv_param(); + ::GET::ConvParam* temp = conv_param_; + conv_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_conv_param(::GET::ConvParam* conv_param) { + delete conv_param_; + conv_param_ = conv_param; + if (conv_param) { + set_has_conv_param(); + } else { + clear_has_conv_param(); + } +} + +// optional .GET.PoolParam pool_param = 12; +inline bool TaskParam::has_pool_param() const { + return (_has_bits_[0] & 0x00000800u) != 0; +} +inline void TaskParam::set_has_pool_param() { + _has_bits_[0] |= 0x00000800u; +} +inline void TaskParam::clear_has_pool_param() { + _has_bits_[0] &= ~0x00000800u; +} +inline void TaskParam::clear_pool_param() { + if (pool_param_ != NULL) pool_param_->::GET::PoolParam::Clear(); + clear_has_pool_param(); +} +inline const ::GET::PoolParam& TaskParam::pool_param() const { + return pool_param_ != NULL ? *pool_param_ : *default_instance_->pool_param_; +} +inline ::GET::PoolParam* TaskParam::mutable_pool_param() { + set_has_pool_param(); + if (pool_param_ == NULL) pool_param_ = new ::GET::PoolParam; + return pool_param_; +} +inline ::GET::PoolParam* TaskParam::release_pool_param() { + clear_has_pool_param(); + ::GET::PoolParam* temp = pool_param_; + pool_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_pool_param(::GET::PoolParam* pool_param) { + delete pool_param_; + pool_param_ = pool_param; + if (pool_param) { + set_has_pool_param(); + } else { + clear_has_pool_param(); + } +} + +// optional .GET.LRNParam lrn_param = 13; +inline bool TaskParam::has_lrn_param() const { + return (_has_bits_[0] & 0x00001000u) != 0; +} +inline void TaskParam::set_has_lrn_param() { + _has_bits_[0] |= 0x00001000u; +} +inline void TaskParam::clear_has_lrn_param() { + _has_bits_[0] &= ~0x00001000u; +} +inline void TaskParam::clear_lrn_param() { + if (lrn_param_ != NULL) lrn_param_->::GET::LRNParam::Clear(); + clear_has_lrn_param(); +} +inline const ::GET::LRNParam& TaskParam::lrn_param() const { + return lrn_param_ != NULL ? *lrn_param_ : *default_instance_->lrn_param_; +} +inline ::GET::LRNParam* TaskParam::mutable_lrn_param() { + set_has_lrn_param(); + if (lrn_param_ == NULL) lrn_param_ = new ::GET::LRNParam; + return lrn_param_; +} +inline ::GET::LRNParam* TaskParam::release_lrn_param() { + clear_has_lrn_param(); + ::GET::LRNParam* temp = lrn_param_; + lrn_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_lrn_param(::GET::LRNParam* lrn_param) { + delete lrn_param_; + lrn_param_ = lrn_param; + if (lrn_param) { + set_has_lrn_param(); + } else { + clear_has_lrn_param(); + } +} + +// optional .GET.ReLUParam relu_param = 14; +inline bool TaskParam::has_relu_param() const { + return (_has_bits_[0] & 0x00002000u) != 0; +} +inline void TaskParam::set_has_relu_param() { + _has_bits_[0] |= 0x00002000u; +} +inline void TaskParam::clear_has_relu_param() { + _has_bits_[0] &= ~0x00002000u; +} +inline void TaskParam::clear_relu_param() { + if (relu_param_ != NULL) relu_param_->::GET::ReLUParam::Clear(); + clear_has_relu_param(); +} +inline const ::GET::ReLUParam& TaskParam::relu_param() const { + return relu_param_ != NULL ? *relu_param_ : *default_instance_->relu_param_; +} +inline ::GET::ReLUParam* TaskParam::mutable_relu_param() { + set_has_relu_param(); + if (relu_param_ == NULL) relu_param_ = new ::GET::ReLUParam; + return relu_param_; +} +inline ::GET::ReLUParam* TaskParam::release_relu_param() { + clear_has_relu_param(); + ::GET::ReLUParam* temp = relu_param_; + relu_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_relu_param(::GET::ReLUParam* relu_param) { + delete relu_param_; + relu_param_ = relu_param; + if (relu_param) { + set_has_relu_param(); + } else { + clear_has_relu_param(); + } +} + +// optional .GET.SigmoidParam sigmoid_param = 15; +inline bool TaskParam::has_sigmoid_param() const { + return (_has_bits_[0] & 0x00004000u) != 0; +} +inline void TaskParam::set_has_sigmoid_param() { + _has_bits_[0] |= 0x00004000u; +} +inline void TaskParam::clear_has_sigmoid_param() { + _has_bits_[0] &= ~0x00004000u; +} +inline void TaskParam::clear_sigmoid_param() { + if (sigmoid_param_ != NULL) sigmoid_param_->::GET::SigmoidParam::Clear(); + clear_has_sigmoid_param(); +} +inline const ::GET::SigmoidParam& TaskParam::sigmoid_param() const { + return sigmoid_param_ != NULL ? *sigmoid_param_ : *default_instance_->sigmoid_param_; +} +inline ::GET::SigmoidParam* TaskParam::mutable_sigmoid_param() { + set_has_sigmoid_param(); + if (sigmoid_param_ == NULL) sigmoid_param_ = new ::GET::SigmoidParam; + return sigmoid_param_; +} +inline ::GET::SigmoidParam* TaskParam::release_sigmoid_param() { + clear_has_sigmoid_param(); + ::GET::SigmoidParam* temp = sigmoid_param_; + sigmoid_param_ = NULL; + return temp; +} +inline void TaskParam::set_allocated_sigmoid_param(::GET::SigmoidParam* sigmoid_param) { + delete sigmoid_param_; + sigmoid_param_ = sigmoid_param; + if (sigmoid_param) { + set_has_sigmoid_param(); + } else { + clear_has_sigmoid_param(); + } +} + +// ------------------------------------------------------------------- + +// AddParam + +// optional int32 height = 1; +inline bool AddParam::has_height() const { + return (_has_bits_[0] & 0x00000001u) != 0; +} +inline void AddParam::set_has_height() { + _has_bits_[0] |= 0x00000001u; +} +inline void AddParam::clear_has_height() { + _has_bits_[0] &= ~0x00000001u; +} +inline void AddParam::clear_height() { + height_ = 0; + clear_has_height(); +} +inline ::google::protobuf::int32 AddParam::height() const { + return height_; +} +inline void AddParam::set_height(::google::protobuf::int32 value) { + set_has_height(); + height_ = value; +} + +// optional int32 width = 2; +inline bool AddParam::has_width() const { + return (_has_bits_[0] & 0x00000002u) != 0; +} +inline void AddParam::set_has_width() { + _has_bits_[0] |= 0x00000002u; +} +inline void AddParam::clear_has_width() { + _has_bits_[0] &= ~0x00000002u; +} +inline void AddParam::clear_width() { + width_ = 0; + clear_has_width(); +} +inline ::google::protobuf::int32 AddParam::width() const { + return width_; +} +inline void AddParam::set_width(::google::protobuf::int32 value) { + set_has_width(); + width_ = value; +} + +// optional int32 channels = 3 [default = 1]; +inline bool AddParam::has_channels() const { + return (_has_bits_[0] & 0x00000004u) != 0; +} +inline void AddParam::set_has_channels() { + _has_bits_[0] |= 0x00000004u; +} +inline void AddParam::clear_has_channels() { + _has_bits_[0] &= ~0x00000004u; +} +inline void AddParam::clear_channels() { + channels_ = 1; + clear_has_channels(); +} +inline ::google::protobuf::int32 AddParam::channels() const { + return channels_; +} +inline void AddParam::set_channels(::google::protobuf::int32 value) { + set_has_channels(); + channels_ = value; +} + +// ------------------------------------------------------------------- + +// SubParam + +// optional int32 height = 1; +inline bool SubParam::has_height() const { + return (_has_bits_[0] & 0x00000001u) != 0; +} +inline void SubParam::set_has_height() { + _has_bits_[0] |= 0x00000001u; +} +inline void SubParam::clear_has_height() { + _has_bits_[0] &= ~0x00000001u; +} +inline void SubParam::clear_height() { + height_ = 0; + clear_has_height(); +} +inline ::google::protobuf::int32 SubParam::height() const { + return height_; +} +inline void SubParam::set_height(::google::protobuf::int32 value) { + set_has_height(); + height_ = value; +} + +// optional int32 width = 2; +inline bool SubParam::has_width() const { + return (_has_bits_[0] & 0x00000002u) != 0; +} +inline void SubParam::set_has_width() { + _has_bits_[0] |= 0x00000002u; +} +inline void SubParam::clear_has_width() { + _has_bits_[0] &= ~0x00000002u; +} +inline void SubParam::clear_width() { + width_ = 0; + clear_has_width(); +} +inline ::google::protobuf::int32 SubParam::width() const { + return width_; +} +inline void SubParam::set_width(::google::protobuf::int32 value) { + set_has_width(); + width_ = value; +} + +// optional int32 channels = 3 [default = 1]; +inline bool SubParam::has_channels() const { + return (_has_bits_[0] & 0x00000004u) != 0; +} +inline void SubParam::set_has_channels() { + _has_bits_[0] |= 0x00000004u; +} +inline void SubParam::clear_has_channels() { + _has_bits_[0] &= ~0x00000004u; +} +inline void SubParam::clear_channels() { + channels_ = 1; + clear_has_channels(); +} +inline ::google::protobuf::int32 SubParam::channels() const { + return channels_; +} +inline void SubParam::set_channels(::google::protobuf::int32 value) { + set_has_channels(); + channels_ = value; +} + +// ------------------------------------------------------------------- + +// MulParam + +// optional int32 M = 1; +inline bool MulParam::has_m() const { + return (_has_bits_[0] & 0x00000001u) != 0; +} +inline void MulParam::set_has_m() { + _has_bits_[0] |= 0x00000001u; +} +inline void MulParam::clear_has_m() { + _has_bits_[0] &= ~0x00000001u; +} +inline void MulParam::clear_m() { + m_ = 0; + clear_has_m(); +} +inline ::google::protobuf::int32 MulParam::m() const { + return m_; +} +inline void MulParam::set_m(::google::protobuf::int32 value) { + set_has_m(); + m_ = value; +} + +// optional int32 K = 2; +inline bool MulParam::has_k() const { + return (_has_bits_[0] & 0x00000002u) != 0; +} +inline void MulParam::set_has_k() { + _has_bits_[0] |= 0x00000002u; +} +inline void MulParam::clear_has_k() { + _has_bits_[0] &= ~0x00000002u; +} +inline void MulParam::clear_k() { + k_ = 0; + clear_has_k(); +} +inline ::google::protobuf::int32 MulParam::k() const { + return k_; +} +inline void MulParam::set_k(::google::protobuf::int32 value) { + set_has_k(); + k_ = value; +} + +// optional int32 N = 3; +inline bool MulParam::has_n() const { + return (_has_bits_[0] & 0x00000004u) != 0; +} +inline void MulParam::set_has_n() { + _has_bits_[0] |= 0x00000004u; +} +inline void MulParam::clear_has_n() { + _has_bits_[0] &= ~0x00000004u; +} +inline void MulParam::clear_n() { + n_ = 0; + clear_has_n(); +} +inline ::google::protobuf::int32 MulParam::n() const { + return n_; +} +inline void MulParam::set_n(::google::protobuf::int32 value) { + set_has_n(); + n_ = value; +} + +// optional int32 channels = 4 [default = 1]; +inline bool MulParam::has_channels() const { + return (_has_bits_[0] & 0x00000008u) != 0; +} +inline void MulParam::set_has_channels() { + _has_bits_[0] |= 0x00000008u; +} +inline void MulParam::clear_has_channels() { + _has_bits_[0] &= ~0x00000008u; +} +inline void MulParam::clear_channels() { + channels_ = 1; + clear_has_channels(); +} +inline ::google::protobuf::int32 MulParam::channels() const { + return channels_; +} +inline void MulParam::set_channels(::google::protobuf::int32 value) { + set_has_channels(); + channels_ = value; +} + +// ------------------------------------------------------------------- + +// ConvParam + +// optional int32 data_h = 1; +inline bool ConvParam::has_data_h() const { + return (_has_bits_[0] & 0x00000001u) != 0; +} +inline void ConvParam::set_has_data_h() { + _has_bits_[0] |= 0x00000001u; +} +inline void ConvParam::clear_has_data_h() { + _has_bits_[0] &= ~0x00000001u; +} +inline void ConvParam::clear_data_h() { + data_h_ = 0; + clear_has_data_h(); +} +inline ::google::protobuf::int32 ConvParam::data_h() const { + return data_h_; +} +inline void ConvParam::set_data_h(::google::protobuf::int32 value) { + set_has_data_h(); + data_h_ = value; +} + +// optional int32 data_w = 2; +inline bool ConvParam::has_data_w() const { + return (_has_bits_[0] & 0x00000002u) != 0; +} +inline void ConvParam::set_has_data_w() { + _has_bits_[0] |= 0x00000002u; +} +inline void ConvParam::clear_has_data_w() { + _has_bits_[0] &= ~0x00000002u; +} +inline void ConvParam::clear_data_w() { + data_w_ = 0; + clear_has_data_w(); +} +inline ::google::protobuf::int32 ConvParam::data_w() const { + return data_w_; +} +inline void ConvParam::set_data_w(::google::protobuf::int32 value) { + set_has_data_w(); + data_w_ = value; +} + +// optional int32 filter_h = 3; +inline bool ConvParam::has_filter_h() const { + return (_has_bits_[0] & 0x00000004u) != 0; +} +inline void ConvParam::set_has_filter_h() { + _has_bits_[0] |= 0x00000004u; +} +inline void ConvParam::clear_has_filter_h() { + _has_bits_[0] &= ~0x00000004u; +} +inline void ConvParam::clear_filter_h() { + filter_h_ = 0; + clear_has_filter_h(); +} +inline ::google::protobuf::int32 ConvParam::filter_h() const { + return filter_h_; +} +inline void ConvParam::set_filter_h(::google::protobuf::int32 value) { + set_has_filter_h(); + filter_h_ = value; +} + +// optional int32 filter_w = 4; +inline bool ConvParam::has_filter_w() const { + return (_has_bits_[0] & 0x00000008u) != 0; +} +inline void ConvParam::set_has_filter_w() { + _has_bits_[0] |= 0x00000008u; +} +inline void ConvParam::clear_has_filter_w() { + _has_bits_[0] &= ~0x00000008u; +} +inline void ConvParam::clear_filter_w() { + filter_w_ = 0; + clear_has_filter_w(); +} +inline ::google::protobuf::int32 ConvParam::filter_w() const { + return filter_w_; +} +inline void ConvParam::set_filter_w(::google::protobuf::int32 value) { + set_has_filter_w(); + filter_w_ = value; +} + +// optional int32 stride_h = 5 [default = 1]; +inline bool ConvParam::has_stride_h() const { + return (_has_bits_[0] & 0x00000010u) != 0; +} +inline void ConvParam::set_has_stride_h() { + _has_bits_[0] |= 0x00000010u; +} +inline void ConvParam::clear_has_stride_h() { + _has_bits_[0] &= ~0x00000010u; +} +inline void ConvParam::clear_stride_h() { + stride_h_ = 1; + clear_has_stride_h(); +} +inline ::google::protobuf::int32 ConvParam::stride_h() const { + return stride_h_; +} +inline void ConvParam::set_stride_h(::google::protobuf::int32 value) { + set_has_stride_h(); + stride_h_ = value; } -// optional .GET.SubParam sub_param = 9; -inline bool TaskParam::has_sub_param() const { - return (_has_bits_[0] & 0x00000100u) != 0; +// optional int32 stride_w = 6 [default = 1]; +inline bool ConvParam::has_stride_w() const { + return (_has_bits_[0] & 0x00000020u) != 0; } -inline void TaskParam::set_has_sub_param() { - _has_bits_[0] |= 0x00000100u; +inline void ConvParam::set_has_stride_w() { + _has_bits_[0] |= 0x00000020u; } -inline void TaskParam::clear_has_sub_param() { - _has_bits_[0] &= ~0x00000100u; +inline void ConvParam::clear_has_stride_w() { + _has_bits_[0] &= ~0x00000020u; } -inline void TaskParam::clear_sub_param() { - if (sub_param_ != NULL) sub_param_->::GET::SubParam::Clear(); - clear_has_sub_param(); +inline void ConvParam::clear_stride_w() { + stride_w_ = 1; + clear_has_stride_w(); } -inline const ::GET::SubParam& TaskParam::sub_param() const { - return sub_param_ != NULL ? *sub_param_ : *default_instance_->sub_param_; +inline ::google::protobuf::int32 ConvParam::stride_w() const { + return stride_w_; } -inline ::GET::SubParam* TaskParam::mutable_sub_param() { - set_has_sub_param(); - if (sub_param_ == NULL) sub_param_ = new ::GET::SubParam; - return sub_param_; +inline void ConvParam::set_stride_w(::google::protobuf::int32 value) { + set_has_stride_w(); + stride_w_ = value; } -inline ::GET::SubParam* TaskParam::release_sub_param() { - clear_has_sub_param(); - ::GET::SubParam* temp = sub_param_; - sub_param_ = NULL; - return temp; + +// optional int32 pad_h = 7 [default = 0]; +inline bool ConvParam::has_pad_h() const { + return (_has_bits_[0] & 0x00000040u) != 0; } -inline void TaskParam::set_allocated_sub_param(::GET::SubParam* sub_param) { - delete sub_param_; - sub_param_ = sub_param; - if (sub_param) { - set_has_sub_param(); - } else { - clear_has_sub_param(); - } +inline void ConvParam::set_has_pad_h() { + _has_bits_[0] |= 0x00000040u; +} +inline void ConvParam::clear_has_pad_h() { + _has_bits_[0] &= ~0x00000040u; +} +inline void ConvParam::clear_pad_h() { + pad_h_ = 0; + clear_has_pad_h(); +} +inline ::google::protobuf::int32 ConvParam::pad_h() const { + return pad_h_; +} +inline void ConvParam::set_pad_h(::google::protobuf::int32 value) { + set_has_pad_h(); + pad_h_ = value; } -// optional .GET.MulParam mul_param = 10; -inline bool TaskParam::has_mul_param() const { - return (_has_bits_[0] & 0x00000200u) != 0; +// optional int32 pad_w = 8 [default = 0]; +inline bool ConvParam::has_pad_w() const { + return (_has_bits_[0] & 0x00000080u) != 0; } -inline void TaskParam::set_has_mul_param() { - _has_bits_[0] |= 0x00000200u; +inline void ConvParam::set_has_pad_w() { + _has_bits_[0] |= 0x00000080u; } -inline void TaskParam::clear_has_mul_param() { - _has_bits_[0] &= ~0x00000200u; +inline void ConvParam::clear_has_pad_w() { + _has_bits_[0] &= ~0x00000080u; } -inline void TaskParam::clear_mul_param() { - if (mul_param_ != NULL) mul_param_->::GET::MulParam::Clear(); - clear_has_mul_param(); +inline void ConvParam::clear_pad_w() { + pad_w_ = 0; + clear_has_pad_w(); } -inline const ::GET::MulParam& TaskParam::mul_param() const { - return mul_param_ != NULL ? *mul_param_ : *default_instance_->mul_param_; +inline ::google::protobuf::int32 ConvParam::pad_w() const { + return pad_w_; } -inline ::GET::MulParam* TaskParam::mutable_mul_param() { - set_has_mul_param(); - if (mul_param_ == NULL) mul_param_ = new ::GET::MulParam; - return mul_param_; +inline void ConvParam::set_pad_w(::google::protobuf::int32 value) { + set_has_pad_w(); + pad_w_ = value; } -inline ::GET::MulParam* TaskParam::release_mul_param() { - clear_has_mul_param(); - ::GET::MulParam* temp = mul_param_; - mul_param_ = NULL; - return temp; + +// optional int32 channels = 9 [default = 1]; +inline bool ConvParam::has_channels() const { + return (_has_bits_[0] & 0x00000100u) != 0; } -inline void TaskParam::set_allocated_mul_param(::GET::MulParam* mul_param) { - delete mul_param_; - mul_param_ = mul_param; - if (mul_param) { - set_has_mul_param(); - } else { - clear_has_mul_param(); - } +inline void ConvParam::set_has_channels() { + _has_bits_[0] |= 0x00000100u; +} +inline void ConvParam::clear_has_channels() { + _has_bits_[0] &= ~0x00000100u; +} +inline void ConvParam::clear_channels() { + channels_ = 1; + clear_has_channels(); +} +inline ::google::protobuf::int32 ConvParam::channels() const { + return channels_; +} +inline void ConvParam::set_channels(::google::protobuf::int32 value) { + set_has_channels(); + channels_ = value; } // ------------------------------------------------------------------- -// AddParam +// PoolParam -// optional int32 height = 1; -inline bool AddParam::has_height() const { +// optional int32 data_h = 1; +inline bool PoolParam::has_data_h() const { return (_has_bits_[0] & 0x00000001u) != 0; } -inline void AddParam::set_has_height() { +inline void PoolParam::set_has_data_h() { _has_bits_[0] |= 0x00000001u; } -inline void AddParam::clear_has_height() { +inline void PoolParam::clear_has_data_h() { _has_bits_[0] &= ~0x00000001u; } -inline void AddParam::clear_height() { - height_ = 0; - clear_has_height(); +inline void PoolParam::clear_data_h() { + data_h_ = 0; + clear_has_data_h(); } -inline ::google::protobuf::int32 AddParam::height() const { - return height_; +inline ::google::protobuf::int32 PoolParam::data_h() const { + return data_h_; } -inline void AddParam::set_height(::google::protobuf::int32 value) { - set_has_height(); - height_ = value; +inline void PoolParam::set_data_h(::google::protobuf::int32 value) { + set_has_data_h(); + data_h_ = value; } -// optional int32 width = 2; -inline bool AddParam::has_width() const { +// optional int32 data_w = 2; +inline bool PoolParam::has_data_w() const { return (_has_bits_[0] & 0x00000002u) != 0; } -inline void AddParam::set_has_width() { +inline void PoolParam::set_has_data_w() { _has_bits_[0] |= 0x00000002u; } -inline void AddParam::clear_has_width() { +inline void PoolParam::clear_has_data_w() { _has_bits_[0] &= ~0x00000002u; } -inline void AddParam::clear_width() { - width_ = 0; - clear_has_width(); +inline void PoolParam::clear_data_w() { + data_w_ = 0; + clear_has_data_w(); } -inline ::google::protobuf::int32 AddParam::width() const { - return width_; +inline ::google::protobuf::int32 PoolParam::data_w() const { + return data_w_; } -inline void AddParam::set_width(::google::protobuf::int32 value) { - set_has_width(); - width_ = value; +inline void PoolParam::set_data_w(::google::protobuf::int32 value) { + set_has_data_w(); + data_w_ = value; } -// optional int32 channels = 3 [default = 1]; -inline bool AddParam::has_channels() const { +// optional int32 kernel_h = 3; +inline bool PoolParam::has_kernel_h() const { return (_has_bits_[0] & 0x00000004u) != 0; } -inline void AddParam::set_has_channels() { +inline void PoolParam::set_has_kernel_h() { _has_bits_[0] |= 0x00000004u; } -inline void AddParam::clear_has_channels() { +inline void PoolParam::clear_has_kernel_h() { _has_bits_[0] &= ~0x00000004u; } -inline void AddParam::clear_channels() { +inline void PoolParam::clear_kernel_h() { + kernel_h_ = 0; + clear_has_kernel_h(); +} +inline ::google::protobuf::int32 PoolParam::kernel_h() const { + return kernel_h_; +} +inline void PoolParam::set_kernel_h(::google::protobuf::int32 value) { + set_has_kernel_h(); + kernel_h_ = value; +} + +// optional int32 kernel_w = 4; +inline bool PoolParam::has_kernel_w() const { + return (_has_bits_[0] & 0x00000008u) != 0; +} +inline void PoolParam::set_has_kernel_w() { + _has_bits_[0] |= 0x00000008u; +} +inline void PoolParam::clear_has_kernel_w() { + _has_bits_[0] &= ~0x00000008u; +} +inline void PoolParam::clear_kernel_w() { + kernel_w_ = 0; + clear_has_kernel_w(); +} +inline ::google::protobuf::int32 PoolParam::kernel_w() const { + return kernel_w_; +} +inline void PoolParam::set_kernel_w(::google::protobuf::int32 value) { + set_has_kernel_w(); + kernel_w_ = value; +} + +// optional int32 stride_h = 5 [default = 1]; +inline bool PoolParam::has_stride_h() const { + return (_has_bits_[0] & 0x00000010u) != 0; +} +inline void PoolParam::set_has_stride_h() { + _has_bits_[0] |= 0x00000010u; +} +inline void PoolParam::clear_has_stride_h() { + _has_bits_[0] &= ~0x00000010u; +} +inline void PoolParam::clear_stride_h() { + stride_h_ = 1; + clear_has_stride_h(); +} +inline ::google::protobuf::int32 PoolParam::stride_h() const { + return stride_h_; +} +inline void PoolParam::set_stride_h(::google::protobuf::int32 value) { + set_has_stride_h(); + stride_h_ = value; +} + +// optional int32 stride_w = 6 [default = 1]; +inline bool PoolParam::has_stride_w() const { + return (_has_bits_[0] & 0x00000020u) != 0; +} +inline void PoolParam::set_has_stride_w() { + _has_bits_[0] |= 0x00000020u; +} +inline void PoolParam::clear_has_stride_w() { + _has_bits_[0] &= ~0x00000020u; +} +inline void PoolParam::clear_stride_w() { + stride_w_ = 1; + clear_has_stride_w(); +} +inline ::google::protobuf::int32 PoolParam::stride_w() const { + return stride_w_; +} +inline void PoolParam::set_stride_w(::google::protobuf::int32 value) { + set_has_stride_w(); + stride_w_ = value; +} + +// optional int32 pad_h = 7 [default = 0]; +inline bool PoolParam::has_pad_h() const { + return (_has_bits_[0] & 0x00000040u) != 0; +} +inline void PoolParam::set_has_pad_h() { + _has_bits_[0] |= 0x00000040u; +} +inline void PoolParam::clear_has_pad_h() { + _has_bits_[0] &= ~0x00000040u; +} +inline void PoolParam::clear_pad_h() { + pad_h_ = 0; + clear_has_pad_h(); +} +inline ::google::protobuf::int32 PoolParam::pad_h() const { + return pad_h_; +} +inline void PoolParam::set_pad_h(::google::protobuf::int32 value) { + set_has_pad_h(); + pad_h_ = value; +} + +// optional int32 pad_w = 8 [default = 0]; +inline bool PoolParam::has_pad_w() const { + return (_has_bits_[0] & 0x00000080u) != 0; +} +inline void PoolParam::set_has_pad_w() { + _has_bits_[0] |= 0x00000080u; +} +inline void PoolParam::clear_has_pad_w() { + _has_bits_[0] &= ~0x00000080u; +} +inline void PoolParam::clear_pad_w() { + pad_w_ = 0; + clear_has_pad_w(); +} +inline ::google::protobuf::int32 PoolParam::pad_w() const { + return pad_w_; +} +inline void PoolParam::set_pad_w(::google::protobuf::int32 value) { + set_has_pad_w(); + pad_w_ = value; +} + +// optional int32 channels = 9 [default = 1]; +inline bool PoolParam::has_channels() const { + return (_has_bits_[0] & 0x00000100u) != 0; +} +inline void PoolParam::set_has_channels() { + _has_bits_[0] |= 0x00000100u; +} +inline void PoolParam::clear_has_channels() { + _has_bits_[0] &= ~0x00000100u; +} +inline void PoolParam::clear_channels() { channels_ = 1; clear_has_channels(); } -inline ::google::protobuf::int32 AddParam::channels() const { +inline ::google::protobuf::int32 PoolParam::channels() const { return channels_; } -inline void AddParam::set_channels(::google::protobuf::int32 value) { +inline void PoolParam::set_channels(::google::protobuf::int32 value) { set_has_channels(); channels_ = value; } // ------------------------------------------------------------------- -// SubParam +// LRNParam -// optional int32 height = 1; -inline bool SubParam::has_height() const { +// optional int32 channels = 1 [default = 0]; +inline bool LRNParam::has_channels() const { return (_has_bits_[0] & 0x00000001u) != 0; } -inline void SubParam::set_has_height() { +inline void LRNParam::set_has_channels() { _has_bits_[0] |= 0x00000001u; } -inline void SubParam::clear_has_height() { +inline void LRNParam::clear_has_channels() { _has_bits_[0] &= ~0x00000001u; } -inline void SubParam::clear_height() { - height_ = 0; +inline void LRNParam::clear_channels() { + channels_ = 0; + clear_has_channels(); +} +inline ::google::protobuf::int32 LRNParam::channels() const { + return channels_; +} +inline void LRNParam::set_channels(::google::protobuf::int32 value) { + set_has_channels(); + channels_ = value; +} + +// ------------------------------------------------------------------- + +// ReLUParam + +// optional int32 height = 1 [default = 1]; +inline bool ReLUParam::has_height() const { + return (_has_bits_[0] & 0x00000001u) != 0; +} +inline void ReLUParam::set_has_height() { + _has_bits_[0] |= 0x00000001u; +} +inline void ReLUParam::clear_has_height() { + _has_bits_[0] &= ~0x00000001u; +} +inline void ReLUParam::clear_height() { + height_ = 1; clear_has_height(); } -inline ::google::protobuf::int32 SubParam::height() const { +inline ::google::protobuf::int32 ReLUParam::height() const { return height_; } -inline void SubParam::set_height(::google::protobuf::int32 value) { +inline void ReLUParam::set_height(::google::protobuf::int32 value) { set_has_height(); height_ = value; } // optional int32 width = 2; -inline bool SubParam::has_width() const { +inline bool ReLUParam::has_width() const { return (_has_bits_[0] & 0x00000002u) != 0; } -inline void SubParam::set_has_width() { +inline void ReLUParam::set_has_width() { _has_bits_[0] |= 0x00000002u; } -inline void SubParam::clear_has_width() { +inline void ReLUParam::clear_has_width() { _has_bits_[0] &= ~0x00000002u; } -inline void SubParam::clear_width() { +inline void ReLUParam::clear_width() { width_ = 0; clear_has_width(); } -inline ::google::protobuf::int32 SubParam::width() const { +inline ::google::protobuf::int32 ReLUParam::width() const { return width_; } -inline void SubParam::set_width(::google::protobuf::int32 value) { +inline void ReLUParam::set_width(::google::protobuf::int32 value) { set_has_width(); width_ = value; } // optional int32 channels = 3 [default = 1]; -inline bool SubParam::has_channels() const { +inline bool ReLUParam::has_channels() const { return (_has_bits_[0] & 0x00000004u) != 0; } -inline void SubParam::set_has_channels() { +inline void ReLUParam::set_has_channels() { _has_bits_[0] |= 0x00000004u; } -inline void SubParam::clear_has_channels() { +inline void ReLUParam::clear_has_channels() { _has_bits_[0] &= ~0x00000004u; } -inline void SubParam::clear_channels() { +inline void ReLUParam::clear_channels() { channels_ = 1; clear_has_channels(); } -inline ::google::protobuf::int32 SubParam::channels() const { +inline ::google::protobuf::int32 ReLUParam::channels() const { return channels_; } -inline void SubParam::set_channels(::google::protobuf::int32 value) { +inline void ReLUParam::set_channels(::google::protobuf::int32 value) { set_has_channels(); channels_ = value; } // ------------------------------------------------------------------- -// MulParam +// SigmoidParam -// optional int32 height_A = 1; -inline bool MulParam::has_height_a() const { +// optional int32 height = 1 [default = 1]; +inline bool SigmoidParam::has_height() const { return (_has_bits_[0] & 0x00000001u) != 0; } -inline void MulParam::set_has_height_a() { +inline void SigmoidParam::set_has_height() { _has_bits_[0] |= 0x00000001u; } -inline void MulParam::clear_has_height_a() { +inline void SigmoidParam::clear_has_height() { _has_bits_[0] &= ~0x00000001u; } -inline void MulParam::clear_height_a() { - height_a_ = 0; - clear_has_height_a(); +inline void SigmoidParam::clear_height() { + height_ = 1; + clear_has_height(); } -inline ::google::protobuf::int32 MulParam::height_a() const { - return height_a_; +inline ::google::protobuf::int32 SigmoidParam::height() const { + return height_; } -inline void MulParam::set_height_a(::google::protobuf::int32 value) { - set_has_height_a(); - height_a_ = value; +inline void SigmoidParam::set_height(::google::protobuf::int32 value) { + set_has_height(); + height_ = value; } -// optional int32 width_A = 2; -inline bool MulParam::has_width_a() const { +// optional int32 width = 2; +inline bool SigmoidParam::has_width() const { return (_has_bits_[0] & 0x00000002u) != 0; } -inline void MulParam::set_has_width_a() { +inline void SigmoidParam::set_has_width() { _has_bits_[0] |= 0x00000002u; } -inline void MulParam::clear_has_width_a() { +inline void SigmoidParam::clear_has_width() { _has_bits_[0] &= ~0x00000002u; } -inline void MulParam::clear_width_a() { - width_a_ = 0; - clear_has_width_a(); +inline void SigmoidParam::clear_width() { + width_ = 0; + clear_has_width(); } -inline ::google::protobuf::int32 MulParam::width_a() const { - return width_a_; +inline ::google::protobuf::int32 SigmoidParam::width() const { + return width_; } -inline void MulParam::set_width_a(::google::protobuf::int32 value) { - set_has_width_a(); - width_a_ = value; +inline void SigmoidParam::set_width(::google::protobuf::int32 value) { + set_has_width(); + width_ = value; } -// optional int32 height_B = 3; -inline bool MulParam::has_height_b() const { +// optional int32 channels = 3 [default = 1]; +inline bool SigmoidParam::has_channels() const { return (_has_bits_[0] & 0x00000004u) != 0; } -inline void MulParam::set_has_height_b() { +inline void SigmoidParam::set_has_channels() { _has_bits_[0] |= 0x00000004u; } -inline void MulParam::clear_has_height_b() { +inline void SigmoidParam::clear_has_channels() { _has_bits_[0] &= ~0x00000004u; } -inline void MulParam::clear_height_b() { - height_b_ = 0; - clear_has_height_b(); -} -inline ::google::protobuf::int32 MulParam::height_b() const { - return height_b_; -} -inline void MulParam::set_height_b(::google::protobuf::int32 value) { - set_has_height_b(); - height_b_ = value; -} - -// optional int32 width_B = 4; -inline bool MulParam::has_width_b() const { - return (_has_bits_[0] & 0x00000008u) != 0; -} -inline void MulParam::set_has_width_b() { - _has_bits_[0] |= 0x00000008u; -} -inline void MulParam::clear_has_width_b() { - _has_bits_[0] &= ~0x00000008u; -} -inline void MulParam::clear_width_b() { - width_b_ = 0; - clear_has_width_b(); -} -inline ::google::protobuf::int32 MulParam::width_b() const { - return width_b_; -} -inline void MulParam::set_width_b(::google::protobuf::int32 value) { - set_has_width_b(); - width_b_ = value; -} - -// optional int32 channels = 5 [default = 1]; -inline bool MulParam::has_channels() const { - return (_has_bits_[0] & 0x00000010u) != 0; -} -inline void MulParam::set_has_channels() { - _has_bits_[0] |= 0x00000010u; -} -inline void MulParam::clear_has_channels() { - _has_bits_[0] &= ~0x00000010u; -} -inline void MulParam::clear_channels() { +inline void SigmoidParam::clear_channels() { channels_ = 1; clear_has_channels(); } -inline ::google::protobuf::int32 MulParam::channels() const { +inline ::google::protobuf::int32 SigmoidParam::channels() const { return channels_; } -inline void MulParam::set_channels(::google::protobuf::int32 value) { +inline void SigmoidParam::set_channels(::google::protobuf::int32 value) { set_has_channels(); channels_ = value; } diff --git a/include/proto/GET.proto b/include/proto/GET.proto index 94d94fc..837093c 100644 --- a/include/proto/GET.proto +++ b/include/proto/GET.proto @@ -22,6 +22,9 @@ message TaskParam MULTI = 2; CONVOLUTION = 3; POOLING = 4; + LRN = 5; + RELU = 6; + Sigmoid = 7; } optional TaskType type = 7; @@ -29,8 +32,11 @@ message TaskParam optional AddParam add_param = 8; optional SubParam sub_param = 9; optional MulParam mul_param = 10; - // optional ConvParam conv_param_ = 11; - // optional PoolParam pool_param_ = 12; + optional ConvParam conv_param = 11; + optional PoolParam pool_param = 12; + optional LRNParam lrn_param = 13; + optional ReLUParam relu_param =14; + optional SigmoidParam sigmoid_param = 15; } message AddParam @@ -49,9 +55,54 @@ message SubParam message MulParam { - optional int32 height_A = 1; - optional int32 width_A = 2; - optional int32 height_B = 3; - optional int32 width_B = 4; - optional int32 channels = 5 [default = 1]; + optional int32 M = 1; + optional int32 K = 2; + optional int32 N = 3; + optional int32 channels = 4 [default = 1]; +} + +message ConvParam +{ + optional int32 data_h = 1; + optional int32 data_w = 2; + optional int32 filter_h = 3; + optional int32 filter_w = 4; + optional int32 stride_h = 5 [default = 1]; + optional int32 stride_w = 6 [default = 1]; + optional int32 pad_h = 7 [default = 0]; + optional int32 pad_w = 8 [default = 0]; + optional int32 channels = 9 [default = 1]; +} + +message PoolParam +{ + optional int32 data_h = 1; + optional int32 data_w = 2; + optional int32 kernel_h = 3; + optional int32 kernel_w = 4; + optional int32 stride_h = 5 [default = 1]; + optional int32 stride_w = 6 [default = 1]; + optional int32 pad_h = 7 [default = 0]; + optional int32 pad_w = 8 [default = 0]; + optional int32 channels = 9 [default = 1]; +} + +message LRNParam +{ + optional int32 channels = 1 [default = 0]; +} + + +message ReLUParam +{ + optional int32 height = 1 [default = 1]; + optional int32 width = 2; + optional int32 channels = 3 [default = 1]; +} + +message SigmoidParam +{ + optional int32 height = 1 [default = 1]; + optional int32 width = 2; + optional int32 channels = 3 [default = 1]; } diff --git a/src/get/AddStreamTask.cpp b/src/get/AddStreamTask.cpp index c1264bf..9e52804 100644 --- a/src/get/AddStreamTask.cpp +++ b/src/get/AddStreamTask.cpp @@ -7,8 +7,9 @@ #include using namespace std; -#include"AddStreamTask.h" +#include"StreamTask.h" #include"opencl_z.h" +#include #define GLOBAL_WORKGROUP_HEIGHT 4096 #define GLOBAL_WORKGROUP_WIDTH 4096 @@ -104,7 +105,17 @@ AddStreamTask::Compute() { cl_int status; - char* program_source = cl_readSource("/home/zy/GET/src/get/add_task.cl"); + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_double.cl"); + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); cl_pStatus(status, "clCreateProgramWithSource"); diff --git a/src/get/AddTask.cpp b/src/get/AddTask.cpp index a0c0759..d6fe862 100644 --- a/src/get/AddTask.cpp +++ b/src/get/AddTask.cpp @@ -7,9 +7,10 @@ #include using namespace std; -#include"AddTask.h" +#include"BaseTask.h" #include #include"opencl_z.h" +#include #define GLOBAL_WORKGROUP_HEIGHT 4096 #define GLOBAL_WORKGROUP_WIDTH 4096 @@ -55,7 +56,18 @@ AddTask::Compute() { //read program source and create kernel cl_int status; - char* program_source = cl_readSource("/home/zy/GET/src/get/add_task.cl"); + + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/add_task_double.cl"); + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); cl_pStatus(status, "clCreateProgramWithSource"); diff --git a/src/get/AddTaskDispatcher.cpp b/src/get/AddTaskDispatcher.cpp index 1303768..860d791 100644 --- a/src/get/AddTaskDispatcher.cpp +++ b/src/get/AddTaskDispatcher.cpp @@ -7,7 +7,7 @@ #include using namespace std; -#include"AddTaskDispatcher.h" +#include"TaskDispatcher.h" template void @@ -127,7 +127,7 @@ AddTaskDispatcher::TaskDispatch() this->ordinary_tasks_.push_back(temp_task); cout<<"create " << i << " addtask" < File Name: ConvStreamTask.cpp + > Author: + > Mail: + > Created Time: Mon 20 Apr 2015 05:14:02 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + + + +template +ConvStreamTask::~ConvStreamTask() +{ + +} + +template +void +ConvStreamTask :: PreCompute() +{ + //reshape the block for addstream + + int stride = stride_h_; + int width_filter = filter_w_; + int input_height = data_h_; + int input_width = data_w_; + int output_height = output_h_; + int output_width = output_w_; + long total_device_size = this->device_->GlobalMemory(); + DataBlob *input_matrix,*filter,*output_matrix; + Dtype *host_data_inputmatrix,*host_data_outputmatrix; + host_data_inputmatrix = (Dtype *)this->datas_[0].host_data(); + host_data_outputmatrix = (Dtype *)this->results_[0].host_data(); + int height_per_inputmatrix = (total_device_size/sizeof(Dtype) - width_filter * width_filter)/(input_width + output_width); + int height_per_outputmatrix; + int heightgoback = height_per_inputmatrix - (height_per_inputmatrix - width_filter + stride)/stride * stride; + int blocks_num = 0; + while(input_height > width_filter) + { + //filter = data_input_[1]; + DataBlob filter((void *)this->datas_[1].host_data(),1,1,this->datas_[1].height(),this->datas_[1].width()); + height_per_inputmatrix = height_per_inputmatrix > input_height?input_height:height_per_inputmatrix; + height_per_outputmatrix = (height_per_inputmatrix - width_filter + stride)/stride; + DataBlob input_matrix((void *)host_data_inputmatrix,1,1,height_per_inputmatrix,input_width); + DataBlob output_matrix((void *)host_data_outputmatrix,1,1,height_per_outputmatrix,output_width); + host_data_inputmatrix += (height_per_inputmatrix - heightgoback)*input_width; + host_data_outputmatrix += height_per_outputmatrix* output_width; + input_height = input_height - height_per_inputmatrix + heightgoback; + this->inner_datas_.push_back(input_matrix); + this->inner_datas_.push_back(filter); + this->inner_results_.push_back(output_matrix); + blocks_num++; + } + this->blocks_num_ = blocks_num; + cl_int datasize_i = this->inner_datas_[0].height()*input_width*sizeof(Dtype); + cl_int datasize_f = filter_h_*filter_w_*sizeof(Dtype); + cl_int datasize_o = this->inner_results_[0].height()*this->inner_results_[0].width()*sizeof(Dtype); + + + //datasize = this->device_->GlobalMemory() / 3; + + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize_i , NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize_f, NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, datasize_o, NULL, &status); + cl_pStatus(status, "clCreateBuffer 3"); + this->device_buffers_.push_back(temp_buffer); + + /* + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 4"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 5"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 6"); + this->device_buffers_.push_back(temp_buffer); +*/ + +} +template +void +ConvStreamTask::Compute() +{ + cl_int status; + + char* program_source = NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + cl_pStatus(status, "clBuildProgram"); + + + int padding_pixels = (int)(filter_w_ / 2) * 2; + + unsigned int global_itemh = RoundUp(this->inner_datas_[0].height() - padding_pixels , LOCAL_WORKGROUP_HEIGHT); + unsigned int global_itemw = RoundUp(this->inner_datas_[0].width() - padding_pixels, LOCAL_WORKGROUP_WIDTH); + size_t global_worksize[2] = {global_itemh , global_itemw}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + cl_uint local_height = LOCAL_WORKGROUP_HEIGHT + padding_pixels; + cl_uint local_width = LOCAL_WORKGROUP_WIDTH + padding_pixels; + + size_t local_mem = local_height * local_width * sizeof(Dtype); + + int blocks_num = this->blocks_num_; + cl_kernel kernel[blocks_num]; + + int i; + cout <<"blocks_num : "<inner_datas_[0].height(); + int in_width_each = this->inner_datas_[0].width(); + int out_height_each = this->inner_results_[0].height(); + int out_width_each = this->inner_results_[0].width(); + for(i = 0; i < blocks_num-1 ; i++) + { + kernel[i] = clCreateKernel(program, "convMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &out_height_each); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_int), &out_width_each); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_int), &in_height_each); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_int), &in_width_each); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel[i], 7, sizeof(cl_uint), &filter_w_); + cl_pStatus(status, "clSetKernelArg_7"); + status = clSetKernelArg(kernel[i], 8, local_mem, NULL); + cl_pStatus(status, "clSetKernelArg_8"); + status = clSetKernelArg(kernel[i], 9, sizeof(cl_uint), &local_height); + cl_pStatus(status, "clSetKernelArg_9"); + status = clSetKernelArg(kernel[i], 10, sizeof(cl_uint), &local_width); + cl_pStatus(status, "clSetKernelArg_10"); + + } + + int in_height_left = this->inner_datas_[2*i].height(); + int in_width = this->inner_datas_[2*i].width(); + int out_height = this->inner_results_[i].height(); + int out_width = this->inner_results_[i].width(); + + kernel[i] = clCreateKernel(program, "convMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &out_height); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_int), &out_width); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_int), &in_height_left); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_int), &in_width); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel[i], 7, sizeof(cl_uint), &filter_w_); + cl_pStatus(status, "clSetKernelArg_7"); + status = clSetKernelArg(kernel[i], 8, local_mem, NULL); + cl_pStatus(status, "clSetKernelArg_8"); + status = clSetKernelArg(kernel[i], 9, sizeof(cl_uint), &local_height); + cl_pStatus(status, "clSetKernelArg_9"); + status = clSetKernelArg(kernel[i], 10, sizeof(cl_uint), &local_width); + cl_pStatus(status, "clSetKernelArg_10"); + + + for(int i = 0; i < blocks_num ; i+=2) + { + //cout<<"address : "<inner_datas_[2*i].host_data()<device_->DeviceCmdQueA(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[2*i].height()*data_w_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_FALSE, 0, this->inner_datas_[2*i+1].height()*filter_w_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel[i] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[2*(i+1)].height()*data_w_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*(i+1)].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[1], CL_FALSE, 0, this->inner_datas_[2*(i+1)+1].height()*filter_w_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*(i+1)+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueB(), kernel[i+1] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + } + + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[2], CL_FALSE, 0, this->inner_results_[i].height()*output_w_*sizeof(Dtype), (Dtype*)(this->inner_results_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[2], CL_FALSE, 0, this->inner_results_[i+1].height()*output_w_*sizeof(Dtype), (Dtype*)(this->inner_results_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + } + } + + + clFinish(this->device_->DeviceCmdQueA()); + clFinish(this->device_->DeviceCmdQueB()); + + //clean + for(int i = 0; i < blocks_num; i++) + { + clReleaseKernel(kernel[i]); + } + clReleaseProgram(program); +} +template +void +ConvStreamTask::PostCompute() +{ +} + + +template class ConvStreamTask; +template class ConvStreamTask; +template class ConvStreamTask; +template class ConvStreamTask; diff --git a/src/get/ConvTask.cpp b/src/get/ConvTask.cpp index e8b9387..61201f6 100644 --- a/src/get/ConvTask.cpp +++ b/src/get/ConvTask.cpp @@ -1,10 +1,151 @@ /************************************************************************* - > File Name: ConvTask.cpp - > Author: - > Mail: - > Created Time: Thu 23 Apr 2015 09:18:35 AM EDT + > File Name: ConvTask.cpp + > Author: crows + > Mail: 136211494@qq.com + > Created Time: Tue 24 Mar 2015 02:55:48 AM EDT ************************************************************************/ #include using namespace std; +#include"BaseTask.h" +#include +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 4 +#define LOCAL_WORKGROUP_WIDTH 4 + + +template +void +ConvTask::PreCompute() +{ + //set device buffer + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, data_h_*data_w_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + + this->device_buffers_.push_back(temp_buffer); + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, filter_h_*filter_w_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + + this->device_buffers_.push_back(temp_buffer); + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, output_h_*output_w_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 3"); + + this->device_buffers_.push_back(temp_buffer); + + //copy data from host to device + cout<<"A address : " <datas_[0].host_data() <datas_[1].host_data() <device_->DeviceCmdQueA(),this->device_buffers_[0], CL_TRUE, 0, data_h_*data_w_*sizeof(Dtype), (Dtype*)(this->datas_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_TRUE, 0, filter_h_*filter_w_*sizeof(Dtype), (Dtype*)(this->datas_[1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); +} + +template +void +ConvTask::Compute() +{ + //read program source and create kernel + cl_int status; + + char * program_source = NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/conv_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + if (status != CL_SUCCESS) + { + char buildLog[16384]; + clGetProgramBuildInfo(program, temp_device, CL_PROGRAM_BUILD_LOG, sizeof(buildLog), buildLog, NULL); + cout << "Error in Kernel : "<device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &output_h_); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &output_w_); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel, 5, sizeof(cl_int), &data_h_); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel, 6, sizeof(cl_int), &data_w_); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel, 7, sizeof(cl_uint), &filter_w_); + cl_pStatus(status, "clSetKernelArg_7"); + status = clSetKernelArg(kernel, 8, local_mem, NULL); + cl_pStatus(status, "clSetKernelArg_8"); + status = clSetKernelArg(kernel, 9, sizeof(cl_uint), &local_height); + cl_pStatus(status, "clSetKernelArg_9"); + status = clSetKernelArg(kernel, 10, sizeof(cl_uint), &local_width); + cl_pStatus(status, "clSetKernelArg_10"); + + //start computing + + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + cout<<"task start on GPU"<device_->DeviceCmdQueA()); +} + +template +void +ConvTask::PostCompute() +{ + //copy memory from device to host + cl_int status; + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[2], CL_TRUE, 0, output_h_*output_w_*sizeof(Dtype), (Dtype*)(this->results_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + clFinish(this->device_->DeviceCmdQueA()); + // clReleaseMemObject(this->device_buffers_[0]); + // clReleaseMemObject(this->device_buffers_[1]); + // clReleaseMemObject(this->device_buffers_[2]); +} + +template class ConvTask; +template class ConvTask; +template class ConvTask; +template class ConvTask; diff --git a/src/get/ConvTaskDispatcher.cpp b/src/get/ConvTaskDispatcher.cpp new file mode 100644 index 0000000..cae1e0b --- /dev/null +++ b/src/get/ConvTaskDispatcher.cpp @@ -0,0 +1,227 @@ +/************************************************************************* + > File Name: ConvTaskDispatcher.cpp + > Author: + > Mail: + > Created Time: Mon 13 Apr 2015 03:43:08 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"TaskDispatcher.h" + +template +void +ConvTaskDispatcher::PreTaskDispatch() +{ + //create input blobs and output blobs + GET::TaskParam_DataPosition source_pos = this->task_param_.source_pos(); + int input_blob_num = 0; + if (source_pos == 1) + input_blob_num = this->task_param_.sourcem_size(); + else if (source_pos == 0) + input_blob_num = this->task_param_.sourcef_size(); + + cout <<"input_blob_num = "<task_param_.result_pos(); + int output_blob_num = 0; + if (result_pos == 1) + output_blob_num = this->task_param_.resultm_size(); + else if (result_pos == 0) + output_blob_num = this->task_param_.resultf_size(); + + DataBlob blob; + //prepare A + blob.ReshapeLike(1, 1, conv_param_.data_h(), conv_param_.data_w()); + this->datablobs_input_.push_back(blob); + //prepare B + blob.ReshapeLike(1, 1, conv_param_.filter_h(), conv_param_.filter_w()); + this->datablobs_input_.push_back(blob); + //prepare input + this->PrepareBlob(input_blob_num, source_pos,(this->datablobs_input_), SOURCE); + //prepare output + unsigned int output_height = (conv_param_.data_h() - conv_param_.filter_h()) / conv_param_.stride_h() + 1; + unsigned int output_width = (conv_param_.data_w() - conv_param_.filter_w()) / conv_param_.stride_w() + 1; + blob.ReshapeLike(1, 1, output_height, output_width); + this->datablobs_output_.push_back(blob); + this->PrepareBlob(output_blob_num, result_pos,(this->datablobs_output_), RESULT); + +} + +template +void +ConvTaskDispatcher::TaskDispatch() +{ + //get device : mem_used + long mem_used = 0; + for(int i = 0; i< this->datablobs_input_.size(); i++) + mem_used += this->datablobs_input_[i].count() * sizeof(Dtype); + for(int i = 0; idatablobs_output_.size(); i++) + mem_used += this->datablobs_output_[i].count() * sizeof(Dtype); + this->mem_used_ =mem_used; + + bool device_enough = false; + while (!device_enough) + { + BaseDevice* device; + if ((device = this->device_manager_->GetAvailableDevice()) != NULL) + { + this->devices_needed_.push_back(device); + if(mem_used < (device->GlobalMemory())) + { + device_enough = true; + } + else + { + mem_used -= device->GlobalMemory(); + } + } + else + { + break; + } + } + + cout << "Need Devices = "<< this->devices_needed_.size() <devices_needed_.size()); + //compute the weight for load banlancing + unsigned int offset = 0; + int denum = this->devices_needed_.size(); + unsigned long sum_gl_mem = 0; + float left_part =0; + float weight[denum]; + + for(int i = 0; i devices_needed_[i]->GlobalMemory(); + } + + for(int i =0; idevices_needed_[i]->GlobalMemory())/((float)sum_gl_mem); + left_part += weight[i]; + } + weight[denum-1] = 1 - left_part; + + unsigned int height_current = 0; + int heightgoback = 0; + int width_filter = conv_param_.filter_w(); + int stride = conv_param_.stride_h(); + int i= 0; + int height_each = conv_param_.data_h() * weight[0]; + int height_out_each = 0; + int input_width = conv_param_.data_w(); + int output_width = (input_width - width_filter)/stride + 1; + Dtype* input_data = (Dtype*)this->datablobs_input_[0].host_data(); + Dtype* output_data = (Dtype*)this->datablobs_output_[0].host_data(); + for( i = 0; i < this->devices_needed_.size()-1;i++) + { + + heightgoback = height_each - (height_each - width_filter + stride)/stride * stride; + height_current += (height_each - heightgoback); + ConvTask* temp_task = new ConvTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, conv_param_.data_w(), conv_param_.filter_h(), conv_param_.filter_w(), conv_param_.stride_h(), conv_param_.stride_w(), conv_param_.pad_h(), conv_param_.pad_w()); + //create inner + //prepare in + DataBlob inner_i((void*)input_data, 1, 1 , height_each, input_width); + temp_task->AddtoDatas(inner_i); + //prepare filter + DataBlob filter((void *)this->datablobs_input_[1].host_data(),1,1,conv_param_.filter_h(), conv_param_.filter_w()); + temp_task->AddtoDatas(filter); + //create inner_results + //prepare out + height_out_each = (height_each - width_filter + stride)/stride; + DataBlob inner_o((void*)output_data,1, 1, height_out_each, output_width); + temp_task->AddtoResults(inner_o); + + this->ordinary_tasks_.push_back(temp_task); + cout<<"create " << i << " convtask" <* temp_task = new ConvTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, conv_param_.data_w(), conv_param_.filter_h(), conv_param_.filter_w(), conv_param_.stride_h(), conv_param_.stride_w(), conv_param_.pad_h(), conv_param_.pad_w()); + //create inner + //prepare in + DataBlob inner_i((void*)input_data, 1, 1 , height_each, input_width); + temp_task->AddtoDatas(inner_i); + //prepare filter + DataBlob filter((void *)this->datablobs_input_[1].host_data(),1,1,conv_param_.filter_h(), conv_param_.filter_w()); + temp_task->AddtoDatas(filter); + //create inner_results + //prepare out + height_out_each = (height_each - width_filter + stride)/stride; + DataBlob inner_o((void*)output_data,1, 1, height_out_each, output_width); + temp_task->AddtoResults(inner_o); + + //attach the blob to the task + + this->ordinary_tasks_.push_back(temp_task); + + this->process_type_ = ORDINARY; + } + else + { + // stream task + //decide to compute on one or more devices + + //compute on the device with the biggest global memory + unsigned long max_glmem = 0; + BaseDevice* stream_device = NULL ; + for(int i = 0; i < this->devices_needed_.size(); i++) + { + if(max_glmem < this->devices_needed_[i]->GlobalMemory()) + { + stream_device = this->devices_needed_[i]; + max_glmem = this->devices_needed_[i]->GlobalMemory(); + } + } + cout<< "stream device determined"<::iterator l_it = this->devices_needed_.begin(); + + while(l_it != this->devices_needed_.end()) + { + if((*l_it) != stream_device) + { + this->device_manager_->FreeDevice((*l_it)->DeviceIndex()); + l_it = this->devices_needed_.erase(l_it); + } + else + { + l_it++; + } + } + + ConvStreamTask* temp_adst = new ConvStreamTask(stream_device); + temp_adst->SetParams(1, conv_param_.data_h(), conv_param_.data_w(), conv_param_.filter_h(), conv_param_.filter_w(), conv_param_.stride_h(), conv_param_.stride_w(), conv_param_.pad_h(), conv_param_.pad_w()); + temp_adst->AddtoDatas(this->datablobs_input_[0]); + temp_adst->AddtoDatas(this->datablobs_input_[1]); + temp_adst->AddtoResults(this->datablobs_output_[0]); + this->stream_tasks_.push_back(temp_adst); + + cout<< "create stream tasks" <process_type_ =STREAM; + + + } + +} + +template class ConvTaskDispatcher; +template class ConvTaskDispatcher; +template class ConvTaskDispatcher; +template class ConvTaskDispatcher; + diff --git a/src/get/DataBlob.cpp b/src/get/DataBlob.cpp index 2bc4ca0..2cc9579 100644 --- a/src/get/DataBlob.cpp +++ b/src/get/DataBlob.cpp @@ -19,11 +19,12 @@ DataBlob::Reshape_HOST(const int n, const int c, const int h, const int w height_ = h; width_ = w; count_ = num_ * channels_ * height_ * width_; - if (count_ > capacity_ ) +/* if (count_ > capacity_ ) { capacity_ = count_; host_data_ptr_ =(Dtype *) malloc(capacity_ * sizeof(Dtype)); } +*/ } template @@ -63,6 +64,16 @@ DataBlob::Reshape_DEVICE(const int n, const int c, const int h, const int } */ +template +DataBlob::DataBlob(void* ptr, const int n, const int c , const int h, const int w) +{ + num_ = n; + channels_ = c; + height_ = h; + width_ = w; + host_data_ptr_ = ptr; +} + template DataBlob::DataBlob(const int n, const int c, const int h, const int w) : capacity_(0) diff --git a/src/get/MulStreamTask.cpp b/src/get/MulStreamTask.cpp new file mode 100644 index 0000000..1cb10d2 --- /dev/null +++ b/src/get/MulStreamTask.cpp @@ -0,0 +1,250 @@ +/************************************************************************* + > File Name: MulStreamTask.cpp + > Author: + > Mail: + > Created Time: Mon 20 Apr 2015 05:14:02 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + + +template +MulStreamTask::~MulStreamTask() +{ + +} + +template +void +MulStreamTask :: PreCompute() +{ + //reshape the block for addstream + long mem_used = 0; + mem_used += this->datas_[0].count() * sizeof(Dtype); + + int temp_num = 0; + + temp_num = (mem_used + this->data_per_block_ - 1) / this->data_per_block_; + + //create inner blocks + unsigned int height_each = M_ / temp_num; + unsigned int height_left = 0; + unsigned int offset = 0; + for(int i = 0; i < temp_num - 1 ; i++) + { + DataBlob inner; + + inner.ReshapeLike(1, 1, height_each, K_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset*K_); + this->inner_datas_.push_back(inner); + inner.ReshapeLike(1, 1, K_, N_); + inner.CopyFromMemory(this->datas_[1].host_data()); + this->inner_datas_.push_back(inner); + + inner.ReshapeLike(1, 1, height_each, N_); + inner.CopyFromMemory(this->results_[0].host_data()+offset*N_); + this->inner_results_.push_back(inner); + + offset += height_each*sizeof(Dtype); + height_left += height_each; + } + + height_left = M_ - height_left; + DataBlob inner; + inner.ReshapeLike(1, 1, height_left, K_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset*K_); + this->inner_datas_.push_back(inner); + inner.ReshapeLike(1, 1, K_, N_); + inner.CopyFromMemory(this->datas_[1].host_data()); + this->inner_datas_.push_back(inner); + + inner.ReshapeLike(1, 1, height_left, N_); + inner.CopyFromMemory(this->results_[0].host_data()+offset*N_); + this->inner_results_.push_back(inner); + + this->blocks_num_ = temp_num; + //create the buffer + cl_int status; + cl_mem temp_buffer; + + cl_int datasize; + + //datasize = this->device_->GlobalMemory() / 3; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each*K_*sizeof(Dtype) , NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, K_*N_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*N_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 3"); + this->device_buffers_.push_back(temp_buffer); + + /* + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 4"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 5"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 6"); + this->device_buffers_.push_back(temp_buffer); +*/ + +} +template +void +MulStreamTask::Compute() +{ + cl_int status; + + char* program_source = NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + cl_pStatus(status, "clBuildProgram"); + + cl_uint thread_division_row = 1; + cl_uint thread_division_col = 1; + int height_each = this->inner_datas_[0].height(); + + if (height_each > GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (height_each + GLOBAL_WORKGROUP_HEIGHT- 1 ) / GLOBAL_WORKGROUP_HEIGHT; + + if (N_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (N_ + GLOBAL_WORKGROUP_WIDTH -1 ) / GLOBAL_WORKGROUP_WIDTH; + + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT, GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + int blocks_num = this->blocks_num_; + cl_kernel kernel[blocks_num]; + + int i; + cout <<"blocks_num : "<device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_uint), &height_each); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &K_); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_uint), &N_); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel[i], 7, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_7"); + + } + + int height_left = this->inner_datas_[blocks_num - 1].height(); + kernel[i] = clCreateKernel(program, "mulMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[blocks_num-1], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[blocks_num-1], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[blocks_num-1], 2, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[blocks_num-1], 3, sizeof(cl_uint), &height_left); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[blocks_num-1], 4, sizeof(cl_uint), &K_); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[blocks_num-1], 5, sizeof(cl_uint), &N_); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[blocks_num-1], 6, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel[blocks_num-1], 7, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_7"); + + + + for(int i = 0; i < blocks_num ; i+=2) + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[2*i].height()*K_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_FALSE, 0, this->inner_datas_[2*i+1].height()*N_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel[i] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[2*(i+1)].height()*K_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*(i+1)].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[1], CL_FALSE, 0, this->inner_datas_[2*(i+1)+1].height()*N_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*(i+1)+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueB(), kernel[i+1] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + } + + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[2], CL_FALSE, 0, this->inner_results_[i].height()*N_*sizeof(Dtype), (Dtype*)(this->inner_results_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[2], CL_FALSE, 0, this->inner_results_[i+1].height()*N_*sizeof(Dtype), (Dtype*)(this->inner_results_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + } + } + + + clFinish(this->device_->DeviceCmdQueA()); + clFinish(this->device_->DeviceCmdQueB()); + + //clean + for(int i = 0; i < blocks_num; i++) + { + clReleaseKernel(kernel[i]); + } + clReleaseProgram(program); +} +template +void +MulStreamTask::PostCompute() +{ +} + + +template class MulStreamTask; +template class MulStreamTask; +template class MulStreamTask; +template class MulStreamTask; diff --git a/src/get/MulTask.cpp b/src/get/MulTask.cpp new file mode 100644 index 0000000..8247e2b --- /dev/null +++ b/src/get/MulTask.cpp @@ -0,0 +1,143 @@ +/************************************************************************* + > File Name: CommonTask.cpp + > Author: crows + > Mail: 136211494@qq.com + > Created Time: Tue 24 Mar 2015 02:55:48 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"BaseTask.h" +#include +#include +#include"opencl_z.h" + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + +template +void +MulTask::PreCompute() +{ + //set device buffer + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, M_*K_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + + this->device_buffers_.push_back(temp_buffer); + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, K_*N_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + + this->device_buffers_.push_back(temp_buffer); + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, M_*N_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 3"); + + this->device_buffers_.push_back(temp_buffer); + + //copy data from host to device + cout<<"A address : " <datas_[0].host_data() <datas_[1].host_data() <device_->DeviceCmdQueA(),this->device_buffers_[0], CL_TRUE, 0, M_*K_*sizeof(Dtype), (Dtype*)(this->datas_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_TRUE, 0, K_*N_*sizeof(Dtype), (Dtype*)(this->datas_[1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); +} + +template +void +MulTask::Compute() +{ + //read program source and create kernel + cl_int status; + + char* program_source = NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/mul_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + if (status != CL_SUCCESS) + { + char buildLog[16384]; + clGetProgramBuildInfo(program, temp_device, CL_PROGRAM_BUILD_LOG, sizeof(buildLog), buildLog, NULL); + cout << "Error in Kernel : "< GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (M_ + GLOBAL_WORKGROUP_HEIGHT - 1) / GLOBAL_WORKGROUP_HEIGHT; + + if (N_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (N_ + GLOBAL_WORKGROUP_WIDTH - 1) / GLOBAL_WORKGROUP_WIDTH; + + //set kernel arg + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel, 3, sizeof(cl_uint), &M_); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel, 4, sizeof(cl_uint), &K_); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel, 5, sizeof(cl_uint), &N_); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel, 6, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel, 7, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_7"); + + //start computing + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT , GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + cout<<"task start on GPU"<device_->DeviceCmdQueA()); +} + +template +void +MulTask::PostCompute() +{ + //copy memory from device to host + cl_int status; + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[2], CL_TRUE, 0, M_*N_*sizeof(Dtype), (Dtype*)(this->results_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + // clFinish(this->device_->DeviceCmdQueA()); + // clReleaseMemObject(this->device_buffers_[0]); + // clReleaseMemObject(this->device_buffers_[1]); + // clReleaseMemObject(this->device_buffers_[2]); +} + +template class MulTask; +template class MulTask; +template class MulTask; +template class MulTask; + diff --git a/src/get/MulTaskDispatcher.cpp b/src/get/MulTaskDispatcher.cpp new file mode 100644 index 0000000..13f0aa8 --- /dev/null +++ b/src/get/MulTaskDispatcher.cpp @@ -0,0 +1,215 @@ +/************************************************************************* + > File Name: MulTaskDispatcher.cpp + > Author: + > Mail: + > Created Time: Mon 13 Apr 2015 03:43:08 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"TaskDispatcher.h" + +template +void +MulTaskDispatcher::PreTaskDispatch() +{ + //create input blobs and output blobs + GET::TaskParam_DataPosition source_pos = this->task_param_.source_pos(); + int input_blob_num = 0; + if (source_pos == 1) + input_blob_num = this->task_param_.sourcem_size(); + else if (source_pos == 0) + input_blob_num = this->task_param_.sourcef_size(); + + cout <<"input_blob_num = "<task_param_.result_pos(); + int output_blob_num = 0; + if (result_pos == 1) + output_blob_num = this->task_param_.resultm_size(); + else if (result_pos == 0) + output_blob_num = this->task_param_.resultf_size(); + + DataBlob blob; + //prepare A + blob.ReshapeLike(1, 1, mul_param_.m(), mul_param_.k()); + this->datablobs_input_.push_back(blob); + //prepare B + blob.ReshapeLike(1, 1, mul_param_.k(), mul_param_.n()); + this->datablobs_input_.push_back(blob); + //prepare input + this->PrepareBlob(input_blob_num, source_pos,(this->datablobs_input_), SOURCE); + //prepare output + blob.ReshapeLike(1, 1, mul_param_.m(), mul_param_.n()); + this->datablobs_output_.push_back(blob); + this->PrepareBlob(output_blob_num, result_pos,(this->datablobs_output_), RESULT); + +} + +template +void +MulTaskDispatcher::TaskDispatch() +{ + //get device : mem_used + long mem_used = 0; + for(int i = 0; i< this->datablobs_input_.size(); i++) + mem_used += this->datablobs_input_[i].count() * sizeof(Dtype); + for(int i = 0; idatablobs_output_.size(); i++) + mem_used += this->datablobs_output_[i].count() * sizeof(Dtype); + this->mem_used_ =mem_used; + + bool device_enough = false; + while (!device_enough) + { + BaseDevice* device; + if ((device = this->device_manager_->GetAvailableDevice()) != NULL) + { + this->devices_needed_.push_back(device); + if(mem_used < (device->GlobalMemory())) + { + device_enough = true; + } + else + { + mem_used -= device->GlobalMemory(); + } + } + else + { + break; + } + } + + cout << "Need Devices = "<< this->devices_needed_.size() <devices_needed_.size()); + //compute the weight for load banlancing + unsigned int height_each = 0; + unsigned int offset = 0; + int denum = this->devices_needed_.size(); + unsigned long sum_gl_mem = 0; + float left_part =0; + float weight[denum]; + + for(int i = 0; i devices_needed_[i]->GlobalMemory(); + } + + for(int i =0; idevices_needed_[i]->GlobalMemory())/((float)sum_gl_mem); + left_part += weight[i]; + } + weight[denum-1] = 1 - left_part; + + unsigned int height_current = 0; + + int i= 0; + for( i = 0; i < this->devices_needed_.size()-1;i++) + { + + height_each = mul_param_.m() * weight[i]; + height_current += height_each; + MulTask* temp_task = new MulTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, mul_param_.k(), mul_param_.n()); + //create inner + DataBlob inner; + //prepare A inner + inner.ReshapeLike(1, 1, height_each, mul_param_.k()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset*mul_param_.k())); + temp_task->AddtoDatas(inner); + //prepare B inner + inner.ReshapeLike(1, 1, mul_param_.k(), mul_param_.n()); + inner.CopyFromMemory((void*)((this->datablobs_input_[1].host_data()))); + temp_task->AddtoDatas(inner); + //create inner_results + //prepare C inner + inner.ReshapeLike(1, 1, height_each, mul_param_.k()); + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset*mul_param_.n())); + temp_task->AddtoResults(inner); + + this->ordinary_tasks_.push_back(temp_task); + cout<<"create " << i << " multask" <* temp_task = new MulTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, mul_param_.k(), mul_param_.n()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, mul_param_.k()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset*mul_param_.k())); + temp_task->AddtoDatas(inner); + inner.ReshapeLike(1, 1, mul_param_.k(), mul_param_.n()); + inner.CopyFromMemory((void*)((this->datablobs_input_[1].host_data()))); + temp_task->AddtoDatas(inner); + //create inner_results + inner.ReshapeLike(1, 1, height_each, mul_param_.n()); + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset*mul_param_.n())); + temp_task->AddtoResults(inner); + + //attach the blob to the task + + this->ordinary_tasks_.push_back(temp_task); + + this->process_type_ = ORDINARY; + } + else + { + // stream task + //decide to compute on one or more devices + + //compute on the device with the biggest global memory + unsigned long max_glmem = 0; + BaseDevice* stream_device = NULL ; + for(int i = 0; i < this->devices_needed_.size(); i++) + { + if(max_glmem < this->devices_needed_[i]->GlobalMemory()) + { + stream_device = this->devices_needed_[i]; + max_glmem = this->devices_needed_[i]->GlobalMemory(); + } + } + cout<< "stream device determined"<::iterator l_it = this->devices_needed_.begin(); + + while(l_it != this->devices_needed_.end()) + { + if((*l_it) != stream_device) + { + this->device_manager_->FreeDevice((*l_it)->DeviceIndex()); + l_it = this->devices_needed_.erase(l_it); + } + else + { + l_it++; + } + } + + MulStreamTask* temp_adst = new MulStreamTask(stream_device); + temp_adst->SetParams(1, mul_param_.m(), mul_param_.k(), mul_param_.n()); + temp_adst->AddtoDatas(this->datablobs_input_[0]); + temp_adst->AddtoDatas(this->datablobs_input_[1]); + temp_adst->AddtoResults(this->datablobs_output_[0]); + this->stream_tasks_.push_back(temp_adst); + + cout<< "create stream tasks" <process_type_ =STREAM; + + + } + +} + +template class MulTaskDispatcher; +template class MulTaskDispatcher; +template class MulTaskDispatcher; +template class MulTaskDispatcher; + diff --git a/src/get/PoolStreamTask.cpp b/src/get/PoolStreamTask.cpp new file mode 100644 index 0000000..16308dd --- /dev/null +++ b/src/get/PoolStreamTask.cpp @@ -0,0 +1,248 @@ +/************************************************************************* + > File Name: PoolStreamTask.cpp + > Author: + > Mail: + > Created Time: Mon 20 Apr 2015 05:14:02 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + + + +template +PoolStreamTask::~PoolStreamTask() +{ + +} + +template +void +PoolStreamTask :: PreCompute() +{ + //reshape the block for addstream + + int stride = stride_h_; + int width_filter = kernel_w_; + int input_height = data_h_; + int input_width = data_w_; + int output_height = output_h_; + int output_width = output_w_; + long total_device_size = this->device_->GlobalMemory() * 0.6; + DataBlob *input_matrix,*filter,*output_matrix; + Dtype *host_data_inputmatrix,*host_data_outputmatrix; + host_data_inputmatrix = (Dtype *)this->datas_[0].host_data(); + host_data_outputmatrix = (Dtype *)this->results_[0].host_data(); + int height_per_inputmatrix = (total_device_size/sizeof(Dtype) - width_filter * width_filter)/(input_width + output_width); + int height_per_outputmatrix; + int heightgoback = height_per_inputmatrix - (height_per_inputmatrix - width_filter + stride)/stride * stride; + int blocks_num = 0; + while(input_height > width_filter) + { + //filter = data_input_[1]; + height_per_inputmatrix = height_per_inputmatrix > input_height?input_height:height_per_inputmatrix; + height_per_outputmatrix = (height_per_inputmatrix - width_filter + stride)/stride; + DataBlob input_matrix((void *)host_data_inputmatrix,1,1,height_per_inputmatrix,input_width); + DataBlob output_matrix((void *)host_data_outputmatrix,1,1,height_per_outputmatrix,output_width); + host_data_inputmatrix += (height_per_inputmatrix - heightgoback)*input_width; + host_data_outputmatrix += height_per_outputmatrix* output_width; + input_height = input_height - height_per_inputmatrix + heightgoback; + this->inner_datas_.push_back(input_matrix); + this->inner_results_.push_back(output_matrix); + blocks_num++; + } + this->blocks_num_ = blocks_num; + cl_int datasize_i = this->inner_datas_[0].height()*input_width*sizeof(Dtype); + cl_int datasize_o = this->inner_results_[0].height()*this->inner_results_[0].width()*sizeof(Dtype); + + + //datasize = this->device_->GlobalMemory() / 3; + + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize_i , NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, datasize_o, NULL, &status); + cl_pStatus(status, "clCreateBuffer 3"); + this->device_buffers_.push_back(temp_buffer); + + /* + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 4"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 5"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 6"); + this->device_buffers_.push_back(temp_buffer); +*/ + +} +template +void +PoolStreamTask::Compute() +{ + cl_int status; + + char* program_source = NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + cl_pStatus(status, "clBuildProgram"); + + + int padding_pixels = (int)(kernel_w_ / 2) * 2; + + unsigned int global_itemh = RoundUp(this->inner_datas_[0].height() - padding_pixels , LOCAL_WORKGROUP_HEIGHT); + unsigned int global_itemw = RoundUp(this->inner_datas_[0].width() - padding_pixels, LOCAL_WORKGROUP_WIDTH); + size_t global_worksize[2] = {global_itemh , global_itemw}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + cl_uint local_height = LOCAL_WORKGROUP_HEIGHT + padding_pixels; + cl_uint local_width = LOCAL_WORKGROUP_WIDTH + padding_pixels; + + size_t local_mem = local_height * local_width * sizeof(Dtype); + + int blocks_num = this->blocks_num_; + cl_kernel kernel[blocks_num]; + + int i; + cout <<"blocks_num : "<inner_datas_[0].height(); + int in_width_each = this->inner_datas_[0].width(); + int out_height_each = this->inner_results_[0].height(); + int out_width_each = this->inner_results_[0].width(); + for(i = 0; i < blocks_num-1 ; i++) + { + kernel[i] = clCreateKernel(program, "poolMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &out_height_each); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &out_width_each); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_int), &in_height_each); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_int), &in_width_each); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_uint), &kernel_w_); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel[i], 7, local_mem, NULL); + cl_pStatus(status, "clSetKernelArg_7"); + status = clSetKernelArg(kernel[i], 8, sizeof(cl_uint), &local_height); + cl_pStatus(status, "clSetKernelArg_8"); + status = clSetKernelArg(kernel[i], 9, sizeof(cl_uint), &local_width); + cl_pStatus(status, "clSetKernelArg_9"); + + } + + int in_height_left = this->inner_datas_[i].height(); + int in_width = this->inner_datas_[i].width(); + int out_height = this->inner_results_[i].height(); + int out_width = this->inner_results_[i].width(); + + kernel[i] = clCreateKernel(program, "poolMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &out_height); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &out_width); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_int), &in_height_left); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_int), &in_width); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_uint), &kernel_w_); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel[i], 7, local_mem, NULL); + cl_pStatus(status, "clSetKernelArg_7"); + status = clSetKernelArg(kernel[i], 8, sizeof(cl_uint), &local_height); + cl_pStatus(status, "clSetKernelArg_8"); + status = clSetKernelArg(kernel[i], 9, sizeof(cl_uint), &local_width); + cl_pStatus(status, "clSetKernelArg_9"); + + + for(int i = 0; i < blocks_num ; i+=2) + { + //cout<<"address : "<inner_datas_[2*i].host_data()<device_->DeviceCmdQueA(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[i].height()*data_w_*sizeof(Dtype), (Dtype*)(this->inner_datas_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel[i] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[(i+1)].height()*data_w_*sizeof(Dtype), (Dtype*)(this->inner_datas_[(i+1)].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueB(), kernel[i+1] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + } + + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_FALSE, 0, this->inner_results_[i].height()*output_w_*sizeof(Dtype), (Dtype*)(this->inner_results_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[1], CL_FALSE, 0, this->inner_results_[i+1].height()*output_w_*sizeof(Dtype), (Dtype*)(this->inner_results_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + } + } + + + clFinish(this->device_->DeviceCmdQueA()); + clFinish(this->device_->DeviceCmdQueB()); + + //clean + for(int i = 0; i < blocks_num; i++) + { + clReleaseKernel(kernel[i]); + } + clReleaseProgram(program); +} +template +void +PoolStreamTask::PostCompute() +{ +} + + +template class PoolStreamTask; +template class PoolStreamTask; +template class PoolStreamTask; +template class PoolStreamTask; diff --git a/src/get/PoolTask.cpp b/src/get/PoolTask.cpp new file mode 100644 index 0000000..0952d7c --- /dev/null +++ b/src/get/PoolTask.cpp @@ -0,0 +1,140 @@ +/************************************************************************* + > File Name: ReLUTask.cpp + > Author: crows + > Mail: 136211494@qq.com + > Created Time: Tue 24 Mar 2015 02:55:48 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"BaseTask.h" +#include +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + +template +void +PoolTask::PreCompute() +{ + //set device buffer + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, data_h_*data_w_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + + this->device_buffers_.push_back(temp_buffer); + + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, output_h_*output_w_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + + this->device_buffers_.push_back(temp_buffer); + + //copy data from host to device + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_TRUE, 0, data_h_*data_w_*sizeof(Dtype), (Dtype*)(this->datas_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + +} + +template +void +PoolTask::Compute() +{ + //read program source and create kernel + cl_int status; + + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/pool_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + if (status != CL_SUCCESS) + { + char buildLog[16384]; + clGetProgramBuildInfo(program, temp_device, CL_PROGRAM_BUILD_LOG, sizeof(buildLog), buildLog, NULL); + cout << "Error in Kernel : "<device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &output_h_); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &output_w_); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &data_h_); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel, 5, sizeof(cl_int), &data_w_); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel, 6, sizeof(cl_uint), &kernel_w_); + cl_pStatus(status, "clSetKernelArg_6"); + status = clSetKernelArg(kernel, 7, local_mem, NULL); + cl_pStatus(status, "clSetKernelArg_7"); + status = clSetKernelArg(kernel, 8, sizeof(cl_uint), &local_height); + cl_pStatus(status, "clSetKernelArg_8"); + status = clSetKernelArg(kernel, 9, sizeof(cl_uint), &local_width); + cl_pStatus(status, "clSetKernelArg_9"); + //start computing + + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + cout<<"task start on GPU"<device_->DeviceCmdQueA()); +} + +template +void +PoolTask::PostCompute() +{ + //copy memory from device to host + cl_int status; + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_TRUE, 0, output_h_*output_w_*sizeof(Dtype), (Dtype*)(this->results_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + clFinish(this->device_->DeviceCmdQueA()); + // clReleaseMemObject(this->device_buffers_[0]); + // clReleaseMemObject(this->device_buffers_[1]); + // clReleaseMemObject(this->device_buffers_[2]); +} + +template class PoolTask; +template class PoolTask; +template class PoolTask; +template class PoolTask; + diff --git a/src/get/PoolTaskDispatcher.cpp b/src/get/PoolTaskDispatcher.cpp new file mode 100644 index 0000000..b806efb --- /dev/null +++ b/src/get/PoolTaskDispatcher.cpp @@ -0,0 +1,220 @@ +/************************************************************************* + > File Name: PoolTaskDispatcher.cpp + > Author: + > Mail: + > Created Time: Mon 13 Apr 2015 03:43:08 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"TaskDispatcher.h" + +template +void +PoolTaskDispatcher::PreTaskDispatch() +{ + //create input blobs and output blobs + GET::TaskParam_DataPosition source_pos = this->task_param_.source_pos(); + int input_blob_num = 0; + if (source_pos == 1) + input_blob_num = this->task_param_.sourcem_size(); + else if (source_pos == 0) + input_blob_num = this->task_param_.sourcef_size(); + + cout <<"input_blob_num = "<task_param_.result_pos(); + int output_blob_num = 0; + if (result_pos == 1) + output_blob_num = this->task_param_.resultm_size(); + else if (result_pos == 0) + output_blob_num = this->task_param_.resultf_size(); + + DataBlob blob; + //prepare A + blob.ReshapeLike(1, 1, pool_param_.data_h(), pool_param_.data_w()); + this->datablobs_input_.push_back(blob); + //prepare B + //prepare input + this->PrepareBlob(input_blob_num, source_pos,(this->datablobs_input_), SOURCE); + //prepare output + unsigned int output_height = (pool_param_.data_h() - pool_param_.kernel_h()) / pool_param_.stride_h() + 1; + unsigned int output_width = (pool_param_.data_w() - pool_param_.kernel_w()) / pool_param_.stride_w() + 1; + blob.ReshapeLike(1, 1, output_height, output_width); + this->datablobs_output_.push_back(blob); + this->PrepareBlob(output_blob_num, result_pos,(this->datablobs_output_), RESULT); + +} + +template +void +PoolTaskDispatcher::TaskDispatch() +{ + //get device : mem_used + long mem_used = 0; + for(int i = 0; i< this->datablobs_input_.size(); i++) + mem_used += this->datablobs_input_[i].count() * sizeof(Dtype); + for(int i = 0; idatablobs_output_.size(); i++) + mem_used += this->datablobs_output_[i].count() * sizeof(Dtype); + this->mem_used_ =mem_used; + + bool device_enough = false; + while (!device_enough) + { + BaseDevice* device; + if ((device = this->device_manager_->GetAvailableDevice()) != NULL) + { + this->devices_needed_.push_back(device); + if(mem_used < (device->GlobalMemory())) + { + device_enough = true; + } + else + { + mem_used -= device->GlobalMemory(); + } + } + else + { + break; + } + } + + cout << "Need Devices = "<< this->devices_needed_.size() <devices_needed_.size()); + //compute the weight for load banlancing + unsigned int offset = 0; + int denum = this->devices_needed_.size(); + unsigned long sum_gl_mem = 0; + float left_part =0; + float weight[denum]; + + for(int i = 0; i devices_needed_[i]->GlobalMemory(); + } + + for(int i =0; idevices_needed_[i]->GlobalMemory())/((float)sum_gl_mem); + left_part += weight[i]; + } + weight[denum-1] = 1 - left_part; + + unsigned int height_current = 0; + int heightgoback = 0; + int width_filter = pool_param_.kernel_w(); + int stride = pool_param_.stride_h(); + int i= 0; + int height_each = pool_param_.data_h() * weight[0]; + int height_out_each = 0; + int input_width = pool_param_.data_w(); + int output_width = (input_width - width_filter)/stride + 1; + Dtype* input_data = (Dtype*)this->datablobs_input_[0].host_data(); + Dtype* output_data = (Dtype*)this->datablobs_output_[0].host_data(); + for( i = 0; i < this->devices_needed_.size()-1;i++) + { + + heightgoback = height_each - (height_each - width_filter + stride)/stride * stride; + height_current += (height_each - heightgoback); + PoolTask* temp_task = new PoolTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, pool_param_.data_w(), pool_param_.kernel_h(), pool_param_.kernel_w(), pool_param_.stride_h(), pool_param_.stride_w(), pool_param_.pad_h(), pool_param_.pad_w()); + //create inner + //prepare in + DataBlob inner_i((void*)input_data, 1, 1 , height_each, input_width); + temp_task->AddtoDatas(inner_i); + //prepare filter + //create inner_results + //prepare out + height_out_each = (height_each - width_filter + stride)/stride; + DataBlob inner_o((void*)output_data,1, 1, height_out_each, output_width); + temp_task->AddtoResults(inner_o); + + this->ordinary_tasks_.push_back(temp_task); + cout<<"create " << i << " convtask" <* temp_task = new PoolTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, pool_param_.data_w(), pool_param_.kernel_h(), pool_param_.kernel_w(), pool_param_.stride_h(), pool_param_.stride_w(), pool_param_.pad_h(), pool_param_.pad_w()); + //create inner + //prepare in + DataBlob inner_i((void*)input_data, 1, 1 , height_each, input_width); + temp_task->AddtoDatas(inner_i); + //prepare filter + //create inner_results + //prepare out + height_out_each = (height_each - width_filter + stride)/stride; + DataBlob inner_o((void*)output_data,1, 1, height_out_each, output_width); + temp_task->AddtoResults(inner_o); + + //attach the blob to the task + + this->ordinary_tasks_.push_back(temp_task); + + this->process_type_ = ORDINARY; + } + else + { + // stream task + //decide to compute on one or more devices + + //compute on the device with the biggest global memory + unsigned long max_glmem = 0; + BaseDevice* stream_device = NULL ; + for(int i = 0; i < this->devices_needed_.size(); i++) + { + if(max_glmem < this->devices_needed_[i]->GlobalMemory()) + { + stream_device = this->devices_needed_[i]; + max_glmem = this->devices_needed_[i]->GlobalMemory(); + } + } + cout<< "stream device determined"<::iterator l_it = this->devices_needed_.begin(); + + while(l_it != this->devices_needed_.end()) + { + if((*l_it) != stream_device) + { + this->device_manager_->FreeDevice((*l_it)->DeviceIndex()); + l_it = this->devices_needed_.erase(l_it); + } + else + { + l_it++; + } + } + + PoolStreamTask* temp_adst = new PoolStreamTask(stream_device); + temp_adst->SetParams(1, pool_param_.data_h(), pool_param_.data_w(), pool_param_.kernel_h(), pool_param_.kernel_w(), pool_param_.stride_h(), pool_param_.stride_w(), pool_param_.pad_h(), pool_param_.pad_w()); + temp_adst->AddtoDatas(this->datablobs_input_[0]); + temp_adst->AddtoResults(this->datablobs_output_[0]); + this->stream_tasks_.push_back(temp_adst); + + cout<< "create stream tasks" <process_type_ =STREAM; + + + } + +} + +template class PoolTaskDispatcher; +template class PoolTaskDispatcher; +template class PoolTaskDispatcher; +template class PoolTaskDispatcher; + diff --git a/src/get/ReLUStreamTask.cpp b/src/get/ReLUStreamTask.cpp new file mode 100644 index 0000000..3709013 --- /dev/null +++ b/src/get/ReLUStreamTask.cpp @@ -0,0 +1,226 @@ +/************************************************************************* + > File Name: AddStreamTask.cpp + > Author: + > Mail: + > Created Time: Mon 20 Apr 2015 05:14:02 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + + +template +ReLUStreamTask::~ReLUStreamTask() +{ + +} + +template +void +ReLUStreamTask :: PreCompute() +{ + //reshape the block for addstream + long mem_used = 0; + mem_used += this->datas_[0].count() * sizeof(Dtype); + + int temp_num = 0; + + temp_num = (mem_used + this->data_per_block_ - 1) / this->data_per_block_; + + //create inner blocks + unsigned int height_each = height_ / temp_num; + unsigned int height_left = 0; + unsigned int offset = 0; + for(int i = 0; i < temp_num - 1 ; i++) + { + DataBlob inner; + + inner.ReshapeLike(1, 1, height_each, width_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset); + this->inner_datas_.push_back(inner); + + inner.CopyFromMemory(this->results_[0].host_data()+offset); + this->inner_results_.push_back(inner); + + offset += height_each*width_*sizeof(Dtype); + height_left += height_each; + } + + height_left = height_ - height_left; + DataBlob inner; + inner.ReshapeLike(1, 1, height_left, width_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset); + this->inner_datas_.push_back(inner); + + inner.CopyFromMemory(this->results_[0].host_data()+offset); + this->inner_results_.push_back(inner); + + this->blocks_num_ = temp_num; + //create the buffer + cl_int status; + cl_mem temp_buffer; + + cl_int datasize; + + datasize = this->device_->GlobalMemory() / 3; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize , NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, datasize, NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + this->device_buffers_.push_back(temp_buffer); + + /* + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 4"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 5"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 6"); + this->device_buffers_.push_back(temp_buffer); +*/ + +} +template +void +ReLUStreamTask::Compute() +{ + cl_int status; + + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + cl_pStatus(status, "clBuildProgram"); + + cl_uint thread_division_row = 1; + cl_uint thread_division_col = 1; + int height_each = this->inner_datas_[0].height(); + + if (height_each > GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (height_each + GLOBAL_WORKGROUP_HEIGHT- 1 ) / GLOBAL_WORKGROUP_HEIGHT; + + if (width_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (width_ + GLOBAL_WORKGROUP_WIDTH -1 ) / GLOBAL_WORKGROUP_WIDTH; + + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT, GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + int blocks_num = this->blocks_num_; + cl_kernel kernel[blocks_num]; + + int i; + cout <<"blocks_num : "<device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &height_each); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_5"); + } + + int height_left = this->inner_datas_[blocks_num - 1].height(); + kernel[i] = clCreateKernel(program, "ReLUMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &height_left); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_5"); + + + + for(int i = 0; i < blocks_num ; i+=2) + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[i].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel[i] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[i+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueB(), kernel[i+1] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + } + + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_FALSE, 0, this->inner_results_[i].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_results_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[1], CL_FALSE, 0, this->inner_results_[i+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_results_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + } + } + + + clFinish(this->device_->DeviceCmdQueA()); + clFinish(this->device_->DeviceCmdQueB()); + + //clean + for(int i = 0; i < blocks_num; i++) + { + clReleaseKernel(kernel[i]); + } + clReleaseProgram(program); +} +template +void +ReLUStreamTask::PostCompute() +{ +} + + +template class ReLUStreamTask; +template class ReLUStreamTask; +template class ReLUStreamTask; +template class ReLUStreamTask; diff --git a/src/get/ReLUTask.cpp b/src/get/ReLUTask.cpp new file mode 100644 index 0000000..84ad9cc --- /dev/null +++ b/src/get/ReLUTask.cpp @@ -0,0 +1,131 @@ +/************************************************************************* + > File Name: ReLUTask.cpp + > Author: crows + > Mail: 136211494@qq.com + > Created Time: Tue 24 Mar 2015 02:55:48 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"BaseTask.h" +#include +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + +template +void +ReLUTask::PreCompute() +{ + //set device buffer + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + + this->device_buffers_.push_back(temp_buffer); + + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + + this->device_buffers_.push_back(temp_buffer); + + //copy data from host to device + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_TRUE, 0, height_*width_*sizeof(Dtype), (Dtype*)(this->datas_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + +} + +template +void +ReLUTask::Compute() +{ + //read program source and create kernel + cl_int status; + + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/ReLU_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + if (status != CL_SUCCESS) + { + char buildLog[16384]; + clGetProgramBuildInfo(program, temp_device, CL_PROGRAM_BUILD_LOG, sizeof(buildLog), buildLog, NULL); + cout << "Error in Kernel : "< GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (height_ + GLOBAL_WORKGROUP_HEIGHT - 1) / GLOBAL_WORKGROUP_HEIGHT; + + if (width_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (width_ + GLOBAL_WORKGROUP_WIDTH - 1) / GLOBAL_WORKGROUP_WIDTH; + + //set kernel arg + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &height_); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel, 4, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel, 5, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_5"); + + //start computing + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT , GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + cout<<"task start on GPU"<device_->DeviceCmdQueA()); +} + +template +void +ReLUTask::PostCompute() +{ + //copy memory from device to host + cl_int status; + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_TRUE, 0, height_*width_*sizeof(Dtype), (Dtype*)(this->results_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + clFinish(this->device_->DeviceCmdQueA()); + // clReleaseMemObject(this->device_buffers_[0]); + // clReleaseMemObject(this->device_buffers_[1]); + // clReleaseMemObject(this->device_buffers_[2]); +} + +template class ReLUTask; +template class ReLUTask; +template class ReLUTask; +template class ReLUTask; + diff --git a/src/get/ReLUTaskDispatcher.cpp b/src/get/ReLUTaskDispatcher.cpp new file mode 100644 index 0000000..023c0eb --- /dev/null +++ b/src/get/ReLUTaskDispatcher.cpp @@ -0,0 +1,198 @@ +/************************************************************************* + > File Name: ReLUTaskDispatcher.cpp + > Author: + > Mail: + > Created Time: Mon 13 Apr 2015 03:43:08 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"TaskDispatcher.h" + +template +void +ReLUTaskDispatcher::PreTaskDispatch() +{ + //create input blobs and output blobs + GET::TaskParam_DataPosition source_pos = this->task_param_.source_pos(); + int input_blob_num = 0; + if (source_pos == 1) + input_blob_num = this->task_param_.sourcem_size(); + else if (source_pos == 0) + input_blob_num = this->task_param_.sourcef_size(); + + cout <<"input_blob_num = "<task_param_.result_pos(); + int output_blob_num = 0; + if (result_pos == 1) + output_blob_num = this->task_param_.resultm_size(); + else if (result_pos == 0) + output_blob_num = this->task_param_.resultf_size(); + + DataBlob blob; + blob.ReshapeLike(1, 1, relu_param_.height(), relu_param_.width()); + //prepare input + this->datablobs_input_.push_back(blob); + this->PrepareBlob(input_blob_num, source_pos,(this->datablobs_input_), SOURCE); + //prepare output + this->datablobs_output_.push_back(blob); + this->PrepareBlob(output_blob_num, result_pos,(this->datablobs_output_), RESULT); + +} + +template +void +ReLUTaskDispatcher::TaskDispatch() +{ + //get device : mem_used + long mem_used = 0; + for(int i = 0; i< this->datablobs_input_.size(); i++) + mem_used += this->datablobs_input_[i].count() * sizeof(Dtype); + for(int i = 0; idatablobs_output_.size(); i++) + mem_used += this->datablobs_output_[i].count() * sizeof(Dtype); + this->mem_used_ =mem_used; + + bool device_enough = false; + while (!device_enough) + { + BaseDevice* device; + if ((device = this->device_manager_->GetAvailableDevice()) != NULL) + { + this->devices_needed_.push_back(device); + if(mem_used < (device->GlobalMemory())) + { + device_enough = true; + } + else + { + mem_used -= device->GlobalMemory(); + } + } + else + { + break; + } + } + + cout << "Need Devices = "<< this->devices_needed_.size() <devices_needed_.size()); + //compute the weight for load banlancing + unsigned int height_each = 0; + unsigned int offset = 0; + int denum = this->devices_needed_.size(); + unsigned long sum_gl_mem = 0; + float left_part =0; + float weight[denum]; + + for(int i = 0; i devices_needed_[i]->GlobalMemory(); + } + + for(int i =0; idevices_needed_[i]->GlobalMemory())/((float)sum_gl_mem); + left_part += weight[i]; + } + weight[denum-1] = 1 - left_part; + + unsigned int height_current = 0; + + int i= 0; + for( i = 0; i < this->devices_needed_.size()-1;i++) + { + + height_each = relu_param_.height() * weight[i]; + height_current += height_each; + ReLUTask* temp_task = new ReLUTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, relu_param_.width()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, relu_param_.width()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset)); + temp_task->AddtoDatas(inner); + //create inner_results + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset)); + temp_task->AddtoResults(inner); + + this->ordinary_tasks_.push_back(temp_task); + cout<<"create " << i << " relutask" <* temp_task = new ReLUTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, relu_param_.width()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, relu_param_.width()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset)); + temp_task->AddtoDatas(inner); + //create inner_results + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset)); + temp_task->AddtoResults(inner); + + //attach the blob to the task + + this->ordinary_tasks_.push_back(temp_task); + + this->process_type_ = ORDINARY; + } + else + { + // stream task + //decide to compute on one or more devices + + //compute on the device with the biggest global memory + unsigned long max_glmem = 0; + BaseDevice* stream_device = NULL ; + for(int i = 0; i < this->devices_needed_.size(); i++) + { + if(max_glmem < this->devices_needed_[i]->GlobalMemory()) + { + stream_device = this->devices_needed_[i]; + max_glmem = this->devices_needed_[i]->GlobalMemory(); + } + } + cout<< "stream device determined"<::iterator l_it = this->devices_needed_.begin(); + + while(l_it != this->devices_needed_.end()) + { + if((*l_it) != stream_device) + { + this->device_manager_->FreeDevice((*l_it)->DeviceIndex()); + l_it = this->devices_needed_.erase(l_it); + } + else + { + l_it++; + } + } + + ReLUStreamTask* temp_adst = new ReLUStreamTask(stream_device); + temp_adst->SetParams(1, relu_param_.height(), relu_param_.width()); + temp_adst->AddtoDatas(this->datablobs_input_[0]); + temp_adst->AddtoResults(this->datablobs_output_[0]); + this->stream_tasks_.push_back(temp_adst); + + cout<< "create stream tasks" <process_type_ =STREAM; + + + } + +} + +template class ReLUTaskDispatcher; +template class ReLUTaskDispatcher; +template class ReLUTaskDispatcher; +template class ReLUTaskDispatcher; + diff --git a/src/get/ReLU_task_double.cl b/src/get/ReLU_task_double.cl new file mode 100644 index 0000000..b43872f --- /dev/null +++ b/src/get/ReLU_task_double.cl @@ -0,0 +1,23 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void ReLUMatrix(__global double* Input, + __global double* Output, + int height, + int width, + unsigned int divT_r, + unsigned int divT_c + ) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + int iterations = divT_c*divT_r; + int index = row_t*width+col_t; + if ((row_t < height)&&(col_t < width)) + { + for(int i = 0 ; i < iterations ; i++) + { + Output[index+i] = Input[index+i] > 0 ? Input[index+i] : 0 ; + } + } +} + diff --git a/src/get/ReLU_task_float.cl b/src/get/ReLU_task_float.cl new file mode 100644 index 0000000..046ecfd --- /dev/null +++ b/src/get/ReLU_task_float.cl @@ -0,0 +1,23 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void ReLUMatrix(__global float* Input, + __global float* Output, + int height, + int width, + unsigned int divT_r, + unsigned int divT_c + ) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + int iterations = divT_c*divT_r; + int index = row_t*width+col_t; + if ((row_t < height)&&(col_t < width)) + { + for(int i = 0 ; i < iterations ; i++) + { + Output[index+i] = Input[index+i] > 0 ? Input[index+i] : 0 ; + } + } +} + diff --git a/src/get/ReLU_task_int.cl b/src/get/ReLU_task_int.cl new file mode 100644 index 0000000..f01c9bb --- /dev/null +++ b/src/get/ReLU_task_int.cl @@ -0,0 +1,23 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void ReLUMatrix(__global int* Input, + __global int* Output, + int height, + int width, + unsigned int divT_r, + unsigned int divT_c + ) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + int iterations = divT_c*divT_r; + int index = row_t*width+col_t; + if ((row_t < height)&&(col_t < width)) + { + for(int i = 0 ; i < iterations ; i++) + { + Output[index+i] = Input[index+i] > 0 ? Input[index+i] : 0 ; + } + } +} + diff --git a/src/get/ReLU_task_long.cl b/src/get/ReLU_task_long.cl new file mode 100644 index 0000000..a4664ed --- /dev/null +++ b/src/get/ReLU_task_long.cl @@ -0,0 +1,23 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void ReLUMatrix(__global long* Input, + __global long* Output, + int height, + int width, + unsigned int divT_r, + unsigned int divT_c + ) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + int iterations = divT_c*divT_r; + int index = row_t*width+col_t; + if ((row_t < height)&&(col_t < width)) + { + for(int i = 0 ; i < iterations ; i++) + { + Output[index+i] = Input[index+i] > 0 ? Input[index+i] : 0 ; + } + } +} + diff --git a/src/get/SigmoidStreamTask.cpp b/src/get/SigmoidStreamTask.cpp new file mode 100644 index 0000000..eae612e --- /dev/null +++ b/src/get/SigmoidStreamTask.cpp @@ -0,0 +1,226 @@ +/************************************************************************* + > File Name: AddStreamTask.cpp + > Author: + > Mail: + > Created Time: Mon 20 Apr 2015 05:14:02 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + + +template +SigmoidStreamTask::~SigmoidStreamTask() +{ + +} + +template +void +SigmoidStreamTask :: PreCompute() +{ + //reshape the block for addstream + long mem_used = 0; + mem_used += this->datas_[0].count() * sizeof(Dtype); + + int temp_num = 0; + + temp_num = (mem_used + this->data_per_block_ - 1) / this->data_per_block_; + + //create inner blocks + unsigned int height_each = height_ / temp_num; + unsigned int height_left = 0; + unsigned int offset = 0; + for(int i = 0; i < temp_num - 1 ; i++) + { + DataBlob inner; + + inner.ReshapeLike(1, 1, height_each, width_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset); + this->inner_datas_.push_back(inner); + + inner.CopyFromMemory(this->results_[0].host_data()+offset); + this->inner_results_.push_back(inner); + + offset += height_each*width_*sizeof(Dtype); + height_left += height_each; + } + + height_left = height_ - height_left; + DataBlob inner; + inner.ReshapeLike(1, 1, height_left, width_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset); + this->inner_datas_.push_back(inner); + + inner.CopyFromMemory(this->results_[0].host_data()+offset); + this->inner_results_.push_back(inner); + + this->blocks_num_ = temp_num; + //create the buffer + cl_int status; + cl_mem temp_buffer; + + cl_int datasize; + + datasize = this->device_->GlobalMemory() / 3; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize , NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, datasize, NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + this->device_buffers_.push_back(temp_buffer); + + /* + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 4"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 5"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 6"); + this->device_buffers_.push_back(temp_buffer); +*/ + +} +template +void +SigmoidStreamTask::Compute() +{ + cl_int status; + + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + cl_pStatus(status, "clBuildProgram"); + + cl_uint thread_division_row = 1; + cl_uint thread_division_col = 1; + int height_each = this->inner_datas_[0].height(); + + if (height_each > GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (height_each + GLOBAL_WORKGROUP_HEIGHT- 1 ) / GLOBAL_WORKGROUP_HEIGHT; + + if (width_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (width_ + GLOBAL_WORKGROUP_WIDTH -1 ) / GLOBAL_WORKGROUP_WIDTH; + + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT, GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + int blocks_num = this->blocks_num_; + cl_kernel kernel[blocks_num]; + + int i; + cout <<"blocks_num : "<device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &height_each); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_5"); + } + + int height_left = this->inner_datas_[blocks_num - 1].height(); + kernel[i] = clCreateKernel(program, "sigMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_int), &height_left); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_5"); + + + + for(int i = 0; i < blocks_num ; i+=2) + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[i].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel[i] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[i+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueB(), kernel[i+1] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + } + + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_FALSE, 0, this->inner_results_[i].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_results_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[1], CL_FALSE, 0, this->inner_results_[i+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_results_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + } + } + + + clFinish(this->device_->DeviceCmdQueA()); + clFinish(this->device_->DeviceCmdQueB()); + + //clean + for(int i = 0; i < blocks_num; i++) + { + clReleaseKernel(kernel[i]); + } + clReleaseProgram(program); +} +template +void +SigmoidStreamTask::PostCompute() +{ +} + + +template class SigmoidStreamTask; +template class SigmoidStreamTask; +template class SigmoidStreamTask; +template class SigmoidStreamTask; diff --git a/src/get/SigmoidTask.cpp b/src/get/SigmoidTask.cpp new file mode 100644 index 0000000..539ff51 --- /dev/null +++ b/src/get/SigmoidTask.cpp @@ -0,0 +1,131 @@ +/************************************************************************* + > File Name: SigmoidTask.cpp + > Author: crows + > Mail: 136211494@qq.com + > Created Time: Tue 24 Mar 2015 02:55:48 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"BaseTask.h" +#include +#include"opencl_z.h" +#include + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + +template +void +SigmoidTask::PreCompute() +{ + //set device buffer + cl_mem temp_buffer; + cl_int status; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + + this->device_buffers_.push_back(temp_buffer); + + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + + this->device_buffers_.push_back(temp_buffer); + + //copy data from host to device + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_TRUE, 0, height_*width_*sizeof(Dtype), (Dtype*)(this->datas_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + +} + +template +void +SigmoidTask::Compute() +{ + //read program source and create kernel + cl_int status; + + char* program_source =NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/sig_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + if (status != CL_SUCCESS) + { + char buildLog[16384]; + clGetProgramBuildInfo(program, temp_device, CL_PROGRAM_BUILD_LOG, sizeof(buildLog), buildLog, NULL); + cout << "Error in Kernel : "< GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (height_ + GLOBAL_WORKGROUP_HEIGHT - 1) / GLOBAL_WORKGROUP_HEIGHT; + + if (width_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (width_ + GLOBAL_WORKGROUP_WIDTH - 1) / GLOBAL_WORKGROUP_WIDTH; + + //set kernel arg + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &height_); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel, 4, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel, 5, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_5"); + + //start computing + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT , GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + cout<<"task start on GPU"<device_->DeviceCmdQueA()); +} + +template +void +SigmoidTask::PostCompute() +{ + //copy memory from device to host + cl_int status; + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_TRUE, 0, height_*width_*sizeof(Dtype), (Dtype*)(this->results_[0].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + clFinish(this->device_->DeviceCmdQueA()); + // clReleaseMemObject(this->device_buffers_[0]); + // clReleaseMemObject(this->device_buffers_[1]); + // clReleaseMemObject(this->device_buffers_[2]); +} + +template class SigmoidTask; +template class SigmoidTask; +template class SigmoidTask; +template class SigmoidTask; + diff --git a/src/get/SigmoidTaskDispatcher.cpp b/src/get/SigmoidTaskDispatcher.cpp new file mode 100644 index 0000000..d993f9c --- /dev/null +++ b/src/get/SigmoidTaskDispatcher.cpp @@ -0,0 +1,198 @@ +/************************************************************************* + > File Name: SigmoidTaskDispatcher.cpp + > Author: + > Mail: + > Created Time: Mon 13 Apr 2015 03:43:08 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"TaskDispatcher.h" + +template +void +SigmoidTaskDispatcher::PreTaskDispatch() +{ + //create input blobs and output blobs + GET::TaskParam_DataPosition source_pos = this->task_param_.source_pos(); + int input_blob_num = 0; + if (source_pos == 1) + input_blob_num = this->task_param_.sourcem_size(); + else if (source_pos == 0) + input_blob_num = this->task_param_.sourcef_size(); + + cout <<"input_blob_num = "<task_param_.result_pos(); + int output_blob_num = 0; + if (result_pos == 1) + output_blob_num = this->task_param_.resultm_size(); + else if (result_pos == 0) + output_blob_num = this->task_param_.resultf_size(); + + DataBlob blob; + blob.ReshapeLike(1, 1, sig_param_.height(), sig_param_.width()); + //prepare input + this->datablobs_input_.push_back(blob); + this->PrepareBlob(input_blob_num, source_pos,(this->datablobs_input_), SOURCE); + //prepare output + this->datablobs_output_.push_back(blob); + this->PrepareBlob(output_blob_num, result_pos,(this->datablobs_output_), RESULT); + +} + +template +void +SigmoidTaskDispatcher::TaskDispatch() +{ + //get device : mem_used + long mem_used = 0; + for(int i = 0; i< this->datablobs_input_.size(); i++) + mem_used += this->datablobs_input_[i].count() * sizeof(Dtype); + for(int i = 0; idatablobs_output_.size(); i++) + mem_used += this->datablobs_output_[i].count() * sizeof(Dtype); + this->mem_used_ =mem_used; + + bool device_enough = false; + while (!device_enough) + { + BaseDevice* device; + if ((device = this->device_manager_->GetAvailableDevice()) != NULL) + { + this->devices_needed_.push_back(device); + if(mem_used < (device->GlobalMemory())) + { + device_enough = true; + } + else + { + mem_used -= device->GlobalMemory(); + } + } + else + { + break; + } + } + + cout << "Need Devices = "<< this->devices_needed_.size() <devices_needed_.size()); + //compute the weight for load banlancing + unsigned int height_each = 0; + unsigned int offset = 0; + int denum = this->devices_needed_.size(); + unsigned long sum_gl_mem = 0; + float left_part =0; + float weight[denum]; + + for(int i = 0; i devices_needed_[i]->GlobalMemory(); + } + + for(int i =0; idevices_needed_[i]->GlobalMemory())/((float)sum_gl_mem); + left_part += weight[i]; + } + weight[denum-1] = 1 - left_part; + + unsigned int height_current = 0; + + int i= 0; + for( i = 0; i < this->devices_needed_.size()-1;i++) + { + + height_each = sig_param_.height() * weight[i]; + height_current += height_each; + SigmoidTask* temp_task = new SigmoidTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, sig_param_.width()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, sig_param_.width()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset)); + temp_task->AddtoDatas(inner); + //create inner_results + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset)); + temp_task->AddtoResults(inner); + + this->ordinary_tasks_.push_back(temp_task); + cout<<"create " << i << " relutask" <* temp_task = new SigmoidTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, sig_param_.width()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, sig_param_.width()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset)); + temp_task->AddtoDatas(inner); + //create inner_results + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset)); + temp_task->AddtoResults(inner); + + //attach the blob to the task + + this->ordinary_tasks_.push_back(temp_task); + + this->process_type_ = ORDINARY; + } + else + { + // stream task + //decide to compute on one or more devices + + //compute on the device with the biggest global memory + unsigned long max_glmem = 0; + BaseDevice* stream_device = NULL ; + for(int i = 0; i < this->devices_needed_.size(); i++) + { + if(max_glmem < this->devices_needed_[i]->GlobalMemory()) + { + stream_device = this->devices_needed_[i]; + max_glmem = this->devices_needed_[i]->GlobalMemory(); + } + } + cout<< "stream device determined"<::iterator l_it = this->devices_needed_.begin(); + + while(l_it != this->devices_needed_.end()) + { + if((*l_it) != stream_device) + { + this->device_manager_->FreeDevice((*l_it)->DeviceIndex()); + l_it = this->devices_needed_.erase(l_it); + } + else + { + l_it++; + } + } + + SigmoidStreamTask* temp_adst = new SigmoidStreamTask(stream_device); + temp_adst->SetParams(1, sig_param_.height(), sig_param_.width()); + temp_adst->AddtoDatas(this->datablobs_input_[0]); + temp_adst->AddtoResults(this->datablobs_output_[0]); + this->stream_tasks_.push_back(temp_adst); + + cout<< "create stream tasks" <process_type_ =STREAM; + + + } + +} + +template class SigmoidTaskDispatcher; +template class SigmoidTaskDispatcher; +template class SigmoidTaskDispatcher; +template class SigmoidTaskDispatcher; + diff --git a/src/get/SubStreamTask.cpp b/src/get/SubStreamTask.cpp new file mode 100644 index 0000000..b0d8298 --- /dev/null +++ b/src/get/SubStreamTask.cpp @@ -0,0 +1,238 @@ +/************************************************************************* + > File Name: AddStreamTask.cpp + > Author: + > Mail: + > Created Time: Mon 20 Apr 2015 05:14:02 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"opencl_z.h" + +#define GLOBAL_WORKGROUP_HEIGHT 4096 +#define GLOBAL_WORKGROUP_WIDTH 4096 +#define LOCAL_WORKGROUP_HEIGHT 16 +#define LOCAL_WORKGROUP_WIDTH 16 + + +template +SubStreamTask::~SubStreamTask() +{ + +} + +template +void +SubStreamTask :: PreCompute() +{ + //reshape the block for addstream + long mem_used = 0; + mem_used += this->datas_[0].count() * sizeof(Dtype); + + int temp_num = 0; + + temp_num = (mem_used + this->data_per_block_ - 1) / this->data_per_block_; + + //create inner blocks + unsigned int height_each = height_ / temp_num; + unsigned int height_left = 0; + unsigned int offset = 0; + for(int i = 0; i < temp_num - 1 ; i++) + { + DataBlob inner; + + inner.ReshapeLike(1, 1, height_each, width_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset); + this->inner_datas_.push_back(inner); + inner.CopyFromMemory(this->datas_[1].host_data()+offset); + this->inner_datas_.push_back(inner); + + inner.CopyFromMemory(this->results_[0].host_data()+offset); + this->inner_results_.push_back(inner); + + offset += height_each*width_*sizeof(Dtype); + height_left += height_each; + } + + height_left = height_ - height_left; + DataBlob inner; + inner.ReshapeLike(1, 1, height_left, width_); + inner.CopyFromMemory(this->datas_[0].host_data()+offset); + this->inner_datas_.push_back(inner); + inner.CopyFromMemory(this->datas_[1].host_data()+offset); + this->inner_datas_.push_back(inner); + + inner.CopyFromMemory(this->results_[0].host_data()+offset); + this->inner_results_.push_back(inner); + + this->blocks_num_ = temp_num; + //create the buffer + cl_int status; + cl_mem temp_buffer; + + cl_int datasize; + + datasize = this->device_->GlobalMemory() / 3; + + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize , NULL, &status); + cl_pStatus(status, "clCreateBuffer 1"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, datasize, NULL, &status); + cl_pStatus(status, "clCreateBuffer 2"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, datasize, NULL, &status); + cl_pStatus(status, "clCreateBuffer 3"); + this->device_buffers_.push_back(temp_buffer); + + /* + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 4"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_READ_ONLY, height_each *width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 5"); + this->device_buffers_.push_back(temp_buffer); + temp_buffer = clCreateBuffer(this->device_->DeviceContext(), CL_MEM_WRITE_ONLY, height_each*width_*sizeof(Dtype), NULL, &status); + cl_pStatus(status, "clCreateBuffer 6"); + this->device_buffers_.push_back(temp_buffer); +*/ + +} +template +void +SubStreamTask::Compute() +{ + cl_int status; + + if (typeid(Dtype) == typeid(int)) + char* program_source = cl_readSource("sub_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + char* program_source = cl_readSource("sub_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + char* program_source = cl_readSource("sub_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + char* program_source = cl_readSource("sub_task_double.cl"); + + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); + cl_pStatus(status, "clCreateProgramWithSource"); + + cl_device_id temp_device = this->device_->DeviceID(); + status = clBuildProgram(program, 1, &temp_device, NULL, NULL, NULL); + cl_pStatus(status, "clBuildProgram"); + + cl_uint thread_division_row = 1; + cl_uint thread_division_col = 1; + int height_each = this->inner_datas_[0].height(); + + if (height_each > GLOBAL_WORKGROUP_HEIGHT) + thread_division_row = (height_each + GLOBAL_WORKGROUP_HEIGHT- 1 ) / GLOBAL_WORKGROUP_HEIGHT; + + if (width_ > GLOBAL_WORKGROUP_WIDTH) + thread_division_col = (width_ + GLOBAL_WORKGROUP_WIDTH -1 ) / GLOBAL_WORKGROUP_WIDTH; + + size_t global_worksize[2] = {GLOBAL_WORKGROUP_HEIGHT, GLOBAL_WORKGROUP_WIDTH}; + size_t local_worksize[2] = {LOCAL_WORKGROUP_HEIGHT, LOCAL_WORKGROUP_WIDTH}; + + int blocks_num = this->blocks_num_; + cl_kernel kernel[blocks_num]; + + int i; + cout <<"blocks_num : "<device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &height_each); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_6"); + } + + int height_left = this->inner_datas_[blocks_num - 1].height(); + kernel[i] = clCreateKernel(program, "subMatrix", &status); + cl_pStatus(status, "clCreateKernel"); + status = clSetKernelArg(kernel[i], 0, sizeof(cl_mem), &(this->device_buffers_[2])); + cl_pStatus(status, "clSetKernelArg_0"); + status = clSetKernelArg(kernel[i], 1, sizeof(cl_int), &width_); + cl_pStatus(status, "clSetKernelArg_1"); + status = clSetKernelArg(kernel[i], 2, sizeof(cl_int), &height_left); + cl_pStatus(status, "clSetKernelArg_2"); + status = clSetKernelArg(kernel[i], 3, sizeof(cl_uint), &thread_division_row); + cl_pStatus(status, "clSetKernelArg_3"); + status = clSetKernelArg(kernel[i], 4, sizeof(cl_uint), &thread_division_col); + cl_pStatus(status, "clSetKernelArg_4"); + status = clSetKernelArg(kernel[i], 5, sizeof(cl_mem), &(this->device_buffers_[0])); + cl_pStatus(status, "clSetKernelArg_5"); + status = clSetKernelArg(kernel[i], 6, sizeof(cl_mem), &(this->device_buffers_[1])); + cl_pStatus(status, "clSetKernelArg_6"); + + + + for(int i = 0; i < blocks_num ; i+=2) + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[2*i].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[1], CL_FALSE, 0, this->inner_datas_[2*i+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueA(), kernel[i] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[0], CL_FALSE, 0, this->inner_datas_[2*(i+1)].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*(i+1)].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueWriteBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[1], CL_FALSE, 0, this->inner_datas_[2*(i+1)+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_datas_[2*(i+1)+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueWriteBuffer"); + status = clEnqueueNDRangeKernel(this->device_->DeviceCmdQueB(), kernel[i+1] , 2, NULL , global_worksize, local_worksize, 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueNDRangeKernel"); + } + + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueA(),this->device_buffers_[2], CL_FALSE, 0, this->inner_results_[i].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_results_[i].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + if (i+1 == blocks_num) + { + + } + else + { + status = clEnqueueReadBuffer(this->device_->DeviceCmdQueB(),this->device_buffers_[2], CL_FALSE, 0, this->inner_results_[i+1].height()*width_*sizeof(Dtype), (Dtype*)(this->inner_results_[i+1].host_data()), 0, NULL, NULL ); + cl_pStatus(status, "clEnqueueReadBuffer"); + } + } + + + clFinish(this->device_->DeviceCmdQueA()); + clFinish(this->device_->DeviceCmdQueB()); + + //clean + for(int i = 0; i < blocks_num; i++) + { + clReleaseKernel(kernel[i]); + } + clReleaseProgram(program); +} +template +void +SubStreamTask::PostCompute() +{ +} + + +template class SubStreamTask; +template class SubStreamTask; +template class SubStreamTask; +template class SubStreamTask; diff --git a/src/get/SubTask.cpp b/src/get/SubTask.cpp index 623386c..442a332 100644 --- a/src/get/SubTask.cpp +++ b/src/get/SubTask.cpp @@ -7,9 +7,10 @@ #include using namespace std; -#include"SubTask.h" +#include"BaseTask.h" #include #include"opencl_z.h" +#include #define GLOBAL_WORKGROUP_HEIGHT 4096 #define GLOBAL_WORKGROUP_WIDTH 4096 @@ -55,7 +56,18 @@ SubTask::Compute() { //read program source and create kernel cl_int status; - char* program_source = cl_readSource("/home/zy/GET/src/get/sub_task.cl"); + + char * program_source = NULL; + + if (typeid(Dtype) == typeid(int)) + program_source = cl_readSource("/home/zy/GET/src/get/sub_task_int.cl"); + if (typeid(Dtype) == typeid(long)) + program_source = cl_readSource("/home/zy/GET/src/get/sub_task_long.cl"); + if (typeid(Dtype) == typeid(float)) + program_source = cl_readSource("/home/zy/GET/src/get/sub_task_float.cl"); + if (typeid(Dtype) == typeid(double)) + program_source = cl_readSource("/home/zy/GET/src/get/sub_task_double.cl"); + cl_program program = clCreateProgramWithSource(this->device_->DeviceContext(), 1, (const char**)&program_source, NULL, &status); cl_pStatus(status, "clCreateProgramWithSource"); diff --git a/src/get/SubTaskDispatcher.cpp b/src/get/SubTaskDispatcher.cpp new file mode 100644 index 0000000..30ecd31 --- /dev/null +++ b/src/get/SubTaskDispatcher.cpp @@ -0,0 +1,204 @@ +/************************************************************************* + > File Name: SubTaskDispatcher.cpp + > Author: + > Mail: + > Created Time: Mon 13 Apr 2015 03:43:08 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"TaskDispatcher.h" + +template +void +SubTaskDispatcher::PreTaskDispatch() +{ + //create input blobs and output blobs + GET::TaskParam_DataPosition source_pos = this->task_param_.source_pos(); + int input_blob_num = 0; + if (source_pos == 1) + input_blob_num = this->task_param_.sourcem_size(); + else if (source_pos == 0) + input_blob_num = this->task_param_.sourcef_size(); + + cout <<"input_blob_num = "<task_param_.result_pos(); + int output_blob_num = 0; + if (result_pos == 1) + output_blob_num = this->task_param_.resultm_size(); + else if (result_pos == 0) + output_blob_num = this->task_param_.resultf_size(); + + DataBlob blob; + blob.ReshapeLike(1, 1, sub_param_.height(), sub_param_.width()); + //prepare input + this->datablobs_input_.push_back(blob); + this->datablobs_input_.push_back(blob); + this->PrepareBlob(input_blob_num, source_pos,(this->datablobs_input_), SOURCE); + //prepare output + this->datablobs_output_.push_back(blob); + this->PrepareBlob(output_blob_num, result_pos,(this->datablobs_output_), RESULT); + +} + +template +void +SubTaskDispatcher::TaskDispatch() +{ + //get device : mem_used + long mem_used = 0; + for(int i = 0; i< this->datablobs_input_.size(); i++) + mem_used += this->datablobs_input_[i].count() * sizeof(Dtype); + for(int i = 0; idatablobs_output_.size(); i++) + mem_used += this->datablobs_output_[i].count() * sizeof(Dtype); + this->mem_used_ =mem_used; + + bool device_enough = false; + while (!device_enough) + { + BaseDevice* device; + if ((device = this->device_manager_->GetAvailableDevice()) != NULL) + { + this->devices_needed_.push_back(device); + if(mem_used < (device->GlobalMemory())) + { + device_enough = true; + } + else + { + mem_used -= device->GlobalMemory(); + } + } + else + { + break; + } + } + + cout << "Need Devices = "<< this->devices_needed_.size() <devices_needed_.size()); + //compute the weight for load banlancing + unsigned int height_each = 0; + unsigned int offset = 0; + int denum = this->devices_needed_.size(); + unsigned long sum_gl_mem = 0; + float left_part =0; + float weight[denum]; + + for(int i = 0; i devices_needed_[i]->GlobalMemory(); + } + + for(int i =0; idevices_needed_[i]->GlobalMemory())/((float)sum_gl_mem); + left_part += weight[i]; + } + weight[denum-1] = 1 - left_part; + + unsigned int height_current = 0; + + int i= 0; + for( i = 0; i < this->devices_needed_.size()-1;i++) + { + + height_each = sub_param_.height() * weight[i]; + height_current += height_each; + SubTask* temp_task = new SubTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, sub_param_.width()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, sub_param_.width()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset)); + temp_task->AddtoDatas(inner); + inner.CopyFromMemory((void*)((this->datablobs_input_[1].host_data())+offset)); + temp_task->AddtoDatas(inner); + //create inner_results + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset)); + temp_task->AddtoResults(inner); + + this->ordinary_tasks_.push_back(temp_task); + cout<<"create " << i << " subtask" <* temp_task = new SubTask(this->devices_needed_[i]); + temp_task->SetParams(1, height_each, sub_param_.width()); + //create inner + DataBlob inner; + inner.ReshapeLike(1, 1, height_each, sub_param_.width()); + inner.CopyFromMemory((void*)((this->datablobs_input_[0].host_data())+offset)); + temp_task->AddtoDatas(inner); + inner.CopyFromMemory((void*)((this->datablobs_input_[1].host_data())+offset)); + temp_task->AddtoDatas(inner); + //create inner_results + inner.CopyFromMemory((void*)((this->datablobs_output_[0].host_data())+offset)); + temp_task->AddtoResults(inner); + + //attach the blob to the task + + this->ordinary_tasks_.push_back(temp_task); + + this->process_type_ = ORDINARY; + } + else + { + // stream task + //decide to compute on one or more devices + + //compute on the device with the biggest global memory + unsigned long max_glmem = 0; + BaseDevice* stream_device = NULL ; + for(int i = 0; i < this->devices_needed_.size(); i++) + { + if(max_glmem < this->devices_needed_[i]->GlobalMemory()) + { + stream_device = this->devices_needed_[i]; + max_glmem = this->devices_needed_[i]->GlobalMemory(); + } + } + cout<< "stream device determined"<::iterator l_it = this->devices_needed_.begin(); + + while(l_it != this->devices_needed_.end()) + { + if((*l_it) != stream_device) + { + this->device_manager_->FreeDevice((*l_it)->DeviceIndex()); + l_it = this->devices_needed_.erase(l_it); + } + else + { + l_it++; + } + } + + SubStreamTask* temp_adst = new SubStreamTask(stream_device); + temp_adst->SetParams(1, sub_param_.height(), sub_param_.width()); + temp_adst->AddtoDatas(this->datablobs_input_[0]); + temp_adst->AddtoDatas(this->datablobs_input_[1]); + temp_adst->AddtoResults(this->datablobs_output_[0]); + this->stream_tasks_.push_back(temp_adst); + + cout<< "create stream tasks" <process_type_ =STREAM; + + + } + +} + +template class SubTaskDispatcher; +template class SubTaskDispatcher; +template class SubTaskDispatcher; +template class SubTaskDispatcher; + diff --git a/src/get/TaskDispatcher.cpp b/src/get/TaskDispatcher.cpp index 2c88da9..68bc0c6 100644 --- a/src/get/TaskDispatcher.cpp +++ b/src/get/TaskDispatcher.cpp @@ -10,7 +10,7 @@ using namespace std; #include"TaskDispatcher.h" #include #include - +#include template TaskDispatcher:: ~TaskDispatcher() @@ -90,14 +90,23 @@ template void TaskDispatcher::ComputeAll() { + int thread_error; switch (process_type_) { case ORDINARY : { - for(int i = 0; i < ordinary_tasks_.size(); i++) + int tasks_num = ordinary_tasks_.size(); + pthread_t tid[tasks_num]; + for(int i = 0; i < tasks_num; i++) { cout<<"task "<< i << " begin" <TaskOn(); + thread_error = pthread_create(&tid[i], NULL, BaseTaskOn,(void*)ordinary_tasks_[i]); + + } + + for(int i = 0; i < tasks_num; i++ ) + { + pthread_join(tid[i], NULL); } break; } diff --git a/src/get/TaskManager.cpp b/src/get/TaskManager.cpp index d0b67c9..ed1fe1a 100644 --- a/src/get/TaskManager.cpp +++ b/src/get/TaskManager.cpp @@ -8,16 +8,38 @@ #include using namespace std; + +template +TaskDispacher* +TaskManager::GetTaskDispatcher(GET::TaskParam param) +{ + GET::TaskParam_TaskType type = param.type(); + switch(type) + { + case TaskParam_TaskType_ADD : return new AddTaskDispatcher(param, &device_manager_); + case TaskParam_TaskType_SUB : return new SubTaskDispatcher(param, &device_manager_); + case TaskParam_TaskType_MULTI : return new MulTaskDispatcher(param, &device_manager_); + case TaskParam_TaskType_CONVOLUTION : return new ConvTaskDispatcher(param, &device_manager_); + case TaskParam_TaskType_POOL : return new PoolTaskDispatcher(param, &device_manager_); + default: + { + cout<<"Unknown Computing Type"< int TaskManager::TaskRequestLocal(GET::TaskParam param) { - TaskDispatcher task(param, &device_manager_); + TaskDispatcher* task = GetTaskDispatcher(); tasks_.insert(make_pair(++task_num_, task)); //tasks_params_.insert(make_pair(task_num_, param)); tasks_status_.insert(make_pair(task_num_, TASKWAIT)); //send the task to threadpool if there is any available thread cout << "TaskRequest get" < int TaskManager::TaskRequestLocal(const char* filename) { - + } template @@ -54,6 +76,12 @@ TaskManager::Init() //Init the main Epoll } +template +int +TaskManager::TaskOn() +{ +} + template class TaskManager; template class TaskManager; template class TaskManager; diff --git a/src/get/add_task.cl b/src/get/add_task_double.cl similarity index 100% rename from src/get/add_task.cl rename to src/get/add_task_double.cl diff --git a/src/get/add_task_float.cl b/src/get/add_task_float.cl new file mode 100644 index 0000000..d4f2e8a --- /dev/null +++ b/src/get/add_task_float.cl @@ -0,0 +1,28 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void addMatrix(__global float* OutputC, + int width, + int height, + unsigned int divT_r, + unsigned int divT_c, + __global float* InputA, + __global float* InputB) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + + if ((row_t < height)&&(col_t < width)) + { + for(int row = row_t; row < row_t+divT_r; row++) + { + for (int col = col_t; col< col_t+divT_c; col++) + { + if ((row < height)&&(col < width)) + { + OutputC[row*width+col] = InputA[row*width+col]+InputB[row*width+col]; + } + } + } + } + +} diff --git a/src/get/add_task_int.cl b/src/get/add_task_int.cl new file mode 100644 index 0000000..89cdad2 --- /dev/null +++ b/src/get/add_task_int.cl @@ -0,0 +1,28 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void addMatrix(__global int* OutputC, + int width, + int height, + unsigned int divT_r, + unsigned int divT_c, + __global int* InputA, + __global int* InputB) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + + if ((row_t < height)&&(col_t < width)) + { + for(int row = row_t; row < row_t+divT_r; row++) + { + for (int col = col_t; col< col_t+divT_c; col++) + { + if ((row < height)&&(col < width)) + { + OutputC[row*width+col] = InputA[row*width+col]+InputB[row*width+col]; + } + } + } + } + +} diff --git a/src/get/add_task_long.cl b/src/get/add_task_long.cl new file mode 100644 index 0000000..ab760c9 --- /dev/null +++ b/src/get/add_task_long.cl @@ -0,0 +1,28 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void addMatrix(__global long* OutputC, + int width, + int height, + unsigned int divT_r, + unsigned int divT_c, + __global long* InputA, + __global long* InputB) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + + if ((row_t < height)&&(col_t < width)) + { + for(int row = row_t; row < row_t+divT_r; row++) + { + for (int col = col_t; col< col_t+divT_c; col++) + { + if ((row < height)&&(col < width)) + { + OutputC[row*width+col] = InputA[row*width+col]+InputB[row*width+col]; + } + } + } + } + +} diff --git a/src/get/conv_task_double.cl b/src/get/conv_task_double.cl new file mode 100644 index 0000000..a2911bf --- /dev/null +++ b/src/get/conv_task_double.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void convMatrix( + __global double* imageIn, + __global double* imageOut, + __constant double* filter, + unsigned int rowsout, + unsigned int colsout, + unsigned int rowsin, + unsigned int colsin, + unsigned int filterWidth, + __local double* localImage, + unsigned int localHeight, + unsigned int localWidth) +{ + int filterRadius = (filterWidth/2); + int padding = filterRadius * 2; + + int groupStartCol = get_group_id(1)*get_local_size(1); + int groupStartRow = get_group_id(0)*get_local_size(0); + + int localCol = get_local_id(1); + int localRow = get_local_id(0); + + int globalCol = groupStartCol + localCol; + int globalRow = groupStartRow + localRow; + + for(int i = localRow; i < localHeight; i += get_local_size(0)) + { + int curRow = groupStartRow+i; + for(int j = localCol; j < localWidth; j += get_local_size(1)) + { + int curCol = groupStartCol+j; + if(curRow < rowsin && curCol < colsin) + { + localImage[i*localWidth + j] = imageIn[curRow*colsin+curCol]; + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(globalRow < rowsout && globalCol < colsout) + { + float sum = 0.0f; + int filterIdx = 0; + for(int i = localRow; i < localRow+filterWidth; i++) + { + int offset = i*localWidth; + for(int j = localCol; j < localCol+filterWidth; j++) + { + sum += localImage[offset+j] * filter[filterIdx++]; + } + } + imageOut[(globalRow)*colsout +(globalCol)] = sum; + } + +} diff --git a/src/get/conv_task_float.cl b/src/get/conv_task_float.cl new file mode 100644 index 0000000..0464fb3 --- /dev/null +++ b/src/get/conv_task_float.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void convMatrix( + __global float* imageIn, + __global float* imageOut, + __constant float* filter, + unsigned int rowsout, + unsigned int colsout, + unsigned int rowsin, + unsigned int colsin, + unsigned int filterWidth, + __local float* localImage, + unsigned int localHeight, + unsigned int localWidth) +{ + int filterRadius = (filterWidth/2); + int padding = filterRadius * 2; + + int groupStartCol = get_group_id(1)*get_local_size(1); + int groupStartRow = get_group_id(0)*get_local_size(0); + + int localCol = get_local_id(1); + int localRow = get_local_id(0); + + int globalCol = groupStartCol + localCol; + int globalRow = groupStartRow + localRow; + + for(int i = localRow; i < localHeight; i += get_local_size(0)) + { + int curRow = groupStartRow+i; + for(int j = localCol; j < localWidth; j += get_local_size(1)) + { + int curCol = groupStartCol+j; + if(curRow < rowsin && curCol < colsin) + { + localImage[i*localWidth + j] = imageIn[curRow*colsin+curCol]; + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(globalRow < rowsout && globalCol < colsout) + { + float sum = 0.0f; + int filterIdx = 0; + for(int i = localRow; i < localRow+filterWidth; i++) + { + int offset = i*localWidth; + for(int j = localCol; j < localCol+filterWidth; j++) + { + sum += localImage[offset+j] * filter[filterIdx++]; + } + } + imageOut[(globalRow)*colsout +(globalCol)] = sum; + } + +} diff --git a/src/get/conv_task_int.cl b/src/get/conv_task_int.cl new file mode 100644 index 0000000..668a2ab --- /dev/null +++ b/src/get/conv_task_int.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void convMatrix( + __global int* imageIn, + __global int* imageOut, + __constant int* filter, + unsigned int rowsout, + unsigned int colsout, + unsigned int rowsin, + unsigned int colsin, + unsigned int filterWidth, + __local int* localImage, + unsigned int localHeight, + unsigned int localWidth) +{ + int filterRadius = (filterWidth/2); + int padding = filterRadius * 2; + + int groupStartCol = get_group_id(1)*get_local_size(1); + int groupStartRow = get_group_id(0)*get_local_size(0); + + int localCol = get_local_id(1); + int localRow = get_local_id(0); + + int globalCol = groupStartCol + localCol; + int globalRow = groupStartRow + localRow; + + for(int i = localRow; i < localHeight; i += get_local_size(0)) + { + int curRow = groupStartRow+i; + for(int j = localCol; j < localWidth; j += get_local_size(1)) + { + int curCol = groupStartCol+j; + if(curRow < rowsin && curCol < colsin) + { + localImage[i*localWidth + j] = imageIn[curRow*colsin+curCol]; + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(globalRow < rowsout && globalCol < colsout) + { + float sum = 0.0f; + int filterIdx = 0; + for(int i = localRow; i < localRow+filterWidth; i++) + { + int offset = i*localWidth; + for(int j = localCol; j < localCol+filterWidth; j++) + { + sum += localImage[offset+j] * filter[filterIdx++]; + } + } + imageOut[(globalRow)*colsout +(globalCol)] = sum; + } + +} diff --git a/src/get/conv_task_long.cl b/src/get/conv_task_long.cl new file mode 100644 index 0000000..1f8f242 --- /dev/null +++ b/src/get/conv_task_long.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void convMatrix( + __global long* imageIn, + __global long* imageOut, + __constant long* filter, + unsigned int rowsout, + unsigned int colsout, + unsigned int rowsin, + unsigned int colsin, + unsigned int filterWidth, + __local long* localImage, + unsigned int localHeight, + unsigned int localWidth) +{ + int filterRadius = (filterWidth/2); + int padding = filterRadius * 2; + + int groupStartCol = get_group_id(1)*get_local_size(1); + int groupStartRow = get_group_id(0)*get_local_size(0); + + int localCol = get_local_id(1); + int localRow = get_local_id(0); + + int globalCol = groupStartCol + localCol; + int globalRow = groupStartRow + localRow; + + for(int i = localRow; i < localHeight; i += get_local_size(0)) + { + int curRow = groupStartRow+i; + for(int j = localCol; j < localWidth; j += get_local_size(1)) + { + int curCol = groupStartCol+j; + if(curRow < rowsin && curCol < colsin) + { + localImage[i*localWidth + j] = imageIn[curRow*colsin+curCol]; + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(globalRow < rowsout && globalCol < colsout) + { + float sum = 0.0f; + int filterIdx = 0; + for(int i = localRow; i < localRow+filterWidth; i++) + { + int offset = i*localWidth; + for(int j = localCol; j < localCol+filterWidth; j++) + { + sum += localImage[offset+j] * filter[filterIdx++]; + } + } + imageOut[(globalRow)*colsout +(globalCol)] = sum; + } + +} diff --git a/src/get/mul_task_double.cl b/src/get/mul_task_double.cl new file mode 100644 index 0000000..4409c37 --- /dev/null +++ b/src/get/mul_task_double.cl @@ -0,0 +1,36 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel +void mulMatrix( + __global double* InputA, + __global double* InputB, + __global double* OutputC, + unsigned int M, + unsigned int K, + unsigned int N, + unsigned int divT_r, + unsigned int divT_c) +{ + int col_t = get_global_id(1)*divT_c; + int row_t = get_global_id(0)*divT_r; + double sum = 0.0f; + + if ((row_t < M)&&(col_t < N)) + { + for(int row = row_t; row < row_t+divT_r; row++) + { + for (int col = col_t; col< col_t+divT_c; col++) + { + if ((row < M)&&(col #include using namespace std; -#include "AddTaskDispatcher.h" +#include "TaskDispatcher.h" #include #include int main(int argc, char** argv) diff --git a/src/test/test_convdispatcher.cpp b/src/test/test_convdispatcher.cpp new file mode 100644 index 0000000..9e8673c --- /dev/null +++ b/src/test/test_convdispatcher.cpp @@ -0,0 +1,98 @@ +/************************************************************************* + > File Name: test_adddispatcher.cpp + > Author: + > Mail: + > Created Time: Wed 15 Apr 2015 10:32:17 PM EDT + ************************************************************************/ + +#include +#include +#include +using namespace std; +#include "TaskDispatcher.h" +#include +#include +int main(int argc, char** argv) +{ + DeviceManager dm; + + dm.Init(); + + cout << "Number of Devices = "<(in); + uint64_t b = reinterpret_cast(f); + uint64_t c = reinterpret_cast(out); + + + task_param.add_sourcem(a); + task_param.add_sourcem(b); + task_param.add_resultm(c); + + + cout<< "SOURCE num = "<set_data_h(n); + pConvParam->set_data_w(n); + pConvParam->set_filter_h(3); + pConvParam->set_filter_w(3); + + + ConvTaskDispatcher ctd(task_param, &dm); + + ctd.TaskOn(); +/* + int count = 0; + int answer = 0 ; + for (int i = 0; i < K ; i++) + answer += A[i]*B[i]; + for(int i = 0; i File Name: test_addtask.cpp + > Author: + > Mail: + > Created Time: Tue 14 Apr 2015 02:43:26 AM EDT + ************************************************************************/ + +#include +#include +using namespace std; +#include"BaseTask.h" +#include"StreamTask.h" +#include"DeviceManager.h" +#include + +int +main(int argc ,char** argv) +{ + DeviceManager dm; + + dm.Init(); + + BaseDevice* device; + + device = dm.GetAvailableDevice(); + + ConvStreamTask task(device); + + int n = atoi(argv[1]); + int f = 3; + int inh = n; + int inw = n; + int outh = inh - f + 1; + int outw = inw - f + 1; + + int* in = (int*)malloc(inh*inw*sizeof(int)); + int filter[9] = { + 1 , 1 , 1 , + 1 , 1 , 1 , + 1 , 1 , 1 + }; + int* out = (int*)malloc(outh*outw*sizeof(int)); + + + + for(int i = 0; i < inh*inw ; i++) + { + in[i] = 1 ; + } + for(int i = 0; i < outh*outw; i++) + { + out[i] = 0; + } + + DataBlob data[3]; + + data[0].ReshapeLike(1 , 1, inh, inw); + data[1].ReshapeLike(1 , 1, f , f); + data[2].ReshapeLike(1 , 1, outh, outw); + + data[0].CopyFromMemory((void *)in); + data[1].CopyFromMemory((void *)filter); + data[2].CopyFromMemory((void *)out); + + task.SetParams(1, inh , inw, f, f, 1, 1, 0, 0); + task.AddtoDatas(data[0]); + task.AddtoDatas(data[1]); + task.AddtoResults(data[2]); + + task.TaskOn(); + + cout<<"task finished"< File Name: test_addtask.cpp + > Author: + > Mail: + > Created Time: Tue 14 Apr 2015 02:43:26 AM EDT + ************************************************************************/ + +#include +#include +using namespace std; +#include"BaseTask.h" +#include"DeviceManager.h" +#include + +int +main(int argc ,char** argv) +{ + DeviceManager dm; + + dm.Init(); + + BaseDevice* device; + + device = dm.GetAvailableDevice(); + + ConvTask task(device); + + int n = atoi(argv[1]); + int f = 3; + int inh = n; + int inw = n; + int outh = inh - f + 1; + int outw = inw - f + 1; + + int* in = (int*)malloc(inh*inw*sizeof(int)); + int filter[9] = { + 0 , 0 , 0 , + 0 , 1 , 0 , + 0 , 0 , 0 + }; + int* out = (int*)malloc(outh*outw*sizeof(int)); + + + + for(int i = 0; i < inh*inw ; i++) + { + if ( i % 2 == 0 ) + in[i] = 1 ; + else + in[i] = 2 ; + } + for(int i = 0; i < outh*outw; i++) + { + out[i] = 0; + } + + DataBlob data[3]; + + data[0].ReshapeLike(1 , 1, inh, inw); + data[1].ReshapeLike(1 , 1, f , f); + data[2].ReshapeLike(1 , 1, outh, outw); + + data[0].CopyFromMemory((void *)in); + data[1].CopyFromMemory((void *)filter); + data[2].CopyFromMemory((void *)out); + + task.SetParams(1, inh , inw, f, f, 1, 1, 0, 0); + task.AddtoDatas(data[0]); + task.AddtoDatas(data[1]); + task.AddtoResults(data[2]); + + pthread_t tid; + pthread_create(&tid, NULL, BaseTaskOn,(void*)&task); + pthread_join(tid,NULL); + cout<<"task finished"< File Name: test_adddispatcher.cpp + > Author: + > Mail: + > Created Time: Wed 15 Apr 2015 10:32:17 PM EDT + ************************************************************************/ + +#include +#include +#include +using namespace std; +#include "TaskDispatcher.h" +#include +#include +int main(int argc, char** argv) +{ + DeviceManager dm; + + dm.Init(); + + cout << "Number of Devices = "<(A); + uint64_t b = reinterpret_cast(B); + uint64_t c = reinterpret_cast(C); + + + task_param.add_sourcem(a); + task_param.add_sourcem(b); + task_param.add_resultm(c); + + + cout<< "SOURCE num = "<set_m(M); + pMulParam->set_k(K); + pMulParam->set_n(N); + + MulTaskDispatcher mtd(task_param, &dm); + + mtd.TaskOn(); + + int count = 0; + int answer = 0 ; + for (int i = 0; i < K ; i++) + answer += A[i]*B[i]; + for(int i = 0; i File Name: test_addstream.cpp + > Author: + > Mail: + > Created Time: Fri 24 Apr 2015 11:36:19 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"DeviceManager.h" +#include"DataBlob.h" +#include +#include + +int +main(int main , char** argv) +{ + int n = atoi(argv[1]); + int elements = n*n; + + int* A = (int *)malloc(elements*sizeof(int)); + int* B = (int *)malloc(elements*sizeof(int)); + int* C = (int *)malloc(elements*sizeof(int)); + + cout << "A address : "< data[3]; + + for(int i = 0 ; i < 3 ;i++) + data[i].ReshapeLike(1,1,n,n); + + data[0].CopyFromMemory((void*)A); + data[1].CopyFromMemory((void*)B); + data[2].CopyFromMemory((void*)C); + + BaseDevice* device = dm.GetAvailableDevice(); + + MulStreamTask mulst(device); + + mulst.SetParams(1, n, n, n); + mulst.AddtoDatas(data[0]); + mulst.AddtoDatas(data[1]); + mulst.AddtoResults(data[2]); + + mulst.TaskOn(); + int cnt = 0; + + int answer = 0; + for(int i = 0; i< n ;i++) + answer+= A[i]*B[i]; + for(int i = 0; i < elements ; i++) + { + if (C[i] != answer) + { + cnt++; + if (cnt % 50 == 0) + cout< File Name: test_addtask.cpp + > Author: + > Mail: + > Created Time: Tue 14 Apr 2015 02:43:26 AM EDT + ************************************************************************/ + +#include +#include +using namespace std; +#include"BaseTask.h" +#include"DeviceManager.h" +#include + +int +main(int argc ,char** argv) +{ + DeviceManager dm; + + dm.Init(); + + BaseDevice* device; + + device = dm.GetAvailableDevice(); + + MulTask task(device); + + int n = atoi(argv[1]); + + int h = n; + int w = n; + + int* A = (int*)malloc(h*w*sizeof(int)); + int* B = (int*)malloc(h*w*sizeof(int)); + int* C = (int*)malloc(h*w*sizeof(int)); + + for(int i = 0; i < h*w ; i++) + { + A[i] = 2; + B[i] = 1; + C[i] = 0; + } + + DataBlob data[3]; + + for(int i = 0; i < 3 ; i++) + data[i].ReshapeLike(1 , 1, h, w); + + data[0].CopyFromMemory((void *)A); + data[1].CopyFromMemory((void *)B); + data[2].CopyFromMemory((void *)C); + + task.SetParams(1, n, n ,n); + task.AddtoDatas(data[0]); + task.AddtoDatas(data[1]); + task.AddtoResults(data[2]); + + pthread_t tid; + pthread_create(&tid, NULL, BaseTaskOn,(void*)&task); + pthread_join(tid,NULL); + cout<<"task finished"< File Name: test_adddispatcher.cpp + > Author: + > Mail: + > Created Time: Wed 15 Apr 2015 10:32:17 PM EDT + ************************************************************************/ + +#include +#include +#include +using namespace std; +#include "TaskDispatcher.h" +#include +#include +int main(int argc, char** argv) +{ + DeviceManager dm; + + dm.Init(); + + cout << "Number of Devices = "<(in); + uint64_t b = reinterpret_cast(out); + + + task_param.add_sourcem(a); + task_param.add_resultm(b); + + + cout<< "SOURCE num = "<set_data_h(h); + pPoolParam->set_data_w(w); + pPoolParam->set_kernel_h(4); + pPoolParam->set_kernel_w(4); + pPoolParam->set_stride_h(4); + pPoolParam->set_stride_w(4); + + PoolTaskDispatcher rlu(task_param, &dm); + + rlu.TaskOn(); + + int count = 0; + + if (count == 0) + cout << "Right Answer" < File Name: test_addstream.cpp + > Author: + > Mail: + > Created Time: Fri 24 Apr 2015 11:36:19 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"DeviceManager.h" +#include"DataBlob.h" +#include +#include + +int +main(int main , char** argv) +{ + int n = atoi(argv[1]); + int elements = n*n; + + int* IN = (int *)malloc(elements*sizeof(int)); + int* OUT = (int *)malloc(elements*sizeof(int)); + + + for(int i = 0; i < elements ; i++) + { + IN[i] = 1; + OUT[i] = 0; + } + + DeviceManager dm; + dm.Init(); + + DataBlob data[2]; + + for(int i = 0 ; i < 2 ;i++) + data[i].ReshapeLike(1,1,n,n); + + data[0].CopyFromMemory((void*)IN); + data[1].CopyFromMemory((void*)OUT); + + BaseDevice* device = dm.GetAvailableDevice(); + + PoolStreamTask relust(device); + + relust.SetParams(1, n, n, 4, 4, 4, 4, 0 , 0); + relust.AddtoDatas(data[0]); + relust.AddtoResults(data[1]); + + relust.TaskOn(); + + cout<<"Right Answer "< File Name: test_addtask.cpp + > Author: + > Mail: + > Created Time: Tue 14 Apr 2015 02:43:26 AM EDT + ************************************************************************/ + +#include +#include +using namespace std; +#include"BaseTask.h" +#include"DeviceManager.h" +#include + +int +main(int argc ,char** argv) +{ + DeviceManager dm; + + dm.Init(); + + BaseDevice* device; + + device = dm.GetAvailableDevice(); + + PoolTask task(device); + + int n = atoi(argv[1]); + + int h = n; + int w = n; + int outh = h / 4; + int outw = w / 4; + + int* IN = (int*)malloc(h*w*sizeof(int)); + int* OUT = (int*)malloc(h*w*sizeof(int)); + + for(int i = 0; i < h*w ; i++) + { + if (i % 2 == 0) + IN[i] = 1; + else + IN[i] = 3 ; + } + + for (int i = 0; i< outh*outw; i++) + { + OUT[i] = 0; + } + DataBlob data[2]; + + for(int i = 0; i < 2 ; i++) + data[i].ReshapeLike(1 , 1, h, w); + + data[0].CopyFromMemory((void *)IN); + data[1].CopyFromMemory((void *)OUT); + + task.SetParams(1, h, w, 4, 4 ,4, 4, 0, 0); + task.AddtoDatas((data[0])); + task.AddtoResults((data[1])); + + + pthread_t tid; + pthread_create(&tid, NULL, BaseTaskOn,(void*)&task); + pthread_join(tid,NULL); + + for(int i =0 ; i < outh; i++) + { + for(int j = 0; j < outw ; j++) + cout<< OUT[i * outw + j] << " "; + cout << endl; + } + + cout<<"task finished"< File Name: test_adddispatcher.cpp + > Author: + > Mail: + > Created Time: Wed 15 Apr 2015 10:32:17 PM EDT + ************************************************************************/ + +#include +#include +#include +using namespace std; +#include "TaskDispatcher.h" +#include +#include +int main(int argc, char** argv) +{ + DeviceManager dm; + + dm.Init(); + + cout << "Number of Devices = "<(in); + uint64_t b = reinterpret_cast(out); + + + task_param.add_sourcem(a); + task_param.add_resultm(b); + + + cout<< "SOURCE num = "<set_height(h); + pReLUParam->set_width(w); + + ReLUTaskDispatcher rlu(task_param, &dm); + + rlu.TaskOn(); + + int count = 0; + + if (count == 0) + cout << "Right Answer" < File Name: test_addstream.cpp + > Author: + > Mail: + > Created Time: Fri 24 Apr 2015 11:36:19 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"DeviceManager.h" +#include"DataBlob.h" +#include +#include + +int +main(int main , char** argv) +{ + int n = atoi(argv[1]); + int elements = n*n; + + int* IN = (int *)malloc(elements*sizeof(int)); + int* OUT = (int *)malloc(elements*sizeof(int)); + + + for(int i = 0; i < elements ; i++) + { + IN[i] = 1; + OUT[i] = 0; + } + + DeviceManager dm; + dm.Init(); + + DataBlob data[2]; + + for(int i = 0 ; i < 2 ;i++) + data[i].ReshapeLike(1,1,n,n); + + data[0].CopyFromMemory((void*)IN); + data[1].CopyFromMemory((void*)OUT); + + BaseDevice* device = dm.GetAvailableDevice(); + + ReLUStreamTask relust(device); + + relust.SetParams(1, n, n); + relust.AddtoDatas(data[0]); + relust.AddtoResults(data[1]); + + relust.TaskOn(); + + cout<<"Right Answer "< File Name: test_addtask.cpp + > Author: + > Mail: + > Created Time: Tue 14 Apr 2015 02:43:26 AM EDT + ************************************************************************/ + +#include +#include +using namespace std; +#include"BaseTask.h" +#include"DeviceManager.h" +#include + +int +main(int argc ,char** argv) +{ + DeviceManager dm; + + dm.Init(); + + BaseDevice* device; + + device = dm.GetAvailableDevice(); + + ReLUTask task(device); + + int n = atoi(argv[1]); + + int h = n; + int w = n; + + int* IN = (int*)malloc(h*w*sizeof(int)); + int* OUT = (int*)malloc(h*w*sizeof(int)); + + for(int i = 0; i < h*w ; i++) + { + IN[i] = 1; + OUT[i] = 0; + } + + DataBlob data[2]; + + for(int i = 0; i < 2 ; i++) + data[i].ReshapeLike(1 , 1, h, w); + + data[0].CopyFromMemory((void *)IN); + data[1].CopyFromMemory((void *)OUT); + + task.SetParams(1, h, w); + task.AddtoDatas((data[0])); + task.AddtoResults((data[1])); + + + pthread_t tid; + pthread_create(&tid, NULL, BaseTaskOn,(void*)&task); + pthread_join(tid,NULL); + + cout<<"task finished"< File Name: test_adddispatcher.cpp + > Author: + > Mail: + > Created Time: Wed 15 Apr 2015 10:32:17 PM EDT + ************************************************************************/ + +#include +#include +#include +using namespace std; +#include "TaskDispatcher.h" +#include +#include +int main(int argc, char** argv) +{ + DeviceManager dm; + + dm.Init(); + + cout << "Number of Devices = "<(in); + uint64_t b = reinterpret_cast(out); + + + task_param.add_sourcem(a); + task_param.add_resultm(b); + + + cout<< "SOURCE num = "<set_height(h); + pSigmoidParam->set_width(w); + + SigmoidTaskDispatcher rlu(task_param, &dm); + + rlu.TaskOn(); + + int count = 0; + + if (count == 0) + cout << "Right Answer" < File Name: test_addstream.cpp + > Author: + > Mail: + > Created Time: Fri 24 Apr 2015 11:36:19 AM EDT + ************************************************************************/ + +#include +using namespace std; +#include"StreamTask.h" +#include"DeviceManager.h" +#include"DataBlob.h" +#include +#include + +int +main(int main , char** argv) +{ + int n = atoi(argv[1]); + int elements = n*n; + + int* IN = (int *)malloc(elements*sizeof(int)); + int* OUT = (int *)malloc(elements*sizeof(int)); + + + for(int i = 0; i < elements ; i++) + { + IN[i] = 1; + OUT[i] = 0; + } + + DeviceManager dm; + dm.Init(); + + DataBlob data[2]; + + for(int i = 0 ; i < 2 ;i++) + data[i].ReshapeLike(1,1,n,n); + + data[0].CopyFromMemory((void*)IN); + data[1].CopyFromMemory((void*)OUT); + + BaseDevice* device = dm.GetAvailableDevice(); + + SigmoidStreamTask relust(device); + + relust.SetParams(1, n, n); + relust.AddtoDatas(data[0]); + relust.AddtoResults(data[1]); + + relust.TaskOn(); + + cout<<"Right Answer "< File Name: test_addtask.cpp + > Author: + > Mail: + > Created Time: Tue 14 Apr 2015 02:43:26 AM EDT + ************************************************************************/ + +#include +#include +using namespace std; +#include"BaseTask.h" +#include"DeviceManager.h" +#include + +int +main(int argc ,char** argv) +{ + DeviceManager dm; + + dm.Init(); + + BaseDevice* device; + + device = dm.GetAvailableDevice(); + + SigmoidTask task(device); + + int n = atoi(argv[1]); + + int h = n; + int w = n; + + int* IN = (int*)malloc(h*w*sizeof(int)); + int* OUT = (int*)malloc(h*w*sizeof(int)); + + for(int i = 0; i < h*w ; i++) + { + IN[i] = 1; + OUT[i] = 0; + } + + DataBlob data[2]; + + for(int i = 0; i < 2 ; i++) + data[i].ReshapeLike(1 , 1, h, w); + + data[0].CopyFromMemory((void *)IN); + data[1].CopyFromMemory((void *)OUT); + + task.SetParams(1, h, w); + task.AddtoDatas((data[0])); + task.AddtoResults((data[1])); + + + pthread_t tid; + pthread_create(&tid, NULL, BaseTaskOn,(void*)&task); + pthread_join(tid,NULL); + + cout<<"task finished"< using namespace std; #include"BaseTask.h" -#include"CommonTask.h" #include"DeviceManager.h" int @@ -55,7 +54,7 @@ main(int argc ,char** argv) task.AddtoDatas(data[1]); task.AddtoResults(data[2]); - task.TaskOn(); + BaseTaskOn((void *)&task); cout<<"task finished"<