Skip to content

Commit 2b7657c

Browse files
author
Jeff Hammond
committed
dunno
1 parent 60e32c8 commit 2b7657c

File tree

3 files changed

+35
-9
lines changed

3 files changed

+35
-9
lines changed

lapack/test_xgemm.c

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
#include <stdlib.h>
33
#include <string.h>
44

5-
#ifdef MKL
6-
#include "mkl.h"
5+
#if defined(MKL)
6+
# include <mkl.h>
7+
# ifdef MKL_ILP64
8+
# error Use the MKL library for 32-bit integers!
9+
# endif
10+
#elif defined(ACCELERATE)
11+
/* The location of cblas.h is not in the system
12+
* include path when -framework Accelerate is provided. */
13+
# include <Accelerate/Accelerate.h>
714
#else
8-
#include "cblas.h"
15+
# include <cblas.h>
916
#endif
1017

1118
#if (__STDC_VERSION__ < 199901L)
@@ -18,6 +25,18 @@
1825
#define PDEBUG(fmt, ...)
1926
#endif
2027

28+
void test_sgemm(const int m, const int n, const int k)
29+
{
30+
31+
const double alpha = 1.0;
32+
const double beta = 1.0;
33+
const int lda = k;
34+
const int ldb = n;
35+
const int ldc = n;
36+
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
37+
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
38+
}
39+
2140
int main(int argc, char* argv[])
2241
{
2342
char type;
@@ -96,6 +115,9 @@ int main(int argc, char* argv[])
96115
printf("You have chosen an unsupported datatype...\n");
97116
goto input_error;
98117
}
118+
printf("C+=A*B with m=%d, n=%d, k=%d\n", m, n, k);
119+
120+
test_xgemm(type, m, n k);
99121

100122
printf("SUCCESS\n");
101123
return 0;
@@ -112,5 +134,6 @@ int main(int argc, char* argv[])
112134
goto fail;
113135

114136
fail:
137+
printf("FAILURE\n");
115138
return 1;
116139
}

simd/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ CXXFLAGS = -g -Wall -O3 -qopenmp -xCORE-AVX2
99
#CXXFLAGS = -g -Wall -O3 -qopenmp -xCORE-AVX512 -qopt-zmm-usage=high
1010
ASMFLAGS = -fsource-asm -fverbose-asm -fasm-blocks -fcode-asm
1111

12-
all: bfloat16 bfloat16.s
12+
all: bfloat16 bfloat16.s time_conversion
1313

1414
bfloat16: bfloat16.cc
1515
${CXX} ${CXXFLAGS} $< -o $@

simd/time_conversion.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,9 @@ int main(int argc, char* argv[])
2121

2222
uint16_t * ibf16 = new uint16_t[n];
2323
uint16_t * if16 = new uint16_t[n];
24-
float * f32 = new float[n];
2524
uint16_t * obf16 = new uint16_t[n];
2625
uint16_t * of16 = new uint16_t[n];
2726

28-
for (int i=0; i<n; i++) {
29-
f32[i] = (float)i;
30-
}
31-
3227
for (int i=0; i<n; ++i) {
3328
uint16_t j = i % 16384;
3429
ibf16[i] = j;
@@ -37,13 +32,21 @@ int main(int argc, char* argv[])
3732
of16[i] = j;
3833
}
3934

35+
#if defined(__AVX512F__) && defined(__AVX512BW__)
4036
for (int i=0; i<n; i+=8) {
4137
__m128i a = _mm_load_si128((__m128i*)&ibf16[i]);
4238
__m256i b = _mm256_cvtepu16_epi32(a);
4339
__m256i c = _mm256_slli_epi32(b,16);
4440

4541
_mm_store_si128((__m128i*)&obf16[i],c);
4642
}
43+
#else
44+
PRAGMA_SIMD
45+
for (int i=0; i<n; i++) {
46+
uint32_t u32 = ((uint32_t)bf[i]) << 16;
47+
float f32 = *(float*)&u32;
48+
}
49+
#endif
4750

4851
for (int i=0; i<n; i++) {
4952
std::cout << std::setw(10) << i << ": "

0 commit comments

Comments
 (0)