package com.mbien.opencl; import com.mbien.opencl.util.CLUtil; import com.jogamp.gluegen.runtime.Buffers; import com.jogamp.gluegen.runtime.Platform; import com.jogamp.gluegen.runtime.Int64Buffer; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import static com.mbien.opencl.CLException.*; import static com.mbien.opencl.CL.*; /** * High level abstraction for an OpenCL Kernel. * A kernel is a function declared in a program. A kernel is identified by the kernel qualifier * applied to any function in a program. A kernel object encapsulates the specific kernel * function declared in a program and the argument values to be used when executing this * kernel function. * CLKernel is not threadsafe. * @see CLProgram#createCLKernel(java.lang.String) * @see CLProgram#createCLKernels() * @author Michael Bien */ public class CLKernel extends CLObject implements CLResource, Cloneable { public final String name; public final int numArgs; private final CLProgram program; private final ByteBuffer buffer; private int argIndex; private boolean force32BitArgs; CLKernel(CLProgram program, long id) { super(program.getContext(), id); this.program = program; this.buffer = Buffers.newDirectByteBuffer(8); Int64Buffer size = Int64Buffer.allocateDirect(1); // get function name int ret = cl.clGetKernelInfo(ID, CL_KERNEL_FUNCTION_NAME, 0, null, size); checkForError(ret, "error while asking for kernel function name"); ByteBuffer bb = ByteBuffer.allocateDirect((int)size.get(0)).order(ByteOrder.nativeOrder()); ret = cl.clGetKernelInfo(ID, CL_KERNEL_FUNCTION_NAME, bb.capacity(), bb, null); checkForError(ret, "error while asking for kernel function name"); this.name = CLUtil.clString2JavaString(bb, bb.capacity()); // get number of arguments ret = cl.clGetKernelInfo(ID, CL_KERNEL_NUM_ARGS, bb.capacity(), bb, null); checkForError(ret, "error while asking for number of function arguments."); numArgs = bb.getInt(0); } // public CLKernel putArg(Buffer value) { // setArg(argIndex++, value); // return this; // } public CLKernel putArg(CLMemory value) { setArg(argIndex++, value); return this; } public CLKernel putArg(int value) { setArg(argIndex++, value); return this; } public CLKernel putArg(long value) { setArg(argIndex++, value); return this; } public CLKernel putArg(float value) { setArg(argIndex++, value); return this; } public CLKernel putArg(double value) { setArg(argIndex++, value); return this; } public CLKernel putNullArg(int size) { setNullArg(argIndex++, size); return this; } public CLKernel putArgs(CLMemory... values) { setArgs(argIndex, values); argIndex += values.length; return this; } public CLKernel rewind() { argIndex = 0; return this; } // public CLKernel setArg(int argumentIndex, Buffer value) { // setArgument(argumentIndex, CLMemory.sizeOfBufferElem(value)*value.capacity(), value); // return this; // } public CLKernel setArg(int argumentIndex, CLMemory value) { setArgument(argumentIndex, Platform.is32Bit()?4:8, wrap(value.ID)); return this; } public CLKernel setArg(int argumentIndex, int value) { setArgument(argumentIndex, 4, wrap(value)); return this; } public CLKernel setArg(int argumentIndex, long value) { if(force32BitArgs) { setArgument(argumentIndex, 4, wrap((int)value)); }else{ setArgument(argumentIndex, 8, wrap(value)); } return this; } public CLKernel setArg(int argumentIndex, float value) { setArgument(argumentIndex, 4, wrap(value)); return this; } public CLKernel setArg(int argumentIndex, double value) { if(force32BitArgs) { setArgument(argumentIndex, 4, wrap((float)value)); }else{ setArgument(argumentIndex, 8, wrap(value)); } return this; } public CLKernel setNullArg(int argumentIndex, int size) { setArgument(argumentIndex, size, null); return this; } public CLKernel setArgs(CLMemory... values) { setArgs(0, values); return this; } private void setArgs(int startIndex, CLMemory... values) { for (int i = 0; i < values.length; i++) { setArg(i+startIndex, values[i]); } } private void setArgument(int argumentIndex, int size, Buffer value) { if(argumentIndex >= numArgs || argumentIndex < 0) { throw new IndexOutOfBoundsException("kernel "+ toString() +" has "+numArgs+ " arguments, can not set argument with index "+argumentIndex); } if(!program.isExecutable()) { throw new IllegalStateException("can not set program" + " arguments for a not executable program. "+program); } int ret = cl.clSetKernelArg(ID, argumentIndex, size, value); checkForError(ret, "error on clSetKernelArg"); } /** * Forces double and long arguments to be passed as float and int to the OpenCL kernel. * This can be used in applications which want to mix kernels with different floating point precision. */ public CLKernel setForce32BitArgs(boolean force) { this.force32BitArgs = force; return this; } public CLProgram getProgram() { return program; } /** * @see #setForce32BitArgs(boolean) */ public boolean isForce32BitArgsEnabled() { return force32BitArgs; } private Buffer wrap(float value) { return buffer.putFloat(value).rewind(); } private Buffer wrap(double value) { return buffer.putDouble(value).rewind(); } private Buffer wrap(int value) { return buffer.putInt(value).rewind(); } private Buffer wrap(long value) { return buffer.putLong(value).rewind(); } /** * Returns the amount of local memory in bytes being used by a kernel. * This includes local memory that may be needed by an implementation to execute the kernel, * variables declared inside the kernel with the __local address qualifier and local memory * to be allocated for arguments to the kernel declared as pointers with the __local address * qualifier and whose size is specified with clSetKernelArg. * If the local memory size, for any pointer argument to the kernel declared with * the __local address qualifier, is not specified, its size is assumed to be 0. */ public long getLocalMemorySize(CLDevice device) { return getWorkGroupInfo(device, CL_KERNEL_LOCAL_MEM_SIZE); } /** * Returns the work group size for this kernel on the given device. * This provides a mechanism for the application to query the work-group size * that can be used to execute a kernel on a specific device given by device. * The OpenCL implementation uses the resource requirements of the kernel * (register usage etc.) to determine what this work-group size should be. */ public long getWorkGroupSize(CLDevice device) { return getWorkGroupInfo(device, CL_KERNEL_WORK_GROUP_SIZE); } /** * Returns the work-group size specified by the __attribute__((reqd_work_gr oup_size(X, Y, Z))) qualifier in kernel sources. * If the work-group size is not specified using the above attribute qualifier new long[]{(0, 0, 0)} is returned. */ public long[] getCompileWorkGroupSize(CLDevice device) { int ret = cl.clGetKernelWorkGroupInfo(ID, device.ID, CL_KERNEL_COMPILE_WORK_GROUP_SIZE, 8*3, buffer, null); checkForError(ret, "error while asking for clGetKernelWorkGroupInfo"); return new long[] { buffer.getLong(0), buffer.getLong(1), buffer.getLong(2) }; } private long getWorkGroupInfo(CLDevice device, int flag) { int ret = cl.clGetKernelWorkGroupInfo(ID, device.ID, flag, 8, buffer, null); checkForError(ret, "error while asking for clGetKernelWorkGroupInfo"); return buffer.getLong(0); } /** * Releases all resources of this kernel from its context. */ public void release() { int ret = cl.clReleaseKernel(ID); program.onKernelReleased(this); checkForError(ret, "can not release kernel"); } public void close() { release(); } @Override public String toString() { return "CLKernel [id: " + ID + " name: " + name+"]"; } @Override public boolean equals(Object obj) { if (obj == null) { return false; } if (getClass() != obj.getClass()) { return false; } final CLKernel other = (CLKernel) obj; if (this.ID != other.ID) { return false; } if (!this.program.equals(other.program)) { return false; } return true; } @Override public int hashCode() { int hash = 7; hash = 43 * hash + (int) (this.ID ^ (this.ID >>> 32)); hash = 43 * hash + (this.program != null ? this.program.hashCode() : 0); return hash; } /** * Returns a new instance of this kernel with uninitialized arguments. */ @Override public CLKernel clone() { return program.createCLKernel(name).setForce32BitArgs(force32BitArgs); } }