diff options
Diffstat (limited to 'test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java')
-rw-r--r-- | test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java | 107 |
1 files changed, 97 insertions, 10 deletions
diff --git a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java index 7a1ed7aa..e5bcb1c5 100644 --- a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java +++ b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java @@ -3,14 +3,24 @@ */ package com.jogamp.opencl.util.concurrent; +import com.jogamp.common.nio.Buffers; +import com.jogamp.opencl.CLBuffer; +import com.jogamp.opencl.CLCommandQueue; import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLDevice; +import com.jogamp.opencl.CLKernel; import com.jogamp.opencl.CLPlatform; +import com.jogamp.opencl.util.concurrent.CLQueueContext.CLSimpleQueueContext; import com.jogamp.opencl.util.concurrent.CLQueueContextFactory.CLSimpleContextFactory; +import java.nio.IntBuffer; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import org.junit.Rule; import org.junit.rules.MethodRule; import org.junit.rules.Timeout; import com.jogamp.opencl.util.CLMultiContext; +import java.nio.Buffer; +import java.util.ArrayList; import java.util.List; import org.junit.Test; @@ -52,27 +62,97 @@ public class CLMultiContextTest { } private final static String programSource = - "kernel void vectorAdd(global const int* a, global const int* b, global int* c, int iNumElements) { \n" - + " int iGID = get_global_id(0); \n" - + " if (iGID >= iNumElements) { \n" - + " return; \n" - + " } \n" - + " c[iGID] = a[iGID] + b[iGID]; \n" - + "} \n"; + "kernel void increment(global int* array, int numElements) { \n" + + " int index = get_global_id(0); \n" + + " if (index >= numElements) { \n" + + " return; \n" + + " } \n" + + " array[index]++; \n" + + "} \n"; + + private final class CLTestTask implements CLTask<CLSimpleQueueContext, Buffer> { + + private final Buffer data; + + public CLTestTask(Buffer buffer) { + this.data = buffer; + } + + public Buffer execute(CLSimpleQueueContext qc) { + + CLCommandQueue queue = qc.getQueue(); + CLContext context = qc.getCLContext(); + CLKernel kernel = qc.getKernel("increment"); + + CLBuffer<Buffer> buffer = null; + try{ + buffer = context.createBuffer(data); + int gws = buffer.getCLCapacity(); + + kernel.putArg(buffer).putArg(gws).rewind(); + + queue.putWriteBuffer(buffer, true); + queue.put1DRangeKernel(kernel, 0, gws, 0); + queue.putReadBuffer(buffer, true); + }finally{ + if(buffer != null) { + buffer.release(); + } + } + + return data; + } + + } @Test - public void commandQueuePoolTest() { + public void commandQueuePoolTest() throws InterruptedException, ExecutionException { CLMultiContext mc = CLMultiContext.create(CLPlatform.listCLPlatforms()); try { CLSimpleContextFactory factory = CLQueueContextFactory.createSimple(programSource); - CLCommandQueuePool pool = CLCommandQueuePool.create(factory, mc); + CLCommandQueuePool<CLSimpleQueueContext> pool = CLCommandQueuePool.create(factory, mc); assertTrue(pool.getSize() > 0); + + final int slice = 64; + final int tasksPerQueue = 10; + final int taskCount = pool.getSize() * tasksPerQueue; + + IntBuffer data = Buffers.newDirectIntBuffer(slice*taskCount); + + List<CLTestTask> tasks = new ArrayList<CLTestTask>(taskCount); + + for (int i = 0; i < taskCount; i++) { + IntBuffer subBuffer = Buffers.slice(data, i*slice, slice); + assertEquals(slice, subBuffer.capacity()); + tasks.add(new CLTestTask(subBuffer)); + } + + out.println("invoking "+tasks.size()+" tasks on "+pool.getSize()+" queues"); + + pool.invokeAll(tasks); + checkBuffer(1, data); + + + for (CLTestTask task : tasks) { + pool.submit(task).get(); + } + checkBuffer(2, data); + + + List<Future<Buffer>> futures = new ArrayList<Future<Buffer>>(taskCount); + for (CLTestTask task : tasks) { + futures.add(pool.submit(task)); + } + for (Future<Buffer> future : futures) { + future.get(); + } + checkBuffer(3, data); - pool.switchContext(factory); +// pool.switchContext(factory); pool.release(); }finally{ @@ -80,4 +160,11 @@ public class CLMultiContextTest { } } + private void checkBuffer(int expected, IntBuffer data) { + while(data.hasRemaining()) { + assertEquals(expected, data.get()); + } + data.rewind(); + } + } |