From 6bd00879eec56c2753d84708f551557a2684904b Mon Sep 17 00:00:00 2001 From: Michael Bien Date: Mon, 11 Jul 2011 19:36:33 +0200 Subject: redesigned CLCommandQueuePool. --- .../opencl/util/concurrent/CLCommandQueuePool.java | 185 ++++++++++----------- .../opencl/util/concurrent/CLQueueContext.java | 14 +- .../util/concurrent/CLQueueContextFactory.java | 51 ------ src/com/jogamp/opencl/util/concurrent/CLTask.java | 23 ++- .../util/concurrent/CLTaskCompletionService.java | 8 +- .../opencl/util/concurrent/CLMultiContextTest.java | 40 +++-- 6 files changed, 147 insertions(+), 174 deletions(-) delete mode 100644 src/com/jogamp/opencl/util/concurrent/CLQueueContextFactory.java diff --git a/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java b/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java index eac3dc13..12bfba82 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java +++ b/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java @@ -9,7 +9,10 @@ import com.jogamp.opencl.CLResource; import com.jogamp.opencl.util.CLMultiContext; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; @@ -30,63 +33,46 @@ import java.util.concurrent.TimeoutException; * instead of {@link Callable}s and provides a per-queue context for resource sharing across all tasks of one queue. * @author Michael Bien */ -public class CLCommandQueuePool implements CLResource { +public class CLCommandQueuePool implements CLResource { - private List contexts; private ThreadPoolExecutor excecutor; private FinishAction finishAction = FinishAction.DO_NOTHING; private boolean released; + private final List queues; - private CLCommandQueuePool(CLQueueContextFactory factory, Collection queues) { - this.contexts = initContexts(queues, factory); + private CLCommandQueuePool(Collection queues) { + this.queues = new ArrayList(queues); initExecutor(); } - private List initContexts(Collection queues, CLQueueContextFactory factory) { - List newContexts = new ArrayList(queues.size()); - - int index = 0; - for (CLCommandQueue queue : queues) { - - CLQueueContext old = null; - if(this.contexts != null && !this.contexts.isEmpty()) { - old = this.contexts.get(index++); - old.release(); - } - - newContexts.add(factory.setup(queue, old)); - } - return newContexts; - } - private void initExecutor() { BlockingQueue queue = new LinkedBlockingDeque(); - QueueThreadFactory factory = new QueueThreadFactory(contexts); - int size = contexts.size(); + QueueThreadFactory factory = new QueueThreadFactory(queues); + int size = queues.size(); this.excecutor = new CLThreadPoolExecutor(size, size, 0L, TimeUnit.MILLISECONDS, queue, factory); } - public static CLCommandQueuePool create(CLQueueContextFactory factory, CLMultiContext mc, CLCommandQueue.Mode... modes) { - return create(factory, mc.getDevices(), modes); + public static CLCommandQueuePool create(CLMultiContext mc, CLCommandQueue.Mode... modes) { + return create(mc.getDevices(), modes); } - public static CLCommandQueuePool create(CLQueueContextFactory factory, Collection devices, CLCommandQueue.Mode... modes) { + public static CLCommandQueuePool create(Collection devices, CLCommandQueue.Mode... modes) { List queues = new ArrayList(devices.size()); for (CLDevice device : devices) { queues.add(device.createCommandQueue(modes)); } - return create(factory, queues); + return create(queues); } - public static CLCommandQueuePool create(CLQueueContextFactory factory, Collection queues) { - return new CLCommandQueuePool(factory, queues); + public static CLCommandQueuePool create(Collection queues) { + return new CLCommandQueuePool(queues); } /** * Submits this task to the pool for execution returning its {@link Future}. * @see ExecutorService#submit(java.util.concurrent.Callable) */ - public Future submit(CLTask task) { + public Future submit(CLTask task) { return excecutor.submit(wrapTask(task)); } @@ -94,9 +80,9 @@ public class CLCommandQueuePool implements CLResource * Submits all tasks to the pool for execution and returns their {@link Future}. * Calls {@link #submit(com.jogamp.opencl.util.concurrent.CLTask)} for every task. */ - public List> submitAll(Collection> tasks) { + public List> submitAll(Collection> tasks) { List> futures = new ArrayList>(tasks.size()); - for (CLTask task : tasks) { + for (CLTask task : tasks) { futures.add(submit(task)); } return futures; @@ -106,8 +92,8 @@ public class CLCommandQueuePool implements CLResource * Submits all tasks to the pool for immediate execution (blocking) and returns their {@link Future} holding the result. * @see ExecutorService#invokeAll(java.util.Collection) */ - public List> invokeAll(Collection> tasks) throws InterruptedException { - List> wrapper = wrapTasks(tasks); + public List> invokeAll(Collection> tasks) throws InterruptedException { + List> wrapper = wrapTasks(tasks); return excecutor.invokeAll(wrapper); } @@ -115,8 +101,8 @@ public class CLCommandQueuePool implements CLResource * Submits all tasks to the pool for immediate execution (blocking) and returns their {@link Future} holding the result. * @see ExecutorService#invokeAll(java.util.Collection, long, java.util.concurrent.TimeUnit) */ - public List> invokeAll(Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { - List> wrapper = wrapTasks(tasks); + public List> invokeAll(Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { + List> wrapper = wrapTasks(tasks); return excecutor.invokeAll(wrapper, timeout, unit); } @@ -125,13 +111,13 @@ public class CLCommandQueuePool implements CLResource * All other unfinished but started tasks are cancelled. * @see ExecutorService#invokeAny(java.util.Collection) */ - public R invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { - List> wrapper = wrapTasks(tasks); + public R invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + List> wrapper = wrapTasks(tasks); return excecutor.invokeAny(wrapper); } - /*public*/ CLTask takeCLTask() throws InterruptedException { - return ((CLFutureTask)excecutor.getQueue().take()).getCLTask(); + /*public*/ CLTask takeCLTask() throws InterruptedException { + return ((CLFutureTask)excecutor.getQueue().take()).getCLTask(); } /** @@ -139,47 +125,32 @@ public class CLCommandQueuePool implements CLResource * All other unfinished but started tasks are cancelled. * @see ExecutorService#invokeAny(java.util.Collection, long, java.util.concurrent.TimeUnit) */ - public R invokeAny(Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - List> wrapper = wrapTasks(tasks); + public R invokeAny(Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + List> wrapper = wrapTasks(tasks); return excecutor.invokeAny(wrapper, timeout, unit); } - TaskWrapper wrapTask(CLTask task) { + TaskWrapper wrapTask(CLTask task) { return new TaskWrapper(task, finishAction); } - private List> wrapTasks(Collection> tasks) { - List> wrapper = new ArrayList>(tasks.size()); - for (CLTask task : tasks) { + private List> wrapTasks(Collection> tasks) { + List> wrapper = new ArrayList>(tasks.size()); + for (CLTask task : tasks) { if(task == null) { throw new NullPointerException("at least one task was null"); } - wrapper.add(new TaskWrapper(task, finishAction)); + wrapper.add(new TaskWrapper((CLTask)task, finishAction)); } return wrapper; } - - /** - * Switches the context of all queues - this operation can be expensive. - * Blocks until all tasks finish and sets up a new context for all queues. - * @return this - */ - public CLCommandQueuePool switchContext(CLQueueContextFactory factory) { - - excecutor.shutdown(); - finishQueues(); // just to be sure - - contexts = initContexts(getQueues(), factory); - initExecutor(); - return this; - } /** * Calls {@link CLCommandQueue#flush()} on all queues. */ public void flushQueues() { - for (CLQueueContext context : contexts) { - context.queue.flush(); + for (CLCommandQueue queue : queues) { + queue.flush(); } } @@ -187,8 +158,8 @@ public class CLCommandQueuePool implements CLResource * Calls {@link CLCommandQueue#finish()} on all queues. */ public void finishQueues() { - for (CLQueueContext context : contexts) { - context.queue.finish(); + for (CLCommandQueue queue : queues) { + queue.finish(); } } @@ -202,16 +173,11 @@ public class CLCommandQueuePool implements CLResource throw new RuntimeException(getClass().getSimpleName()+" already released"); } released = true; - excecutor.shutdownNow(); + excecutor.shutdownNow(); // threads will cleanup CL resources on exit try { excecutor.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS); } catch (InterruptedException ex) { throw new RuntimeException(ex); - }finally{ - for (CLQueueContext context : contexts) { - context.queue.finish().release(); - context.release(); - } } } @@ -223,18 +189,14 @@ public class CLCommandQueuePool implements CLResource * Returns the command queues used in this pool. */ public List getQueues() { - List queues = new ArrayList(contexts.size()); - for (CLQueueContext context : contexts) { - queues.add(context.queue); - } - return queues; + return Collections.unmodifiableList(queues); } /** * Returns the size of this pool (number of command queues). */ public int getPoolSize() { - return contexts.size(); + return queues.size(); } /** @@ -284,16 +246,16 @@ public class CLCommandQueuePool implements CLResource @Override public String toString() { - return getClass().getSimpleName()+" [queues: "+contexts.size()+" on finish: "+finishAction+"]"; + return getClass().getSimpleName()+" [queues: "+getPoolSize()+" on finish: "+finishAction+"]"; } private static class QueueThreadFactory implements ThreadFactory { - private final List context; + private final List queues; private int index; - private QueueThreadFactory(List queues) { - this.context = queues; + private QueueThreadFactory(List queues) { + this.queues = queues; this.index = 0; } @@ -303,7 +265,7 @@ public class CLCommandQueuePool implements CLResource SecurityManager sm = System.getSecurityManager(); ThreadGroup group = (sm != null) ? sm.getThreadGroup() : Thread.currentThread().getThreadGroup(); - CLQueueContext queue = context.get(index); + CLCommandQueue queue = queues.get(index); QueueThread thread = new QueueThread(group, runnable, queue, index++); thread.setDaemon(true); @@ -313,27 +275,52 @@ public class CLCommandQueuePool implements CLResource } private static class QueueThread extends Thread { - private final CLQueueContext context; - public QueueThread(ThreadGroup group, Runnable runnable, CLQueueContext context, int index) { - super(group, runnable, "queue-worker-thread-"+index+"["+context+"]"); - this.context = context; + + private final CLCommandQueue queue; + private final Map contextMap; + + public QueueThread(ThreadGroup group, Runnable runnable, CLCommandQueue queue, int index) { + super(group, runnable, "queue-worker-thread-"+index+"["+queue+"]"); + this.queue = queue; + this.contextMap = new HashMap(); + } + + @Override + public void run() { + super.run(); + //release threadlocal contexts + queue.finish(); + for (CLQueueContext context : contextMap.values()) { + context.release(); + } } + } - private static class TaskWrapper implements Callable { + private static class TaskWrapper implements Callable { - private final CLTask task; + private final CLTask task; private final FinishAction mode; - private TaskWrapper(CLTask task, FinishAction mode) { + private TaskWrapper(CLTask task, FinishAction mode) { this.task = task; this.mode = mode; } @Override public R call() throws Exception { - CLQueueContext context = ((QueueThread)Thread.currentThread()).context; - R result = task.execute((C)context); + + QueueThread thread = (QueueThread)Thread.currentThread(); + + final Object key = task.getContextKey(); + + CLQueueContext context = thread.contextMap.get(key); + if(context == null) { + context = task.createQueueContext(thread.queue); + thread.contextMap.put(key, context); + } + + R result = task.execute(context); if(mode.equals(FinishAction.FLUSH)) { context.queue.flush(); }else if(mode.equals(FinishAction.FINISH)) { @@ -344,16 +331,16 @@ public class CLCommandQueuePool implements CLResource } - private static class CLFutureTask extends FutureTask { + private static class CLFutureTask extends FutureTask { - private final TaskWrapper wrapper; + private final TaskWrapper wrapper; - public CLFutureTask(TaskWrapper wrapper) { + public CLFutureTask(TaskWrapper wrapper) { super(wrapper); this.wrapper = wrapper; } - public CLTask getCLTask() { + public CLTask getCLTask() { return wrapper.task; } @@ -366,9 +353,9 @@ public class CLCommandQueuePool implements CLResource } @Override - protected RunnableFuture newTaskFor(Callable callable) { - TaskWrapper wrapper = (TaskWrapper)callable; - return new CLFutureTask(wrapper); + protected RunnableFuture newTaskFor(Callable callable) { + TaskWrapper wrapper = (TaskWrapper)callable; + return new CLFutureTask(wrapper); } } diff --git a/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java b/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java index 9f92b9a3..93c2d226 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java +++ b/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java @@ -40,17 +40,21 @@ public abstract class CLQueueContext implements CLResource { * A simple queue context holding a precompiled program and its kernels. * @author Michael Bien */ - public static class CLSimpleQueueContext extends CLQueueContext { + public static class CLSingleProgramQueueContext extends CLQueueContext { public final CLProgram program; public final Map kernels; - public CLSimpleQueueContext(CLCommandQueue queue, CLProgram program) { + public CLSingleProgramQueueContext(CLCommandQueue queue, CLProgram program) { super(queue); this.program = program; this.kernels = program.createCLKernels(); } + public CLSingleProgramQueueContext(CLCommandQueue queue, String... source) { + this(queue, queue.getContext().createProgram(source).build()); + } + public Map getKernels() { return kernels; } @@ -65,7 +69,11 @@ public abstract class CLQueueContext implements CLResource { @Override public void release() { - program.release(); + synchronized(program) { + if(!program.isReleased()) { + program.release(); + } + } } @Override diff --git a/src/com/jogamp/opencl/util/concurrent/CLQueueContextFactory.java b/src/com/jogamp/opencl/util/concurrent/CLQueueContextFactory.java deleted file mode 100644 index 58f389bf..00000000 --- a/src/com/jogamp/opencl/util/concurrent/CLQueueContextFactory.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Created onSaturday, May 07 2011 00:40 - */ -package com.jogamp.opencl.util.concurrent; - -import com.jogamp.opencl.CLCommandQueue; -import com.jogamp.opencl.CLProgram; -import com.jogamp.opencl.util.concurrent.CLQueueContext.CLSimpleQueueContext; - -/** - * Creates {@link CLQueueContext}s. - * @author Michael Bien - */ -public abstract class CLQueueContextFactory { - - /** - * Creates a new queue context for the given queue. - * @param old the old context or null. - */ - public abstract C setup(CLCommandQueue queue, CLQueueContext old); - - - /** - * Creates a simple context factory producing single program contexts. - * @param source sourcecode of a OpenCL program. - */ - public static CLSimpleContextFactory createSimple(String source) { - return new CLSimpleContextFactory(source); - } - - /** - * Creates {@link CLSimpleQueueContext}s containing a precompiled program. - * @author Michael Bien - */ - public static class CLSimpleContextFactory extends CLQueueContextFactory { - - private final String source; - - public CLSimpleContextFactory(String source) { - this.source = source; - } - - @Override - public CLSimpleQueueContext setup(CLCommandQueue queue, CLQueueContext old) { - CLProgram program = queue.getContext().createProgram(source).build(queue.getDevice()); - return new CLSimpleQueueContext(queue, program); - } - - } - -} diff --git a/src/com/jogamp/opencl/util/concurrent/CLTask.java b/src/com/jogamp/opencl/util/concurrent/CLTask.java index 0cfd24a5..04d433c8 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLTask.java +++ b/src/com/jogamp/opencl/util/concurrent/CLTask.java @@ -3,16 +3,35 @@ */ package com.jogamp.opencl.util.concurrent; +import com.jogamp.opencl.CLCommandQueue; + /** * A task executed on a command queue. * @author Michael Bien */ -public interface CLTask { +public abstract class CLTask { + + + /** + * Creates a CLQueueContext for this task. A context may contain static resources + * like OpenCL program binaries or pre allocated buffers. A context can be used by an group + * of tasks identified by a common context key ({@link #getContextKey()}). This method + * won't be called if a context was already created by an previously executed task with the + * same context key as this task. + */ + public abstract C createQueueContext(CLCommandQueue queue); + + /** + * Returns the context key for this task. Default implementation returns {@link #getClass()}. + */ + public Object getContextKey() { + return getClass(); + } /** * Runs the task on a queue and returns a result. */ - R execute(C context); + public abstract R execute(C context); } diff --git a/src/com/jogamp/opencl/util/concurrent/CLTaskCompletionService.java b/src/com/jogamp/opencl/util/concurrent/CLTaskCompletionService.java index d1d26824..630ee1c7 100644 --- a/src/com/jogamp/opencl/util/concurrent/CLTaskCompletionService.java +++ b/src/com/jogamp/opencl/util/concurrent/CLTaskCompletionService.java @@ -15,7 +15,7 @@ import java.util.concurrent.TimeUnit; * @see CompletionService * @author Michael Bien */ -public class CLTaskCompletionService { +public class CLTaskCompletionService { private final ExecutorCompletionService service; private final CLCommandQueuePool pool; @@ -25,7 +25,7 @@ public class CLTaskCompletionService { * task execution and a LinkedBlockingQueue with the capacity of {@link Integer#MAX_VALUE} * as a completion queue. */ - public CLTaskCompletionService(CLCommandQueuePool pool) { + public CLTaskCompletionService(CLCommandQueuePool pool) { this.service = new ExecutorCompletionService(pool.getExcecutor()); this.pool = pool; } @@ -34,7 +34,7 @@ public class CLTaskCompletionService { * Creates an CLTaskCompletionService using the supplied pool for base * task execution the supplied queue as its completion queue. */ - public CLTaskCompletionService(CLCommandQueuePool pool, BlockingQueue queue) { + public CLTaskCompletionService(CLCommandQueuePool pool, BlockingQueue queue) { this.service = new ExecutorCompletionService(pool.getExcecutor(), queue); this.pool = pool; } @@ -44,7 +44,7 @@ public class CLTaskCompletionService { * results of the task. Upon completion, this task may be taken or polled. * @see CompletionService#submit(java.util.concurrent.Callable) */ - public Future submit(CLTask task) { + public Future submit(CLTask task) { return service.submit(pool.wrapTask(task)); } diff --git a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java index 1b2575f5..48425d5e 100644 --- a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java +++ b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java @@ -10,8 +10,7 @@ import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLDevice; import com.jogamp.opencl.CLKernel; import com.jogamp.opencl.CLPlatform; -import com.jogamp.opencl.util.concurrent.CLQueueContext.CLSimpleQueueContext; -import com.jogamp.opencl.util.concurrent.CLQueueContextFactory.CLSimpleContextFactory; +import com.jogamp.opencl.util.concurrent.CLQueueContext.CLSingleProgramQueueContext; import java.nio.IntBuffer; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -69,16 +68,23 @@ public class CLMultiContextTest { + " array[index]++; \n" + "} \n"; - private final class CLTestTask implements CLTask { + private final class CLTestTask extends CLTask { private final IntBuffer data; + private final String source; - public CLTestTask(IntBuffer buffer) { + public CLTestTask(String source, IntBuffer buffer) { this.data = buffer; + this.source = source; } @Override - public IntBuffer execute(CLSimpleQueueContext qc) { + public CLSingleProgramQueueContext createQueueContext(CLCommandQueue queue) { + return new CLSingleProgramQueueContext(queue, source); + } + + @Override + public IntBuffer execute(CLSingleProgramQueueContext qc) { CLCommandQueue queue = qc.getQueue(); CLContext context = qc.getCLContext(); @@ -110,14 +116,19 @@ public class CLMultiContextTest { return data; } + @Override + public Object getContextKey() { + return source.hashCode(); + } + } - private List createTasks(IntBuffer data, int taskCount, int slice) { + private List createTasks(String source, IntBuffer data, int taskCount, int slice) { List tasks = new ArrayList(taskCount); for (int i = 0; i < taskCount; i++) { IntBuffer subBuffer = Buffers.slice(data, i*slice, slice); assertEquals(slice, subBuffer.capacity()); - tasks.add(new CLTestTask(subBuffer)); + tasks.add(new CLTestTask(source, subBuffer)); } return tasks; } @@ -129,8 +140,7 @@ public class CLMultiContextTest { try { - CLSimpleContextFactory factory = CLQueueContextFactory.createSimple(programSource); - CLCommandQueuePool pool = CLCommandQueuePool.create(factory, mc); + CLCommandQueuePool pool = CLCommandQueuePool.create(mc); assertTrue(pool.getPoolSize() > 0); @@ -139,7 +149,7 @@ public class CLMultiContextTest { final int taskCount = pool.getPoolSize() * tasksPerQueue; IntBuffer data = Buffers.newDirectIntBuffer(slice*taskCount); - List tasks = createTasks(data, taskCount, slice); + List tasks = createTasks(programSource, data, taskCount, slice); out.println("invoking "+tasks.size()+" tasks on "+pool.getPoolSize()+" queues"); @@ -166,8 +176,8 @@ public class CLMultiContextTest { checkBuffer(3, data); // switching contexts using different program - factory = CLQueueContextFactory.createSimple(programSource.replaceAll("\\+\\+", "--")); - pool.switchContext(factory); + final String decrementProgramSource = programSource.replaceAll("\\+\\+", "--"); + tasks = createTasks(decrementProgramSource, data, taskCount, slice); List> results2 = pool.invokeAll(tasks); assertNotNull(results2); checkBuffer(2, data); @@ -176,7 +186,7 @@ public class CLMultiContextTest { // we wait only for completion of a subset of tasks. // submit any data = Buffers.newDirectIntBuffer(slice*taskCount); - tasks = createTasks(data, taskCount, slice); + tasks = createTasks(decrementProgramSource, data, taskCount, slice); IntBuffer ret1 = pool.invokeAny(tasks); assertNotNull(ret1); @@ -185,9 +195,9 @@ public class CLMultiContextTest { // completionservice take/any test data = Buffers.newDirectIntBuffer(slice*taskCount); - tasks = createTasks(data, taskCount, slice); + tasks = createTasks(decrementProgramSource, data, taskCount, slice); - CLTaskCompletionService service = new CLTaskCompletionService(pool); + CLTaskCompletionService service = new CLTaskCompletionService(pool); for (CLTestTask task : tasks) { service.submit(task); } -- cgit v1.2.3