/* * 20:48 Sunday, February 28 2010 */ package com.jogamp.opencl.demos.radixsort; import com.jogamp.opencl.CLBuffer; import com.jogamp.opencl.CLCommandQueue; import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLDevice; import com.jogamp.opencl.CLPlatform; import java.io.IOException; import java.nio.IntBuffer; import java.util.Random; import static com.jogamp.opencl.CLMemory.Mem.*; import static java.lang.System.*; import static com.jogamp.opencl.CLDevice.Type.*; /** * GPU radix sort demo. * @author Michael Bien */ public class RadixSortDemo { public RadixSortDemo() throws IOException { CLContext context = null; try{ //single GPU setup context = CLContext.create(CLPlatform.getDefault().getMaxFlopsDevice(GPU)); CLDevice device = context.getDevices()[0]; CLCommandQueue queue = device.createCommandQueue(); int maxValue = Integer.MAX_VALUE; int samples = 10; int[] workgroupSizes = new int[] {128, 256}; // make sure workgroup sizes don't exceed device maximum int maxWorkgroupSize = device.getMaxWorkGroupSize(); for( int i = 0; i < workgroupSizes.length; ++i ) { if( workgroupSizes[i] > maxWorkgroupSize ) throw new RuntimeException("Workgroup size " + workgroupSizes[i] + " greater than device max of "+ maxWorkgroupSize); } int[] runs = new int[] { 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608 }; for (int i = 0; i < workgroupSizes.length; i++) { int workgroupSize = workgroupSizes[i]; out.println("\n = = = workgroup size: "+workgroupSize+" = = = "); for(int run = 0; run < runs.length; run++) { if( workgroupSize==128 && runs[run] >= 8388608 || workgroupSize==256 && runs[run] <= 32768) { continue; // we can only sort up to 4MB with wg size of 128 } int numElements = runs[run]; CLBuffer array = context.createIntBuffer(numElements, READ_WRITE); out.print("array size: " + array.getCLSize()/1000000.0f+"MB; "); out.println("elements: " + array.getCLCapacity()/1000+"K"); fillBuffer(array, maxValue); RadixSort radixSort = new RadixSort(queue, numElements, workgroupSize); for(int a = 0; a < samples; a++) { queue.finish(); long time = nanoTime(); queue.putWriteBuffer(array, false); radixSort.sort(array, numElements, 32); queue.putReadBuffer(array, true); out.println("time: " + (nanoTime() - time)/1000000.0f+"ms"); } out.print("snapshot: "); printSnapshot(array.getBuffer(), 20); out.println("validating..."); checkIfSorted(array.getBuffer()); out.println("values sorted"); array.release(); radixSort.release(); } } }finally{ if(context != null) { context.release(); } } } private void fillBuffer(CLBuffer array, int maxValue) { Random random = new Random(42); for (int n = 0; n < array.getBuffer().capacity(); n++) { int rnd = random.nextInt(maxValue); array.getBuffer().put(n, rnd); } } private void printSnapshot(IntBuffer buffer, int snapshot) { for(int i = 0; i < snapshot; i++) out.print(buffer.get() + ", "); out.println("...; " + buffer.remaining() + " more"); buffer.rewind(); } private void checkIfSorted(IntBuffer keys) { for (int i = 1; i < keys.capacity(); i++) { if (keys.get(i - 1) > keys.get(i)) { throw new RuntimeException("not sorted "+ keys.get(i - 1) +" !> "+ keys.get(i)); } } } public static void main(String[] args) throws IOException { new RadixSortDemo(); } }