@@ -27,12 +27,12 @@ aclDataType type_mapping(ggml_type type) {
2727 * Transform ggml_tensor to acl_tensor. Note that ggml_tensor dimension order
2828 * is reversed compared to acl_tensor.
2929 *
30- * If bcast_ne and bcast_stride is nullptr, use ggml_tensor's ne and nb.
31- * otherwise, use bcast_ne bcast_stride , which means tensor dims should be
30+ * If bcast_ne and bcast_nb is nullptr, use ggml_tensor's ne and nb.
31+ * otherwise, use bcast_ne bcast_nb , which means tensor dims should be
3232 * changed to satisfy the broadcast. @sa: get_bcast_shape.
3333 */
3434aclTensor* create_acl_tensor (const ggml_tensor* tensor, int64_t * bcast_ne,
35- int64_t * bcast_stride , int64_t bcast_dims) {
35+ size_t * bcast_nb , int64_t bcast_dims, aclFormat format ) {
3636 size_t size = ggml_nbytes (tensor);
3737 void * deviceAddr = nullptr ;
3838
@@ -53,13 +53,13 @@ aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne,
5353 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
5454 acl_ne[i] = tensor->ne [i];
5555 // The step size of acl is in elements.
56- acl_stride[i] = tensor->nb [i] / tensor->nb [ 0 ] ;
56+ acl_stride[i] = tensor->nb [i] / ggml_type_size ( tensor->type ) ;
5757 }
5858 } else {
5959 // With bcast
6060 for (int i = 0 ; i < bcast_dims; i++) {
6161 acl_ne[i] = bcast_ne[i];
62- acl_stride[i] = bcast_stride [i] / tensor->nb [ 0 ] ;
62+ acl_stride[i] = bcast_nb [i] / ggml_type_size ( tensor->type ) ;
6363 }
6464 }
6565
@@ -69,13 +69,13 @@ aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne,
6969
7070 aclTensor* acl_tensor =
7171 aclCreateTensor (acl_ne, dims, type_mapping (tensor->type ), acl_stride, 0 ,
72- aclFormat::ACL_FORMAT_ND , acl_ne, dims, deviceAddr);
72+ format , acl_ne, dims, deviceAddr);
7373
7474 return acl_tensor;
7575}
7676
7777aclTensor* create_acl_tensor (void * data_ptr, aclDataType dtype, size_t type_size, int64_t * ne,
78- size_t * nb, int64_t dims) {
78+ size_t * nb, int64_t dims, aclFormat format ) {
7979
8080 int64_t tmp_ne[GGML_MAX_DIMS * 2 ];
8181 int64_t tmp_stride[GGML_MAX_DIMS * 2 ];
@@ -90,7 +90,7 @@ aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype, size_t type_size
9090
9191 aclTensor* acl_tensor =
9292 aclCreateTensor (tmp_ne, dims, dtype, tmp_stride, 0 ,
93- aclFormat::ACL_FORMAT_ND , tmp_ne, dims, data_ptr);
93+ format , tmp_ne, dims, data_ptr);
9494
9595 return acl_tensor;
9696}
@@ -132,26 +132,26 @@ aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype, size_t type_size
132132 */
133133int64_t get_bcast_shape (const ggml_tensor* src0, const ggml_tensor* src1,
134134 int64_t * bcast_ne_src0, int64_t * bcast_ne_src1,
135- int64_t * bcast_stride_src0 ,
136- int64_t * bcast_stride_src1 ) {
135+ size_t * bcast_nb_src0 ,
136+ size_t * bcast_nb_src1 ) {
137137 GGML_ASSERT (ggml_can_repeat (src1, src0));
138138 int bcast_dim_cnt = 0 ;
139139 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
140140 int64_t nr = src0->ne [i] / src1->ne [i];
141141 bcast_ne_src0[bcast_dim_cnt] = src0->ne [i] / nr;
142142 bcast_ne_src1[bcast_dim_cnt] = src1->ne [i];
143- bcast_stride_src0 [bcast_dim_cnt] = src0->nb [i];
144- bcast_stride_src1 [bcast_dim_cnt] = src1->nb [i];
143+ bcast_nb_src0 [bcast_dim_cnt] = src0->nb [i];
144+ bcast_nb_src1 [bcast_dim_cnt] = src1->nb [i];
145145 bcast_dim_cnt++;
146146 if (nr != 1 ) {
147147 // Need to add an extra dim.
148148 bcast_ne_src0[bcast_dim_cnt] = nr;
149149 bcast_ne_src1[bcast_dim_cnt] = 1 ;
150- bcast_stride_src0 [bcast_dim_cnt] =
151- bcast_stride_src0 [bcast_dim_cnt - 1 ] *
150+ bcast_nb_src0 [bcast_dim_cnt] =
151+ bcast_nb_src0 [bcast_dim_cnt - 1 ] *
152152 bcast_ne_src0[bcast_dim_cnt - 1 ];
153- bcast_stride_src1 [bcast_dim_cnt] =
154- bcast_stride_src1 [bcast_dim_cnt - 1 ] *
153+ bcast_nb_src1 [bcast_dim_cnt] =
154+ bcast_nb_src1 [bcast_dim_cnt - 1 ] *
155155 bcast_ne_src1[bcast_dim_cnt - 1 ];
156156 bcast_dim_cnt++;
157157 }
0 commit comments