#!perl

use strict;
use warnings;
use OpenGL::Modern qw(
  glewCreateContext glewInit glpSetAutoCheckErrors glewDestroyContext
  glpErrorString glGetError glGetString
  glGenTextures_p glBindTexture glDeleteTextures_p
  glTexImage2D_c
  glGetIntegerv_p
  glCreateShader glDeleteShader glShaderSource_p glCompileShader
  glAttachShader glDetachShader
  glGetShaderiv_p glGetShaderInfoLog_p
  glCreateProgram glDeleteProgram glLinkProgram glUseProgram
  glGetProgramiv_p glGetProgramInfoLog_p
  glGenFramebuffers_p glBindFramebuffer glDeleteFramebuffers_p
  glFramebufferTexture glCheckFramebufferStatus
  glTransformFeedbackVaryings_p
  glGenBuffers_p glDeleteBuffers_p glBindBuffer glBufferData_c
  glGetBufferSubData_c
  glGenVertexArrays_p glDeleteVertexArrays_p glBindVertexArray
  glEnableVertexAttribArray glDisableVertexAttribArray
  glVertexAttribPointer_c glGetAttribLocation
  glEnable glBindBufferBase
  glBeginTransformFeedback glEndTransformFeedback
  glDrawArrays glFinish
  GL_FRAMEBUFFER GL_COLOR_ATTACHMENT0 GL_FRAMEBUFFER_COMPLETE
  GL_TEXTURE_2D GL_R32F GL_RED GL_FLOAT
  GL_VERSION GLEW_OK
  GL_COMPILE_STATUS GL_LINK_STATUS GL_FALSE
  GL_INTERLEAVED_ATTRIBS
  GL_VERTEX_SHADER GL_FRAGMENT_SHADER
  GL_ARRAY_BUFFER GL_STREAM_DRAW GL_STREAM_READ GL_POINTS
  GL_RASTERIZER_DISCARD GL_TRANSFORM_FEEDBACK_BUFFER
);
use PDL;

sub with_time (&$) {
  require Time::HiRes;
  my @t = Time::HiRes::gettimeofday();
  my $ret = &{$_[0]}();
  printf "$_[1]: %g ms\n", Time::HiRes::tv_interval(\@t) * 1000;
  $ret;
}

sub binescalate (&$) {
  my ($f, $val) = @_;
  return undef if !$f->($val);
  while (1) {
    last if !$f->(my $next = $val * 2);
    $val = $next;
  }
  $val;
}

sub binsearch (&$$$) {
  my ($f, $low, $high, $eps) = @_;
  my ($low_good, $high_good) = map !!$f->($_), $low, $high;
  return undef if !!$low_good == !!$high_good;
  while (1) {
    return $high_good ? $high : $low if (my $diff = $high - $low) <= $eps;
    my $mid = $low + ($diff / 2);
    my $mid_good = !!$f->($mid);
    (($mid_good ? $high_good : $low_good) ? ($high, $high_good) : ($low, $low_good))
       = ($mid, $mid_good);
  }
}

print "Perl $^V OpenGL::Modern $OpenGL::Modern::VERSION PDL $PDL::VERSION\n";

