Skip to content

Conversation

bob80905
Copy link
Contributor

@bob80905 bob80905 commented Sep 5, 2025

Adds WaveActiveMax tests.
Fixes #122

Copy link
Collaborator

@spall spall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of saying 'tracked by' say 'bug' or 'unimplemented' and please add a comment with bug/unimplemented + issue for the XFAILS that don't have one.

half4 v = In[tid.x];

half s1 = WaveActiveMax( v.x );
half s2 = tid.x < 3 ? WaveActiveMax( v.x ) : 0;
Copy link

@tex3d tex3d Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious, but this is dependent on short-circuiting of ternary operator, introduced in HLSL 2021. It obfuscates the control flow a bit, which could be confusing for some at first read.

Another approach would be to group assignments under explicit control flow blocks. You could even utilize arrays to reduce the duplicated code. I believe the following is equivalent in functionality:

    half s1[4] = (half[4])0;
    half2 v2[4] = (half2[4])0;
    half3 v3[4] = (half3[4])0;
    half4 v4[4] = (half4[4])0;

    for (int i = 0; i < 4; i++) {
        if (tid.x <= i) {
            s1[i] = WaveActiveMax( v.x );
            v2[i] = WaveActiveMax( v.xy );
            v3[i] = WaveActiveMax( v.xyz );
            v4[i] = WaveActiveMax( v );
        }
    }

    Out1[tid.x].x = s1[tid.x];
    Out2[tid.x].xy = v2[tid.x];
    Out3[tid.x].xyz = v3[tid.x];
    Out4[tid.x] = v4[tid.x];

This seems easier to follow. It might also catch implementations that might apply illegal control flow optimizations impacting wave ops.

Written this way, I notice that it is a bit of an odd approach with the arrays. While we write to local arrays on each thread, we only ever output the array element corresponding to thread id on each thread. It seems you could do away with the local arrays altogether. Like this:

    half s1 = 0;
    half2 v2 = 0;
    half3 v3 = 0;
    half4 v4 = 0;

    // Reverse order allows thread local values to end up
    // with max value for all threads <= tid.x.
    for (int i = 4; i > 0; i--) {
        if (tid.x < i) {
            s1 = WaveActiveMax( v.x );
            v2 = WaveActiveMax( v.xy );
            v3 = WaveActiveMax( v.xyz );
            v4 = WaveActiveMax( v );
        }
    }

    Out1[tid.x].x = s1;
    Out2[tid.x].xy = v2;
    Out3[tid.x].xyz = v3;
    Out4[tid.x] = v4;

With this, thread 0 should end up with max of just thread 0, thread 1 will be max of thread 0 and thread 1, and so on with thread 3 being the max of threads 0, 1, 2, and 3. Each thread overwrites the local values until it reaches the final iteration the thread participates in (max of all thread values up to and including this thread).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another limitation to this approach is that we aren't really verifying that prior threads are getting the same max value as later threads. To do so simply, we might have to expand the outputs verified to write each result for each thread, instead of just the one corresponding to that thread (using the original or local array version I suggested first). This would require 4 times the output, but would be straightforward copy-paste or script work to extend expected outputs.

Like this:

    half s1[4] = (half[4])0;
    half2 v2[4] = (half2[4])0;
    half3 v3[4] = (half3[4])0;
    half4 v4[4] = (half4[4])0;

    for (int i = 0; i < 4; i++) {
        if (tid.x <= i) {
            s1[i] = WaveActiveMax( v.x );
            v2[i] = WaveActiveMax( v.xy );
            v3[i] = WaveActiveMax( v.xyz );
            v4[i] = WaveActiveMax( v );
        }
    }

    // Output all results for each thread to verify max broadcast
    for (int i = 0; i < 4; i++) {
        Out1[i * 4 + tid.x].x = s1[i];
        Out2[i * 4 + tid.x].xy = v2[i];
        Out3[i * 4 + tid.x].xyz = v3[i];
        Out4[i * 4 + tid.x] = v4[i];
    }

