Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -75,7 +75,7 @@ if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
else()
target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX)
target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand)
endif()
target_compile_options(test_operator PRIVATE -O2 -fopenmp)

Expand Down
45 changes: 44 additions & 1 deletion tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand Down Expand Up @@ -635,4 +635,47 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
});

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
const size_t rows = 128;
const size_t cols = 256;
const size_t N = rows * cols;

test::Tensor t("fillUniform_regression_fp32",
std::vector<size_t>{rows, cols},
transformer_engine::DType::kFloat32,
/*rowwise=*/true,
/*columnwise=*/false);

// Tensor constructor initializes CPU mirror + device to zero.
// If GPU generation happens but CPU mirror is not updated,
// any later test::Tensor::from_cpu() will overwrite device back to zeros.
fillUniform(&t);

// Check the CPU mirror has *actual* generated values, not all zeros
const float* cpu = t.rowwise_cpu_dptr<float>();

bool any_nonzero = false;
for (size_t i = 0; i < N; ++i) {
any_nonzero |= (cpu[i] != 0.0f);
if (any_nonzero)
break;
}

ASSERT_TRUE(any_nonzero) << "CPU mirror is all zeros. "
<< "Likely GPU-generated data got overwritten by from_cpu().";

// Check device matches CPU mirror after fillUniform completes
std::vector<float> dev(N, 0.0f);
NVTE_CHECK_CUDA(cudaMemcpy(dev.data(),
t.rowwise_dptr(),
N * sizeof(float),
cudaMemcpyDeviceToHost));

for (size_t i = 0; i < N; ++i) {
ASSERT_EQ(dev[i], cpu[i]) << "Mismatch at i=" << i
<< " dev=" << dev[i] << " cpu=" << cpu[i];
}
}

#endif // __HIP_PLATFORM_AMD__
157 changes: 138 additions & 19 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

#include <curand.h>

namespace test {

size_t create_seed_from_tensor_name(const std::string& tensor_name) {
Expand Down Expand Up @@ -786,16 +788,9 @@ std::pair<double, double> getTolerances(const DType type) {
return {0, 0};
}

#ifndef __HIP_PLATFORM_AMD__
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#ifdef __HIP_PLATFORM_AMD__
// TODO: Introduce a parallel RNG library (Random123, PCG, rocRAND)
std::uniform_real_distribution<> dis(-2.0, 1.0);
for (int i = 0; i < size; i++) {
data[i] = static_cast<T>(dis(*gen));
}
gen->discard(size);
#else
// Check how many RNG calls are required to generate one uniform random value
int rng_calls_per_val = 0;
{
Expand Down Expand Up @@ -825,28 +820,134 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
}
}
gen->discard(size * rng_calls_per_val);
}
#endif

#ifdef __HIP_PLATFORM_AMD__
template <typename T>
__global__ void affine_transform_and_cast(const float* __restrict__ in,
T* __restrict__ out, size_t n, double lo,
double hi) {
// Clamp values in *in* to [lo, hi] and cast to type *T* for *out*.
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
out[idx] = static_cast<T>(lo + (hi - lo) * in[idx]);
}
}

template <typename T>
__global__ void apply_random_sign(T* __restrict__ data,
const float* __restrict__ signs,
size_t n) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
if (signs[idx] < 0.5f) {
data[idx] = static_cast<T>(-static_cast<float>(data[idx]));
}
}
}

template <typename T>
static void fillUniformLinearBufferDevice(T* dst_dev,
T* dst_cpu, // nullable
size_t N,
unsigned long long seed,
double lo, double hi,
bool random_sign=false) {
// Fill a linear device buffer with uniform randoms in [*lo*, *hi*] and cast them to *T*.
// Optionally mirror the result into a provided CPU pointer.
if (N == 0)
return;

float* tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&tmp, N * sizeof(float)));

float* tmp_sign = nullptr;
if (random_sign) {
NVTE_CHECK_CUDA(cudaMalloc(&tmp_sign, N * sizeof(float)));
}

curandGenerator_t gen;
NVTE_CHECK(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_PHILOX4_32_10) == CURAND_STATUS_SUCCESS);
NVTE_CHECK(curandSetPseudoRandomGeneratorSeed(gen, seed) == CURAND_STATUS_SUCCESS);
NVTE_CHECK(curandGenerateUniform(gen, tmp, N) == CURAND_STATUS_SUCCESS);

if (random_sign) {
NVTE_CHECK(curandGenerateUniform(gen, tmp_sign, N) == CURAND_STATUS_SUCCESS);
}

dim3 block(256);
dim3 grid((N + block.x - 1) / block.x);

affine_transform_and_cast<T><<<grid, block, 0, 0>>>(
tmp, reinterpret_cast<T*>(dst_dev), N, lo, hi);

