-
Notifications
You must be signed in to change notification settings - Fork 20
Add WaveActiveMax tests #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of saying 'tracked by' say 'bug' or 'unimplemented' and please add a comment with bug/unimplemented + issue for the XFAILS that don't have one.
half4 v = In[tid.x]; | ||
|
||
half s1 = WaveActiveMax( v.x ); | ||
half s2 = tid.x < 3 ? WaveActiveMax( v.x ) : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not obvious, but this is dependent on short-circuiting of ternary operator, introduced in HLSL 2021. It obfuscates the control flow a bit, which could be confusing for some at first read.
Another approach would be to group assignments under explicit control flow blocks. You could even utilize arrays to reduce the duplicated code. I believe the following is equivalent in functionality:
half s1[4] = (half[4])0;
half2 v2[4] = (half2[4])0;
half3 v3[4] = (half3[4])0;
half4 v4[4] = (half4[4])0;
for (int i = 0; i < 4; i++) {
if (tid.x <= i) {
s1[i] = WaveActiveMax( v.x );
v2[i] = WaveActiveMax( v.xy );
v3[i] = WaveActiveMax( v.xyz );
v4[i] = WaveActiveMax( v );
}
}
Out1[tid.x].x = s1[tid.x];
Out2[tid.x].xy = v2[tid.x];
Out3[tid.x].xyz = v3[tid.x];
Out4[tid.x] = v4[tid.x];
This seems easier to follow. It might also catch implementations that might apply illegal control flow optimizations impacting wave ops.
Written this way, I notice that it is a bit of an odd approach with the arrays. While we write to local arrays on each thread, we only ever output the array element corresponding to thread id on each thread. It seems you could do away with the local arrays altogether. Like this:
half s1 = 0;
half2 v2 = 0;
half3 v3 = 0;
half4 v4 = 0;
// Reverse order allows thread local values to end up
// with max value for all threads <= tid.x.
for (int i = 4; i > 0; i--) {
if (tid.x < i) {
s1 = WaveActiveMax( v.x );
v2 = WaveActiveMax( v.xy );
v3 = WaveActiveMax( v.xyz );
v4 = WaveActiveMax( v );
}
}
Out1[tid.x].x = s1;
Out2[tid.x].xy = v2;
Out3[tid.x].xyz = v3;
Out4[tid.x] = v4;
With this, thread 0 should end up with max of just thread 0, thread 1 will be max of thread 0 and thread 1, and so on with thread 3 being the max of threads 0, 1, 2, and 3. Each thread overwrites the local values until it reaches the final iteration the thread participates in (max of all thread values up to and including this thread).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another limitation to this approach is that we aren't really verifying that prior threads are getting the same max value as later threads. To do so simply, we might have to expand the outputs verified to write each result for each thread, instead of just the one corresponding to that thread (using the original or local array version I suggested first). This would require 4 times the output, but would be straightforward copy-paste or script work to extend expected outputs.
Like this:
half s1[4] = (half[4])0;
half2 v2[4] = (half2[4])0;
half3 v3[4] = (half3[4])0;
half4 v4[4] = (half4[4])0;
for (int i = 0; i < 4; i++) {
if (tid.x <= i) {
s1[i] = WaveActiveMax( v.x );
v2[i] = WaveActiveMax( v.xy );
v3[i] = WaveActiveMax( v.xyz );
v4[i] = WaveActiveMax( v );
}
}
// Output all results for each thread to verify max broadcast
for (int i = 0; i < 4; i++) {
Out1[i * 4 + tid.x].x = s1[i];
Out2[i * 4 + tid.x].xy = v2[i];
Out3[i * 4 + tid.x].xyz = v3[i];
Out4[i * 4 + tid.x] = v4[i];
}
@@ -0,0 +1,178 @@ | |||
#--- source.hlsl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these HLSL sources appear to be identical except the base type used. Is there any way to reference a shared source file and use compilation arguments, like -D TYPE=half
instead?
- Name: In | ||
Format: Float16 | ||
Stride: 8 | ||
# 1, 10, 100, 1000, 2, 20, 200, 2000, 3, 30, 300, 3000, 4, 40, 400, 4000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All tests appear to use the same whole-number test values, which always increase in value for higher thread ids. I feel like this could miss implementation errors like:
- implicit casting for wave op to int (no fractional values)
- implicit casting for wave op to different bit-size (no values requiring selected bit size to accurately express)
- not handling/preserving denorms when required (for half and double, as well as float when denorm mode is preserve)
- mishandling of negative values
- mishandling of INF/-INF
- I believe inf/-inf should be reliably handled for this op, but could be wrong
- just returning the value from the highest active thread index (bad implementation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional note: the comment with the values is helpful, but would be even more helpful if formatted so you could line up values compared across threads, like:
# x, y, z, w
# 1, 10, 100, 1000, # thread 0
# 2, 20, 200, 2000, # thread 1
# 3, 30, 300, 3000, # thread 2
# 4, 40, 400, 4000 # thread 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, more sets of values could be tested with an outer loop around the value set, if that's desired given my feedback on the limitations of this chosen set of values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option for selecting sets of threads could have been an input mask input set instead of always using max of values from threads [0,n] for each n=[0,3]. That could look like:
# 13 active mask sets for threads 0, 1, 2, 3:
# 1 1 1 1
# 1 0 0 0
# 0 1 0 0
# 0 0 1 0
# 0 0 0 1
# 0 1 1 1
# 1 0 1 1
# 1 1 0 1
# 1 1 1 0
# 1 1 0 0
# 0 0 1 1
# 0 1 1 0
# 1 0 0 1
An updated shader that could work with this and multiple value sets:
#define VALUE_SETS 2
#define NUM_MASKS 13
#define NUM_THREADS 4
struct MaskStruct {
int mask[NUM_THREADS];
};
StructuredBuffer<half4> In : register(t0);
StructuredBuffer<MaskStruct> Masks : register(t1);
RWStructuredBuffer<half4> Out1 : register(u2); // test scalar
RWStructuredBuffer<half4> Out2 : register(u3); // test half2
RWStructuredBuffer<half4> Out3 : register(u4); // test half3
RWStructuredBuffer<half4> Out4 : register(u5); // test half4
RWStructuredBuffer<half4> Out5 : register(u6); // constant folding
[numthreads(NUM_THREADS,1,1)]
void main(uint3 tid : SV_GroupThreadID)
{
for (int ValueSet = 0; ValueSet < VALUE_SETS; ValueSet++) {
const uint ValueSetOffset = ValueSet * NUM_MASKS * NUM_THREADS;
half4 v = In[ValueSet * NUM_THREADS + tid.x];
for (int MaskIdx = 0; MaskIdx < NUM_MASKS; MaskIdx++) {
const uint OutIdx = ValueSetOffset + MaskIdx * NUM_THREADS + tid.x;
if (Masks[MaskIdx].mask[tid.x]) {
Out1[OutIdx].x = WaveActiveMax( v.x );
Out2[OutIdx].xy = WaveActiveMax( v.xy );
Out3[OutIdx].xyz = WaveActiveMax( v.xyz );
Out4[OutIdx] = WaveActiveMax( v );
}
}
}
// constant folding case
Out5[0] = WaveActiveMax(half4(1,2,3,4));
}
Adds WaveActiveMax tests.
Fixes #122