# 3.3 core so MacOS allows >2.1
glewCreateContext(3, 3, 1, 2) == GLEW_OK or die "glewCreateContext failed";
glewInit() == GLEW_OK or die "glewInit failed";
glpSetAutoCheckErrors(1);
print "OpenGL ", glGetString(GL_VERSION), "\n";
sub buffer_alloc {
  my $buffer = glGenBuffers_p(1);
  glBindBuffer(GL_ARRAY_BUFFER, $buffer);
  my $pdl = zeroes(byte, $_[0]);
  eval { glBufferData_c(GL_ARRAY_BUFFER, $pdl->nbytes, $pdl->make_physical->address_data, GL_STREAM_DRAW) };
  my $ok = !$@;
  glBindBuffer(GL_ARRAY_BUFFER, 0);
  glDeleteBuffers_p($buffer);
  $ok;
}
my $max_buffer = int(0.5 + (binsearch(\&buffer_alloc, 1, 1e7, 1)//1e7));
print "Max GPU buffer size, binsearched = $max_buffer\n";

my $vertex_shader = <<'EOF';
#version 330
precision highp float;
in float invalue;
out float outvalue;

void main() {
  outvalue = pow(invalue, 2);
}
EOF

my $program = compile_program($vertex_shader, undef, sub {
  glTransformFeedbackVaryings_p($_[0], GL_INTERLEAVED_ATTRIBS, "outvalue");
});
my $vao = glGenVertexArrays_p(1);
glBindVertexArray($vao);

my ($xdim, $ydim) = (0.5 * int sqrt $max_buffer) x 2;
print "ndarray dim = $xdim\n";
my $p = sequence(float, $xdim, $ydim);
my ($skip0, $skip1) = map int(($_-1) / 2), $xdim, $ydim;
my $slicearg = join ',', map '::'.$_, $skip0, $skip1;
print "Source data: ", $p->slice($slicearg);
my $p_cpu_squared = with_time { $p ** 2 } 'square CPU';
print "Squared on CPU: ", $p_cpu_squared->slice($slicearg);

my ($type, $internalformat, $format) = (GL_FLOAT, GL_R32F, GL_RED);
my $input_buffer = glGenBuffers_p(1);
with_time {
my $input = $p;
glBindBuffer(GL_ARRAY_BUFFER, $input_buffer);
glBufferData_c(GL_ARRAY_BUFFER, $input->nbytes, $input->make_physical->address_data, GL_STREAM_DRAW);
glBindBuffer(GL_ARRAY_BUFFER, 0);
} 'setup src buffer';

my ($destBufferID) = glGenBuffers_p(1);
glBindBuffer(GL_ARRAY_BUFFER, $destBufferID);
glBufferData_c(GL_ARRAY_BUFFER, $p->nbytes, 0, GL_STREAM_READ);
glBindBuffer(GL_ARRAY_BUFFER, 0);

my ($destTextureID) = glGenTextures_p(1);
glBindTexture(GL_TEXTURE_2D, $destTextureID);
glTexImage2D_c(GL_TEXTURE_2D, 0, $internalformat, 1, 1,
  0, $format, $type, 0);
glBindTexture(GL_TEXTURE_2D, 0);
my ($fbo_id) = glGenFramebuffers_p(1);
glBindFramebuffer(GL_FRAMEBUFFER, $fbo_id);
glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, $destTextureID, 0);
my $fbstat = glCheckFramebufferStatus(GL_FRAMEBUFFER);
die "FBO Status error: " . glpErrorString(glGetError()) if !$fbstat;
die "FBO Status: ".OpenGL::Modern::enum2name('FramebufferStatus', $fbstat)
  if $fbstat != GL_FRAMEBUFFER_COMPLETE;

with_time {
glBindBuffer(GL_ARRAY_BUFFER, $input_buffer);
my $inputAttrib = glGetAttribLocation($program, "invalue");
glEnableVertexAttribArray($inputAttrib);
glVertexAttribPointer_c($inputAttrib, 1, GL_FLOAT, GL_FALSE, 0, 0);
glUseProgram($program);
glEnable(GL_RASTERIZER_DISCARD);
glBindBufferBase(GL_TRANSFORM_FEEDBACK_BUFFER, 0, $destBufferID);
glBeginTransformFeedback(GL_POINTS);
glDrawArrays(GL_POINTS, 0, $p->nelem);
glEndTransformFeedback();
glUseProgram(0);
glDisableVertexAttribArray(0);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glFinish();
} 'render' for 1..10;

my $p2 = zeroes(float, $xdim, $ydim);
with_time {
glGetBufferSubData_c(GL_TRANSFORM_FEEDBACK_BUFFER, 0, $p2->nbytes, $p2->address_data);
} 'copy dest to CPU';
print "From GPU: ", $p2->slice($slicearg);

END {
glBindTexture(GL_TEXTURE_2D, 0);
glDeleteTextures_p($_) for grep $_, $destTextureID;
glBindFramebuffer(GL_FRAMEBUFFER, 0);
glDeleteFramebuffers_p($_) for grep $_, $fbo_id;
glBindVertexArray(0);
glDeleteVertexArrays_p($_) for grep $_, $vao;
glUseProgram(0);
glDeleteProgram($_) for grep $_, $program;
glBindBuffer(GL_ARRAY_BUFFER, 0);
glDeleteBuffers_p($_) for grep $_, $input_buffer;
glewDestroyContext();
}

sub compile_shader {
  my ($type, $src) = @_;
  my $shader = glCreateShader($type);
  glShaderSource_p($shader, $src);
  glCompileShader($shader);
  my $status = glGetShaderiv_p($shader, GL_COMPILE_STATUS);
  if ($status == GL_FALSE) {
    my $str = sprintf("%s shader compilation failed!\n",
        $type == GL_VERTEX_SHADER ? "Vertex" : "Fragment");
    $str .= glGetShaderInfoLog_p($shader);
    glDeleteShader($shader);
    die $str;
  }
  $shader;
}

sub compile_program {
  my ($vsrc, $fsrc, $prelink) = @_;
  my $vShader = compile_shader(GL_VERTEX_SHADER, $vsrc);
  my $fShader = $fsrc && eval { compile_shader(GL_FRAGMENT_SHADER, $fsrc) };
  if (my $err = $@) {
    glDeleteShader($vShader);
    die $err;
  }
  my $program = glCreateProgram();
  glAttachShader($program, $vShader);
  glAttachShader($program, $fShader) if $fsrc;
  $prelink->($program) if $prelink;
  glLinkProgram($program);
  my $status = glGetProgramiv_p($program, GL_LINK_STATUS);
  glDetachShader($program, $vShader);
  glDetachShader($program, $fShader) if $fsrc;
  glDeleteShader($vShader);
  glDeleteShader($fShader) if $fsrc;
  if ($status == GL_FALSE) {
    my $str = "Program linker failed.\n";
    $str .= glGetProgramInfoLog_p($program);
    glDeleteProgram($program);
    die $str;
  }
  $program;
}
