diff options
author | Michael Bien <[email protected]> | 2011-05-09 03:00:55 +0200 |
---|---|---|
committer | Michael Bien <[email protected]> | 2011-05-09 03:00:55 +0200 |
commit | c59bc50229181ab9cb0e5012d7bb5caf2faa781f (patch) | |
tree | 62230d2d14861c14814d6bfcc98b7ee2e7c170fc | |
parent | dedded707fc70fda3e40cf963d208202f8d6c42b (diff) |
concurrent utils bugfixes and improvements.
- more utility methods
- generics fixes
- basic junit test for CLCommandQueuePool
- javadoc and argument validation
5 files changed, 168 insertions, 32 deletions
diff --git a/src/com/jogamp/opencl/util/CLMultiContext.java b/src/com/jogamp/opencl/util/CLMultiContext.java index f588fcef..f74c0a35 100644 --- a/src/com/jogamp/opencl/util/CLMultiContext.java +++ b/src/com/jogamp/opencl/util/CLMultiContext.java @@ -41,6 +41,13 @@ public class CLMultiContext implements CLResource { * Creates a multi context with all devices of the specified platforms and types. */ public static CLMultiContext create(CLPlatform[] platforms, CLDevice.Type... types) { + + if(platforms == null) { + throw new NullPointerException("platform list was null"); + }else if(platforms.length == 0) { + throw new IllegalArgumentException("platform list was empty"); + } + List<CLDevice> devices = new ArrayList<CLDevice>(); for (CLPlatform platform : platforms) { devices.addAll(asList(platform.listCLDevices(types))); @@ -54,6 +61,10 @@ public class CLMultiContext implements CLResource { */ public static CLMultiContext create(Collection<CLDevice> devices) { + if(devices.isEmpty()) { + throw new IllegalArgumentException("device list was empty"); + } + Map<CLPlatform, List<CLDevice>> platformDevicesMap = filterPlatformConflicts(devices); // create contexts diff --git a/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java b/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java index b80f09e6..a6bbe4d0 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java +++ b/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java @@ -15,6 +15,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; /** * A multithreaded fixed size pool of OpenCL command queues. @@ -23,7 +24,7 @@ import java.util.concurrent.ThreadFactory; * instead of {@link Callable}s. * @author Michael Bien */ -public class CLCommandQueuePool implements CLResource { +public class CLCommandQueuePool<C extends CLQueueContext> implements CLResource { private List<CLQueueContext> contexts; private ExecutorService excecutor; @@ -55,11 +56,11 @@ public class CLCommandQueuePool implements CLResource { this.excecutor = Executors.newFixedThreadPool(contexts.size(), new QueueThreadFactory(contexts)); } - public static CLCommandQueuePool create(CLQueueContextFactory factory, CLMultiContext mc, CLCommandQueue.Mode... modes) { + public static <C extends CLQueueContext> CLCommandQueuePool<C> create(CLQueueContextFactory<C> factory, CLMultiContext mc, CLCommandQueue.Mode... modes) { return create(factory, mc.getDevices(), modes); } - public static CLCommandQueuePool create(CLQueueContextFactory factory, Collection<CLDevice> devices, CLCommandQueue.Mode... modes) { + public static <C extends CLQueueContext> CLCommandQueuePool<C> create(CLQueueContextFactory<C> factory, Collection<CLDevice> devices, CLCommandQueue.Mode... modes) { List<CLCommandQueue> queues = new ArrayList<CLCommandQueue>(devices.size()); for (CLDevice device : devices) { queues.add(device.createCommandQueue(modes)); @@ -67,21 +68,43 @@ public class CLCommandQueuePool implements CLResource { return create(factory, queues); } - public static CLCommandQueuePool create(CLQueueContextFactory factory, Collection<CLCommandQueue> queues) { + public static <C extends CLQueueContext> CLCommandQueuePool create(CLQueueContextFactory<C> factory, Collection<CLCommandQueue> queues) { return new CLCommandQueuePool(factory, queues); } - public <R> Future<R> submit(CLTask<R> task) { + /** + * @see ExecutorService#submit(java.util.concurrent.Callable) + */ + public <R> Future<R> submit(CLTask<? extends C, R> task) { return excecutor.submit(new TaskWrapper(task, finishAction)); } - public <R> List<Future<R>> invokeAll(Collection<CLTask<R>> tasks) throws InterruptedException { - List<TaskWrapper<R>> wrapper = new ArrayList<TaskWrapper<R>>(tasks.size()); - for (CLTask<R> task : tasks) { - wrapper.add(new TaskWrapper<R>(task, finishAction)); - } + /** + * @see ExecutorService#invokeAll(java.util.Collection) + */ + public <R> List<Future<R>> invokeAll(Collection<? extends CLTask<? super C, R>> tasks) throws InterruptedException { + List<TaskWrapper<C, R>> wrapper = wrapTasks(tasks); return excecutor.invokeAll(wrapper); } + + /** + * @see ExecutorService#invokeAll(java.util.Collection, long, java.util.concurrent.TimeUnit) + */ + public <R> List<Future<R>> invokeAll(Collection<? extends CLTask<? super C, R>> tasks, long timeout, TimeUnit unit) throws InterruptedException { + List<TaskWrapper<C, R>> wrapper = wrapTasks(tasks); + return excecutor.invokeAll(wrapper, timeout, unit); + } + + private <R> List<TaskWrapper<C, R>> wrapTasks(Collection<? extends CLTask<? super C, R>> tasks) { + List<TaskWrapper<C, R>> wrapper = new ArrayList<TaskWrapper<C, R>>(tasks.size()); + for (CLTask<? super C, R> task : tasks) { + if(task == null) { + throw new NullPointerException("at least one task was null"); + } + wrapper.add(new TaskWrapper<C, R>(task, finishAction)); + } + return wrapper; + } /** * Switches the context of all queues - this operation can be expensive. @@ -171,35 +194,41 @@ public class CLCommandQueuePool implements CLResource { this.index = 0; } - public synchronized Thread newThread(Runnable r) { + public synchronized Thread newThread(Runnable runnable) { + + SecurityManager sm = System.getSecurityManager(); + ThreadGroup group = (sm != null)? sm.getThreadGroup() : Thread.currentThread().getThreadGroup(); + CLQueueContext queue = context.get(index); - return new QueueThread(queue, index++); + QueueThread thread = new QueueThread(group, runnable, queue, index++); + thread.setDaemon(true); + + return thread; } } private static class QueueThread extends Thread { private final CLQueueContext context; - public QueueThread(CLQueueContext context, int index) { - super("queue-worker-thread-"+index+"["+context+"]"); + public QueueThread(ThreadGroup group, Runnable runnable, CLQueueContext context, int index) { + super(group, runnable, "queue-worker-thread-"+index+"["+context+"]"); this.context = context; - this.setDaemon(true); } } - private static class TaskWrapper<T> implements Callable<T> { + private static class TaskWrapper<C extends CLQueueContext, R> implements Callable<R> { - private final CLTask<T> task; + private final CLTask<? super C, R> task; private final FinishAction mode; - public TaskWrapper(CLTask<T> task, FinishAction mode) { + public TaskWrapper(CLTask<? super C, R> task, FinishAction mode) { this.task = task; this.mode = mode; } - public T call() throws Exception { + public R call() throws Exception { CLQueueContext context = ((QueueThread)Thread.currentThread()).context; - T result = task.run(context); + R result = task.execute((C)context); if(mode.equals(FinishAction.FLUSH)) { context.queue.flush(); }else if(mode.equals(FinishAction.FINISH)) { diff --git a/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java b/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java index 3956f93d..11b86889 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java +++ b/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java @@ -4,6 +4,7 @@ package com.jogamp.opencl.util.concurrent; import com.jogamp.opencl.CLCommandQueue; +import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLKernel; import com.jogamp.opencl.CLProgram; import com.jogamp.opencl.CLResource; @@ -24,6 +25,10 @@ public abstract class CLQueueContext implements CLResource { return queue; } + public CLContext getCLContext() { + return queue.getContext(); + } + public static class CLSimpleQueueContext extends CLQueueContext { public final CLProgram program; @@ -39,6 +44,10 @@ public abstract class CLQueueContext implements CLResource { return kernels; } + public CLKernel getKernel(String name) { + return kernels.get(name); + } + public CLProgram getProgram() { return program; } diff --git a/src/com/jogamp/opencl/util/concurrent/CLTask.java b/src/com/jogamp/opencl/util/concurrent/CLTask.java index ff0f7614..0cfd24a5 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLTask.java +++ b/src/com/jogamp/opencl/util/concurrent/CLTask.java @@ -8,11 +8,11 @@ package com.jogamp.opencl.util.concurrent; * A task executed on a command queue. * @author Michael Bien */ -public interface CLTask<R> { +public interface CLTask<C extends CLQueueContext, R> { /** * Runs the task on a queue and returns a result. */ - R run(CLQueueContext context); + R execute(C context); } diff --git a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java index 7a1ed7aa..e5bcb1c5 100644 --- a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java +++ b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java @@ -3,14 +3,24 @@ */ package com.jogamp.opencl.util.concurrent; +import com.jogamp.common.nio.Buffers; +import com.jogamp.opencl.CLBuffer; +import com.jogamp.opencl.CLCommandQueue; import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLDevice; +import com.jogamp.opencl.CLKernel; import com.jogamp.opencl.CLPlatform; +import com.jogamp.opencl.util.concurrent.CLQueueContext.CLSimpleQueueContext; import com.jogamp.opencl.util.concurrent.CLQueueContextFactory.CLSimpleContextFactory; +import java.nio.IntBuffer; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import org.junit.Rule; import org.junit.rules.MethodRule; import org.junit.rules.Timeout; import com.jogamp.opencl.util.CLMultiContext; +import java.nio.Buffer; +import java.util.ArrayList; import java.util.List; import org.junit.Test; @@ -52,27 +62,97 @@ public class CLMultiContextTest { } private final static String programSource = - "kernel void vectorAdd(global const int* a, global const int* b, global int* c, int iNumElements) { \n" - + " int iGID = get_global_id(0); \n" - + " if (iGID >= iNumElements) { \n" - + " return; \n" - + " } \n" - + " c[iGID] = a[iGID] + b[iGID]; \n" - + "} \n"; + "kernel void increment(global int* array, int numElements) { \n" + + " int index = get_global_id(0); \n" + + " if (index >= numElements) { \n" + + " return; \n" + + " } \n" + + " array[index]++; \n" + + "} \n"; + + private final class CLTestTask implements CLTask<CLSimpleQueueContext, Buffer> { + + private final Buffer data; + + public CLTestTask(Buffer buffer) { + this.data = buffer; + } + + public Buffer execute(CLSimpleQueueContext qc) { + + CLCommandQueue queue = qc.getQueue(); + CLContext context = qc.getCLContext(); + CLKernel kernel = qc.getKernel("increment"); + + CLBuffer<Buffer> buffer = null; + try{ + buffer = context.createBuffer(data); + int gws = buffer.getCLCapacity(); + + kernel.putArg(buffer).putArg(gws).rewind(); + + queue.putWriteBuffer(buffer, true); + queue.put1DRangeKernel(kernel, 0, gws, 0); + queue.putReadBuffer(buffer, true); + }finally{ + if(buffer != null) { + buffer.release(); + } + } + + return data; + } + + } @Test - public void commandQueuePoolTest() { + public void commandQueuePoolTest() throws InterruptedException, ExecutionException { CLMultiContext mc = CLMultiContext.create(CLPlatform.listCLPlatforms()); try { CLSimpleContextFactory factory = CLQueueContextFactory.createSimple(programSource); - CLCommandQueuePool pool = CLCommandQueuePool.create(factory, mc); + CLCommandQueuePool<CLSimpleQueueContext> pool = CLCommandQueuePool.create(factory, mc); assertTrue(pool.getSize() > 0); + + final int slice = 64; + final int tasksPerQueue = 10; + final int taskCount = pool.getSize() * tasksPerQueue; + + IntBuffer data = Buffers.newDirectIntBuffer(slice*taskCount); + + List<CLTestTask> tasks = new ArrayList<CLTestTask>(taskCount); + + for (int i = 0; i < taskCount; i++) { + IntBuffer subBuffer = Buffers.slice(data, i*slice, slice); + assertEquals(slice, subBuffer.capacity()); + tasks.add(new CLTestTask(subBuffer)); + } + + out.println("invoking "+tasks.size()+" tasks on "+pool.getSize()+" queues"); + + pool.invokeAll(tasks); + checkBuffer(1, data); + + + for (CLTestTask task : tasks) { + pool.submit(task).get(); + } + checkBuffer(2, data); + + + List<Future<Buffer>> futures = new ArrayList<Future<Buffer>>(taskCount); + for (CLTestTask task : tasks) { + futures.add(pool.submit(task)); + } + for (Future<Buffer> future : futures) { + future.get(); + } + checkBuffer(3, data); - pool.switchContext(factory); +// pool.switchContext(factory); pool.release(); }finally{ @@ -80,4 +160,11 @@ public class CLMultiContextTest { } } + private void checkBuffer(int expected, IntBuffer data) { + while(data.hasRemaining()) { + assertEquals(expected, data.get()); + } + data.rewind(); + } + } |