Skip to content

Commit 40624dd

Browse files
committed
ScalaCL: added customCode function to allow for execution of hand-written OpenCL kernels
1 parent 39bf97b commit 40624dd

File tree

6 files changed

+118
-5
lines changed

6 files changed

+118
-5
lines changed

Core/src/main/velocity/com/nativelibs4java/opencl/CLContext.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ public String toString() {
206206
public CLQueue createDefaultOutOfOrderQueueIfPossible() {
207207
try {
208208
return createDefaultOutOfOrderQueue();
209-
} catch (CLException.InvalidQueueProperties ex) {
209+
} catch (Throwable th) {//CLException.InvalidQueueProperties ex) {
210210
return createDefaultQueue();
211211
}
212212
}

ScalaCL/src/main/scala/scalacl/ScalaCL.scala

+13-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ package scalacl {
1212
override def canEqual(that: Any) = getClass.isInstance(that.asInstanceOf[AnyRef])
1313
}
1414
class Context(val context: CLContext, val queue: CLQueue) extends AbstractProduct {
15-
def this(context: CLContext) = this(context, context.createDefaultOutOfOrderQueueIfPossible())
15+
def this(context: CLContext) = this(context, context.createDefaultQueue())//createDefaultOutOfOrderQueueIfPossible())
1616
def this() = this(JavaCL.createBestContext(CLPlatform.DeviceFeature.OutOfOrderQueueSupport, CLPlatform.DeviceFeature.MaxComputeUnits))
1717

1818
def release = {
@@ -74,6 +74,18 @@ package object scalacl {
7474
val OutOfOrderQueueSupport = CLPlatform.DeviceFeature.OutOfOrderQueueSupport
7575
val MostImageFormats = CLPlatform.DeviceFeature.MostImageFormats
7676

77+
def customCode(
78+
source: String,
79+
compilerArguments: Array[String] = Array(),
80+
macros: Map[String, String] = Map()
81+
): CLCode = {
82+
new CLSimpleCode(
83+
Array(source),
84+
compilerArguments,
85+
macros
86+
)
87+
}
88+
7789
private[scalacl] def reuse[T](value: Any, create: => T): T =
7890
if (value != null && value.isInstanceOf[T])
7991
value.asInstanceOf[T]

ScalaCL/src/main/scala/scalacl/impl/CLCode.scala

+50
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,54 @@ trait CLCode {
6666

6767
override def hashCode = hc
6868
override def equals(o: Any) = o.isInstanceOf[CLCode] && strs.equals(o.asInstanceOf[CLCode].strs)
69+
70+
/**
71+
* Execute the code using the provided arguments (may be CLArray instances, CLGuardedBuffer instances, CLBuffer instances or primitive values), using the provided global sizes (and optional local group sizes) and respecting the execution order of arguments declared as "read from" and "written to".<br>
72+
* Example :
73+
* <code>
74+
import scalacl._
75+
implicit val context = Context.best(CPU)
76+
val n = 100
77+
val f = 0.5f
78+
val sinCosOutputs: CLArray[Float] = new CLArray[Float](2 * n)
79+
val sinCosCode = customCode("""
80+
__kernel void sinCos(__global float2* outputs, float f) {
81+
int i = get_global_id(0);
82+
float c, s = sincos(i * f, &c);
83+
outputs[i] = (float2)(s, c);
84+
}
85+
""")
86+
sinCosCode.execute(
87+
args = Array(sinCosOutputs, f),
88+
writes = Array(sinCosOutputs),
89+
globalSizes = Array(n)
90+
)
91+
val resCL = sinCosOutputs.toArray
92+
* </code>
93+
*/
94+
def execute(
95+
args: Array[Any],
96+
globalSizes: Array[Int],
97+
localSizes: Array[Int] = null,
98+
reads: Array[CLEventBoundContainer] = Array(),
99+
writes: Array[CLEventBoundContainer] = Array(),
100+
kernelName: String = null
101+
)(implicit context: Context): Unit = {
102+
val flatArgs: Array[Object] = args.flatMap(_ match {
103+
case b: CLArray[_] =>
104+
b.buffers.map(_.buffer.asInstanceOf[Object])
105+
case b: CLGuardedBuffer[_] =>
106+
Array[Object](b.buffer)
107+
case a =>
108+
Array[Object](a.asInstanceOf[Object])
109+
})
110+
val kernel = getKernel(context, kernelName)
111+
kernel.synchronized {
112+
//println("flatArgs = " + flatArgs.map(a => a + ": " + a.getClass.getSimpleName).mkString(", "))
113+
kernel.setArgs(flatArgs:_*)
114+
CLEventBound.syncBlock(CLEventBound.flatten(reads), CLEventBound.flatten(writes), evts => {
115+
kernel.enqueueNDRange(context.queue, globalSizes, localSizes, evts:_*)
116+
})
117+
}
118+
}
69119
}

ScalaCL/src/main/scala/scalacl/impl/CLEventBound.scala

+8-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ trait CLEventBound extends CLEventBoundContainer {
6060
}
6161

6262
object CLEventBound {
63+
def flatten(containers: Array[CLEventBoundContainer]) =
64+
containers.flatMap(_.eventBoundComponents)
65+
6366
def syncBlock(reads: Array[CLEventBound], writes: Array[CLEventBound], action: Array[CLEvent] => CLEvent): CLEvent = {
6467

6568
def recursiveSync(ebs: List[(CLEventBound, Boolean)], evts: ArrayBuilder[CLEvent]): CLEvent = {
@@ -94,6 +97,10 @@ object CLEventBound {
9497
lb += ((eb, false))
9598
for (eb <- writes)
9699
lb += ((eb, true))
97-
recursiveSync(lb.result, Array.newBuilder[CLEvent])
100+
101+
if (lb.isEmpty)
102+
action(Array())
103+
else
104+
recursiveSync(lb.result, Array.newBuilder[CLEvent])
98105
}
99106
}

ScalaCL/src/main/scala/scalacl/impl/CLFunction.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ trait CLRunnable {
1717
if (dims.sum > 0) {
1818
lazy val defaultContainers = args collect { case c: CLEventBoundContainer => c }
1919
CLEventBound.syncBlock(
20-
Option(reads).getOrElse(defaultContainers).flatMap(_.eventBoundComponents),
21-
Option(writes).getOrElse(defaultContainers).flatMap(_.eventBoundComponents),
20+
CLEventBound.flatten(Option(reads).getOrElse(defaultContainers)),
21+
CLEventBound.flatten(Option(writes).getOrElse(defaultContainers)),
2222
evts => {
2323
run(dims = dims, args = args, eventsToWaitFor = evts)
2424
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package scalacl
2+
3+
import impl._
4+
5+
import com.nativelibs4java.opencl._
6+
import org.bridj.Pointer
7+
import org.bridj.Pointer._
8+
9+
import org.junit._
10+
import Assert._
11+
12+
import scala.math._
13+
14+
class CLCustomCodeTest {
15+
16+
@Test
17+
def testSinCos = {
18+
import scalacl._
19+
implicit val context = Context.best(CPU)
20+
val n = 100
21+
val f = 0.5f
22+
val sinCosOutputs: CLArray[Float] = new CLArray[Float](2 * n)
23+
val sinCosCode = customCode("""
24+
__kernel void sinCos(__global float2* outputs, float f) {
25+
int i = get_global_id(0);
26+
float c, s = sincos(i * f, &c);
27+
outputs[i] = (float2)(s, c);
28+
}
29+
""")
30+
sinCosCode.execute(
31+
args = Array(sinCosOutputs, f),
32+
writes = Array(sinCosOutputs),
33+
globalSizes = Array(n)
34+
)
35+
val resCL = sinCosOutputs.toArray
36+
37+
val resJava = (0 until n).flatMap(i => {
38+
val x = i * f
39+
Array(sin(x).toFloat, cos(x).toFloat)
40+
}).toArray
41+
42+
assertArrayEquals(resJava, resCL, 0.00001f)
43+
}
44+
}

0 commit comments

Comments
 (0)