diff options
author | Michael Bien <[email protected]> | 2011-07-07 23:32:28 +0200 |
---|---|---|
committer | Michael Bien <[email protected]> | 2011-07-07 23:32:28 +0200 |
commit | 4fe7110357d2631960e23861a3221489d313c467 (patch) | |
tree | 40f1ff4ddef2cd151e22b7d1c496ce3f6b4d9e76 | |
parent | 25b19e390a0a6a3cae8d129a579f16ffd5d4e2e5 (diff) |
CLKernel utility methods for setting vector arguments with up to 4 elements + test.
-rw-r--r-- | src/com/jogamp/opencl/CLKernel.java | 291 | ||||
-rw-r--r-- | test/com/jogamp/opencl/CLProgramTest.java | 74 | ||||
-rw-r--r-- | test/com/jogamp/opencl/TestUtils.java | 21 |
3 files changed, 372 insertions, 14 deletions
diff --git a/src/com/jogamp/opencl/CLKernel.java b/src/com/jogamp/opencl/CLKernel.java index 8a3a44b9..7d5751fd 100644 --- a/src/com/jogamp/opencl/CLKernel.java +++ b/src/com/jogamp/opencl/CLKernel.java @@ -45,7 +45,24 @@ import static com.jogamp.common.os.Platform.*; * applied to any function in a program. A kernel object encapsulates the specific <code>kernel</code> * function declared in a program and the argument values to be used when executing this * <code>kernel</code> function. - * CLKernel is not threadsafe. + * <p> + * Example: + * <pre> + * CLKernel addKernel = program.createCLKernel("add"); + * addKernel.setArgs(clBufferA, clBufferB); + * ... + * queue.putEnqueue1DKernel(addKernel, 0, clBufferA.getSize(), 0); + * </pre> + * CLKernel provides utility methods for setting vector types (float4, int2...) with up to 4 elements. Larger + * vectors like float16 can be set using {@link #setArg(int, java.nio.Buffer)}. + * + * Arguments pointing to {@link CLBuffer}s or {@link CLImage}s can be set using {@link #setArg(int, com.jogamp.opencl.CLMemory) } + * or its relative putArg(..) methods. + * </p> + * <p> + * CLKernel is not threadsafe. However it is perfectly safe to create a new instance of a CLKernel for every + * involved Thread. + * </p> * @see CLProgram#createCLKernel(java.lang.String) * @see CLProgram#createCLKernels() * @author Michael Bien @@ -71,7 +88,7 @@ public class CLKernel extends CLObjectResource implements Cloneable { super(program.getContext(), id); this.program = program; - this.buffer = Buffers.newDirectByteBuffer((is32Bit()?4:8)*3); + this.buffer = Buffers.newDirectByteBuffer(8*4); binding = program.getPlatform().getKernelBinding(); @@ -99,10 +116,10 @@ public class CLKernel extends CLObjectResource implements Cloneable { } -// public CLKernel putArg(Buffer value) { -// setArg(argIndex++, value); -// return this; -// } + public CLKernel putArg(Buffer value) { + setArg(argIndex++, value); + return this; + } public CLKernel putArg(CLMemory<?> value) { setArg(argIndex, value); @@ -116,30 +133,120 @@ public class CLKernel extends CLObjectResource implements Cloneable { return this; } + public CLKernel putArg(short x, short y) { + setArg(argIndex, x, y); + argIndex++; + return this; + } + + public CLKernel putArg(short x, short y, short z) { + setArg(argIndex, x, y, z); + argIndex++; + return this; + } + + public CLKernel putArg(short x, short y, short z, short w) { + setArg(argIndex, x, y, z, w); + argIndex++; + return this; + } + public CLKernel putArg(int value) { setArg(argIndex, value); argIndex++; return this; } + public CLKernel putArg(int x, int y) { + setArg(argIndex, x, y); + argIndex++; + return this; + } + + public CLKernel putArg(int x, int y, int z) { + setArg(argIndex, x, y, z); + argIndex++; + return this; + } + + public CLKernel putArg(int x, int y, int z, int w) { + setArg(argIndex, x, y, z, w); + argIndex++; + return this; + } + public CLKernel putArg(long value) { setArg(argIndex, value); argIndex++; return this; } + public CLKernel putArg(long x, long y) { + setArg(argIndex, x, y); + argIndex++; + return this; + } + + public CLKernel putArg(long x, long y, long z) { + setArg(argIndex, x, y, z); + argIndex++; + return this; + } + + public CLKernel putArg(long x, long y, long z, long w) { + setArg(argIndex, x, y, z, w); + argIndex++; + return this; + } + public CLKernel putArg(float value) { setArg(argIndex, value); argIndex++; return this; } + public CLKernel putArg(float x, float y) { + setArg(argIndex, x, y); + argIndex++; + return this; + } + + public CLKernel putArg(float x, float y, float z) { + setArg(argIndex, x, y, z); + argIndex++; + return this; + } + + public CLKernel putArg(float x, float y, float z, float w) { + setArg(argIndex, x, y, z, w); + argIndex++; + return this; + } + public CLKernel putArg(double value) { setArg(argIndex, value); argIndex++; return this; } + public CLKernel putArg(double x, double y) { + setArg(argIndex, x, y); + argIndex++; + return this; + } + + public CLKernel putArg(double x, double y, double z) { + setArg(argIndex, x, y, z); + argIndex++; + return this; + } + + public CLKernel putArg(double x, double y, double z, double w) { + setArg(argIndex, x, y, z, w); + argIndex++; + return this; + } + public CLKernel putNullArg(int size) { setNullArg(argIndex, size); argIndex++; @@ -167,10 +274,13 @@ public class CLKernel extends CLObjectResource implements Cloneable { return argIndex; } -// public CLKernel setArg(int argumentIndex, Buffer value) { -// setArgument(argumentIndex, CLMemory.sizeOfBufferElem(value)*value.capacity(), value); -// return this; -// } + public CLKernel setArg(int argumentIndex, Buffer value) { + if(!value.isDirect()) { + throw new IllegalArgumentException("buffer must be direct."); + } + setArgument(argumentIndex, Buffers.sizeOfBufferElem(value)*value.remaining(), value); + return this; + } public CLKernel setArg(int argumentIndex, CLMemory<?> value) { setArgument(argumentIndex, is32Bit()?4:8, wrap(value.ID)); @@ -182,11 +292,41 @@ public class CLKernel extends CLObjectResource implements Cloneable { return this; } + public CLKernel setArg(int argumentIndex, short x, short y) { + setArgument(argumentIndex, 2*2, wrap(x, y)); + return this; + } + + public CLKernel setArg(int argumentIndex, short x, short y, short z) { + setArgument(argumentIndex, 2*3, wrap(x, y, z)); + return this; + } + + public CLKernel setArg(int argumentIndex, short x, short y, short z, short w) { + setArgument(argumentIndex, 2*4, wrap(x, y, z, w)); + return this; + } + public CLKernel setArg(int argumentIndex, int value) { setArgument(argumentIndex, 4, wrap(value)); return this; } + public CLKernel setArg(int argumentIndex, int x, int y) { + setArgument(argumentIndex, 4*2, wrap(x, y)); + return this; + } + + public CLKernel setArg(int argumentIndex, int x, int y, int z) { + setArgument(argumentIndex, 4*3, wrap(x, y, z)); + return this; + } + + public CLKernel setArg(int argumentIndex, int x, int y, int z, int w) { + setArgument(argumentIndex, 4*4, wrap(x, y, z, w)); + return this; + } + public CLKernel setArg(int argumentIndex, long value) { if(force32BitArgs) { setArgument(argumentIndex, 4, wrap((int)value)); @@ -196,11 +336,53 @@ public class CLKernel extends CLObjectResource implements Cloneable { return this; } + public CLKernel setArg(int argumentIndex, long x, long y) { + if(force32BitArgs) { + setArgument(argumentIndex, 4*2, wrap((int)x, (int)y)); + }else{ + setArgument(argumentIndex, 8*2, wrap(x, y)); + } + return this; + } + + public CLKernel setArg(int argumentIndex, long x, long y, long z) { + if(force32BitArgs) { + setArgument(argumentIndex, 4*3, wrap((int)x, (int)y, (int)z)); + }else{ + setArgument(argumentIndex, 8*3, wrap(x, y, z)); + } + return this; + } + + public CLKernel setArg(int argumentIndex, long x, long y, long z, long w) { + if(force32BitArgs) { + setArgument(argumentIndex, 4*4, wrap((int)x, (int)y, (int)z, (int)w)); + }else{ + setArgument(argumentIndex, 8*4, wrap(x, y, z, w)); + } + return this; + } + public CLKernel setArg(int argumentIndex, float value) { setArgument(argumentIndex, 4, wrap(value)); return this; } + public CLKernel setArg(int argumentIndex, float x, float y) { + setArgument(argumentIndex, 4*2, wrap(x, y)); + return this; + } + + public CLKernel setArg(int argumentIndex, float x, float y, float z) { + setArgument(argumentIndex, 4*3, wrap(x, y, z)); + return this; + } + + public CLKernel setArg(int argumentIndex, float x, float y, float z, float w) { + setArgument(argumentIndex, 4*4, wrap(x, y, z, w)); + return this; + } + public CLKernel setArg(int argumentIndex, double value) { if(force32BitArgs) { setArgument(argumentIndex, 4, wrap((float)value)); @@ -210,6 +392,33 @@ public class CLKernel extends CLObjectResource implements Cloneable { return this; } + public CLKernel setArg(int argumentIndex, double x, double y) { + if(force32BitArgs) { + setArgument(argumentIndex, 4*2, wrap((float)x, (float)y)); + }else{ + setArgument(argumentIndex, 8*2, wrap(x, y)); + } + return this; + } + + public CLKernel setArg(int argumentIndex, double x, double y, double z) { + if(force32BitArgs) { + setArgument(argumentIndex, 4*3, wrap((float)x, (float)y, (float)z)); + }else{ + setArgument(argumentIndex, 8*3, wrap(x, y, z)); + } + return this; + } + + public CLKernel setArg(int argumentIndex, double x, double y, double z, double w) { + if(force32BitArgs) { + setArgument(argumentIndex, 4*4, wrap((float)x, (float)y, (float)z, (float)w)); + }else{ + setArgument(argumentIndex, 8*4, wrap(x, y, z, w)); + } + return this; + } + public CLKernel setNullArg(int argumentIndex, int size) { setArgument(argumentIndex, size, null); return this; @@ -238,6 +447,8 @@ public class CLKernel extends CLObjectResource implements Cloneable { setArg(i, (Float)value); }else if(value instanceof Double) { setArg(i, (Double)value); + }else if(value instanceof Buffer) { + setArg(i, (Buffer)value); }else{ throw new IllegalArgumentException(value + " is not a valid argument."); } @@ -291,22 +502,82 @@ public class CLKernel extends CLObjectResource implements Cloneable { return buffer.putFloat(0, value); } + private Buffer wrap(float a, float b) { + return buffer.putFloat(0, a).putFloat(4, b); + } + + private Buffer wrap(float a, float b, float c) { + return buffer.putFloat(0, a).putFloat(4, b).putFloat(8, c); + } + + private Buffer wrap(float a, float b, float c, float d) { + return buffer.putFloat(0, a).putFloat(4, b).putFloat(8, c).putFloat(12, d); + } + private Buffer wrap(double value) { return buffer.putDouble(0, value); } + private Buffer wrap(double a, double b) { + return buffer.putDouble(0, a).putDouble(8, b); + } + + private Buffer wrap(double a, double b, double c) { + return buffer.putDouble(0, a).putDouble(8, b).putDouble(16, c); + } + + private Buffer wrap(double a, double b, double c, double d) { + return buffer.putDouble(0, a).putDouble(8, b).putDouble(16, c).putDouble(24, d); + } + private Buffer wrap(short value) { return buffer.putShort(0, value); } + private Buffer wrap(short a, short b) { + return buffer.putShort(0, a).putShort(2, b); + } + + private Buffer wrap(short a, short b, short c) { + return buffer.putShort(0, a).putShort(2, b).putShort(4, c); + } + + private Buffer wrap(short a, short b, short c, short d) { + return buffer.putShort(0, a).putShort(2, b).putShort(4, c).putShort(6, d); + } + private Buffer wrap(int value) { return buffer.putInt(0, value); } + private Buffer wrap(int a, int b) { + return buffer.putInt(0, a).putInt(4, b); + } + + private Buffer wrap(int a, int b, int c) { + return buffer.putInt(0, a).putInt(4, b).putInt(8, c); + } + + private Buffer wrap(int a, int b, int c, int d) { + return buffer.putInt(0, a).putInt(4, b).putInt(8, c).putInt(12, d); + } + private Buffer wrap(long value) { return buffer.putLong(0, value); } + private Buffer wrap(long a, long b) { + return buffer.putLong(0, a).putLong(8, b); + } + + private Buffer wrap(long a, long b, long c) { + return buffer.putLong(0, a).putLong(8, b).putLong(16, c); + } + + private Buffer wrap(long a, long b, long c, long d) { + return buffer.putLong(0, a).putLong(8, b).putLong(16, c).putLong(24, d); + } + /** * Returns the amount of local memory in bytes being used by a kernel. * This includes local memory that may be needed by an implementation to execute the kernel, diff --git a/test/com/jogamp/opencl/CLProgramTest.java b/test/com/jogamp/opencl/CLProgramTest.java index d083c770..3c8ef8ba 100644 --- a/test/com/jogamp/opencl/CLProgramTest.java +++ b/test/com/jogamp/opencl/CLProgramTest.java @@ -28,6 +28,7 @@ package com.jogamp.opencl; +import com.jogamp.common.nio.Buffers; import com.jogamp.opencl.util.CLBuildConfiguration; import com.jogamp.opencl.util.CLProgramConfiguration; import com.jogamp.opencl.CLProgram.Status; @@ -39,7 +40,9 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.nio.FloatBuffer; import java.util.Map; +import java.util.Random; import java.util.concurrent.CountDownLatch; import org.junit.Rule; import org.junit.Test; @@ -313,7 +316,76 @@ public class CLProgramTest { context.release(); } - } + } + + @Test + public void kernelVectorArgsTest() { + + String source = + "kernel void vector(global float * out,\n" + + " const float v1,\n" + + " const float2 v2,\n" + + "// const float3 v3,\n" // nv does not support float3 + + " const float4 v4,\n" + + " const float8 v8) {\n" + + " out[0] = v1;\n" + + + " out[1] = v2.x;\n" + + " out[2] = v2.y;\n" + + + " out[3] = v4.x;\n" + + " out[4] = v4.y;\n" + + " out[5] = v4.z;\n" + + " out[6] = v4.w;\n" + + + " out[ 7] = v8.s0;\n" + + " out[ 8] = v8.s1;\n" + + " out[ 9] = v8.s2;\n" + + " out[10] = v8.s3;\n" + + " out[11] = v8.s4;\n" + + " out[12] = v8.s5;\n" + + " out[13] = v8.s6;\n" + + " out[14] = v8.s7;\n" + + "}\n"; + + CLContext context = CLContext.create(); + + try{ + CLProgram program = context.createProgram(source).build(); + CLKernel kernel = program.createCLKernel("vector"); + + CLBuffer<FloatBuffer> buffer = context.createFloatBuffer(15, CLBuffer.Mem.WRITE_ONLY); + + final int seed = 7; + Random rnd = new Random(seed); + + kernel.putArg(buffer); + kernel.putArg(rnd.nextFloat()); + kernel.putArg(rnd.nextFloat(), rnd.nextFloat()); +// kernel.putArg(rnd.nextFloat(), rnd.nextFloat(), rnd.nextFloat()); // nv does not support float3 + kernel.putArg(rnd.nextFloat(), rnd.nextFloat(), rnd.nextFloat(), rnd.nextFloat()); + kernel.putArg(TestUtils.fillBuffer(Buffers.newDirectFloatBuffer(8), seed)); + + CLCommandQueue queue = context.getMaxFlopsDevice().createCommandQueue(); + queue.putTask(kernel).putReadBuffer(buffer, true); + + FloatBuffer out = buffer.getBuffer(); + + rnd = new Random(seed); + for(int i = 0; i < 7; i++) { + assertEquals(rnd.nextFloat(), out.get(), 0.01f); + } + + rnd = new Random(seed); + for(int i = 0; i < 8; i++) { + assertEquals(rnd.nextFloat(), out.get(), 0.01f); + } + + }finally{ + context.release(); + } + + } @Test public void createAllKernelsTest() { diff --git a/test/com/jogamp/opencl/TestUtils.java b/test/com/jogamp/opencl/TestUtils.java index bf1fd153..efe6855e 100644 --- a/test/com/jogamp/opencl/TestUtils.java +++ b/test/com/jogamp/opencl/TestUtils.java @@ -29,6 +29,7 @@ package com.jogamp.opencl; import java.nio.ByteBuffer; +import java.nio.FloatBuffer; import java.util.Random; import static java.lang.System.*; @@ -44,7 +45,7 @@ public class TestUtils { final static int NUM_ELEMENTS = 10000000; - public static final void fillBuffer(ByteBuffer buffer, int seed) { + public static ByteBuffer fillBuffer(ByteBuffer buffer, int seed) { Random rnd = new Random(seed); @@ -52,9 +53,23 @@ public class TestUtils { buffer.putInt(rnd.nextInt()); buffer.rewind(); + + return buffer; } - public static final int roundUp(int groupSize, int globalSize) { + public static FloatBuffer fillBuffer(FloatBuffer buffer, int seed) { + + Random rnd = new Random(seed); + + while(buffer.remaining() != 0) + buffer.put(rnd.nextFloat()); + + buffer.rewind(); + + return buffer; + } + + public static int roundUp(int groupSize, int globalSize) { int r = globalSize % groupSize; if (r == 0) { return globalSize; @@ -63,7 +78,7 @@ public class TestUtils { } } - public static final void checkIfEqual(ByteBuffer a, ByteBuffer b, int elements) { + public static void checkIfEqual(ByteBuffer a, ByteBuffer b, int elements) { for(int i = 0; i < elements; i++) { int aVal = a.getInt(); int bVal = b.getInt(); |