@@ -0,0 +1,178 @@
#--- source.hlsl
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these HLSL sources appear to be identical except the base type used. Is there any way to reference a shared source file and use compilation arguments, like -D TYPE=half instead?

- Name: In
Format: Float16
Stride: 8
# 1, 10, 100, 1000, 2, 20, 200, 2000, 3, 30, 300, 3000, 4, 40, 400, 4000
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tests appear to use the same whole-number test values, which always increase in value for higher thread ids. I feel like this could miss implementation errors like:

  • implicit casting for wave op to int (no fractional values)
  • implicit casting for wave op to different bit-size (no values requiring selected bit size to accurately express)
  • not handling/preserving denorms when required (for half and double, as well as float when denorm mode is preserve)
  • mishandling of negative values
  • mishandling of INF/-INF
    • I believe inf/-inf should be reliably handled for this op, but could be wrong
  • just returning the value from the highest active thread index (bad implementation)

Copy link

@tex3d tex3d Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional note: the comment with the values is helpful, but would be even more helpful if formatted so you could line up values compared across threads, like:

    # x,  y,   z,    w
    # 1, 10, 100, 1000, # thread 0
    # 2, 20, 200, 2000, # thread 1
    # 3, 30, 300, 3000, # thread 2
    # 4, 40, 400, 4000  # thread 3

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, more sets of values could be tested with an outer loop around the value set, if that's desired given my feedback on the limitations of this chosen set of values.

Copy link

@tex3d tex3d Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option for selecting sets of threads could have been an input mask input set instead of always using max of values from threads [0,n] for each n=[0,3]. That could look like:

    # 13 active mask sets for threads 0, 1, 2, 3:
    # 1 1 1 1
    # 1 0 0 0
    # 0 1 0 0
    # 0 0 1 0
    # 0 0 0 1
    # 0 1 1 1
    # 1 0 1 1
    # 1 1 0 1
    # 1 1 1 0
    # 1 1 0 0
    # 0 0 1 1
    # 0 1 1 0
    # 1 0 0 1

An updated shader that could work with this and multiple value sets:

#define VALUE_SETS 2
#define NUM_MASKS 13
#define NUM_THREADS 4

struct MaskStruct {
    int mask[NUM_THREADS];
};

StructuredBuffer<half4> In  : register(t0);
StructuredBuffer<MaskStruct> Masks  : register(t1);
RWStructuredBuffer<half4> Out1 : register(u2); // test scalar
RWStructuredBuffer<half4> Out2 : register(u3); // test half2
RWStructuredBuffer<half4> Out3 : register(u4); // test half3
RWStructuredBuffer<half4> Out4 : register(u5); // test half4
RWStructuredBuffer<half4> Out5 : register(u6); // constant folding

[numthreads(NUM_THREADS,1,1)]
void main(uint3 tid : SV_GroupThreadID)
{
    for (int ValueSet = 0; ValueSet < VALUE_SETS; ValueSet++) {
        const uint ValueSetOffset = ValueSet * NUM_MASKS * NUM_THREADS;
        half4 v = In[ValueSet * NUM_THREADS + tid.x];
        for (int MaskIdx = 0; MaskIdx < NUM_MASKS; MaskIdx++) {
            const uint OutIdx = ValueSetOffset + MaskIdx * NUM_THREADS + tid.x;
            if (Masks[MaskIdx].mask[tid.x]) {
                Out1[OutIdx].x = WaveActiveMax( v.x );
                Out2[OutIdx].xy = WaveActiveMax( v.xy );
                Out3[OutIdx].xyz = WaveActiveMax( v.xyz );
                Out4[OutIdx] = WaveActiveMax( v );
            }
        }
    }

    // constant folding case
    Out5[0] = WaveActiveMax(half4(1,2,3,4));
}

See: https://www.godbolt.org/z/P1r3E869h

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add test for WaveActiveMax
3 participants