Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_empty_tensor_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
if ((t.dtype() == DT_INT32 && t.scalar<int32_t>()() == -1) ||
(t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1)) {
*out = PartialTensorShape();
return Status::OK();
return ::tsl::OkStatus();
}
return errors::InvalidArgument(
"The only valid scalar shape tensor is the fully unknown shape "
Expand Down
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_fill_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ inline Status MusaFillCall(mTensor* out_mt, T value, OpKernelContext* context) {
return errors::Internal("mudnn run op error!");
}

return Status::OK();
return ::tsl::OkStatus();
}

struct SetZeroFunctor {
Expand Down
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_fill_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct is_any<T, First, Rest...>
// return errors::Internal("mtdnn run op error!");
// }

// return Status::OK();
// return ::tsl::OkStatus();
// }

} // namespace
Expand Down
22 changes: 11 additions & 11 deletions musa_ext/kernels/array/musa_tensordot_bias_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ REGISTER_OP("MusaTensorDotBias")

if (!c->RankKnown(a_shape) || !c->RankKnown(b_shape)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
return ::tsl::OkStatus();
}

std::vector<int> axes_a, axes_b;
Expand Down Expand Up @@ -92,7 +92,7 @@ REGISTER_OP("MusaTensorDotBias")
}

c->set_output(0, c->MakeShape(output_dims));
return Status::OK();
return ::tsl::OkStatus();
});

// =============================================================================
Expand Down Expand Up @@ -198,7 +198,7 @@ Status ComputeTensorDotDims(const TensorShape& a_shape,
}
}

return Status::OK();
return ::tsl::OkStatus();
}

template <typename T>
Expand Down Expand Up @@ -345,7 +345,7 @@ class MusaTensorDotBiasOp : public MusaOpKernel {
static_cast<int>(status));
}

return Status::OK();
return ::tsl::OkStatus();
}

Status DoTensorDotBias(OpKernelContext* ctx, const Tensor& a, const Tensor& b,
Expand Down Expand Up @@ -391,11 +391,11 @@ class MusaTensorDotBiasOp : public MusaOpKernel {
if (!output->CopyFrom(matmul_temp, output_shape)) {
return errors::Internal("Failed to reshape matmul result to output");
}
return Status::OK();
return ::tsl::OkStatus();
}

TF_RETURN_IF_ERROR(DoMatMulWithBias(ctx, a_2d, b_2d, bias_2d, &matmul_view));
return Status::OK();
return ::tsl::OkStatus();
}

// 准备张量:transpose + reshape 为 2D
Expand Down Expand Up @@ -434,7 +434,7 @@ class MusaTensorDotBiasOp : public MusaOpKernel {
}
}

return Status::OK();
return ::tsl::OkStatus();
}

// 准备 bias 张量:reshape 为 1D 向量 (N 维度)
Expand All @@ -448,15 +448,15 @@ class MusaTensorDotBiasOp : public MusaOpKernel {
if (!output->CopyFrom(bias, bias.shape())) {
return errors::Internal("Failed to copy bias");
}
return Status::OK();
return ::tsl::OkStatus();
}

// 如果 bias 是标量,需要 broadcast 到向量
if (bias.NumElements() == 1) {
TensorShape target_shape({n_dim});
TF_RETURN_IF_ERROR(ctx->allocate_temp(bias.dtype(), target_shape, output));
// TODO: 实现标量到向量的 broadcast
return Status::OK();
return ::tsl::OkStatus();
}

// 如果 bias 是 2D [M, N],需要提取 N 维度作为向量
Expand All @@ -465,7 +465,7 @@ class MusaTensorDotBiasOp : public MusaOpKernel {
TensorShape target_shape({n_dim});
TF_RETURN_IF_ERROR(ctx->allocate_temp(bias.dtype(), target_shape, output));
// TODO: 实现从 2D 提取 1D 的逻辑
return Status::OK();
return ::tsl::OkStatus();
}

// 尝试直接 reshape
Expand All @@ -476,7 +476,7 @@ class MusaTensorDotBiasOp : public MusaOpKernel {
" to ", target_shape.DebugString());
}

return Status::OK();
return ::tsl::OkStatus();
}
};

Expand Down
14 changes: 7 additions & 7 deletions musa_ext/kernels/array/musa_tensordot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ REGISTER_OP("MusaTensorDot")

if (!c->RankKnown(a_shape) || !c->RankKnown(b_shape)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
return ::tsl::OkStatus();
}

