Skip to content

Commit 9a4f586

Browse files
authored
Introduce lut encoder (#49)
* introduce lut encoder * change clamp to zero to clamp to edge
1 parent 0ca79ef commit 9a4f586

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//
2+
// LookUpTableEncoder.swift
3+
// Alloy
4+
//
5+
// Created by Andrey Volodin on 29.10.2019.
6+
//
7+
8+
import Metal
9+
10+
final public class LookUpTableEncoder {
11+
12+
// MARK: - Properties
13+
14+
public let pipelineState: MTLComputePipelineState
15+
private let deviceSupportsNonuniformThreadgroups: Bool
16+
17+
// MARK: - Life Cycle
18+
19+
public convenience init(context: MTLContext) throws {
20+
guard let library = context.shaderLibrary(for: type(of: self))
21+
else { throw CommonErrors.metalInitializationFailed }
22+
try self.init(library: library)
23+
}
24+
25+
public init(library: MTLLibrary) throws {
26+
self.deviceSupportsNonuniformThreadgroups = library.device.supports(feature: .nonUniformThreadgroups)
27+
let constantValues = MTLFunctionConstantValues()
28+
constantValues.set(self.deviceSupportsNonuniformThreadgroups,
29+
at: 0)
30+
let functionName = type(of: self).functionName
31+
self.pipelineState = try library.computePipelineState(function: functionName,
32+
constants: constantValues)
33+
}
34+
35+
// MARK: - Encode
36+
37+
public func encode(sourceTexture: MTLTexture,
38+
outputTexture: MTLTexture,
39+
lut: MTLTexture,
40+
intensity: Float,
41+
in commandBuffer: MTLCommandBuffer) {
42+
commandBuffer.compute { encoder in
43+
encoder.label = "Look Up Table Encoder"
44+
self.encode(sourceTexture: sourceTexture,
45+
outputTexture: outputTexture,
46+
lut: lut,
47+
intensity: intensity,
48+
using: encoder)
49+
}
50+
}
51+
52+
public func encode(sourceTexture: MTLTexture,
53+
outputTexture: MTLTexture,
54+
lut: MTLTexture,
55+
intensity: Float,
56+
using encoder: MTLComputeCommandEncoder) {
57+
encoder.set(textures: [sourceTexture, outputTexture, lut])
58+
encoder.set(intensity, at: 0)
59+
60+
if self.deviceSupportsNonuniformThreadgroups {
61+
encoder.dispatch2d(state: pipelineState,
62+
exactly: outputTexture.size)
63+
} else {
64+
encoder.dispatch2d(state: pipelineState,
65+
covering: outputTexture.size)
66+
}
67+
}
68+
69+
public static let functionName = "lookUpTable"
70+
}

Alloy/Shaders/Shaders.metal

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,30 @@ vertex float4 simpleVertex(constant float4* vertices [[ buffer(0) ]],
713713
fragment float4 plainColorFragment(constant float4& pointColor [[ buffer(0) ]]) {
714714
return pointColor;
715715
}
716+
717+
// MARK: Look up table
718+
719+
kernel void lookUpTable(texture2d<float, access::read> source [[ texture(0) ]],
720+
texture2d<float, access::write> destination [[ texture(1) ]],
721+
texture3d<float, access::sample> lut [[ texture(2) ]],
722+
constant float& intensity [[ buffer(0) ]],
723+
uint2 position [[thread_position_in_grid]]) {
724+
const ushort2 textureSize = ushort2(destination.get_width(),
725+
destination.get_height());
726+
checkPosition(position, textureSize, deviceSupportsNonuniformThreadgroups);
727+
728+
constexpr sampler s(coord::normalized,
729+
address::clamp_to_edge,
730+
filter::linear);
731+
732+
// read original color
733+
float4 sourceColor = source.read(position);
734+
735+
// use it to sample target color
736+
sourceColor.rgb = mix(sourceColor.rgb,
737+
lut.sample(s, sourceColor.rgb).rgb,
738+
intensity);
739+
740+
// write it to destination texture
741+
destination.write(sourceColor, position);
742+
}

0 commit comments

Comments
 (0)