summaryrefslogtreecommitdiffstats
path: root/src/com/jogamp/opencl/util/CLMultiContext.java
blob: 789ed0f5d1cdd15c22a437baa658474d72b44f83 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
/*
 * Created on Thursday, April 28 2011 22:10
 */
package com.jogamp.opencl.util;

import com.jogamp.opencl.CLContext;
import com.jogamp.opencl.CLDevice;
import com.jogamp.opencl.CLPlatform;
import com.jogamp.opencl.CLResource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.Arrays.*;
import static com.jogamp.opencl.CLDevice.Type.*;

/**
 * Utility for organizing multiple {@link CLContext}s.
 *
 * @author Michael Bien
 */
public class CLMultiContext implements CLResource {

    private final List<CLContext> contexts;
    private boolean released;

    private CLMultiContext() {
        contexts = new ArrayList<CLContext>();
    }

    /**
     * Creates a multi context with all devices of the specified platforms.
     */
    @SuppressWarnings("unchecked")
    public static CLMultiContext create(CLPlatform... platforms) {
        return create(platforms, CLDeviceFilters.type(ALL));
    }

    /**
     * Creates a multi context with all matching devices of the specified platforms.
     */
    public static CLMultiContext create(CLPlatform[] platforms, Filter<CLDevice>... filters) {
        return create(Arrays.asList(platforms), filters);
    }

    /**
     * Creates a multi context with all matching devices of the specified platforms.
     */
    public static CLMultiContext create(Collection<CLPlatform> platforms, Filter<CLDevice>... filters) {

        if(platforms == null) {
            throw new NullPointerException("platform list was null");
        }else if(platforms.isEmpty()) {
            throw new IllegalArgumentException("platform list was empty");
        }

        List<CLDevice> devices = new ArrayList<CLDevice>();
        for (CLPlatform platform : platforms) {
            devices.addAll(asList(platform.listCLDevices(filters)));
        }
        return create(devices);
    }

    /**
     * Creates a multi context with the specified devices.
     * The devices don't have to be from the same platform.
     */
    public static CLMultiContext create(Collection<? extends CLDevice> devices) {

        if(devices.isEmpty()) {
            throw new IllegalArgumentException("device list was empty");
        }

        Map<CLPlatform, List<CLDevice>> platformDevicesMap = filterPlatformConflicts(devices);

        // create contexts
        CLMultiContext mc = new CLMultiContext();
        for (Map.Entry<CLPlatform, List<CLDevice>> entry : platformDevicesMap.entrySet()) {
            List<CLDevice> list = entry.getValue();
            // one context per device to workaround driver bugs
            for (CLDevice device : list) {
                CLContext context = CLContext.create(device);
                mc.contexts.add(context);
            }
        }

        return mc;
    }

    /**
     * Creates a multi context with specified contexts.
     */
    public static CLMultiContext wrap(CLContext... contexts) {
        CLMultiContext mc = new CLMultiContext();
        mc.contexts.addAll(asList(contexts));
        return mc;
    }

    /**
     * filter devices; don't allow the same device to be used in more than one platform.
     * example: a CPU available via the AMD and Intel SDKs shouldn't end up in two contexts
     */
    private static Map<CLPlatform, List<CLDevice>> filterPlatformConflicts(Collection<? extends CLDevice> devices) {

        // FIXME: devicename-platform is used as unique device identifier - replace if we have something better
        
        Map<CLPlatform, List<CLDevice>> filtered = new HashMap<CLPlatform, List<CLDevice>>();
        Map<String, CLPlatform> used = new HashMap<String, CLPlatform>();

        for (CLDevice device : devices) {

            String name = device.getName(); 

            CLPlatform platform = device.getPlatform();
            CLPlatform usedPlatform = used.get(name);

            if(usedPlatform == null || platform.equals(usedPlatform)) {
                if(!filtered.containsKey(platform)) {
                    filtered.put(platform, new ArrayList<CLDevice>());
                }
                filtered.get(platform).add(device);
                used.put(name, platform);
            }
            
        }
        return filtered;
    }


    /**
     * Releases all contexts.
     * @see CLContext#release()
     */
    @Override
    public void release() {
        if(released) {
            throw new RuntimeException(getClass().getSimpleName()+" already released");
        }
        released = true;
        for (CLContext context : contexts) {
            context.release();
        }
        contexts.clear();
    }

    public List<CLContext> getContexts() {
        return Collections.unmodifiableList(contexts);
    }

    /**
     * Returns a list containing all devices used in this multi context.
     */
    public List<CLDevice> getDevices() {
        List<CLDevice> devices = new ArrayList<CLDevice>();
        for (CLContext context : contexts) {
            devices.addAll(asList(context.getDevices()));
        }
        return devices;
    }

    @Override
    public boolean isReleased() {
        return released;
    }

    @Override
    public String toString() {
        return getClass().getSimpleName()+" [" + contexts.size()+" contexts, "
                                               + getDevices().size()+ " devices]";
    }



}