if (random_sign) {
apply_random_sign<T><<<grid, block, 0, 0>>>(
reinterpret_cast<T*>(dst_dev), tmp_sign, N);
}

NVTE_CHECK_CUDA(cudaGetLastError());

if (dst_cpu != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpy(dst_cpu, dst_dev, N * sizeof(T), cudaMemcpyDeviceToHost));
}

NVTE_CHECK(curandDestroyGenerator(gen) == CURAND_STATUS_SUCCESS);
NVTE_CHECK_CUDA(cudaFree(tmp));
if (tmp_sign)
NVTE_CHECK_CUDA(cudaFree(tmp_sign));
}

static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f,
double hi=1.0f, bool random_sign=false) {
void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr();
const auto shape = t->rowwise() ? (t->rowwise_shape()) : (t->columnwise_shape());
const size_t N = product(shape);

// per-tensor deterministic seed
const unsigned long long seed = static_cast<unsigned long long>(t->gen()());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T should either be template parameter and no TRANSFORMER_ENGINE_TYPE_SWITCH_ALL here, or the method calling should be moved out of TRANSFORMER_ENGINE_TYPE_SWITCH_ALL in fillUniform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the restructuring in bdb8349, I believe this comment is now addressed?

T* dst_dev = reinterpret_cast<T*>(dst_dev_void);
// Keep the CPU mirror in sync. We could use Tensor::to_cpu() here,
// but that does more than just copying the data.
T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr<T>() : t->columnwise_cpu_dptr<T>();
fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign);
});
}
#endif

void fillUniform(Tensor *t) {
if (t->rowwise()) {
const size_t size = product(t->rowwise_shape());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
{
#ifdef __HIP_PLATFORM_AMD__
fillUniformTensorDevice(t);
#else
T *data = t->rowwise_cpu_dptr<T>();
generate_data_uniformly(data, size, &(t->gen()));
#endif
}
);
} else {
const size_t size = product(t->columnwise_shape());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
{
#ifdef __HIP_PLATFORM_AMD__
fillUniformTensorDevice(t);
#else
T *data = t->columnwise_cpu_dptr<T>();
generate_data_uniformly(data, size, &(t->gen()));
#endif
}
);
}
#ifndef __HIP_PLATFORM_AMD__
// Data is already on device on AMDGPU
t->from_cpu();
#endif
std::uniform_real_distribution<> dis(-2.0, 1.0);
t->set_scale_inv(dis(t->gen()));
}
Expand All @@ -857,10 +958,18 @@ void fillCase_special(Tensor *t) {

if constexpr (Case == InputsFillCase::zeros) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
#ifdef __HIP_PLATFORM_AMD__
// Fill device and CPU mirror
void* dst_dev = t->rowwise_dptr();
NVTE_CHECK_CUDA(cudaMemset(dst_dev, 0, size * sizeof(InputType)));
InputType* dst_cpu = t->rowwise_cpu_dptr<InputType>();
std::fill_n(dst_cpu, size, static_cast<InputType>(0));
#else
InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<InputType>(0);
}
#endif
});
} else {
double minAbs = -2.0;
Expand All @@ -869,22 +978,32 @@ void fillCase_special(Tensor *t) {
minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
}
std::uniform_real_distribution<> dis(minAbs, maxAbs);
std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t idx = 0; idx < size; ++idx) {
const bool is_negative = (dis_sign(t->gen()) < 0.0);
double val = dis(t->gen());
if (is_negative) {
val = -val;
}
data[idx] = static_cast<InputType>(val);
}
#ifdef __HIP_PLATFORM_AMD__
const unsigned long long seed = static_cast<unsigned long long>(t->gen()());
InputType* dst_dev = static_cast<InputType*>(t->rowwise_dptr());
InputType* dst_cpu = static_cast<InputType*>(t->rowwise_cpu_dptr<InputType>());
fillUniformLinearBufferDevice(dst_dev, dst_cpu, size, seed,
minAbs, maxAbs, /*random_sign=*/true);
#else
std::uniform_real_distribution<> dis(minAbs, maxAbs);
std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t idx = 0; idx < size; ++idx) {
const bool is_negative = (dis_sign(t->gen()) < 0.0);
double val = dis(t->gen());
if (is_negative) {
val = -val;
}
data[idx] = static_cast<InputType>(val);
}
#endif
});
}
t->set_scale_inv(1.0);
#ifndef __HIP_PLATFORM_AMD__
t->from_cpu();
#endif
}

template <typename InputEncoding>
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/util/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -21,7 +21,7 @@ find_package(OpenMP REQUIRED)
if(USE_CUDA)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX)
else()
target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX)
target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand)
endif()
target_compile_options(test_util PRIVATE -O2 -fopenmp)

Expand Down