summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Bien <[email protected]>2011-05-09 03:00:55 +0200
committerMichael Bien <[email protected]>2011-05-09 03:00:55 +0200
commitc59bc50229181ab9cb0e5012d7bb5caf2faa781f (patch)
tree62230d2d14861c14814d6bfcc98b7ee2e7c170fc
parentdedded707fc70fda3e40cf963d208202f8d6c42b (diff)
concurrent utils bugfixes and improvements.
- more utility methods - generics fixes - basic junit test for CLCommandQueuePool - javadoc and argument validation
-rw-r--r--src/com/jogamp/opencl/util/CLMultiContext.java11
-rw-r--r--src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java69
-rw-r--r--src/com/jogamp/opencl/util/concurrent/CLQueueContext.java9
-rw-r--r--src/com/jogamp/opencl/util/concurrent/CLTask.java4
-rw-r--r--test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java107
5 files changed, 168 insertions, 32 deletions
diff --git a/src/com/jogamp/opencl/util/CLMultiContext.java b/src/com/jogamp/opencl/util/CLMultiContext.java
index f588fcef..f74c0a35 100644
--- a/src/com/jogamp/opencl/util/CLMultiContext.java
+++ b/src/com/jogamp/opencl/util/CLMultiContext.java
@@ -41,6 +41,13 @@ public class CLMultiContext implements CLResource {
* Creates a multi context with all devices of the specified platforms and types.
*/
public static CLMultiContext create(CLPlatform[] platforms, CLDevice.Type... types) {
+
+ if(platforms == null) {
+ throw new NullPointerException("platform list was null");
+ }else if(platforms.length == 0) {
+ throw new IllegalArgumentException("platform list was empty");
+ }
+
List<CLDevice> devices = new ArrayList<CLDevice>();
for (CLPlatform platform : platforms) {
devices.addAll(asList(platform.listCLDevices(types)));
@@ -54,6 +61,10 @@ public class CLMultiContext implements CLResource {
*/
public static CLMultiContext create(Collection<CLDevice> devices) {
+ if(devices.isEmpty()) {
+ throw new IllegalArgumentException("device list was empty");
+ }
+
Map<CLPlatform, List<CLDevice>> platformDevicesMap = filterPlatformConflicts(devices);
// create contexts
diff --git a/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java b/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java
index b80f09e6..a6bbe4d0 100644
--- a/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java
+++ b/src/com/jogamp/opencl/util/concurrent/CLCommandQueuePool.java
@@ -15,6 +15,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
/**
* A multithreaded fixed size pool of OpenCL command queues.
@@ -23,7 +24,7 @@ import java.util.concurrent.ThreadFactory;
* instead of {@link Callable}s.
* @author Michael Bien
*/
-public class CLCommandQueuePool implements CLResource {
+public class CLCommandQueuePool<C extends CLQueueContext> implements CLResource {
private List<CLQueueContext> contexts;
private ExecutorService excecutor;
@@ -55,11 +56,11 @@ public class CLCommandQueuePool implements CLResource {
this.excecutor = Executors.newFixedThreadPool(contexts.size(), new QueueThreadFactory(contexts));
}
- public static CLCommandQueuePool create(CLQueueContextFactory factory, CLMultiContext mc, CLCommandQueue.Mode... modes) {
+ public static <C extends CLQueueContext> CLCommandQueuePool<C> create(CLQueueContextFactory<C> factory, CLMultiContext mc, CLCommandQueue.Mode... modes) {
return create(factory, mc.getDevices(), modes);
}
- public static CLCommandQueuePool create(CLQueueContextFactory factory, Collection<CLDevice> devices, CLCommandQueue.Mode... modes) {
+ public static <C extends CLQueueContext> CLCommandQueuePool<C> create(CLQueueContextFactory<C> factory, Collection<CLDevice> devices, CLCommandQueue.Mode... modes) {
List<CLCommandQueue> queues = new ArrayList<CLCommandQueue>(devices.size());
for (CLDevice device : devices) {
queues.add(device.createCommandQueue(modes));
@@ -67,21 +68,43 @@ public class CLCommandQueuePool implements CLResource {
return create(factory, queues);
}
- public static CLCommandQueuePool create(CLQueueContextFactory factory, Collection<CLCommandQueue> queues) {
+ public static <C extends CLQueueContext> CLCommandQueuePool create(CLQueueContextFactory<C> factory, Collection<CLCommandQueue> queues) {
return new CLCommandQueuePool(factory, queues);
}
- public <R> Future<R> submit(CLTask<R> task) {
+ /**
+ * @see ExecutorService#submit(java.util.concurrent.Callable)
+ */
+ public <R> Future<R> submit(CLTask<? extends C, R> task) {
return excecutor.submit(new TaskWrapper(task, finishAction));
}
- public <R> List<Future<R>> invokeAll(Collection<CLTask<R>> tasks) throws InterruptedException {
- List<TaskWrapper<R>> wrapper = new ArrayList<TaskWrapper<R>>(tasks.size());
- for (CLTask<R> task : tasks) {
- wrapper.add(new TaskWrapper<R>(task, finishAction));
- }
+ /**
+ * @see ExecutorService#invokeAll(java.util.Collection)
+ */
+ public <R> List<Future<R>> invokeAll(Collection<? extends CLTask<? super C, R>> tasks) throws InterruptedException {
+ List<TaskWrapper<C, R>> wrapper = wrapTasks(tasks);
return excecutor.invokeAll(wrapper);
}
+
+ /**
+ * @see ExecutorService#invokeAll(java.util.Collection, long, java.util.concurrent.TimeUnit)
+ */
+ public <R> List<Future<R>> invokeAll(Collection<? extends CLTask<? super C, R>> tasks, long timeout, TimeUnit unit) throws InterruptedException {
+ List<TaskWrapper<C, R>> wrapper = wrapTasks(tasks);
+ return excecutor.invokeAll(wrapper, timeout, unit);
+ }
+
+ private <R> List<TaskWrapper<C, R>> wrapTasks(Collection<? extends CLTask<? super C, R>> tasks) {
+ List<TaskWrapper<C, R>> wrapper = new ArrayList<TaskWrapper<C, R>>(tasks.size());
+ for (CLTask<? super C, R> task : tasks) {
+ if(task == null) {
+ throw new NullPointerException("at least one task was null");
+ }
+ wrapper.add(new TaskWrapper<C, R>(task, finishAction));
+ }
+ return wrapper;
+ }
/**
* Switches the context of all queues - this operation can be expensive.
@@ -171,35 +194,41 @@ public class CLCommandQueuePool implements CLResource {
this.index = 0;
}
- public synchronized Thread newThread(Runnable r) {
+ public synchronized Thread newThread(Runnable runnable) {
+
+ SecurityManager sm = System.getSecurityManager();
+ ThreadGroup group = (sm != null)? sm.getThreadGroup() : Thread.currentThread().getThreadGroup();
+
CLQueueContext queue = context.get(index);
- return new QueueThread(queue, index++);
+ QueueThread thread = new QueueThread(group, runnable, queue, index++);
+ thread.setDaemon(true);
+
+ return thread;
}
}
private static class QueueThread extends Thread {
private final CLQueueContext context;
- public QueueThread(CLQueueContext context, int index) {
- super("queue-worker-thread-"+index+"["+context+"]");
+ public QueueThread(ThreadGroup group, Runnable runnable, CLQueueContext context, int index) {
+ super(group, runnable, "queue-worker-thread-"+index+"["+context+"]");
this.context = context;
- this.setDaemon(true);
}
}
- private static class TaskWrapper<T> implements Callable<T> {
+ private static class TaskWrapper<C extends CLQueueContext, R> implements Callable<R> {
- private final CLTask<T> task;
+ private final CLTask<? super C, R> task;
private final FinishAction mode;
- public TaskWrapper(CLTask<T> task, FinishAction mode) {
+ public TaskWrapper(CLTask<? super C, R> task, FinishAction mode) {
this.task = task;
this.mode = mode;
}
- public T call() throws Exception {
+ public R call() throws Exception {
CLQueueContext context = ((QueueThread)Thread.currentThread()).context;
- T result = task.run(context);
+ R result = task.execute((C)context);
if(mode.equals(FinishAction.FLUSH)) {
context.queue.flush();
}else if(mode.equals(FinishAction.FINISH)) {
diff --git a/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java b/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java
index 3956f93d..11b86889 100644
--- a/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java
+++ b/src/com/jogamp/opencl/util/concurrent/CLQueueContext.java
@@ -4,6 +4,7 @@
package com.jogamp.opencl.util.concurrent;
import com.jogamp.opencl.CLCommandQueue;
+import com.jogamp.opencl.CLContext;
import com.jogamp.opencl.CLKernel;
import com.jogamp.opencl.CLProgram;
import com.jogamp.opencl.CLResource;
@@ -24,6 +25,10 @@ public abstract class CLQueueContext implements CLResource {
return queue;
}
+ public CLContext getCLContext() {
+ return queue.getContext();
+ }
+
public static class CLSimpleQueueContext extends CLQueueContext {
public final CLProgram program;
@@ -39,6 +44,10 @@ public abstract class CLQueueContext implements CLResource {
return kernels;
}
+ public CLKernel getKernel(String name) {
+ return kernels.get(name);
+ }
+
public CLProgram getProgram() {
return program;
}
diff --git a/src/com/jogamp/opencl/util/concurrent/CLTask.java b/src/com/jogamp/opencl/util/concurrent/CLTask.java
index ff0f7614..0cfd24a5 100644
--- a/src/com/jogamp/opencl/util/concurrent/CLTask.java
+++ b/src/com/jogamp/opencl/util/concurrent/CLTask.java
@@ -8,11 +8,11 @@ package com.jogamp.opencl.util.concurrent;
* A task executed on a command queue.
* @author Michael Bien
*/
-public interface CLTask<R> {
+public interface CLTask<C extends CLQueueContext, R> {
/**
* Runs the task on a queue and returns a result.
*/
- R run(CLQueueContext context);
+ R execute(C context);
}
diff --git a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java
index 7a1ed7aa..e5bcb1c5 100644
--- a/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java
+++ b/test/com/jogamp/opencl/util/concurrent/CLMultiContextTest.java
@@ -3,14 +3,24 @@
*/
package com.jogamp.opencl.util.concurrent;
+import com.jogamp.common.nio.Buffers;
+import com.jogamp.opencl.CLBuffer;
+import com.jogamp.opencl.CLCommandQueue;
import com.jogamp.opencl.CLContext;
import com.jogamp.opencl.CLDevice;
+import com.jogamp.opencl.CLKernel;
import com.jogamp.opencl.CLPlatform;
+import com.jogamp.opencl.util.concurrent.CLQueueContext.CLSimpleQueueContext;
import com.jogamp.opencl.util.concurrent.CLQueueContextFactory.CLSimpleContextFactory;
+import java.nio.IntBuffer;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import org.junit.Rule;
import org.junit.rules.MethodRule;
import org.junit.rules.Timeout;
import com.jogamp.opencl.util.CLMultiContext;
+import java.nio.Buffer;
+import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
@@ -52,27 +62,97 @@ public class CLMultiContextTest {
}
private final static String programSource =
- "kernel void vectorAdd(global const int* a, global const int* b, global int* c, int iNumElements) { \n"
- + " int iGID = get_global_id(0); \n"
- + " if (iGID >= iNumElements) { \n"
- + " return; \n"
- + " } \n"
- + " c[iGID] = a[iGID] + b[iGID]; \n"
- + "} \n";
+ "kernel void increment(global int* array, int numElements) { \n"
+ + " int index = get_global_id(0); \n"
+ + " if (index >= numElements) { \n"
+ + " return; \n"
+ + " } \n"
+ + " array[index]++; \n"
+ + "} \n";
+
+ private final class CLTestTask implements CLTask<CLSimpleQueueContext, Buffer> {
+
+ private final Buffer data;
+
+ public CLTestTask(Buffer buffer) {
+ this.data = buffer;
+ }
+
+ public Buffer execute(CLSimpleQueueContext qc) {
+
+ CLCommandQueue queue = qc.getQueue();
+ CLContext context = qc.getCLContext();
+ CLKernel kernel = qc.getKernel("increment");
+
+ CLBuffer<Buffer> buffer = null;
+ try{
+ buffer = context.createBuffer(data);
+ int gws = buffer.getCLCapacity();
+
+ kernel.putArg(buffer).putArg(gws).rewind();
+
+ queue.putWriteBuffer(buffer, true);
+ queue.put1DRangeKernel(kernel, 0, gws, 0);
+ queue.putReadBuffer(buffer, true);
+ }finally{
+ if(buffer != null) {
+ buffer.release();
+ }
+ }
+
+ return data;
+ }
+
+ }
@Test
- public void commandQueuePoolTest() {
+ public void commandQueuePoolTest() throws InterruptedException, ExecutionException {
CLMultiContext mc = CLMultiContext.create(CLPlatform.listCLPlatforms());
try {
CLSimpleContextFactory factory = CLQueueContextFactory.createSimple(programSource);
- CLCommandQueuePool pool = CLCommandQueuePool.create(factory, mc);
+ CLCommandQueuePool<CLSimpleQueueContext> pool = CLCommandQueuePool.create(factory, mc);
assertTrue(pool.getSize() > 0);
+
+ final int slice = 64;
+ final int tasksPerQueue = 10;
+ final int taskCount = pool.getSize() * tasksPerQueue;
+
+ IntBuffer data = Buffers.newDirectIntBuffer(slice*taskCount);
+
+ List<CLTestTask> tasks = new ArrayList<CLTestTask>(taskCount);
+
+ for (int i = 0; i < taskCount; i++) {
+ IntBuffer subBuffer = Buffers.slice(data, i*slice, slice);
+ assertEquals(slice, subBuffer.capacity());
+ tasks.add(new CLTestTask(subBuffer));
+ }
+
+ out.println("invoking "+tasks.size()+" tasks on "+pool.getSize()+" queues");
+
+ pool.invokeAll(tasks);
+ checkBuffer(1, data);
+
+
+ for (CLTestTask task : tasks) {
+ pool.submit(task).get();
+ }
+ checkBuffer(2, data);
+
+
+ List<Future<Buffer>> futures = new ArrayList<Future<Buffer>>(taskCount);
+ for (CLTestTask task : tasks) {
+ futures.add(pool.submit(task));
+ }
+ for (Future<Buffer> future : futures) {
+ future.get();
+ }
+ checkBuffer(3, data);
- pool.switchContext(factory);
+// pool.switchContext(factory);
pool.release();
}finally{
@@ -80,4 +160,11 @@ public class CLMultiContextTest {
}
}
+ private void checkBuffer(int expected, IntBuffer data) {
+ while(data.hasRemaining()) {
+ assertEquals(expected, data.get());
+ }
+ data.rewind();
+ }
+
}