std::vector<int> axes_a, axes_b;
Expand Down Expand Up @@ -90,7 +90,7 @@ REGISTER_OP("MusaTensorDot")
}

c->set_output(0, c->MakeShape(output_dims));
return Status::OK();
return ::tsl::OkStatus();
});

// =============================================================================
Expand Down Expand Up @@ -196,7 +196,7 @@ Status ComputeTensorDotDims(const TensorShape& a_shape,
}
}

return Status::OK();
return ::tsl::OkStatus();
}

template <typename T>
Expand Down Expand Up @@ -320,7 +320,7 @@ class MusaTensorDotOp : public MusaOpKernel {
static_cast<int>(status));
}

return Status::OK();
return ::tsl::OkStatus();
}

Status DoTensorDot(OpKernelContext* ctx, const Tensor& a, const Tensor& b,
Expand Down Expand Up @@ -361,11 +361,11 @@ class MusaTensorDotOp : public MusaOpKernel {
if (!output->CopyFrom(matmul_temp, output_shape)) {
return errors::Internal("Failed to reshape matmul result to output");
}
return Status::OK();
return ::tsl::OkStatus();
}

TF_RETURN_IF_ERROR(DoMatMul(ctx, a_2d, b_2d, &matmul_view));
return Status::OK();
return ::tsl::OkStatus();
}

// 准备张量:transpose + reshape 为 2D
Expand Down Expand Up @@ -404,7 +404,7 @@ class MusaTensorDotOp : public MusaOpKernel {
}
}

return Status::OK();
return ::tsl::OkStatus();
}
};

Expand Down
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_tensorinteraction_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ REGISTER_OP("MusaInteract")
::tensorflow::shape_inference::DimensionHandle n = c->Dim(input_shape, 1);

c->set_output(0, c->MakeShape({batch, n, n}));
return Status::OK();
return ::tsl::OkStatus();
});
}
8 changes: 4 additions & 4 deletions musa_ext/kernels/array/musa_tensorlist_fromtensor_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ Status TensorShapeFromTensorMusa(const Tensor& t, PartialTensorShape* out) {
if (t.dtype() == DT_INT32) {
out->Clear();
out->AddDim(static_cast<int64_t>(t.scalar<int32>()()));
return Status::OK();
return ::tsl::OkStatus();
} else if (t.dtype() == DT_INT64) {
out->Clear();
out->AddDim(static_cast<int64_t>(t.scalar<int64>()()));
return Status::OK();
return ::tsl::OkStatus();
} else {
return errors::InvalidArgument(
"element_shape must be int32 or int64, got ",
Expand All @@ -43,15 +43,15 @@ Status TensorShapeFromTensorMusa(const Tensor& t, PartialTensorShape* out) {
for (int i = 0; i < vec.size(); ++i) {
out->AddDim(static_cast<int64_t>(vec(i)));
}
return Status::OK();
return ::tsl::OkStatus();
}

if (t.dtype() == DT_INT64) {
auto vec = t.vec<int64>();
for (int i = 0; i < vec.size(); ++i) {
out->AddDim(static_cast<int64_t>(vec(i)));
}
return Status::OK();
return ::tsl::OkStatus();
}

return errors::InvalidArgument("element_shape must be int32 or int64, got ",
Expand Down
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_tensorlist_reserve_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Status TensorShapeFromTensorReserve(const Tensor& t, PartialTensorShape* out) {
if ((t.dtype() == DT_INT32 && t.scalar<int32_t>()() == -1) ||
(t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1)) {
*out = PartialTensorShape(); // Fully unknown shape
return Status::OK();
return ::tsl::OkStatus();
}
return errors::InvalidArgument(
"The only valid scalar shape tensor is the fully unknown shape "
Expand Down
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_tokenmixer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ REGISTER_OP("MusaTokenMixer")
auto batch_dim = c->Dim(c->input(0), 0);
c->set_output(0, c->MakeShape({batch_dim, c->MakeDim(num_H),
c->MakeDim(num_T * d_k)}));
return Status::OK();
return ::tsl::OkStatus();
});

} // namespace tensorflow
2 changes: 1 addition & 1 deletion musa_ext/kernels/array/musa_transpose_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct TransposeFunctor {
if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) {
return errors::Internal("muDNN Permute Run failed!");
}
return Status::OK();
return ::tsl::OkStatus();
}
};

Expand Down
8 changes: 4 additions & 4 deletions musa_ext/kernels/array/musa_where_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct NumTrue {

if (input.size() == 0) {
*num_true_data = static_cast<TIndex>(0);
return Status::OK();
return ::tsl::OkStatus();
}

// Use the new LaunchIsNonZeroCount operator which directly counts
Expand All @@ -70,7 +70,7 @@ struct NumTrue {
musaGetErrorString(m_err));
}

return Status::OK();
return ::tsl::OkStatus();
}
};

