summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/com/mbien/opencl/CLBuffer.java70
-rw-r--r--src/com/mbien/opencl/CLContext.java61
-rw-r--r--src/com/mbien/opencl/CLDevice.java14
-rw-r--r--src/com/mbien/opencl/CLKernel.java93
-rw-r--r--src/com/mbien/opencl/CLPlatform.java22
-rw-r--r--src/com/mbien/opencl/CLProgram.java77
6 files changed, 328 insertions, 9 deletions
diff --git a/src/com/mbien/opencl/CLBuffer.java b/src/com/mbien/opencl/CLBuffer.java
new file mode 100644
index 00000000..0f6e34a4
--- /dev/null
+++ b/src/com/mbien/opencl/CLBuffer.java
@@ -0,0 +1,70 @@
+package com.mbien.opencl;
+
+import java.nio.ByteBuffer;
+import static com.mbien.opencl.CLException.*;
+
+/**
+ *
+ * @author Michael Bien
+ */
+public class CLBuffer {
+
+ public final ByteBuffer buffer;
+ public final long bufferID;
+
+ private final CLContext context;
+ private final CL cl;
+
+ CLBuffer(CLContext context, int flags, ByteBuffer directBuffer) {
+
+ if(!directBuffer.isDirect())
+ throw new IllegalArgumentException("buffer is not a direct buffer");
+
+ this.buffer = directBuffer;
+ this.context = context;
+ this.cl = context.cl;
+
+ int[] intArray = new int[1];
+
+ this.bufferID = cl.clCreateBuffer(context.contextID, flags, directBuffer.capacity(), null, intArray, 0);
+
+ checkForError(intArray[0], "can not create cl buffer");
+
+ }
+
+ public CLBuffer release() {
+ cl.clReleaseMemObject(bufferID);
+ context.bufferReleased(this);
+ return this;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ final CLBuffer other = (CLBuffer) obj;
+ if (this.buffer != other.buffer && (this.buffer == null || !this.buffer.equals(other.buffer))) {
+ return false;
+ }
+ if (this.context.contextID != other.context.contextID) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 3;
+ hash = 29 * hash + (this.buffer != null ? this.buffer.hashCode() : 0);
+ hash = 29 * hash + (int) (this.context.contextID ^ (this.context.contextID >>> 32));
+ return hash;
+ }
+
+
+
+
+}
diff --git a/src/com/mbien/opencl/CLContext.java b/src/com/mbien/opencl/CLContext.java
index 601a2b79..9cb649d7 100644
--- a/src/com/mbien/opencl/CLContext.java
+++ b/src/com/mbien/opencl/CLContext.java
@@ -22,6 +22,7 @@ public final class CLContext {
private CLDevice[] devices;
private final List<CLProgram> programs;
+ private final List<CLBuffer> buffers;
static{
System.loadLibrary("gluegen-rt");
@@ -32,6 +33,7 @@ public final class CLContext {
private CLContext(long contextID) {
this.contextID = contextID;
this.programs = new ArrayList<CLProgram>();
+ this.buffers = new ArrayList<CLBuffer>();
}
/**
@@ -70,14 +72,32 @@ public final class CLContext {
return program;
}
+ public CLBuffer createBuffer(int flags, ByteBuffer directBuffer) {
+ CLBuffer buffer = new CLBuffer(this, flags, directBuffer);
+ buffers.add(buffer);
+ return buffer;
+ }
+
void programReleased(CLProgram program) {
programs.remove(program);
}
+ void bufferReleased(CLBuffer buffer) {
+ buffers.remove(buffer);
+ }
+
/**
* Releases the context and all resources.
*/
public CLContext release() {
+
+ //release all resources
+ while(!programs.isEmpty())
+ programs.get(0).release();
+
+ while(!buffers.isEmpty())
+ buffers.get(0).release();
+
int ret = cl.clReleaseContext(contextID);
checkForError(ret, "error releasing context");
return this;
@@ -86,10 +106,17 @@ public final class CLContext {
/**
* Returns a read only view of all programs associated with this context.
*/
- public List<CLProgram> getPrograms() {
+ public List<CLProgram> getCLPrograms() {
return Collections.unmodifiableList(programs);
}
+ /**
+ * Returns a read only view of all buffers associated with this context.
+ */
+ public List<CLBuffer> getCLBuffers() {
+ return Collections.unmodifiableList(buffers);
+ }
+
/**
* Gets the device with maximal FLOPS from this context.
@@ -157,7 +184,7 @@ public final class CLContext {
return devices;
}
- CLDevice getCLDevices(long dID) {
+ CLDevice getCLDevice(long dID) {
CLDevice[] deviceArray = getCLDevices();
for (int i = 0; i < deviceArray.length; i++) {
if(dID == deviceArray[i].deviceID)
@@ -198,4 +225,34 @@ public final class CLContext {
}
+ @Override
+ public String toString() {
+ return "CLContext [id: " + contextID
+ + " #devices: " + getCLDevices().length
+ + "]";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ final CLContext other = (CLContext) obj;
+ if (this.contextID != other.contextID) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 7;
+ hash = 23 * hash + (int) (this.contextID ^ (this.contextID >>> 32));
+ return hash;
+ }
+
+
}
diff --git a/src/com/mbien/opencl/CLDevice.java b/src/com/mbien/opencl/CLDevice.java
index d9f643ce..dcc9ee97 100644
--- a/src/com/mbien/opencl/CLDevice.java
+++ b/src/com/mbien/opencl/CLDevice.java
@@ -191,9 +191,17 @@ public final class CLDevice {
@Override
public boolean equals(Object obj) {
- if(obj != null && obj instanceof CLDevice)
- return ((CLDevice)obj).deviceID == deviceID;
- return false;
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ final CLDevice other = (CLDevice) obj;
+ if (this.deviceID != other.deviceID) {
+ return false;
+ }
+ return true;
}
@Override
diff --git a/src/com/mbien/opencl/CLKernel.java b/src/com/mbien/opencl/CLKernel.java
new file mode 100644
index 00000000..be5e03b6
--- /dev/null
+++ b/src/com/mbien/opencl/CLKernel.java
@@ -0,0 +1,93 @@
+package com.mbien.opencl;
+
+import com.sun.gluegen.runtime.BufferFactory;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import static com.mbien.opencl.CLException.*;
+
+/**
+ *
+ * @author Michael Bien
+ */
+public class CLKernel {
+
+ public final long kernelID;
+ public final String name;
+
+ private final CLProgram program;
+ private final CL cl;
+
+ CLKernel(CLProgram program, long id) {
+ this.kernelID = id;
+ this.program = program;
+ this.cl = program.context.cl;
+
+ long[] longArray = new long[1];
+
+ int ret = cl.clGetKernelInfo(kernelID, CL.CL_KERNEL_FUNCTION_NAME, 0, null, longArray, 0);
+ checkForError(ret, "error while asking for kernel function name");
+
+ ByteBuffer bb = ByteBuffer.allocate((int)longArray[0]).order(ByteOrder.nativeOrder());
+
+ ret = cl.clGetKernelInfo(kernelID, CL.CL_KERNEL_FUNCTION_NAME, bb.capacity(), bb, null, 0);
+ checkForError(ret, "error while asking for kernel function name");
+
+ this.name = new String(bb.array(), 0, (int)longArray[0]).trim();
+
+ }
+
+ public CLKernel setArg(int argumentIndex, int argumentSize, CLBuffer value) {
+ int ret = cl.clSetKernelArg(kernelID, argumentIndex, argumentSize, wrapLong(value.bufferID));
+ checkForError(ret, "error on clSetKernelArg");
+ return this;
+ }
+
+ public CLKernel setArg(int argumentIndex, int argumentSize, long value) {
+ int ret = cl.clSetKernelArg(kernelID, argumentIndex, argumentSize, wrapLong(value));
+ checkForError(ret, "error on clSetKernelArg");
+ return this;
+ }
+
+ private final ByteBuffer wrapLong(long value) {
+ return (ByteBuffer) BufferFactory.newDirectByteBuffer(8).putLong(value).rewind();
+ }
+
+ public CLKernel release() {
+ cl.clReleaseKernel(kernelID);
+ program.kernelReleased(this);
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ return "CLKernel [id: " + kernelID
+ + " name: " + name+"]";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ final CLKernel other = (CLKernel) obj;
+ if (this.kernelID != other.kernelID) {
+ return false;
+ }
+ if (!this.program.equals(other.program)) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 7;
+ hash = 43 * hash + (int) (this.kernelID ^ (this.kernelID >>> 32));
+ hash = 43 * hash + (this.program != null ? this.program.hashCode() : 0);
+ return hash;
+ }
+
+}
diff --git a/src/com/mbien/opencl/CLPlatform.java b/src/com/mbien/opencl/CLPlatform.java
index a717564a..d9f8dd25 100644
--- a/src/com/mbien/opencl/CLPlatform.java
+++ b/src/com/mbien/opencl/CLPlatform.java
@@ -11,7 +11,7 @@ public final class CLPlatform {
/**
* OpenCL platform id for this platform.
*/
- public final long platformID;
+ public final long platformID;
private final CL cl;
@@ -94,7 +94,27 @@ public final class CLPlatform {
+" version:"+getVersion()+"]";
}
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ final CLPlatform other = (CLPlatform) obj;
+ if (this.platformID != other.platformID) {
+ return false;
+ }
+ return true;
+ }
+ @Override
+ public int hashCode() {
+ int hash = 7;
+ hash = 71 * hash + (int) (this.platformID ^ (this.platformID >>> 32));
+ return hash;
+ }
}
diff --git a/src/com/mbien/opencl/CLProgram.java b/src/com/mbien/opencl/CLProgram.java
index e93e7246..73fe8cac 100644
--- a/src/com/mbien/opencl/CLProgram.java
+++ b/src/com/mbien/opencl/CLProgram.java
@@ -2,6 +2,10 @@ package com.mbien.opencl;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
import static com.mbien.opencl.CLException.*;
/**
@@ -10,9 +14,12 @@ import static com.mbien.opencl.CLException.*;
*/
public class CLProgram {
- private final CLContext context;
- private final long programID;
+ public final CLContext context;
+ public final long programID;
+
private final CL cl;
+
+ private final Map<String, CLKernel> kernels;
public enum Status {
@@ -53,6 +60,8 @@ public class CLProgram {
this.cl = context.cl;
this.context = context;
+ this.kernels = new HashMap<String, CLKernel>();
+
int[] intArray = new int[1];
// Create the program
programID = cl.clCreateProgramWithSource(contextID, 1, new String[] {src}, new long[]{src.length()}, 0, intArray, 0);
@@ -92,10 +101,46 @@ public class CLProgram {
}
/**
+ * Returns all kernels of this program in a unmodifiable view of a map with the kernel function names as keys.
+ */
+ public Map<String, CLKernel> getCLKernels() {
+
+ if(kernels.isEmpty()) {
+
+ int[] intArray = new int[1];
+ int ret = cl.clCreateKernelsInProgram(programID, 0, null, 0, intArray, 0);
+ checkForError(ret, "can not create kernels for program");
+
+ long[] kernelIDs = new long[intArray[0]];
+ ret = cl.clCreateKernelsInProgram(programID, kernelIDs.length, kernelIDs, 0, null, 0);
+ checkForError(ret, "can not create kernels for program");
+
+ for (int i = 0; i < intArray[0]; i++) {
+ CLKernel kernel = new CLKernel(this, kernelIDs[i]);
+ kernels.put(kernel.name, kernel);
+ }
+ }
+
+ return Collections.unmodifiableMap(kernels);
+ }
+
+ void kernelReleased(CLKernel kernel) {
+ this.kernels.remove(kernel.name);
+ }
+
+ /**
* Releases this program.
* @return this
*/
public CLProgram release() {
+
+ if(!kernels.isEmpty()) {
+ String[] names = kernels.keySet().toArray(new String[kernels.size()]);
+ for (String name : names) {
+ kernels.get(name).release();
+ }
+ }
+
int ret = cl.clReleaseProgram(programID);
checkForError(ret, "can not release program");
context.programReleased(this);
@@ -119,7 +164,7 @@ public class CLProgram {
int count = bb.capacity() / 8; // TODO sizeof cl_device
CLDevice[] devices = new CLDevice[count];
for (int i = 0; i < count; i++) {
- devices[i] = context.getCLDevices(bb.getLong());
+ devices[i] = context.getCLDevice(bb.getLong());
}
return devices;
@@ -194,4 +239,30 @@ public class CLProgram {
return bb.getInt();
}
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ final CLProgram other = (CLProgram) obj;
+ if (this.programID != other.programID) {
+ return false;
+ }
+ if (!this.context.equals(other.context)) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 7;
+ hash = 37 * hash + (this.context != null ? this.context.hashCode() : 0);
+ hash = 37 * hash + (int) (this.programID ^ (this.programID >>> 32));
+ return hash;
+ }
+
}