diff --git a/src/ComputeSharp.Shaders/Extensions/SyntaxNodeExtensions.cs b/src/ComputeSharp.Shaders/Extensions/SyntaxNodeExtensions.cs index 1aa1122d0..8628734b8 100644 --- a/src/ComputeSharp.Shaders/Extensions/SyntaxNodeExtensions.cs +++ b/src/ComputeSharp.Shaders/Extensions/SyntaxNodeExtensions.cs @@ -42,12 +42,14 @@ public static TRoot ReplaceType(this TRoot node, TypeSyntax type) where T /// The input to check and modify if needed /// The to use to load symbols for the input node /// The info on parsed static members, if any + /// The info on parsed static methods, if any /// A instance that is compatible with HLSL [Pure] - public static SyntaxNode ReplaceMember(this MemberAccessExpressionSyntax node, SemanticModel semanticModel, out (string Name, ReadableMember MemberInfo)? variable) + public static SyntaxNode ReplaceMember(this MemberAccessExpressionSyntax node, SemanticModel semanticModel, out (string Name, ReadableMember MemberInfo)? variable, out (string Name, MethodInfo MethodInfo)? method) { - // Set the variable to null, replace it later on if needed + // Set the out parameters to null, replace them later on if needed variable = null; + method = null; SymbolInfo containingMemberSymbolInfo; ISymbol? memberSymbol; @@ -115,37 +117,48 @@ public static SyntaxNode ReplaceMember(this MemberAccessExpressionSyntax node, S return SyntaxFactory.IdentifierName(expression).WithLeadingTrivia(node.GetLeadingTrivia()).WithTrailingTrivia(node.GetTrailingTrivia()); } - // Handle static fields as a special case + // Handle static members as a special case if (memberSymbol.IsStatic && ( memberSymbol.Kind == SymbolKind.Field || - memberSymbol.Kind == SymbolKind.Property)) + memberSymbol.Kind == SymbolKind.Property || + memberSymbol.Kind == SymbolKind.Method)) { // Get the containing type string typeFullname = memberSymbol.ContainingType.ToString(), assemblyFullname = memberSymbol.ContainingAssembly.ToString(); - Type fieldDeclaringType = Type.GetType($"{typeFullname}, {assemblyFullname}"); + Type memberDeclaringType = Type.GetType($"{typeFullname}, {assemblyFullname}"); - // Retrieve the field or property info - bool isReadonly; - ReadableMember memberInfo; - switch (memberSymbol.Kind) + // Static field or property + if (memberSymbol.Kind == SymbolKind.Field || memberSymbol.Kind == SymbolKind.Property) { - case SymbolKind.Field: - FieldInfo fieldInfo = fieldDeclaringType.GetField(memberSymbol.Name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); - isReadonly = fieldInfo.IsInitOnly; - memberInfo = fieldInfo; - break; - case SymbolKind.Property: - PropertyInfo propertyInfo = fieldDeclaringType.GetProperty(memberSymbol.Name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); - isReadonly = !propertyInfo.CanWrite; - memberInfo = propertyInfo; - break; - default: throw new InvalidOperationException($"Invalid symbol kind: {memberSymbol.Kind}"); + bool isReadonly; + ReadableMember memberInfo; + switch (memberSymbol.Kind) + { + case SymbolKind.Field: + FieldInfo fieldInfo = memberDeclaringType.GetField(memberSymbol.Name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); + isReadonly = fieldInfo.IsInitOnly; + memberInfo = fieldInfo; + break; + case SymbolKind.Property: + PropertyInfo propertyInfo = memberDeclaringType.GetProperty(memberSymbol.Name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); + isReadonly = !propertyInfo.CanWrite; + memberInfo = propertyInfo; + break; + default: throw new InvalidOperationException($"Invalid symbol kind: {memberSymbol.Kind}"); + } + + // Handle the loaded info + return ProcessStaticMember(node, memberInfo, isReadonly, ref variable); } - // Handle the loaded info - return ProcessStaticMember(node, memberInfo, isReadonly, ref variable); + // Static method + MethodInfo methodInfo = memberDeclaringType.GetMethod(memberSymbol.Name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); + string name = $"{methodInfo.DeclaringType.Name}_{methodInfo.Name}"; + method = (name, methodInfo); + + return SyntaxFactory.IdentifierName(name).WithLeadingTrivia(node.GetLeadingTrivia()).WithTrailingTrivia(node.GetTrailingTrivia()); } return node; diff --git a/src/ComputeSharp.Shaders/Renderer/Models/Functions/FunctionInfo.cs b/src/ComputeSharp.Shaders/Renderer/Models/Functions/FunctionInfo.cs new file mode 100644 index 000000000..3af27abed --- /dev/null +++ b/src/ComputeSharp.Shaders/Renderer/Models/Functions/FunctionInfo.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Generic; +using ComputeSharp.Shaders.Extensions; + +namespace ComputeSharp.Shaders.Renderer.Models.Functions +{ + /// + /// A that contains info on a shader function + /// + internal class FunctionInfo + { + /// + /// Gets the return type of the current function in the C# source + /// + public string FunctionCsharpReturnType { get; } + + /// + /// Gets the fullname of the current function in the C# source + /// + public string FunctionFullName { get; } + + /// + /// Gets the parameters of the function in the C# source + /// + public string FunctionCsharpParameters { get; } + + /// + /// Gets the return type of the current function + /// + public string ReturnType { get; } + + /// + /// Gets the name of the current function + /// + public string FunctionName { get; } + + /// + /// Gets the list of parameters for the current function + /// + public IReadOnlyList ParametersList { get; } + + /// + /// Gets the body of the current function + /// + public string FunctionBody { get; } + + /// + /// Creates a new instance with the specified parameters + /// + /// The concrete return type for the function + /// The fullname of the function in the C# source + /// The parameters of the function in the C# source + /// The return type of the current function + /// The name of the current function + /// The function parameters, if any + /// The current function + public FunctionInfo( + Type functionCsharpType, + string functionFullname, + string functionCsharpParameters, + string returnType, + string functionName, + IReadOnlyList parameters, + string functionBody) + { + FunctionCsharpReturnType = functionCsharpType.ToFriendlyString(); + FunctionFullName = functionFullname; + FunctionCsharpParameters = functionCsharpParameters; + ReturnType = returnType; + FunctionName = functionName; + ParametersList = parameters; + FunctionBody = functionBody; + } + } +} diff --git a/src/ComputeSharp.Shaders/Renderer/Models/Functions/ParameterInfo.cs b/src/ComputeSharp.Shaders/Renderer/Models/Functions/ParameterInfo.cs new file mode 100644 index 000000000..f5cc5172a --- /dev/null +++ b/src/ComputeSharp.Shaders/Renderer/Models/Functions/ParameterInfo.cs @@ -0,0 +1,52 @@ +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace ComputeSharp.Shaders.Renderer.Models.Functions +{ + /// + /// A that contains info on a function parameter + /// + internal class ParameterInfo + { + /// + /// Gets the modifier to use for the current parameter + /// + public string ParameterModifier { get; } + + /// + /// Gets the type of the current parameter + /// + public string ParameterType { get; } + + /// + /// Gets the name to use for the current parameter + /// + public string ParameterName { get; } + + /// + /// Gets whether or not the current parameter is the last one for the parent function + /// + public bool IsLastParameter { get; } + + /// + /// Creates a new instance with the specified parameters + /// + /// The modifiers used in the current parameter + /// The type of the current parameter + /// The name of the current parameter + /// Indicates whether or not the current parameter is the last one + public ParameterInfo(IReadOnlyList parameterModifiers, string parameterType, string parameterName, bool last) + { + if (parameterModifiers.Count == 0) ParameterModifier = "in"; + else if (parameterModifiers.First().IsKind(SyntaxKind.OutKeyword)) ParameterModifier = "out"; + else if (parameterModifiers.Any(m => m.IsKind(SyntaxKind.RefKeyword)) && !parameterModifiers.Any(m => m.IsKind(SyntaxKind.ReadOnlyKeyword))) ParameterModifier = "inout"; + else ParameterModifier = "in"; + + ParameterType = parameterType; + ParameterName = parameterName; + IsLastParameter = last; + } + } +} diff --git a/src/ComputeSharp.Shaders/Renderer/Models/ShaderInfo.cs b/src/ComputeSharp.Shaders/Renderer/Models/ShaderInfo.cs index 8dc33e8be..306dd3f73 100644 --- a/src/ComputeSharp.Shaders/Renderer/Models/ShaderInfo.cs +++ b/src/ComputeSharp.Shaders/Renderer/Models/ShaderInfo.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using ComputeSharp.Shaders.Renderer.Models.Fields; using ComputeSharp.Shaders.Renderer.Models.Fields.Abstract; +using ComputeSharp.Shaders.Renderer.Models.Functions; #pragma warning disable CS8618 // Non-nullable field is uninitialized @@ -12,12 +13,12 @@ namespace ComputeSharp.Shaders.Renderer.Models internal sealed class ShaderInfo { /// - /// Gets the list of captured buffers being present in the current shader + /// Gets or sets the list of captured buffers being present in the current shader /// public IReadOnlyList BuffersList { get; set; } /// - /// Gets the list of captured variables being present in the current shader + /// Gets or sets the list of captured variables being present in the current shader /// public IReadOnlyList FieldsList { get; set; } @@ -45,5 +46,10 @@ internal sealed class ShaderInfo /// Gets or sets the shader body to compile /// public string ShaderBody { get; set; } + + /// + /// Gets or sets the list of static functions used by the shader + /// + public IReadOnlyList FunctionsList { get; set; } } } diff --git a/src/ComputeSharp.Shaders/Renderer/Templates/ShaderTemplate.mustache b/src/ComputeSharp.Shaders/Renderer/Templates/ShaderTemplate.mustache index 8375efb6c..14c5d0757 100644 --- a/src/ComputeSharp.Shaders/Renderer/Templates/ShaderTemplate.mustache +++ b/src/ComputeSharp.Shaders/Renderer/Templates/ShaderTemplate.mustache @@ -39,6 +39,13 @@ cbuffer _{{FieldName}} : register(b{{BufferIndex}}) {{/IsReadWriteBuffer}} {{/BuffersList}} +{{!Shader private functions}} +{{#FunctionsList}} +// {{FunctionCsharpReturnType}} {{FunctionFullName}}({{FunctionCsharpParameters}}) +{{ReturnType}} {{FunctionName}}({{#ParametersList}}{{ParameterModifier}} {{ParameterType}} {{ParameterName}}{{^IsLastParameter}}, {{/IsLastParameter}}{{/ParametersList}}) +{{FunctionBody}} +{{/FunctionsList}} + {{!Shader entry point}} // Shader body [Shader("compute")] diff --git a/src/ComputeSharp.Shaders/ShaderRunner.cs b/src/ComputeSharp.Shaders/ShaderRunner.cs index a2a31153e..2815ce8b1 100644 --- a/src/ComputeSharp.Shaders/ShaderRunner.cs +++ b/src/ComputeSharp.Shaders/ShaderRunner.cs @@ -76,7 +76,8 @@ public static void Run( NumThreadsY = threadsY, NumThreadsZ = threadsZ, ThreadsIdsVariableName = shaderLoader.ThreadsIdsVariableName, - ShaderBody = shaderLoader.MethodBody + ShaderBody = shaderLoader.MethodBody, + FunctionsList = shaderLoader.FunctionsList }; string shaderSource = ShaderRenderer.Instance.Render(shaderInfo); diff --git a/src/ComputeSharp.Shaders/Translation/Enums/MethodType.cs b/src/ComputeSharp.Shaders/Translation/Enums/MethodType.cs new file mode 100644 index 000000000..7df4f830a --- /dev/null +++ b/src/ComputeSharp.Shaders/Translation/Enums/MethodType.cs @@ -0,0 +1,18 @@ +namespace ComputeSharp.Shaders.Translation.Enums +{ + /// + /// An that indicates a type of method to decompile + /// + internal enum MethodType + { + /// + /// An instance method that belongs to a closure class + /// + Closure, + + /// + /// A standalone, static method + /// + Static + } +} diff --git a/src/ComputeSharp.Shaders/Translation/MethodDecompiler.cs b/src/ComputeSharp.Shaders/Translation/MethodDecompiler.cs index b96e6d2e3..a6a0b3dcc 100644 --- a/src/ComputeSharp.Shaders/Translation/MethodDecompiler.cs +++ b/src/ComputeSharp.Shaders/Translation/MethodDecompiler.cs @@ -9,6 +9,7 @@ using System.Text; using System.Text.RegularExpressions; using ComputeSharp.Shaders.Mappings; +using ComputeSharp.Shaders.Translation.Enums; using ICSharpCode.Decompiler; using ICSharpCode.Decompiler.CSharp; using ICSharpCode.Decompiler.Metadata; @@ -108,66 +109,49 @@ private static CSharpDecompiler CreateDecompiler(string assemblyPath) } /// - /// A used to preprocess the closure type declarations for both lambda expressions and local methods + /// Decompiles a method or the whole declaring type /// - private static readonly Regex ClosureTypeDeclarationRegex = new Regex(@"(?<=private sealed class )<\w*>[\w_]+", RegexOptions.Compiled); + /// The target to decompile + /// Specifies whether or not to force the decompilation of just the given method, even if not static + /// The decompiled source code + [Pure] + private string DecompileMethodOrDeclaringType(MethodInfo methodInfo, bool methodOnly = false) + { + // Get the handle of the containing type method + string assemblyPath = methodInfo.DeclaringType?.Assembly.Location ?? throw new InvalidOperationException(); + int metadataToken = methodInfo.IsStatic || methodOnly ? methodInfo.MetadataToken : methodInfo.DeclaringType.MetadataToken; + EntityHandle typeHandle = MetadataTokenHelpers.TryAsEntityHandle(metadataToken) ?? throw new InvalidOperationException(); - /// - /// A used to preprocess the entry point declaration for both lambda expressions and local methods - /// - private static readonly Regex LambdaMethodDeclarationRegex = new Regex(@"(?:private|internal) void <\w+>[\w_|]+(?=\()", RegexOptions.Compiled); + // Get or create a decompiler for the target assembly, and decompile the type + if (!Decompilers.TryGetValue(assemblyPath, out CSharpDecompiler decompiler)) + { + decompiler = CreateDecompiler(assemblyPath); + Decompilers.Add(assemblyPath, decompiler); + } + + return decompiler.DecompileAsString(typeHandle); + } /// /// Decompiles a target method and returns its and info /// /// The input to inspect + /// The type of method to decompile /// The root node for the syntax tree of the input method /// The semantic model for the input method - public void GetSyntaxTree(MethodInfo methodInfo, out MethodDeclarationSyntax rootNode, out SemanticModel semanticModel) + public void GetSyntaxTree(MethodInfo methodInfo, MethodType methodType, out MethodDeclarationSyntax rootNode, out SemanticModel semanticModel) { lock (Lock) { - // Get the handle of the containing type method - string assemblyPath = methodInfo.DeclaringType?.Assembly.Location ?? throw new InvalidOperationException(); - EntityHandle typeHandle = MetadataTokenHelpers.TryAsEntityHandle(methodInfo.DeclaringType.MetadataToken) ?? throw new InvalidOperationException(); - - // Get or create a decompiler for the target assembly, and decompile the type - if (!Decompilers.TryGetValue(assemblyPath, out CSharpDecompiler decompiler)) - { - decompiler = CreateDecompiler(assemblyPath); - Decompilers.Add(assemblyPath, decompiler); - } - - // Decompile the method source and fix the method declaration for local methods converted to lambdas - string - sourceCode = decompiler.DecompileAsString(typeHandle), - typeFixedCode = ClosureTypeDeclarationRegex.Replace(sourceCode, "Shader"), - methodFixedCode = LambdaMethodDeclarationRegex.Replace(typeFixedCode, m => $"// {m.Value}{Environment.NewLine} internal void Main"); - - // Workaround for some local methods not being decompiled correctly - if (!methodFixedCode.Contains("internal void Main")) + string sourceCode = methodType switch { - EntityHandle methodHandle = MetadataTokenHelpers.TryAsEntityHandle(methodInfo.MetadataToken) ?? throw new InvalidOperationException(); - string - methodOnlySourceCode = decompiler.DecompileAsString(methodHandle), - methodOnlyFixedSourceCode = LambdaMethodDeclarationRegex.Replace(methodOnlySourceCode, m => $"// {m.Value}{Environment.NewLine} internal void Main"), - methodOnlyIndentedSourceCode = $" {methodOnlyFixedSourceCode.Replace(Environment.NewLine, $"{Environment.NewLine} ")}"; - - int lastClosedBracketsIndex = methodFixedCode.LastIndexOf('}'); - methodFixedCode = methodFixedCode.Insert(lastClosedBracketsIndex, methodOnlyIndentedSourceCode); - } - - // Unwrap the nested fields - string unwrappedSourceCode = UnwrapSyntaxTree(methodFixedCode); - - // Remove the in keyword from the source - string inFixedSourceCode = Regex.Replace(unwrappedSourceCode, @"(? GetSyntaxTreeForClosureMethod(methodInfo), + MethodType.Static => GetSyntaxTreeForStaticMethod(methodInfo), + _ => throw new ArgumentOutOfRangeException(nameof(methodType), $"Invalid method type: {methodType}") + }; // Load the type syntax tree - SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(outFixedSourceCode); + SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(sourceCode); // Get the root node to return rootNode = syntaxTree.GetRoot().DescendantNodes().OfType().First(node => node.GetLeadingTrivia().ToFullString().Contains(methodInfo.Name)); @@ -178,6 +162,74 @@ public void GetSyntaxTree(MethodInfo methodInfo, out MethodDeclarationSyntax roo } } + /// + /// Decompiles a target closure method + /// + /// The input to inspect + [Pure] + private string GetSyntaxTreeForClosureMethod(MethodInfo methodInfo) + { + // Decompile the method source and fix the method declaration for local methods converted to lambdas + string + sourceCode = DecompileMethodOrDeclaringType(methodInfo), + typeFixedCode = ClosureTypeDeclarationRegex.Replace(sourceCode, "Shader"), + methodFixedCode = LambdaMethodDeclarationRegex.Replace(typeFixedCode, m => $"// {m.Value}{Environment.NewLine} internal void Main"); + + // Workaround for some local methods not being decompiled correctly + if (!methodFixedCode.Contains("internal void Main")) + { + string + methodOnlySourceCode = DecompileMethodOrDeclaringType(methodInfo, true), + methodOnlyFixedSourceCode = LambdaMethodDeclarationRegex.Replace(methodOnlySourceCode, m => $"// {m.Value}{Environment.NewLine} internal void Main"), + methodOnlyIndentedSourceCode = $" {methodOnlyFixedSourceCode.Replace(Environment.NewLine, $"{Environment.NewLine} ")}"; + + int lastClosedBracketsIndex = methodFixedCode.LastIndexOf('}'); + methodFixedCode = methodFixedCode.Insert(lastClosedBracketsIndex, methodOnlyIndentedSourceCode); + } + + // Unwrap the nested fields + string unwrappedSourceCode = UnwrapSyntaxTree(methodFixedCode); + + // Remove the in keyword from the source + string inFixedSourceCode = Regex.Replace(unwrappedSourceCode, @"(? + /// Decompiles a target method and returns its and info + /// + /// The input to inspect + [Pure] + private string GetSyntaxTreeForStaticMethod(MethodInfo methodInfo) + { + string + sourceCode = DecompileMethodOrDeclaringType(methodInfo), + prototype = sourceCode.Split(Environment.NewLine)[0], + commentedSourceCode = $"// {methodInfo.Name}{Environment.NewLine}{sourceCode}", + inFixedSourceCode = Regex.Replace(commentedSourceCode, @"(?().First(); + + return $"// {methodInfo.Name}{Environment.NewLine}{prototype}{Environment.NewLine}{block.ToFullString()}"; + } + + /// + /// A used to preprocess the closure type declarations for both lambda expressions and local methods + /// + private static readonly Regex ClosureTypeDeclarationRegex = new Regex(@"(?<=private sealed class )<\w*>[\w_]+", RegexOptions.Compiled); + + /// + /// A used to preprocess the entry point declaration for both lambda expressions and local methods + /// + private static readonly Regex LambdaMethodDeclarationRegex = new Regex(@"(?:private|internal) void <\w+>[\w_|]+(?=\()", RegexOptions.Compiled); + /// /// A used to find fields that represent nested closure types /// @@ -272,14 +324,14 @@ where argument.RefKindKeyword.IsKind(SyntaxKind.OutKeyword) && argument.Expression.IsKind(SyntaxKind.DeclarationExpression) let match = Regex.Match(argument.Expression.ToFullString(), @"([\w.]+) ([\w_]+)") let mappedType = HlslKnownTypes.GetMappedName(match.Groups[1].Value) - let declatation = $"{mappedType} {match.Groups[2].Value} = ({mappedType})0;" - select (match.Groups[1].Value, match.Groups[2].Value, declatation)).ToArray(); + let declaration = $"{mappedType} {match.Groups[2].Value} = ({mappedType})0;" + select declaration).ToArray(); // Insert the explicit declarations at the start of the method int start = rootNode.Body.ChildNodes().First().SpanStart; - foreach (var item in outs.Reverse()) + foreach (var declaration in outs.Reverse()) { - source = source.Insert(start, $"{item.declatation}{Environment.NewLine} "); + source = source.Insert(start, $"{declaration}{Environment.NewLine} "); } // Remove the out keyword from the source diff --git a/src/ComputeSharp.Shaders/Translation/ShaderLoader.cs b/src/ComputeSharp.Shaders/Translation/ShaderLoader.cs index 599ade23c..1c3c24ccf 100644 --- a/src/ComputeSharp.Shaders/Translation/ShaderLoader.cs +++ b/src/ComputeSharp.Shaders/Translation/ShaderLoader.cs @@ -5,13 +5,18 @@ using System.Reflection; using System.Text.RegularExpressions; using ComputeSharp.Graphics.Buffers.Abstract; +using ComputeSharp.Shaders.Extensions; using ComputeSharp.Shaders.Mappings; using ComputeSharp.Shaders.Renderer.Models.Fields; using ComputeSharp.Shaders.Renderer.Models.Fields.Abstract; +using ComputeSharp.Shaders.Renderer.Models.Functions; +using ComputeSharp.Shaders.Translation.Enums; using ComputeSharp.Shaders.Translation.Models; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Vortice.Direct3D12; +using ParameterInfo = ComputeSharp.Shaders.Renderer.Models.Functions.ParameterInfo; #pragma warning disable CS8618 // Non-nullable field is uninitialized @@ -137,6 +142,13 @@ public IEnumerable GetVariables(Action action) /// public string MethodBody { get; private set; } + private readonly List _FunctionsList = new List(); + + /// + /// Gets the collection of items for the shader + /// + public IReadOnlyList FunctionsList => _FunctionsList; + /// /// Loads and processes an input /// @@ -243,16 +255,22 @@ private void LoadFieldInfo(ReadableMember memberInfo, string? name = null, IRead private void LoadMethodSource() { // Decompile the shader method - MethodDecompiler.Instance.GetSyntaxTree(Action.Method, out MethodDeclarationSyntax root, out SemanticModel semanticModel); + MethodDecompiler.Instance.GetSyntaxTree(Action.Method, MethodType.Closure, out MethodDeclarationSyntax root, out SemanticModel semanticModel); // Rewrite the shader method (eg. to fix the type declarations) ShaderSyntaxRewriter syntaxRewriter = new ShaderSyntaxRewriter(semanticModel); root = (MethodDeclarationSyntax)syntaxRewriter.Visit(root); - // Register the captured static fields - foreach (var item in syntaxRewriter.StaticMembers) + // Register the captured static members + foreach (var member in syntaxRewriter.StaticMembers) { - LoadFieldInfo(item.Value, item.Key); + LoadFieldInfo(member.Value, member.Key); + } + + // Register the captured static methods + foreach (var method in syntaxRewriter.StaticMethods) + { + LoadStaticMethodSource(method.Key, method.Value); } // Get the thread ids identifier name and shader method body @@ -264,5 +282,59 @@ private void LoadMethodSource() MethodBody = MethodBody.TrimEnd('\n', '\r', ' '); MethodBody = HlslKnownKeywords.GetMappedText(MethodBody); } + + /// + /// Loads additional static methods used by the shader + /// + /// The HLSL name of the new method to load + /// The instance for the method to load + private void LoadStaticMethodSource(string name, MethodInfo methodInfo) + { + // Decompile the target method + MethodDecompiler.Instance.GetSyntaxTree(methodInfo, MethodType.Static, out MethodDeclarationSyntax root, out SemanticModel semanticModel); + + // Rewrite the method + ShaderSyntaxRewriter syntaxRewriter = new ShaderSyntaxRewriter(semanticModel); + root = (MethodDeclarationSyntax)syntaxRewriter.Visit(root); + + // Register the captured static members + foreach (var member in syntaxRewriter.StaticMembers) + { + LoadFieldInfo(member.Value, member.Key); + } + + // Register the captured static methods + foreach (var method in syntaxRewriter.StaticMethods) + { + LoadStaticMethodSource(method.Key, method.Value); + } + + // Get the function parameters + IReadOnlyList parameters = ( + from parameter in root.ParameterList.Parameters.Select((p, i) => (Node: p, Index: i)) + let modifiers = parameter.Node.Modifiers + let type = parameter.Node.Type.ToFullString() + let parameterName = parameter.Node.Identifier.ToFullString() + let last = parameter.Index == root.ParameterList.Parameters.Count - 1 + select new ParameterInfo(modifiers, type, parameterName, last)).ToArray(); + + // Get the function body + string body = root.Body.ToFullString(); + body = Regex.Replace(body, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value); + body = body.TrimEnd('\n', '\r', ' '); + body = HlslKnownKeywords.GetMappedText(body); + + // Get the final function info instance + FunctionInfo functionInfo = new FunctionInfo( + methodInfo.ReturnType, + $"{methodInfo.DeclaringType.FullName}{Type.Delimiter}{methodInfo.Name}", + string.Join(", ", methodInfo.GetParameters().Select(p => $"{p.ParameterType.ToFriendlyString()} {p.Name}")), + root.ReturnType.ToFullString(), + name, + parameters, + body); + + _FunctionsList.Add(functionInfo); + } } } diff --git a/src/ComputeSharp.Shaders/Translation/ShaderSyntaxRewriter.cs b/src/ComputeSharp.Shaders/Translation/ShaderSyntaxRewriter.cs index 0cb3f121c..c92b2e374 100644 --- a/src/ComputeSharp.Shaders/Translation/ShaderSyntaxRewriter.cs +++ b/src/ComputeSharp.Shaders/Translation/ShaderSyntaxRewriter.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Reflection; using ComputeSharp.Shaders.Extensions; using ComputeSharp.Shaders.Translation.Models; using Microsoft.CodeAnalysis; @@ -30,6 +31,13 @@ internal class ShaderSyntaxRewriter : CSharpSyntaxRewriter /// public IReadOnlyDictionary StaticMembers => _StaticMembers; + private readonly Dictionary _StaticMethods = new Dictionary(); + + /// + /// Gets the mapping of captured static methods used by the target code + /// + public IReadOnlyDictionary StaticMethods => _StaticMethods; + /// public override SyntaxNode VisitParameter(ParameterSyntax node) { @@ -40,6 +48,19 @@ public override SyntaxNode VisitParameter(ParameterSyntax node) return node; } + /// + public override SyntaxNode VisitArgument(ArgumentSyntax node) + { + node = (ArgumentSyntax)base.VisitArgument(node); + + if (node.RefKindKeyword.IsKind(SyntaxKind.RefKeyword)) + { + node = node.WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.None)); + } + + return node; + } + /// public override SyntaxNode VisitCastExpression(CastExpressionSyntax node) { @@ -83,15 +104,33 @@ public override SyntaxNode VisitDefaultExpression(DefaultExpressionSyntax node) public override SyntaxNode VisitMemberAccessExpression(MemberAccessExpressionSyntax node) { node = (MemberAccessExpressionSyntax)base.VisitMemberAccessExpression(node); - SyntaxNode syntaxNode = node.ReplaceMember(SemanticModel, out var variable); + SyntaxNode syntaxNode = node.ReplaceMember(SemanticModel, out var variable, out var method); - // Register the captured member, if any + // Register the captured members, if any if (variable.HasValue && !_StaticMembers.ContainsKey(variable.Value.Name)) { _StaticMembers.Add(variable.Value.Name, variable.Value.MemberInfo); } + if (method.HasValue && !_StaticMethods.ContainsKey(method.Value.Name)) + { + _StaticMethods.Add(method.Value.Name, method.Value.MethodInfo); + } return syntaxNode; } + + /// + public override SyntaxNode VisitMethodDeclaration(MethodDeclarationSyntax node) + { + node = (MethodDeclarationSyntax)base.VisitMethodDeclaration(node); + + // Replace the return type node, if needed + if (!node.ReturnType.ToString().Equals("void")) + { + return node.ReplaceType(node.ReturnType); + } + + return node; + } } } diff --git a/tests/ComputeSharp.Tests/StaticMethodsTests.cs b/tests/ComputeSharp.Tests/StaticMethodsTests.cs new file mode 100644 index 000000000..38103f083 --- /dev/null +++ b/tests/ComputeSharp.Tests/StaticMethodsTests.cs @@ -0,0 +1,95 @@ +using System; +using System.Diagnostics.Contracts; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ComputeSharp.Tests +{ + /// + /// A container for static methods to test + /// + public static class StaticMethodsContainer + { + [Pure] + public static float Square(float x) => x * x; + + [Pure] + public static Float4 Range(float x) => new Float4(x, x + 1, x + 2, x + 3); + + [Pure] + public static int Sum(Float4 x) => (int)(x.X + x.Y + x.Z + x.W); + + public static void Assign(int x, out int y) => y = x; + + public static void ReadAndSquare(ref int x) => x *= x; + } + + [TestClass] + [TestCategory("StaticMethodsTests")] + public class StaticMethodsTests + { + [TestMethod] + public void FloatToFloatFunc() + { + using ReadWriteBuffer buffer = Gpu.Default.AllocateReadWriteBuffer(1); + + Gpu.Default.For(1, id => buffer[0] = StaticMethodsContainer.Square(3)); + + float[] result = buffer.GetData(); + + Assert.IsTrue(MathF.Abs(result[0] - 9) < 0.0001f); + } + + [TestMethod] + public void FloatToFloat4Func() + { + using ReadWriteBuffer buffer = Gpu.Default.AllocateReadWriteBuffer(1); + + Gpu.Default.For(1, id => buffer[0] = StaticMethodsContainer.Range(3)); + + Float4[] result = buffer.GetData(); + + Assert.IsTrue(MathF.Abs(result[0].X - 3) < 0.0001f); + Assert.IsTrue(MathF.Abs(result[0].Y - 4) < 0.0001f); + Assert.IsTrue(MathF.Abs(result[0].Z - 5) < 0.0001f); + Assert.IsTrue(MathF.Abs(result[0].W - 6) < 0.0001f); + } + + [TestMethod] + public void Float4ToIntFunc() + { + using ReadWriteBuffer buffer = Gpu.Default.AllocateReadWriteBuffer(1); + + Gpu.Default.For(1, id => buffer[0] = StaticMethodsContainer.Sum(new Float4(1, 2, 3, 14))); + + int[] result = buffer.GetData(); + + Assert.IsTrue(result[0] == 20); + } + + [TestMethod] + public void IntToOutIntFunc() + { + using ReadWriteBuffer buffer = Gpu.Default.AllocateReadWriteBuffer(1); + + Gpu.Default.For(1, id => StaticMethodsContainer.Assign(7, out buffer[0])); + + int[] result = buffer.GetData(); + + Assert.IsTrue(result[0] == 7); + } + + [TestMethod] + public void IntToRefIntFunc() + { + int[] data = { 3 }; + using ReadWriteBuffer buffer = Gpu.Default.AllocateReadWriteBuffer(data); + + Gpu.Default.For(1, id => StaticMethodsContainer.ReadAndSquare(ref buffer[0])); + + buffer.GetData(data); + + Assert.IsTrue(data[0] == 9); + } + } +} +