summaryrefslogtreecommitdiffstats
path: root/src/com/jogamp/opencl/demos/fft/BlurTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/com/jogamp/opencl/demos/fft/BlurTest.java')
-rw-r--r--src/com/jogamp/opencl/demos/fft/BlurTest.java521
1 files changed, 521 insertions, 0 deletions
diff --git a/src/com/jogamp/opencl/demos/fft/BlurTest.java b/src/com/jogamp/opencl/demos/fft/BlurTest.java
new file mode 100644
index 0000000..c08da43
--- /dev/null
+++ b/src/com/jogamp/opencl/demos/fft/BlurTest.java
@@ -0,0 +1,521 @@
+package com.jogamp.opencl.demos.fft;
+
+import com.jogamp.opencl.CLBuffer;
+import com.jogamp.opencl.CLCommandQueue;
+import com.jogamp.opencl.CLContext;
+import com.jogamp.opencl.CLDevice;
+import com.jogamp.opencl.CLKernel;
+import com.jogamp.opencl.CLMemory.Mem;
+import com.jogamp.opencl.CLPlatform;
+import com.jogamp.opencl.CLProgram;
+import com.jogamp.opencl.demos.fft.CLFFTPlan.InvalidContextException;
+import java.awt.BorderLayout;
+import java.awt.Dimension;
+import java.awt.Graphics;
+import java.awt.GridBagConstraints;
+import java.awt.GridBagLayout;
+import java.awt.Insets;
+import java.awt.event.ActionEvent;
+import java.awt.event.ActionListener;
+import java.awt.image.BufferedImage;
+import java.awt.image.DataBufferByte;
+import java.awt.image.DataBufferInt;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import javax.imageio.ImageIO;
+import javax.swing.BoxLayout;
+import javax.swing.ButtonGroup;
+import javax.swing.JButton;
+import javax.swing.JFileChooser;
+import javax.swing.JFrame;
+import javax.swing.JLabel;
+import javax.swing.JOptionPane;
+import javax.swing.JPanel;
+import javax.swing.JSlider;
+import javax.swing.JToggleButton;
+import javax.swing.SwingUtilities;
+import javax.swing.event.ChangeEvent;
+import javax.swing.event.ChangeListener;
+
+/**
+ * Perform some user-controllable blur on an image.
+ * @author notzed
+ */
+public class BlurTest implements Runnable, ChangeListener, ActionListener {
+
+ public static void main(String[] args) {
+ SwingUtilities.invokeLater(new BlurTest());
+ }
+
+ boolean demo = false;
+ // must be power of 2 and width must be multiple of 64
+ int width = 512;
+ int height = 512;
+ BufferedImage src;
+ BufferedImage psf;
+ BufferedImage dst;
+ PaintView left;
+ ImageView right;
+ //
+ JSlider sizex;
+ JSlider sizey;
+ JSlider angle;
+ //
+ JToggleButton blurButton;
+ JToggleButton drawButton;
+
+ public void run() {
+ try {
+ initCL();
+ } catch (Exception x) {
+ System.out.println("failed to init cl");
+ x.printStackTrace();
+ System.exit(1);
+ }
+
+ JFileChooser fc = new JFileChooser();
+ BufferedImage img = null;
+
+ while (img == null) {
+ try {
+ File file = null;
+
+ if (true) {
+ fc.setDialogTitle("Select Image File");
+ fc.setPreferredSize(new Dimension(500, 600));
+ if (fc.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
+ file = fc.getSelectedFile();
+ } else {
+ System.exit(0);
+ }
+
+ } else {
+ file = new File("/home/notzed/cat0.jpg");
+ }
+ img = ImageIO.read(file);
+ if (img == null) {
+ JOptionPane.showMessageDialog(null, "Couldn't load file");
+ }
+ } catch (IOException x) {
+ JOptionPane.showMessageDialog(null, "Couldn't load file");
+ }
+ }
+
+ src = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
+ dst = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
+ psf = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
+
+ // Ensure loaded image is in known format and size
+ Graphics g = src.createGraphics();
+ g.drawImage(img, (width - img.getWidth()) / 2, (height - img.getHeight()) / 2, null);
+ g.dispose();
+
+ JFrame win = new JFrame("Blur Demo");
+ win.setDefaultCloseOperation(win.EXIT_ON_CLOSE);
+
+ JPanel main = new JPanel();
+ main.setLayout(new BorderLayout());
+
+ JPanel controls = new JPanel();
+ controls.setLayout(new GridBagLayout());
+
+ GridBagConstraints c0 = new GridBagConstraints();
+ c0.gridx = 0;
+ c0.anchor = GridBagConstraints.BASELINE_LEADING;
+ c0.ipadx = 3;
+ c0.insets = new Insets(1, 2, 1, 2);
+
+ controls.add(new JLabel("Width"), c0);
+ controls.add(new JLabel("Height"), c0);
+
+ GridBagConstraints c2 = (GridBagConstraints) c0.clone();
+ c2.gridx = 2;
+ controls.add(new JLabel("Angle"), c2);
+
+ c0 = (GridBagConstraints) c0.clone();
+ c0.gridx = 1;
+ c0.weightx = 1;
+ c0.fill = GridBagConstraints.HORIZONTAL;
+ sizex = new JSlider(100, 5000, 1000);
+ sizey = new JSlider(100, 5000, 100);
+ controls.add(sizex, c0);
+ controls.add(sizey, c0);
+
+ c2 = (GridBagConstraints) c0.clone();
+ c2.gridx = 3;
+ angle = new JSlider(0, (int) (Math.PI * 1000));
+ controls.add(angle, c2);
+
+ sizex.addChangeListener(this);
+ sizey.addChangeListener(this);
+ angle.addChangeListener(this);
+
+ JPanel buttons = new JPanel();
+ controls.add(buttons, c2);
+ JButton b;
+ b = new JButton("Clear");
+ buttons.add(b);
+ b.addActionListener(new ActionListener() {
+
+ public void actionPerformed(ActionEvent e) {
+ doclear();
+ }
+ });
+ ButtonGroup opt = new ButtonGroup();
+ JToggleButton tb;
+ blurButton = new JToggleButton("Blur");
+ opt.add(blurButton);
+ buttons.add(blurButton);
+ blurButton.addActionListener(this);
+ drawButton = new JToggleButton("Draw");
+ opt.add(drawButton);
+ buttons.add(drawButton);
+ drawButton.addActionListener(this);
+
+ JPanel imgs = new JPanel();
+ imgs.setLayout(new BoxLayout(imgs, BoxLayout.X_AXIS));
+ left = new PaintView(this, psf);
+ right = new ImageView(dst);
+ imgs.add(left);
+ imgs.add(right);
+
+ main.add(controls, BorderLayout.NORTH);
+ main.add(imgs, BorderLayout.CENTER);
+ win.getContentPane().add(main);
+
+ win.pack();
+ win.setVisible(true);
+
+ // pre-load and transform src, since that wont change
+ loadSource(src);
+
+ blurButton.doClick();
+ }
+
+ public void stateChanged(ChangeEvent e) {
+ if (drawButton.isSelected()) {
+ recalc();
+ } else {
+ double w = sizex.getValue() / 100.0;
+ double h = sizey.getValue() / 100.0;
+ double a = angle.getValue() / 1000.0;
+
+ Graphics g = psf.createGraphics();
+
+ g.clearRect(0, 0, width, height);
+ g.dispose();
+
+ left.drawDot(w, h, a);
+ }
+ }
+
+ public void actionPerformed(ActionEvent e) {
+ stateChanged(null);
+ }
+
+ private void doclear() {
+ Graphics g = psf.createGraphics();
+
+ g.clearRect(0, 0, width, height);
+ g.dispose();
+ left.repaint();
+ recalc();
+ }
+
+ private void dorecalc() {
+ loadPSF(psf);
+
+ // convolve each plane in freq domain
+ convolve(aCBuffer, psfBuffer, aGBuffer);
+ convolve(rCBuffer, psfBuffer, rGBuffer);
+ convolve(gCBuffer, psfBuffer, gGBuffer);
+ convolve(bCBuffer, psfBuffer, bGBuffer);
+
+ // convert back to spatial domain
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, aGBuffer, aBuffer, null, null);
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, rGBuffer, rBuffer, null, null);
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, gGBuffer, gBuffer, null, null);
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, bGBuffer, bBuffer, null, null);
+
+ // while gpu is running, calculate energy of psf
+ float scale;
+
+ long total = 0;
+ DataBufferByte pd = (DataBufferByte) psf.getRaster().getDataBuffer();
+ byte[] data = pd.getData();
+ for (int i = 0; i < data.length; i++) {
+ total += data[i] & 0xff;
+ }
+ scale = 255.0f / total / width / height;
+
+ getDestination(argbBuffer, aBuffer, rBuffer, gBuffer, bBuffer, scale);
+
+ // drop back to java, slow-crappy-method
+ q.putReadBuffer(argbBuffer, true);
+ DataBufferInt db = (DataBufferInt) dst.getRaster().getDataBuffer();
+ argbBuffer.getBuffer().position(0);
+ argbBuffer.getBuffer().get(db.getData());
+ argbBuffer.getBuffer().position(0);
+ right.repaint();
+ }
+ Runnable later;
+
+ void recalc() {
+ if (later == null) {
+ later = new Runnable() {
+
+ public void run() {
+ later = null;
+ dorecalc();
+ }
+ };
+ SwingUtilities.invokeLater(later);
+ }
+ }
+ CLContext cl;
+ CLCommandQueue q;
+ CLProgram prog;
+ CLKernel kImg2Planes;
+ CLKernel kPlanes2Img;
+ CLKernel kGrey2Plane;
+ CLKernel kConvolve;
+ CLKernel kDeconvolve;
+ CLFFTPlan fft;
+ CLBuffer<IntBuffer> argbBuffer;
+ CLBuffer<ByteBuffer> greyBuffer;
+ CLBuffer<FloatBuffer> aBuffer;
+ CLBuffer<FloatBuffer> rBuffer;
+ CLBuffer<FloatBuffer> gBuffer;
+ CLBuffer<FloatBuffer> bBuffer;
+ CLBuffer<FloatBuffer> aCBuffer;
+ CLBuffer<FloatBuffer> rCBuffer;
+ CLBuffer<FloatBuffer> gCBuffer;
+ CLBuffer<FloatBuffer> bCBuffer;
+ CLBuffer<FloatBuffer> aGBuffer;
+ CLBuffer<FloatBuffer> rGBuffer;
+ CLBuffer<FloatBuffer> gGBuffer;
+ CLBuffer<FloatBuffer> bGBuffer;
+ CLBuffer<FloatBuffer> psfBuffer;
+ CLBuffer<FloatBuffer> tmpBuffer;
+ //
+ CLKernel fft512;
+
+ void initCL() throws InvalidContextException {
+
+ // search a platform with a GPU
+ CLPlatform[] platforms = CLPlatform.listCLPlatforms();
+ CLDevice gpu = null;
+ for (CLPlatform platform : platforms) {
+ gpu = platform.getMaxFlopsDevice(CLDevice.Type.GPU);
+ if(gpu != null) {
+ break;
+ }
+ }
+
+ cl = CLContext.create(gpu);
+
+ q = cl.getDevices()[0].createCommandQueue();
+
+ prog = cl.createProgram(img2Planes + planes2Img + convolve + grey2Plane + deconvolve);
+ prog.build("-cl-mad-enable");
+
+ kImg2Planes = prog.createCLKernel("img2planes");
+ kPlanes2Img = prog.createCLKernel("planes2img");
+ kGrey2Plane = prog.createCLKernel("grey2plane");
+ kConvolve = prog.createCLKernel("convolve");
+ kDeconvolve = prog.createCLKernel("deconvolve");
+
+ argbBuffer = cl.createIntBuffer(width * height, Mem.READ_WRITE);
+ greyBuffer = cl.createByteBuffer(width * height, Mem.READ_WRITE);
+ aBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ rBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ gBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ bBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ psfBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ tmpBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+
+ aCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ rCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ gCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ bCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+
+ aGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ rGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ gGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ bGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
+ if (false) {
+ try {
+ CLProgram p = cl.createProgram(new FileInputStream("/home/notzed/cl/fft-512.cl"));
+ p.build();
+ fft512 = p.createCLKernel("fft0");
+ } catch (IOException ex) {
+ Logger.getLogger(BlurTest.class.getName()).log(Level.SEVERE, null, ex);
+ }
+ } else {
+ fft = new CLFFTPlan(cl, new int[]{width, height}, CLFFTPlan.CLFFTDataFormat.InterleavedComplexFormat);
+ }
+ //fft.dumpPlan(null);
+ }
+
+ void loadSource(BufferedImage src) {
+ DataBufferInt sb = (DataBufferInt) src.getRaster().getDataBuffer();
+
+ argbBuffer.getBuffer().position(0);
+ argbBuffer.getBuffer().put(sb.getData());
+ argbBuffer.getBuffer().position(0);
+ q.putWriteBuffer(argbBuffer, false);
+
+ kImg2Planes.setArg(0, argbBuffer);
+ kImg2Planes.setArg(1, 0);
+ kImg2Planes.setArg(2, width);
+ kImg2Planes.setArg(3, aBuffer);
+ kImg2Planes.setArg(4, rBuffer);
+ kImg2Planes.setArg(5, gBuffer);
+ kImg2Planes.setArg(6, bBuffer);
+ kImg2Planes.setArg(7, 0);
+ kImg2Planes.setArg(8, width);
+ q.put2DRangeKernel(kImg2Planes, 0, 0, width, height, 64, 1);
+ q.finish();
+
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, aBuffer, aCBuffer, null, null);
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, rBuffer, rCBuffer, null, null);
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, gBuffer, gCBuffer, null, null);
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, bBuffer, bCBuffer, null, null);
+ }
+
+ void loadPSF(BufferedImage psf) {
+ assert (psf.getType() == BufferedImage.TYPE_BYTE_GRAY);
+ DataBufferByte pb = (DataBufferByte) psf.getRaster().getDataBuffer();
+
+ greyBuffer.getBuffer().position(0);
+ greyBuffer.getBuffer().put(pb.getData());
+ greyBuffer.getBuffer().position(0);
+ q.putWriteBuffer(greyBuffer, false);
+
+ kGrey2Plane.setArg(0, greyBuffer);
+ kGrey2Plane.setArg(1, 0);
+ kGrey2Plane.setArg(2, width);
+ kGrey2Plane.setArg(3, tmpBuffer);
+ kGrey2Plane.setArg(4, 0);
+ kGrey2Plane.setArg(5, width);
+ q.put2DRangeKernel(kGrey2Plane, 0, 0, width, height, 64, 1);
+
+ if (true) {
+ fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, tmpBuffer, psfBuffer, null, null);
+ } else if (true) {
+ fft512.setArg(0, tmpBuffer);
+ fft512.setArg(1, psfBuffer);
+ fft512.setArg(2, -1);
+ fft512.setArg(3, height);
+ //q.put1DRangeKernel(fft512, 0,height*64, 64);
+ q.put2DRangeKernel(fft512, 0, 0, height * 64, 1, 64, 1);
+ System.out.println("running kernel " + 64 * height + ", " + 64);
+ }
+ }
+
+ // g = f x h
+ void convolve(CLBuffer<FloatBuffer> h, CLBuffer<FloatBuffer> f, CLBuffer<FloatBuffer> g) {
+ kConvolve.setArg(0, h);
+ kConvolve.setArg(1, f);
+ kConvolve.setArg(2, g);
+ kConvolve.setArg(3, width);
+ q.put2DRangeKernel(kConvolve, 0, 0, width, height, 64, 1);
+ }
+
+ // g = h*conj(f) / (abs(f)^2 + k)
+ void deconvolve(CLBuffer<FloatBuffer> h, CLBuffer<FloatBuffer> f, CLBuffer<FloatBuffer> g, float k) {
+ kDeconvolve.setArg(0, h);
+ kDeconvolve.setArg(1, f);
+ kDeconvolve.setArg(2, g);
+ kDeconvolve.setArg(3, width);
+ kDeconvolve.setArg(4, k);
+ q.put2DRangeKernel(kDeconvolve, 0, 0, width, height, 64, 1);
+ }
+
+ void getDestination(CLBuffer<IntBuffer> dst, CLBuffer<FloatBuffer> a, CLBuffer<FloatBuffer> r, CLBuffer<FloatBuffer> g, CLBuffer<FloatBuffer> b, float scale) {
+ kPlanes2Img.setArg(0, dst);
+ kPlanes2Img.setArg(1, 0);
+ kPlanes2Img.setArg(2, width);
+ kPlanes2Img.setArg(3, a);
+ kPlanes2Img.setArg(4, r);
+ kPlanes2Img.setArg(5, g);
+ kPlanes2Img.setArg(6, b);
+ kPlanes2Img.setArg(7, 0);
+ kPlanes2Img.setArg(8, width);
+ kPlanes2Img.setArg(9, scale);
+ q.put2DRangeKernel(kPlanes2Img, 0, 0, width, height, 64, 1);
+ }
+ // Convert packed ARGB byte image to planes of complex floats
+ final String img2Planes =
+ "kernel void img2planes(global const uchar4 *argb, int soff, int sstride,"
+ + " global float2 *a, global float2 *r, global float2 *g, global float2 *b, int doff, int dstride) {"
+ + " int gx = get_global_id(0);"
+ + " int gy = get_global_id(1);"
+ + " uchar4 v = argb[soff+sstride*gy+gx];"
+ + " float4 ff = convert_float4(v) * (float4)(1.0f/255);"
+ + " doff += (dstride * gy + gx);"
+ + " b[doff] = (float2){ ff.s0, 0 };\n"
+ + " g[doff] = (float2){ ff.s1, 0 };"
+ + " r[doff] = (float2){ ff.s2, 0 };"
+ + " a[doff] = (float2){ ff.s3, 0 };\n"
+ + "}\n\n";
+ // not the best implementation
+ // this also performs an 'fftshift'
+ final String grey2Plane =
+ "kernel void grey2plane(global const uchar *src, int soff, int sstride,"
+ + " global float2 *dst, int doff, int dstride) {"
+ + " int gx = get_global_id(0);"
+ + " int gy = get_global_id(1);"
+ + " uchar v = src[soff+sstride*gy+gx];"
+ + " float ff = convert_float(v) * (1.0f/255);"
+ // fftshift
+ + " gx ^= get_global_size(0)>>1;"
+ + " gy ^= get_global_size(1)>>1;"
+ + " doff += (dstride * gy + gx);"
+ + " dst[doff] = (float2) { ff, 0 };"
+ + "}\n\n";
+ // This also does the 'fftscale' after the inverse fft.
+ final String planes2Img =
+ "kernel void planes2img(global uchar4 *argb, int soff, int sstride, const global float2 *a, const global float2 *r, const global float2 *g, const global float2 *b, int doff, int dstride, float scale) {"
+ + " int gx = get_global_id(0);"
+ + " int gy = get_global_id(1);"
+ + " float4 fr, fi, fa;"
+ + " float2 t;"
+ + " doff += (dstride * gy + gx);"
+ + " float2 s = (float2)scale;"
+ + " t = b[doff]*s; fr.s0 = t.s0; fi.s0 = t.s1;"
+ + " t = g[doff]*s; fr.s1 = t.s0; fi.s1 = t.s1;"
+ + " t = r[doff]*s; fr.s2 = t.s0; fi.s2 = t.s1;"
+ + " t = a[doff]*s; fr.s3 = t.s0; fi.s3 = t.s1;"
+ + " fa = sqrt(fr*fr + fi*fi) * 255;"
+ + " fa = clamp(fa, 0.0f, 255.0f);"
+ + " argb[soff +sstride*gy+gx] = convert_uchar4(fa);"
+ + "}\n\n";
+ final String convolve =
+ "kernel void convolve(global const float2 *h, global const float2 *ff, global float2 *g, int stride) {"
+ + " int gx = get_global_id(0);"
+ + " int gy = get_global_id(1);"
+ + " int off = stride * gy + gx;"
+ + " float2 a = h[off];"
+ + " float2 b = ff[off];"
+ + " g[off] = (float2) { a.s0 * b.s0 - a.s1 * b.s1, a.s0 * b.s1 + a.s1 * b.s0 };"
+ + "}\n\n";
+ final String deconvolve =
+ "kernel void deconvolve(global const float2 *h, global const float2 *ff, global float2 *g, int stride, float k) {"
+ + " int gx = get_global_id(0);"
+ + " int gy = get_global_id(1);"
+ + " int off = stride * gy + gx;"
+ + " float2 a = h[off];"
+ + " float2 b = ff[off];"
+ + " float d = b.s0 * b.s0 + b.s1 * b.s1 + k;"
+ + " b.s0 /= d;"
+ + " b.s1 /= -d;"
+ + " g[off] = (float2) { a.s0 * b.s0 - a.s1 * b.s1, a.s0 * b.s1 + a.s1 * b.s0 };"
+ + "}\n\n";
+}