Expand Down Expand Up @@ -99,7 +99,7 @@ struct Where {
typename TTypes<T, NDIM>::ConstTensor input,
typename TTypes<TIndex>::Matrix output) {
if (output.dimension(0) == 0) {
return Status::OK();
return ::tsl::OkStatus();
}

musaStream_t stream = GetMusaStreamByCtx(ctx);
Expand Down Expand Up @@ -171,7 +171,7 @@ struct Where {
LaunchPropagateWhereIndicesKernel<NDIM, TIndex>(
output_rows, strides.data(), selected_indices, output.data(), stream);

return Status::OK();
return ::tsl::OkStatus();
}
};

Expand Down
8 changes: 4 additions & 4 deletions musa_ext/kernels/math/musa_Maxpool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Status PermuteTensorOnMusa(OpKernelContext* ctx, const Tensor& input,
static_cast<int>(status));
}

return Status::OK();
return ::tsl::OkStatus();
}

Status ComputeOutputAndPadding2D(int64_t in_h, int64_t in_w, int64_t window_h,
Expand All @@ -65,7 +65,7 @@ Status ComputeOutputAndPadding2D(int64_t in_h, int64_t in_w, int64_t window_h,
*pad_bottom = 0;
*pad_left = 0;
*pad_right = 0;
return Status::OK();
return ::tsl::OkStatus();
}

if (padding == Padding::SAME) {
Expand All @@ -81,7 +81,7 @@ Status ComputeOutputAndPadding2D(int64_t in_h, int64_t in_w, int64_t window_h,
*pad_bottom = static_cast<int>(pad_h - *pad_top);
*pad_left = static_cast<int>(pad_w / 2);
*pad_right = static_cast<int>(pad_w - *pad_left);
return Status::OK();
return ::tsl::OkStatus();
}

return errors::InvalidArgument(
Expand Down Expand Up @@ -119,7 +119,7 @@ Status RunMusaMaxPool(OpKernelContext* ctx, const Tensor& input, Tensor* output,
static_cast<int>(status));
}

return Status::OK();
return ::tsl::OkStatus();
}

} // namespace
Expand Down
4 changes: 2 additions & 2 deletions musa_ext/kernels/math/musa_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Status ConfigureBroadcastView(const Tensor& tensor,
// Express TensorFlow-style broadcast as a muDNN tensor view by keeping the
// output dims and setting broadcasted axes to stride 0.
if (SameShape(tensor, output_shape) || output_shape.dims() == 0) {
return Status::OK();
return ::tsl::OkStatus();
}

const int input_rank = tensor.dims();
Expand Down Expand Up @@ -265,7 +265,7 @@ Status ConfigureBroadcastView(const Tensor& tensor,
return errors::Internal("MUSA Add SetNdInfo failed. Status: ",
static_cast<int>(status));
}
return Status::OK();
return ::tsl::OkStatus();
}

} // namespace
Expand Down
2 changes: 1 addition & 1 deletion musa_ext/kernels/math/musa_cast_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ static Status CastFunctor(OpKernelContext* ctx, const mTensor& input_mt,
return errors::Internal("CastTensor Run failed. Status: ",
static_cast<int>(status));
}
return Status::OK();
return ::tsl::OkStatus();
}

} // namespace musa
Expand Down
6 changes: 3 additions & 3 deletions musa_ext/kernels/math/musa_clip_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ REGISTER_OP("MusaClip")
if (!c->RankKnown(x_shape) || !c->RankKnown(lo_shape) ||
!c->RankKnown(hi_shape)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
return ::tsl::OkStatus();
}

auto BroadcastTwoShapes =
Expand Down Expand Up @@ -136,14 +136,14 @@ REGISTER_OP("MusaClip")

std::reverse(dims.begin(), dims.end());
*out = c->MakeShape(dims);
return Status::OK();
return ::tsl::OkStatus();
};

TF_RETURN_IF_ERROR(BroadcastTwoShapes(x_shape, lo_shape, &x_lo_shape));
TF_RETURN_IF_ERROR(BroadcastTwoShapes(x_lo_shape, hi_shape, &out_shape));

c->set_output(0, out_shape);
return Status::OK();
return ::tsl::OkStatus();
});

} // namespace tensorflow
Expand Down
Loading
Loading