summaryrefslogtreecommitdiffstats
path: root/src/com/mbien/opencl/CLKernel.java
blob: a4ee9f8901a408ec25cbe5a0c909f3cb9f8dd4c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package com.mbien.opencl;

import com.sun.gluegen.runtime.BufferFactory;
import com.sun.gluegen.runtime.CPU;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import static com.mbien.opencl.CLException.*;

/**
 *
 * @author Michael Bien
 */
public class CLKernel {

    public final long ID;
    public final String name;

    private final CLProgram program;
    private final CL cl;

    CLKernel(CLProgram program, long id) {
        this.ID = id;
        this.program = program;
        this.cl = program.context.cl;

        long[] longArray = new long[1];

        int ret = cl.clGetKernelInfo(ID, CL.CL_KERNEL_FUNCTION_NAME, 0, null, longArray, 0);
        checkForError(ret, "error while asking for kernel function name");

        ByteBuffer bb = ByteBuffer.allocate((int)longArray[0]).order(ByteOrder.nativeOrder());

        ret = cl.clGetKernelInfo(ID, CL.CL_KERNEL_FUNCTION_NAME, bb.capacity(), bb, null, 0);
        checkForError(ret, "error while asking for kernel function name");

        this.name = new String(bb.array(), 0, bb.capacity()).trim();

    }

    public CLKernel setArg(int argumentIndex, CLBuffer<?> value) {
        int ret = cl.clSetKernelArg(ID, argumentIndex, CPU.is32Bit()?4:8, wrap(value.ID));
        checkForError(ret, "error on clSetKernelArg");
        return this;
    }

    public CLKernel setArg(int argumentIndex, int value) {
        int ret = cl.clSetKernelArg(ID, argumentIndex, 4, wrap(value));
        checkForError(ret, "error on clSetKernelArg");
        return this;
    }

    public CLKernel setArg(int argumentIndex, long value) {
        int ret = cl.clSetKernelArg(ID, argumentIndex, 8, wrap(value));
        checkForError(ret, "error on clSetKernelArg");
        return this;
    }

    public CLKernel setArg(int argumentIndex, float value) {
        int ret = cl.clSetKernelArg(ID, argumentIndex, 4, wrap(value));
        checkForError(ret, "error on clSetKernelArg");
        return this;
    }

    public CLKernel setArg(int argumentIndex, double value) {
        int ret = cl.clSetKernelArg(ID, argumentIndex, 8, wrap(value));
        checkForError(ret, "error on clSetKernelArg");
        return this;
    }

    private final Buffer wrap(float value) {
        return BufferFactory.newDirectByteBuffer(4).putFloat(value).rewind();
    }

    private final Buffer wrap(double value) {
        return BufferFactory.newDirectByteBuffer(8).putDouble(value).rewind();
    }

    private final Buffer wrap(int value) {
        return BufferFactory.newDirectByteBuffer(4).putInt(value).rewind();
    }

    private final Buffer wrap(long value) {
        return BufferFactory.newDirectByteBuffer(8).putLong(value).rewind();
    }

    /**
     * Releases all resources of this kernel from its context.
     */
    public void release() {
        int ret = cl.clReleaseKernel(ID);
        program.onKernelReleased(this);
        checkForError(ret, "can not release kernel");
    }

    @Override
    public String toString() {
        return "CLKernel [id: " + ID
                      + " name: " + name+"]";
    }

    @Override
    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (getClass() != obj.getClass()) {
            return false;
        }
        final CLKernel other = (CLKernel) obj;
        if (this.ID != other.ID) {
            return false;
        }
        if (!this.program.equals(other.program)) {
            return false;
        }
        return true;
    }

    @Override
    public int hashCode() {
        int hash = 7;
        hash = 43 * hash + (int) (this.ID ^ (this.ID >>> 32));
        hash = 43 * hash + (this.program != null ? this.program.hashCode() : 0);
        return hash;
    }

}