diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/com/mbien/opencl/CLBuffer.java | 70 | ||||
-rw-r--r-- | src/com/mbien/opencl/CLContext.java | 61 | ||||
-rw-r--r-- | src/com/mbien/opencl/CLDevice.java | 14 | ||||
-rw-r--r-- | src/com/mbien/opencl/CLKernel.java | 93 | ||||
-rw-r--r-- | src/com/mbien/opencl/CLPlatform.java | 22 | ||||
-rw-r--r-- | src/com/mbien/opencl/CLProgram.java | 77 |
6 files changed, 328 insertions, 9 deletions
diff --git a/src/com/mbien/opencl/CLBuffer.java b/src/com/mbien/opencl/CLBuffer.java new file mode 100644 index 00000000..0f6e34a4 --- /dev/null +++ b/src/com/mbien/opencl/CLBuffer.java @@ -0,0 +1,70 @@ +package com.mbien.opencl; + +import java.nio.ByteBuffer; +import static com.mbien.opencl.CLException.*; + +/** + * + * @author Michael Bien + */ +public class CLBuffer { + + public final ByteBuffer buffer; + public final long bufferID; + + private final CLContext context; + private final CL cl; + + CLBuffer(CLContext context, int flags, ByteBuffer directBuffer) { + + if(!directBuffer.isDirect()) + throw new IllegalArgumentException("buffer is not a direct buffer"); + + this.buffer = directBuffer; + this.context = context; + this.cl = context.cl; + + int[] intArray = new int[1]; + + this.bufferID = cl.clCreateBuffer(context.contextID, flags, directBuffer.capacity(), null, intArray, 0); + + checkForError(intArray[0], "can not create cl buffer"); + + } + + public CLBuffer release() { + cl.clReleaseMemObject(bufferID); + context.bufferReleased(this); + return this; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final CLBuffer other = (CLBuffer) obj; + if (this.buffer != other.buffer && (this.buffer == null || !this.buffer.equals(other.buffer))) { + return false; + } + if (this.context.contextID != other.context.contextID) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int hash = 3; + hash = 29 * hash + (this.buffer != null ? this.buffer.hashCode() : 0); + hash = 29 * hash + (int) (this.context.contextID ^ (this.context.contextID >>> 32)); + return hash; + } + + + + +} diff --git a/src/com/mbien/opencl/CLContext.java b/src/com/mbien/opencl/CLContext.java index 601a2b79..9cb649d7 100644 --- a/src/com/mbien/opencl/CLContext.java +++ b/src/com/mbien/opencl/CLContext.java @@ -22,6 +22,7 @@ public final class CLContext { private CLDevice[] devices; private final List<CLProgram> programs; + private final List<CLBuffer> buffers; static{ System.loadLibrary("gluegen-rt"); @@ -32,6 +33,7 @@ public final class CLContext { private CLContext(long contextID) { this.contextID = contextID; this.programs = new ArrayList<CLProgram>(); + this.buffers = new ArrayList<CLBuffer>(); } /** @@ -70,14 +72,32 @@ public final class CLContext { return program; } + public CLBuffer createBuffer(int flags, ByteBuffer directBuffer) { + CLBuffer buffer = new CLBuffer(this, flags, directBuffer); + buffers.add(buffer); + return buffer; + } + void programReleased(CLProgram program) { programs.remove(program); } + void bufferReleased(CLBuffer buffer) { + buffers.remove(buffer); + } + /** * Releases the context and all resources. */ public CLContext release() { + + //release all resources + while(!programs.isEmpty()) + programs.get(0).release(); + + while(!buffers.isEmpty()) + buffers.get(0).release(); + int ret = cl.clReleaseContext(contextID); checkForError(ret, "error releasing context"); return this; @@ -86,10 +106,17 @@ public final class CLContext { /** * Returns a read only view of all programs associated with this context. */ - public List<CLProgram> getPrograms() { + public List<CLProgram> getCLPrograms() { return Collections.unmodifiableList(programs); } + /** + * Returns a read only view of all buffers associated with this context. + */ + public List<CLBuffer> getCLBuffers() { + return Collections.unmodifiableList(buffers); + } + /** * Gets the device with maximal FLOPS from this context. @@ -157,7 +184,7 @@ public final class CLContext { return devices; } - CLDevice getCLDevices(long dID) { + CLDevice getCLDevice(long dID) { CLDevice[] deviceArray = getCLDevices(); for (int i = 0; i < deviceArray.length; i++) { if(dID == deviceArray[i].deviceID) @@ -198,4 +225,34 @@ public final class CLContext { } + @Override + public String toString() { + return "CLContext [id: " + contextID + + " #devices: " + getCLDevices().length + + "]"; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final CLContext other = (CLContext) obj; + if (this.contextID != other.contextID) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int hash = 7; + hash = 23 * hash + (int) (this.contextID ^ (this.contextID >>> 32)); + return hash; + } + + } diff --git a/src/com/mbien/opencl/CLDevice.java b/src/com/mbien/opencl/CLDevice.java index d9f643ce..dcc9ee97 100644 --- a/src/com/mbien/opencl/CLDevice.java +++ b/src/com/mbien/opencl/CLDevice.java @@ -191,9 +191,17 @@ public final class CLDevice { @Override public boolean equals(Object obj) { - if(obj != null && obj instanceof CLDevice) - return ((CLDevice)obj).deviceID == deviceID; - return false; + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final CLDevice other = (CLDevice) obj; + if (this.deviceID != other.deviceID) { + return false; + } + return true; } @Override diff --git a/src/com/mbien/opencl/CLKernel.java b/src/com/mbien/opencl/CLKernel.java new file mode 100644 index 00000000..be5e03b6 --- /dev/null +++ b/src/com/mbien/opencl/CLKernel.java @@ -0,0 +1,93 @@ +package com.mbien.opencl; + +import com.sun.gluegen.runtime.BufferFactory; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import static com.mbien.opencl.CLException.*; + +/** + * + * @author Michael Bien + */ +public class CLKernel { + + public final long kernelID; + public final String name; + + private final CLProgram program; + private final CL cl; + + CLKernel(CLProgram program, long id) { + this.kernelID = id; + this.program = program; + this.cl = program.context.cl; + + long[] longArray = new long[1]; + + int ret = cl.clGetKernelInfo(kernelID, CL.CL_KERNEL_FUNCTION_NAME, 0, null, longArray, 0); + checkForError(ret, "error while asking for kernel function name"); + + ByteBuffer bb = ByteBuffer.allocate((int)longArray[0]).order(ByteOrder.nativeOrder()); + + ret = cl.clGetKernelInfo(kernelID, CL.CL_KERNEL_FUNCTION_NAME, bb.capacity(), bb, null, 0); + checkForError(ret, "error while asking for kernel function name"); + + this.name = new String(bb.array(), 0, (int)longArray[0]).trim(); + + } + + public CLKernel setArg(int argumentIndex, int argumentSize, CLBuffer value) { + int ret = cl.clSetKernelArg(kernelID, argumentIndex, argumentSize, wrapLong(value.bufferID)); + checkForError(ret, "error on clSetKernelArg"); + return this; + } + + public CLKernel setArg(int argumentIndex, int argumentSize, long value) { + int ret = cl.clSetKernelArg(kernelID, argumentIndex, argumentSize, wrapLong(value)); + checkForError(ret, "error on clSetKernelArg"); + return this; + } + + private final ByteBuffer wrapLong(long value) { + return (ByteBuffer) BufferFactory.newDirectByteBuffer(8).putLong(value).rewind(); + } + + public CLKernel release() { + cl.clReleaseKernel(kernelID); + program.kernelReleased(this); + return this; + } + + @Override + public String toString() { + return "CLKernel [id: " + kernelID + + " name: " + name+"]"; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final CLKernel other = (CLKernel) obj; + if (this.kernelID != other.kernelID) { + return false; + } + if (!this.program.equals(other.program)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int hash = 7; + hash = 43 * hash + (int) (this.kernelID ^ (this.kernelID >>> 32)); + hash = 43 * hash + (this.program != null ? this.program.hashCode() : 0); + return hash; + } + +} diff --git a/src/com/mbien/opencl/CLPlatform.java b/src/com/mbien/opencl/CLPlatform.java index a717564a..d9f8dd25 100644 --- a/src/com/mbien/opencl/CLPlatform.java +++ b/src/com/mbien/opencl/CLPlatform.java @@ -11,7 +11,7 @@ public final class CLPlatform { /** * OpenCL platform id for this platform. */ - public final long platformID; + public final long platformID; private final CL cl; @@ -94,7 +94,27 @@ public final class CLPlatform { +" version:"+getVersion()+"]"; } + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final CLPlatform other = (CLPlatform) obj; + if (this.platformID != other.platformID) { + return false; + } + return true; + } + @Override + public int hashCode() { + int hash = 7; + hash = 71 * hash + (int) (this.platformID ^ (this.platformID >>> 32)); + return hash; + } } diff --git a/src/com/mbien/opencl/CLProgram.java b/src/com/mbien/opencl/CLProgram.java index e93e7246..73fe8cac 100644 --- a/src/com/mbien/opencl/CLProgram.java +++ b/src/com/mbien/opencl/CLProgram.java @@ -2,6 +2,10 @@ package com.mbien.opencl; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; import static com.mbien.opencl.CLException.*; /** @@ -10,9 +14,12 @@ import static com.mbien.opencl.CLException.*; */ public class CLProgram { - private final CLContext context; - private final long programID; + public final CLContext context; + public final long programID; + private final CL cl; + + private final Map<String, CLKernel> kernels; public enum Status { @@ -53,6 +60,8 @@ public class CLProgram { this.cl = context.cl; this.context = context; + this.kernels = new HashMap<String, CLKernel>(); + int[] intArray = new int[1]; // Create the program programID = cl.clCreateProgramWithSource(contextID, 1, new String[] {src}, new long[]{src.length()}, 0, intArray, 0); @@ -92,10 +101,46 @@ public class CLProgram { } /** + * Returns all kernels of this program in a unmodifiable view of a map with the kernel function names as keys. + */ + public Map<String, CLKernel> getCLKernels() { + + if(kernels.isEmpty()) { + + int[] intArray = new int[1]; + int ret = cl.clCreateKernelsInProgram(programID, 0, null, 0, intArray, 0); + checkForError(ret, "can not create kernels for program"); + + long[] kernelIDs = new long[intArray[0]]; + ret = cl.clCreateKernelsInProgram(programID, kernelIDs.length, kernelIDs, 0, null, 0); + checkForError(ret, "can not create kernels for program"); + + for (int i = 0; i < intArray[0]; i++) { + CLKernel kernel = new CLKernel(this, kernelIDs[i]); + kernels.put(kernel.name, kernel); + } + } + + return Collections.unmodifiableMap(kernels); + } + + void kernelReleased(CLKernel kernel) { + this.kernels.remove(kernel.name); + } + + /** * Releases this program. * @return this */ public CLProgram release() { + + if(!kernels.isEmpty()) { + String[] names = kernels.keySet().toArray(new String[kernels.size()]); + for (String name : names) { + kernels.get(name).release(); + } + } + int ret = cl.clReleaseProgram(programID); checkForError(ret, "can not release program"); context.programReleased(this); @@ -119,7 +164,7 @@ public class CLProgram { int count = bb.capacity() / 8; // TODO sizeof cl_device CLDevice[] devices = new CLDevice[count]; for (int i = 0; i < count; i++) { - devices[i] = context.getCLDevices(bb.getLong()); + devices[i] = context.getCLDevice(bb.getLong()); } return devices; @@ -194,4 +239,30 @@ public class CLProgram { return bb.getInt(); } + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final CLProgram other = (CLProgram) obj; + if (this.programID != other.programID) { + return false; + } + if (!this.context.equals(other.context)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int hash = 7; + hash = 37 * hash + (this.context != null ? this.context.hashCode() : 0); + hash = 37 * hash + (int) (this.programID ^ (this.programID >>> 32)); + return hash; + } + } |