package com.jogamp.opencl.demos.fft; import com.jogamp.opencl.CLBuffer; import com.jogamp.opencl.CLCommandQueue; import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLDevice; import com.jogamp.opencl.CLKernel; import com.jogamp.opencl.CLMemory.Mem; import com.jogamp.opencl.CLPlatform; import com.jogamp.opencl.CLProgram; import com.jogamp.opencl.demos.fft.CLFFTPlan.InvalidContextException; import java.awt.BorderLayout; import java.awt.Dimension; import java.awt.Graphics; import java.awt.GridBagConstraints; import java.awt.GridBagLayout; import java.awt.Insets; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.image.BufferedImage; import java.awt.image.DataBufferByte; import java.awt.image.DataBufferInt; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.util.logging.Level; import java.util.logging.Logger; import javax.imageio.ImageIO; import javax.swing.BoxLayout; import javax.swing.ButtonGroup; import javax.swing.JButton; import javax.swing.JFileChooser; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JOptionPane; import javax.swing.JPanel; import javax.swing.JSlider; import javax.swing.JToggleButton; import javax.swing.SwingUtilities; import javax.swing.event.ChangeEvent; import javax.swing.event.ChangeListener; /** * Perform some user-controllable blur on an image. * @author notzed */ public class BlurTest implements Runnable, ChangeListener, ActionListener { public static void main(String[] args) { SwingUtilities.invokeLater(new BlurTest()); } boolean demo = false; // must be power of 2 and width must be multiple of 64 int width = 512; int height = 512; BufferedImage src; BufferedImage psf; BufferedImage dst; PaintView left; ImageView right; // JSlider sizex; JSlider sizey; JSlider angle; // JToggleButton blurButton; JToggleButton drawButton; public void run() { try { initCL(); } catch (Exception x) { System.out.println("failed to init cl"); x.printStackTrace(); System.exit(1); } JFileChooser fc = new JFileChooser(); BufferedImage img = null; while (img == null) { try { File file = null; if (true) { fc.setDialogTitle("Select Image File"); fc.setPreferredSize(new Dimension(500, 600)); if (fc.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) { file = fc.getSelectedFile(); } else { System.exit(0); } } else { file = new File("/home/notzed/cat0.jpg"); } img = ImageIO.read(file); if (img == null) { JOptionPane.showMessageDialog(null, "Couldn't load file"); } } catch (IOException x) { JOptionPane.showMessageDialog(null, "Couldn't load file"); } } src = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); dst = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB); psf = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY); // Ensure loaded image is in known format and size Graphics g = src.createGraphics(); g.drawImage(img, (width - img.getWidth()) / 2, (height - img.getHeight()) / 2, null); g.dispose(); JFrame win = new JFrame("Blur Demo"); win.setDefaultCloseOperation(win.EXIT_ON_CLOSE); JPanel main = new JPanel(); main.setLayout(new BorderLayout()); JPanel controls = new JPanel(); controls.setLayout(new GridBagLayout()); GridBagConstraints c0 = new GridBagConstraints(); c0.gridx = 0; c0.anchor = GridBagConstraints.BASELINE_LEADING; c0.ipadx = 3; c0.insets = new Insets(1, 2, 1, 2); controls.add(new JLabel("Width"), c0); controls.add(new JLabel("Height"), c0); GridBagConstraints c2 = (GridBagConstraints) c0.clone(); c2.gridx = 2; controls.add(new JLabel("Angle"), c2); c0 = (GridBagConstraints) c0.clone(); c0.gridx = 1; c0.weightx = 1; c0.fill = GridBagConstraints.HORIZONTAL; sizex = new JSlider(100, 5000, 1000); sizey = new JSlider(100, 5000, 100); controls.add(sizex, c0); controls.add(sizey, c0); c2 = (GridBagConstraints) c0.clone(); c2.gridx = 3; angle = new JSlider(0, (int) (Math.PI * 1000)); controls.add(angle, c2); sizex.addChangeListener(this); sizey.addChangeListener(this); angle.addChangeListener(this); JPanel buttons = new JPanel(); controls.add(buttons, c2); JButton b; b = new JButton("Clear"); buttons.add(b); b.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { doclear(); } }); ButtonGroup opt = new ButtonGroup(); JToggleButton tb; blurButton = new JToggleButton("Blur"); opt.add(blurButton); buttons.add(blurButton); blurButton.addActionListener(this); drawButton = new JToggleButton("Draw"); opt.add(drawButton); buttons.add(drawButton); drawButton.addActionListener(this); JPanel imgs = new JPanel(); imgs.setLayout(new BoxLayout(imgs, BoxLayout.X_AXIS)); left = new PaintView(this, psf); right = new ImageView(dst); imgs.add(left); imgs.add(right); main.add(controls, BorderLayout.NORTH); main.add(imgs, BorderLayout.CENTER); win.getContentPane().add(main); win.pack(); win.setVisible(true); // pre-load and transform src, since that wont change loadSource(src); blurButton.doClick(); } public void stateChanged(ChangeEvent e) { if (drawButton.isSelected()) { recalc(); } else { double w = sizex.getValue() / 100.0; double h = sizey.getValue() / 100.0; double a = angle.getValue() / 1000.0; Graphics g = psf.createGraphics(); g.clearRect(0, 0, width, height); g.dispose(); left.drawDot(w, h, a); } } public void actionPerformed(ActionEvent e) { stateChanged(null); } private void doclear() { Graphics g = psf.createGraphics(); g.clearRect(0, 0, width, height); g.dispose(); left.repaint(); recalc(); } private void dorecalc() { loadPSF(psf); // convolve each plane in freq domain convolve(aCBuffer, psfBuffer, aGBuffer); convolve(rCBuffer, psfBuffer, rGBuffer); convolve(gCBuffer, psfBuffer, gGBuffer); convolve(bCBuffer, psfBuffer, bGBuffer); // convert back to spatial domain fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, aGBuffer, aBuffer, null, null); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, rGBuffer, rBuffer, null, null); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, gGBuffer, gBuffer, null, null); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, bGBuffer, bBuffer, null, null); // while gpu is running, calculate energy of psf float scale; long total = 0; DataBufferByte pd = (DataBufferByte) psf.getRaster().getDataBuffer(); byte[] data = pd.getData(); for (int i = 0; i < data.length; i++) { total += data[i] & 0xff; } scale = 255.0f / total / width / height; getDestination(argbBuffer, aBuffer, rBuffer, gBuffer, bBuffer, scale); // drop back to java, slow-crappy-method q.putReadBuffer(argbBuffer, true); DataBufferInt db = (DataBufferInt) dst.getRaster().getDataBuffer(); argbBuffer.getBuffer().position(0); argbBuffer.getBuffer().get(db.getData()); argbBuffer.getBuffer().position(0); right.repaint(); } Runnable later; void recalc() { if (later == null) { later = new Runnable() { public void run() { later = null; dorecalc(); } }; SwingUtilities.invokeLater(later); } } CLContext cl; CLCommandQueue q; CLProgram prog; CLKernel kImg2Planes; CLKernel kPlanes2Img; CLKernel kGrey2Plane; CLKernel kConvolve; CLKernel kDeconvolve; CLFFTPlan fft; CLBuffer argbBuffer; CLBuffer greyBuffer; CLBuffer aBuffer; CLBuffer rBuffer; CLBuffer gBuffer; CLBuffer bBuffer; CLBuffer aCBuffer; CLBuffer rCBuffer; CLBuffer gCBuffer; CLBuffer bCBuffer; CLBuffer aGBuffer; CLBuffer rGBuffer; CLBuffer gGBuffer; CLBuffer bGBuffer; CLBuffer psfBuffer; CLBuffer tmpBuffer; // CLKernel fft512; void initCL() throws InvalidContextException { // search a platform with a GPU CLPlatform[] platforms = CLPlatform.listCLPlatforms(); CLDevice gpu = null; for (CLPlatform platform : platforms) { gpu = platform.getMaxFlopsDevice(CLDevice.Type.GPU); if(gpu != null) { break; } } cl = CLContext.create(gpu); q = cl.getDevices()[0].createCommandQueue(); prog = cl.createProgram(img2Planes + planes2Img + convolve + grey2Plane + deconvolve); prog.build("-cl-mad-enable"); kImg2Planes = prog.createCLKernel("img2planes"); kPlanes2Img = prog.createCLKernel("planes2img"); kGrey2Plane = prog.createCLKernel("grey2plane"); kConvolve = prog.createCLKernel("convolve"); kDeconvolve = prog.createCLKernel("deconvolve"); argbBuffer = cl.createIntBuffer(width * height, Mem.READ_WRITE); greyBuffer = cl.createByteBuffer(width * height, Mem.READ_WRITE); aBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); rBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); gBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); bBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); psfBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); tmpBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); aCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); rCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); gCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); bCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); aGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); rGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); gGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); bGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE); if (false) { try { CLProgram p = cl.createProgram(new FileInputStream("/home/notzed/cl/fft-512.cl")); p.build(); fft512 = p.createCLKernel("fft0"); } catch (IOException ex) { Logger.getLogger(BlurTest.class.getName()).log(Level.SEVERE, null, ex); } } else { fft = new CLFFTPlan(cl, new int[]{width, height}, CLFFTPlan.CLFFTDataFormat.InterleavedComplexFormat); } //fft.dumpPlan(null); } void loadSource(BufferedImage src) { DataBufferInt sb = (DataBufferInt) src.getRaster().getDataBuffer(); argbBuffer.getBuffer().position(0); argbBuffer.getBuffer().put(sb.getData()); argbBuffer.getBuffer().position(0); q.putWriteBuffer(argbBuffer, false); kImg2Planes.setArg(0, argbBuffer); kImg2Planes.setArg(1, 0); kImg2Planes.setArg(2, width); kImg2Planes.setArg(3, aBuffer); kImg2Planes.setArg(4, rBuffer); kImg2Planes.setArg(5, gBuffer); kImg2Planes.setArg(6, bBuffer); kImg2Planes.setArg(7, 0); kImg2Planes.setArg(8, width); q.put2DRangeKernel(kImg2Planes, 0, 0, width, height, 64, 1); q.finish(); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, aBuffer, aCBuffer, null, null); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, rBuffer, rCBuffer, null, null); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, gBuffer, gCBuffer, null, null); fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, bBuffer, bCBuffer, null, null); } void loadPSF(BufferedImage psf) { assert (psf.getType() == BufferedImage.TYPE_BYTE_GRAY); DataBufferByte pb = (DataBufferByte) psf.getRaster().getDataBuffer(); greyBuffer.getBuffer().position(0); greyBuffer.getBuffer().put(pb.getData()); greyBuffer.getBuffer().position(0); q.putWriteBuffer(greyBuffer, false); kGrey2Plane.setArg(0, greyBuffer); kGrey2Plane.setArg(1, 0); kGrey2Plane.setArg(2, width); kGrey2Plane.setArg(3, tmpBuffer); kGrey2Plane.setArg(4, 0); kGrey2Plane.setArg(5, width); q.put2DRangeKernel(kGrey2Plane, 0, 0, width, height, 64, 1); if (true) { fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, tmpBuffer, psfBuffer, null, null); } else if (true) { fft512.setArg(0, tmpBuffer); fft512.setArg(1, psfBuffer); fft512.setArg(2, -1); fft512.setArg(3, height); //q.put1DRangeKernel(fft512, 0,height*64, 64); q.put2DRangeKernel(fft512, 0, 0, height * 64, 1, 64, 1); System.out.println("running kernel " + 64 * height + ", " + 64); } } // g = f x h void convolve(CLBuffer h, CLBuffer f, CLBuffer g) { kConvolve.setArg(0, h); kConvolve.setArg(1, f); kConvolve.setArg(2, g); kConvolve.setArg(3, width); q.put2DRangeKernel(kConvolve, 0, 0, width, height, 64, 1); } // g = h*conj(f) / (abs(f)^2 + k) void deconvolve(CLBuffer h, CLBuffer f, CLBuffer g, float k) { kDeconvolve.setArg(0, h); kDeconvolve.setArg(1, f); kDeconvolve.setArg(2, g); kDeconvolve.setArg(3, width); kDeconvolve.setArg(4, k); q.put2DRangeKernel(kDeconvolve, 0, 0, width, height, 64, 1); } void getDestination(CLBuffer dst, CLBuffer a, CLBuffer r, CLBuffer g, CLBuffer b, float scale) { kPlanes2Img.setArg(0, dst); kPlanes2Img.setArg(1, 0); kPlanes2Img.setArg(2, width); kPlanes2Img.setArg(3, a); kPlanes2Img.setArg(4, r); kPlanes2Img.setArg(5, g); kPlanes2Img.setArg(6, b); kPlanes2Img.setArg(7, 0); kPlanes2Img.setArg(8, width); kPlanes2Img.setArg(9, scale); q.put2DRangeKernel(kPlanes2Img, 0, 0, width, height, 64, 1); } // Convert packed ARGB byte image to planes of complex floats final String img2Planes = "kernel void img2planes(global const uchar4 *argb, int soff, int sstride," + " global float2 *a, global float2 *r, global float2 *g, global float2 *b, int doff, int dstride) {" + " int gx = get_global_id(0);" + " int gy = get_global_id(1);" + " uchar4 v = argb[soff+sstride*gy+gx];" + " float4 ff = convert_float4(v) * (float4)(1.0f/255);" + " doff += (dstride * gy + gx);" + " b[doff] = (float2){ ff.s0, 0 };\n" + " g[doff] = (float2){ ff.s1, 0 };" + " r[doff] = (float2){ ff.s2, 0 };" + " a[doff] = (float2){ ff.s3, 0 };\n" + "}\n\n"; // not the best implementation // this also performs an 'fftshift' final String grey2Plane = "kernel void grey2plane(global const uchar *src, int soff, int sstride," + " global float2 *dst, int doff, int dstride) {" + " int gx = get_global_id(0);" + " int gy = get_global_id(1);" + " uchar v = src[soff+sstride*gy+gx];" + " float ff = convert_float(v) * (1.0f/255);" // fftshift + " gx ^= get_global_size(0)>>1;" + " gy ^= get_global_size(1)>>1;" + " doff += (dstride * gy + gx);" + " dst[doff] = (float2) { ff, 0 };" + "}\n\n"; // This also does the 'fftscale' after the inverse fft. final String planes2Img = "kernel void planes2img(global uchar4 *argb, int soff, int sstride, const global float2 *a, const global float2 *r, const global float2 *g, const global float2 *b, int doff, int dstride, float scale) {" + " int gx = get_global_id(0);" + " int gy = get_global_id(1);" + " float4 fr, fi, fa;" + " float2 t;" + " doff += (dstride * gy + gx);" + " float2 s = (float2)scale;" + " t = b[doff]*s; fr.s0 = t.s0; fi.s0 = t.s1;" + " t = g[doff]*s; fr.s1 = t.s0; fi.s1 = t.s1;" + " t = r[doff]*s; fr.s2 = t.s0; fi.s2 = t.s1;" + " t = a[doff]*s; fr.s3 = t.s0; fi.s3 = t.s1;" + " fa = sqrt(fr*fr + fi*fi) * 255;" + " fa = clamp(fa, 0.0f, 255.0f);" + " argb[soff +sstride*gy+gx] = convert_uchar4(fa);" + "}\n\n"; final String convolve = "kernel void convolve(global const float2 *h, global const float2 *ff, global float2 *g, int stride) {" + " int gx = get_global_id(0);" + " int gy = get_global_id(1);" + " int off = stride * gy + gx;" + " float2 a = h[off];" + " float2 b = ff[off];" + " g[off] = (float2) { a.s0 * b.s0 - a.s1 * b.s1, a.s0 * b.s1 + a.s1 * b.s0 };" + "}\n\n"; final String deconvolve = "kernel void deconvolve(global const float2 *h, global const float2 *ff, global float2 *g, int stride, float k) {" + " int gx = get_global_id(0);" + " int gy = get_global_id(1);" + " int off = stride * gy + gx;" + " float2 a = h[off];" + " float2 b = ff[off];" + " float d = b.s0 * b.s0 + b.s1 * b.s1 + k;" + " b.s0 /= d;" + " b.s1 /= -d;" + " g[off] = (float2) { a.s0 * b.s0 - a.s1 * b.s1, a.s0 * b.s1 + a.s1 * b.s0 };" + "}\n\n"; }