/*
* JCuda - Java bindings for NVIDIA CUDA driver and runtime API
* http://www.jcuda.org
*
* Copyright 2009-2013 Marco Hutter - http://www.jcuda.org
*/
import static jcuda.driver.CUgraphicsMapResourceFlags.CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD;
import static jcuda.driver.JCudaDriver.*;
import static org.lwjgl.opengl.GL11.*;
import static org.lwjgl.opengl.GL15.*;
import static org.lwjgl.opengl.GL20.*;
import static org.lwjgl.opengl.GL30.*;
import java.awt.*;
import java.awt.event.*;
import java.io.*;
import java.nio.*;
import java.util.Arrays;
import javax.swing.*;
import jcuda.*;
import jcuda.driver.*;
import org.lwjgl.LWJGLException;
import org.lwjgl.opengl.AWTGLCanvas;
/**
* This class demonstrates how to use the JCudaDriver GL bindings API
* to interact with LWJGL from http://www.lwjgl.org . It creates
* a vertex buffer object (VBO) consisting of a rectangular grid of
* points, and animates it with a sine wave.
*
* Pressing the 't' key will toggle between the CUDA computation and
* the Java computation mode.
*
* This sample actually uses the kernel that is created for the
* "Simple OpenGL" sample from the NVIDIA CUDA code samples web site.
*/
public class JCudaDriverLWJGLSample3
{
/**
* Entry point for this sample.
*
* @param args not used
*/
public static void main(String args[])
{
SwingUtilities.invokeLater(new Runnable()
{
public void run()
{
new JCudaDriverLWJGLSample3();
}
});
}
/**
* The source code for the vertex shader
*/
private static String vertexShaderSource =
"#version 150 core" + "\n" +
"in vec4 inVertex;" + "\n" +
"in vec3 inColor;" + "\n" +
"uniform mat4 modelviewMatrix;" + "\n" +
"uniform mat4 projectionMatrix;" + "\n" +
"void main(void)" + "\n" +
"{" + "\n" +
" gl_Position = " + "\n" +
" projectionMatrix * modelviewMatrix * inVertex;" + "\n" +
"}";
/**
* The source code for the fragment shader
*/
private static String fragmentShaderSource =
"#version 150 core" + "\n" +
"out vec4 outColor;" + "\n" +
"void main(void)" + "\n" +
"{" + "\n" +
" outColor = vec4(1.0,0.0,0.0,1.0);" + "\n" +
"}";
/**
* The width segments of the mesh to be displayed.
* Should be a multiple of 8.
*/
private static final int meshWidth = 8 * 64;
/**
* The height segments of the mesh to be displayed
* Should be a multiple of 8.
*/
private static final int meshHeight = 8 * 64;
/**
* The LWJGL canvas
*/
private AWTGLCanvas glComponent;
/**
* The VAO identifier
*/
private int vertexArrayObject;
/**
* The VBO identifier
*/
private int vertexBufferObject;
/**
* The Graphics resource associated with the VBO
*/
private CUgraphicsResource vboGraphicsResource;
/**
* The currently mapped VBO data buffer
*/
private ByteBuffer mappedBuffer;
/**
* The current animation state of the mesh
*/
private float animationState = 0.0f;
/**
* The handle for the CUDA function of the kernel that is to be called
*/
private CUfunction function;
/**
* Whether the computation should be performed with CUDA or
* with Java. May be toggled by pressing the 't' key.
*/
private boolean useCUDA = true;
/**
* The ID of the OpenGL shader program
*/
private int shaderProgramID;
/**
* The translation in X-direction
*/
private float translationX = 0;
/**
* The translation in Y-direction
*/
private float translationY = 0;
/**
* The translation in Z-direction
*/
private float translationZ = -4;
/**
* The rotation about the X-axis, in degrees
*/
private float rotationX = 40;
/**
* The rotation about the Y-axis, in degrees
*/
private float rotationY = 30;
/**
* The current projection matrix
*/
float projectionMatrix[] = new float[16];
/**
* The projection matrix buffer
*/
private FloatBuffer projectionMatrixBuffer = createFloatBuffer(16);
/**
* The current projection matrix
*/
float modelviewMatrix[] = new float[16];
/**
* The modelview matrix buffer
*/
private FloatBuffer modelviewMatrixBuffer = createFloatBuffer(16);
/**
* Step counter for FPS computation
*/
private int step = 0;
/**
* Time stamp for FPS computation
*/
private long prevTimeNS = -1;
/**
* The main frame of the application
*/
private Frame frame;
/**
* Inner class encapsulating the MouseMotionListener and
* MouseWheelListener for the interaction
*/
class MouseControl implements MouseMotionListener, MouseWheelListener
{
private Point previousMousePosition = new Point();
@Override
public void mouseDragged(MouseEvent e)
{
int dx = e.getX() - previousMousePosition.x;
int dy = e.getY() - previousMousePosition.y;
// If the left button is held down, move the object
if ((e.getModifiersEx() & MouseEvent.BUTTON1_DOWN_MASK) ==
MouseEvent.BUTTON1_DOWN_MASK)
{
translationX += dx / 100.0f;
translationY -= dy / 100.0f;
}
// If the right button is held down, rotate the object
else if ((e.getModifiersEx() & MouseEvent.BUTTON3_DOWN_MASK) ==
MouseEvent.BUTTON3_DOWN_MASK)
{
rotationX += dy;
rotationY += dx;
}
previousMousePosition = e.getPoint();
updateModelviewMatrix();
}
@Override
public void mouseMoved(MouseEvent e)
{
previousMousePosition = e.getPoint();
}
@Override
public void mouseWheelMoved(MouseWheelEvent e)
{
// Translate along the Z-axis
translationZ += e.getWheelRotation() * 0.25f;
previousMousePosition = e.getPoint();
updateModelviewMatrix();
}
}
/**
* Inner class extending a KeyAdapter for the keyboard
* interaction
*/
class KeyboardControl extends KeyAdapter
{
public void keyTyped(KeyEvent e)
{
char c = e.getKeyChar();
if (c == 't')
{
useCUDA = !useCUDA;
}
}
}
/**
* Creates a new JCudaDriverLWJGLSample3.
*/
public JCudaDriverLWJGLSample3()
{
// Initialize the GL component
createCanvas();
// Initialize the mouse and keyboard controls
MouseControl mouseControl = new MouseControl();
glComponent.addMouseMotionListener(mouseControl);
glComponent.addMouseWheelListener(mouseControl);
KeyboardControl keyboardControl = new KeyboardControl();
glComponent.addKeyListener(keyboardControl);
updateModelviewMatrix();
// Create the main frame
frame = new JFrame("JCuda / LWJGL interaction sample");
frame.addWindowListener(new WindowAdapter()
{
@Override
public void windowClosing(WindowEvent e)
{
System.exit(0);
}
});
frame.setLayout(new BorderLayout());
glComponent.setPreferredSize(new Dimension(800, 800));
frame.add(glComponent, BorderLayout.CENTER);
frame.pack();
frame.setVisible(true);
glComponent.requestFocus();
}
/**
* Create the AWTGLCanvas
*/
private void createCanvas()
{
try
{
glComponent = new AWTGLCanvas()
{
private boolean initialized = false;
private Dimension previousSize = null;
public void paintGL()
{
if (!initialized)
{
init();
glComponent.setVSyncEnabled(false);
initialized = true;
}
if (previousSize == null || !previousSize.equals(getSize()))
{
previousSize = getSize();
setupView();
}
render();
try
{
swapBuffers();
}
catch (LWJGLException e)
{
throw new RuntimeException(
"Could not swap buffers", e);
}
}
};
}
catch (LWJGLException e)
{
throw new RuntimeException(
"Could not create canvas", e);
}
glComponent.setFocusable(true);
// Create the thread that triggers a repaint of the component
Thread thread = new Thread(new Runnable()
{
@Override
public void run()
{
while (true)
{
glComponent.repaint();
try
{
Thread.sleep(1);
}
catch (InterruptedException e)
{
Thread.currentThread().interrupt();
}
}
}
});
thread.setDaemon(true);
thread.start();
}
/**
* Update the modelview matrix depending on the
* current translation and rotation
*/
private void updateModelviewMatrix()
{
float m0[] = translation(translationX, translationY, translationZ);
float m1[] = rotationX(rotationX);
float m2[] = rotationY(rotationY);
modelviewMatrix = multiply(multiply(m1,m2), m0);
}
/**
* Called to initialize the drawing and OpenCL
*/
public void init()
{
// Perform the default GL initialization
glEnable(GL_DEPTH_TEST);
glClearColor(0.0f, 0.0f, 0.0f, 1.0f);
// Initialize the shaders
initShaders();
// Set up the viewport and projection matrix
setupView();
// Initialize JCuda
initJCuda();
// Initialize the OpenGL VBO and the OpenCL VBO memory object
initVBO();
}
/**
* Initialize the shaders and the shader program
*
* @param gl The GL context
*/
private void initShaders()
{
shaderProgramID = glCreateProgram();
int vertexShaderID = glCreateShader(GL_VERTEX_SHADER);
glShaderSource(vertexShaderID, toByteBuffer(vertexShaderSource));
glCompileShader(vertexShaderID);
glAttachShader(shaderProgramID, vertexShaderID);
glDeleteShader(vertexShaderID);
int fragmentShaderID = glCreateShader(GL_FRAGMENT_SHADER);
glShaderSource(fragmentShaderID, toByteBuffer(fragmentShaderSource));
glCompileShader(fragmentShaderID);
glAttachShader(shaderProgramID, fragmentShaderID);
glDeleteShader(fragmentShaderID);
glLinkProgram(shaderProgramID);
glValidateProgram(shaderProgramID);
}
/**
* Initialize the JCudaDriver. Note that this has to be done from the
* same thread that will later use the JCudaDriver API
*/
private void initJCuda()
{
JCudaDriver.setExceptionsEnabled(true);
// Create a device and a context
cuInit(0);
CUdevice dev = new CUdevice();
cuDeviceGet(dev, 0);
CUcontext glCtx = new CUcontext();
cuGLCtxCreate(glCtx, 0, dev);
// Prepare the PTX file containing the kernel
String ptxFileName = "";
try
{
ptxFileName = preparePtxFile("simpleGL_kernel.cu");
}
catch (IOException e)
{
System.err.println("Could not create PTX file");
throw new RuntimeException("Could not create PTX file", e);
}
// Load the PTX file containing the kernel
CUmodule module = new CUmodule();
cuModuleLoad(module, ptxFileName);
// Obtain a function pointer to the kernel function. This function
// will later be called during the animation, in the display
// method of this GLEventListener.
function = new CUfunction();
cuModuleGetFunction(function, module,
"_Z6kernelP6float4jjf");
}
/**
* Create the vertex buffer object (VBO) that stores the
* vertex positions.
*/
private void initVBO()
{
// Create the vertex buffer object
vertexArrayObject = glGenVertexArrays();
glBindVertexArray(vertexArrayObject);
// Create the vertex buffer object
vertexBufferObject = glGenBuffers();
// Initialize the vertex buffer object
glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);
int size = meshWidth * meshHeight * 4 * Sizeof.FLOAT;
glBufferData(GL_ARRAY_BUFFER, size, GL_DYNAMIC_DRAW);
// Initialize the attribute location of the input
// vertices for the shader program
int location = glGetAttribLocation(shaderProgramID, "inVertex");
glVertexAttribPointer(location, 4, GL_FLOAT, false, 0, 0);
glEnableVertexAttribArray(location);
// Register the vertexBufferObject for use with CUDA
vboGraphicsResource = new CUgraphicsResource();
cuGraphicsGLRegisterBuffer(
vboGraphicsResource, vertexBufferObject,
CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD);
}
/**
* Set up a default view
*/
private void setupView()
{
glViewport(0, 0, glComponent.getWidth(), glComponent.getHeight());
float aspect = (float) glComponent.getWidth() / glComponent.getHeight();
projectionMatrix = perspective(50, aspect, 0.1f, 100.0f);
}
/**
* Called when the canvas is to be displayed.
*/
public void render()
{
if (useCUDA)
{
// Run the CUDA kernel to generate new vertex positions.
runCuda();
}
else
{
// Run the Java method to generate new vertex positions.
runJava();
}
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
// Activate the shader program
glUseProgram(shaderProgramID);
// Set the current projection matrix
int projectionMatrixLocation =
glGetUniformLocation(shaderProgramID, "projectionMatrix");
projectionMatrixBuffer.rewind();
projectionMatrixBuffer.put(projectionMatrix);
projectionMatrixBuffer.rewind();
glUniformMatrix4(
projectionMatrixLocation, false, projectionMatrixBuffer);
// Set the current modelview matrix
int modelviewMatrixLocation =
glGetUniformLocation(shaderProgramID, "modelviewMatrix");
modelviewMatrixBuffer.rewind();
modelviewMatrixBuffer.put(modelviewMatrix);
modelviewMatrixBuffer.rewind();
glUniformMatrix4(
modelviewMatrixLocation, false, modelviewMatrixBuffer);
// Render the VBO
glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);
glDrawArrays(GL_POINTS, 0, meshWidth * meshHeight);
// Update FPS information in main frame title
step++;
long currentTime = System.nanoTime();
if (prevTimeNS == -1)
{
prevTimeNS = currentTime;
}
long diff = currentTime - prevTimeNS;
if (diff > 1e9)
{
double fps = (diff / 1e9) * step;
String t = "JCuda / LWJGL interaction sample - ";
t += useCUDA?"JCuda":"Java";
t += " mode: "+String.format("%.2f", fps)+" FPS";
frame.setTitle(t);
prevTimeNS = currentTime;
step = 0;
}
animationState += 0.01;
}
/**
* Run the CUDA computation to create new vertex positions
* inside the vertexBufferObject.
*/
private void runCuda()
{
// Map the vertexBufferObject for writing from CUDA.
// The basePointer will afterwards point to the
// beginning of the memory area of the VBO.
CUdeviceptr basePointer = new CUdeviceptr();
cuGraphicsMapResources(
1, new CUgraphicsResource[]{vboGraphicsResource}, null);
cuGraphicsResourceGetMappedPointer(
basePointer, new long[1], vboGraphicsResource);
// Set up the kernel parameters: A pointer to an array
// of pointers which point to the actual values. One
// pointer to the base pointer of the geometry data,
// one int for the mesh width, one int for the mesh
// height, and one float for the current animation state.
Pointer kernelParameters = Pointer.to(
Pointer.to(basePointer),
Pointer.to(new int[]{meshWidth}),
Pointer.to(new int[]{meshHeight}),
Pointer.to(new float[]{animationState})
);
// Call the kernel function.
int blockX = 8;
int blockY = 8;
int gridX = meshWidth / blockX;
int gridY = meshHeight / blockY;
cuLaunchKernel(function,
gridX, gridY, 1, // Grid dimension
blockX, blockY, 1, // Block dimension
0, null, // Shared memory size and stream
kernelParameters, null // Kernel- and extra parameters
);
cuCtxSynchronize();
// Unmap buffer object
cuGraphicsUnmapResources(
1, new CUgraphicsResource[]{vboGraphicsResource}, null);
}
/**
* Run the Java computation to create new vertex positions
* inside the vertexBufferObject.
*/
private void runJava()
{
glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);
mappedBuffer =
glMapBuffer(GL_ARRAY_BUFFER, GL_READ_WRITE, mappedBuffer);
FloatBuffer vertices =
mappedBuffer.order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int x = 0; x < meshWidth; x++)
{
for (int y = 0; y < meshHeight; y++)
{
// Calculate u/v coordinates
float u = x / (float) meshWidth;
float v = y / (float) meshHeight;
u = u * 2.0f - 1.0f;
v = v * 2.0f - 1.0f;
// Calculate simple sine wave pattern
float freq = 4.0f;
float w = (float) Math.sin(u * freq + animationState) *
(float) Math.cos(v * freq + animationState) * 0.5f;
// Write output vertex
int index = 4 * (y * meshWidth + x);
vertices.put(index + 0, u);
vertices.put(index + 1, w);
vertices.put(index + 2, v);
vertices.put(index + 3, 1);
}
}
glUnmapBuffer(GL_ARRAY_BUFFER);
glBindBuffer(GL_ARRAY_BUFFER, 0);
}
/**
* The extension of the given file name is replaced with "ptx".
* If the file with the resulting name does not exist, it is
* compiled from the given file using NVCC. The name of the
* PTX file is returned.
*
* @param cuFileName The name of the .CU file
* @return The name of the PTX file
* @throws IOException If an I/O error occurs
*/
private static String preparePtxFile(String cuFileName) throws IOException
{
int endIndex = cuFileName.lastIndexOf('.');
if (endIndex == -1)
{
endIndex = cuFileName.length()-1;
}
String ptxFileName = cuFileName.substring(0, endIndex+1)+"ptx";
File ptxFile = new File(ptxFileName);
if (ptxFile.exists())
{
return ptxFileName;
}
File cuFile = new File(cuFileName);
if (!cuFile.exists())
{
throw new IOException("Input file not found: "+cuFileName);
}
String modelString = "-m"+System.getProperty("sun.arch.data.model");
String command =
"nvcc " + modelString + " -ptx "+
cuFile.getPath()+" -o "+ptxFileName;
System.out.println("Executing\n"+command);
Process process = Runtime.getRuntime().exec(command);
String errorMessage =
new String(toByteArray(process.getErrorStream()));
String outputMessage =
new String(toByteArray(process.getInputStream()));
int exitValue = 0;
try
{
exitValue = process.waitFor();
}
catch (InterruptedException e)
{
Thread.currentThread().interrupt();
throw new IOException(
"Interrupted while waiting for nvcc output", e);
}
if (exitValue != 0)
{
System.out.println("nvcc process exitValue "+exitValue);
System.out.println("errorMessage:\n"+errorMessage);
System.out.println("outputMessage:\n"+outputMessage);
throw new IOException(
"Could not create .ptx file: "+errorMessage);
}
System.out.println("Finished creating PTX file");
return ptxFileName;
}
/**
* Fully reads the given InputStream and returns it as a byte array
*
* @param inputStream The input stream to read
* @return The byte array containing the data from the input stream
* @throws IOException If an I/O error occurs
*/
private static byte[] toByteArray(InputStream inputStream)
throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte buffer[] = new byte[8192];
while (true)
{
int read = inputStream.read(buffer);
if (read == -1)
{
break;
}
baos.write(buffer, 0, read);
}
return baos.toByteArray();
}
//=== Helper functions for buffers ========================================
/**
* Creates a direct buffer with the given number of elements and
* native order
*
* @param size The number of elements
* @return The buffer
*/
private static FloatBuffer createFloatBuffer(int size)
{
return ByteBuffer.allocateDirect(size * 4).
order(ByteOrder.nativeOrder()).asFloatBuffer();
}
/**
* Creates a direct buffer with the given number of elements and
* native order
*
* @param size The number of elements
* @return The buffer
*/
private static ByteBuffer createByteBuffer(int size)
{
return ByteBuffer.allocateDirect(size).
order(ByteOrder.nativeOrder());
}
/**
* Converts the given String into a native ByteBuffer
*
* @param s The string to convert
* @return The buffer
*/
private static ByteBuffer toByteBuffer(String s)
{
byte bytes[] = s.getBytes();
ByteBuffer buffer = createByteBuffer(bytes.length+1);
buffer.put(bytes);
buffer.put((byte)0);
buffer.rewind();
return buffer;
}
//=== Helper functions for matrix operations ==============================
/**
* Helper method that creates a perspective matrix
* @param fovy The fov in y-direction, in degrees
*
* @param aspect The aspect ratio
* @param zNear The near clipping plane
* @param zFar The far clipping plane
* @return A perspective matrix
*/
private static float[] perspective(
float fovy, float aspect, float zNear, float zFar)
{
float radians = (float)Math.toRadians(fovy / 2);
float deltaZ = zFar - zNear;
float sine = (float)Math.sin(radians);
if ((deltaZ == 0) || (sine == 0) || (aspect == 0))
{
return identity();
}
float cotangent = (float)Math.cos(radians) / sine;
float m[] = identity();
m[0*4+0] = cotangent / aspect;
m[1*4+1] = cotangent;
m[2*4+2] = -(zFar + zNear) / deltaZ;
m[2*4+3] = -1;
m[3*4+2] = -2 * zNear * zFar / deltaZ;
m[3*4+3] = 0;
return m;
}
/**
* Creates an identity matrix
*
* @return An identity matrix
*/
private static float[] identity()
{
float m[] = new float[16];
Arrays.fill(m, 0);
m[0] = m[5] = m[10] = m[15] = 1.0f;
return m;
}
/**
* Multiplies the given matrices and returns the result
*
* @param m0 The first matrix
* @param m1 The second matrix
* @return The product m0*m1
*/
private static float[] multiply(float m0[], float m1[])
{
float m[] = new float[16];
for (int x=0; x < 4; x++)
{
for(int y=0; y < 4; y++)
{
m[x*4 + y] =
m0[x*4+0] * m1[y+ 0] +
m0[x*4+1] * m1[y+ 4] +
m0[x*4+2] * m1[y+ 8] +
m0[x*4+3] * m1[y+12];
}
}
return m;
}
/**
* Creates a translation matrix
*
* @param x The x translation
* @param y The y translation
* @param z The z translation
* @return A translation matrix
*/
private static float[] translation(float x, float y, float z)
{
float m[] = identity();
m[12] = x;
m[13] = y;
m[14] = z;
return m;
}
/**
* Creates a matrix describing a rotation around the x-axis
*
* @param angleDeg The rotation angle, in degrees
* @return The rotation matrix
*/
private static float[] rotationX(float angleDeg)
{
float m[] = identity();
float angleRad = (float)Math.toRadians(angleDeg);
float ca = (float)Math.cos(angleRad);
float sa = (float)Math.sin(angleRad);
m[ 5] = ca;
m[ 6] = sa;
m[ 9] = -sa;
m[10] = ca;
return m;
}
/**
* Creates a matrix describing a rotation around the y-axis
*
* @param angleDeg The rotation angle, in degrees
* @return The rotation matrix
*/
private static float[] rotationY(float angleDeg)
{
float m[] = identity();
float angleRad = (float)Math.toRadians(angleDeg);
float ca = (float)Math.cos(angleRad);
float sa = (float)Math.sin(angleRad);
m[ 0] = ca;
m[ 2] = -sa;
m[ 8] = sa;
m[10] = ca;
return m;
}
}