/* -*- mode:c;c-file-style:"k&r";c-basic-offset:2;tab-width:2;indent-tabs-mode:nil;rm-trailing-spaces:t -*- */

#include <math.h>
#include "mex.h"
#include "UGM_common.h"
#include "DPMP_common.h"

/* getTRWLogBeliefs : Compute TRW log-beliefs. */
void
getLogBeliefs(
  double * nodeBel,
  const double * msg,
  const double * nodePot,
  const double * edgePot,
  const int nNodes,
  const int nEdges,
  const int maxState,
  const int * nStates, 
  const int * V,
  const int * E,
  const int * edgeEnds
  )
{
  int e, e2, n, s, s1, s2, Vind, n1, n2;

  /* allocate stuff */
  double *tmp1, *tmp2;
  tmp1 = mxMalloc(maxState*sizeof(double));
  tmp2 = mxMalloc(maxState*sizeof(double));
  
  /* compute nodeBel */
  for(n = 0; n < nNodes; n++) {
    for(s = 0; s < nStates[n]; s++)
      nodeBel[n + nNodes*s] = nodePot[n+nNodes*s];
        
    for(Vind = V[n]-1; Vind < V[n+1]-1; Vind++) {
      e = E[Vind]-1;
      n1 = edgeEnds[e]-1;
      n2 = edgeEnds[e+nEdges]-1;
            
      if (n == n2) {
        for(s = 0; s < nStates[n]; s++) {
          nodeBel[n + nNodes*s] += msg[s+maxState*e];
        }
      }
      else {
        for(s = 0; s < nStates[n]; s++) {
          nodeBel[n + nNodes*s] += msg[s+maxState*(e+nEdges)];
        }
      }
    }    
  }

  /* free stuff */
  mxFree(tmp1);
  mxFree(tmp2);
}



/* DPMP_Infer_MaxSum : Run Max-Sum inference. */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {

#define MSG plhs[0]
#define ITER plhs[1]
#define LOGP plhs[2]
  
  /* Variables */
  int n, s, e, e2, n1, n2, neigh, Vind, Vind2, s1, s2, 
    dims[3], maxIter, nNbrs, *edgeEnds, *nStates, *V, *E, *y, *iter, *sched,
    *mapState, schedIdx, eDirIdx;

  long long maxState, nNodes, nEdges;

  double *nodePot, *edgePot, *msg, z, msg_diff, stepsize, damp, *prodMsgs,
    *oldMsgs, *tmp, convTol, *logP, *nodeBel;

  /* Check Input */
  if (!mxIsClass(prhs[2], "int32")) mexErrMsgTxt("edgeEnds must be int32");
  if (!mxIsClass(prhs[3], "int32")) mexErrMsgTxt("nStates must be int32");
  if (!mxIsClass(prhs[4], "int32")) mexErrMsgTxt("V must be int32");
  if (!mxIsClass(prhs[5], "int32")) mexErrMsgTxt("E must be int32");
  if (!mxIsClass(prhs[6], "int32")) mexErrMsgTxt("maxIter must be int32");
  if (!mxIsClass(prhs[9], "int32")) mexErrMsgTxt("sched must be int32");

  /* Input */
  nodePot = mxGetPr(prhs[0]);
  edgePot = mxGetPr(prhs[1]);
  edgeEnds = (int*) mxGetPr(prhs[2]);
  nStates = (int*) mxGetPr(prhs[3]);
  V = (int*) mxGetPr(prhs[4]);
  E = (int*) mxGetPr(prhs[5]);
  maxIter = ((int*) mxGetPr(prhs[6]))[0];
  convTol = (double) mxGetPr(prhs[7])[0];
  stepsize = *mxGetPr(prhs[8]);
  sched = (int*) mxGetPr(prhs[9]);

  /* Compute Sizes */
  nNodes = mxGetDimensions(prhs[0])[0];
  maxState = mxGetDimensions(prhs[0])[1];
  nEdges = mxGetDimensions(prhs[2])[0];

  /* Output */
  MSG = mxCreateDoubleMatrix(maxState, nEdges*2, mxREAL);
  msg = mxGetPr(MSG);
  ITER = mxCreateNumericMatrix(1,1,mxINT32_CLASS,mxREAL);
  iter = (int*) mxGetPr(ITER);
  LOGP = mxCreateDoubleMatrix(maxIter, 1, mxREAL);
  logP = mxGetPr(LOGP);

  /* Allocate Memory */
  prodMsgs = mxCalloc(maxState * nNodes, sizeof(double));
  oldMsgs = mxCalloc(maxState * nEdges * 2, sizeof(double));
  tmp = mxCalloc(maxState, sizeof(double));
  nodeBel = mxMalloc(maxState * nNodes * sizeof(double));
  mapState = mxCalloc(nNodes, sizeof(int));
  
  /* Run Loopy BP Iterations */
  for (*iter = 0; *iter < maxIter; (*iter)++) {
    damp = (*iter == 0 ) ? 1.0 : stepsize;

    /* process message schedule */
    for (schedIdx=0; schedIdx<2*nEdges; ++schedIdx) {
      eDirIdx = sched[schedIdx] - 1;

      /* get edge/node indicies */
      if (eDirIdx < nEdges) {
        e = eDirIdx;
        n = edgeEnds[e] - 1;
      } else {
        e = eDirIdx - nEdges;
        n = edgeEnds[e + nEdges] - 1;
      }
      n1 = edgeEnds[e] - 1;
      n2 = edgeEnds[e + nEdges] - 1;

      /* First part of message is nodePot*/
      for (s = 0; s < nStates[n]; s++)
        tmp[s] = nodePot[n + nNodes * s];

      /* Add incoming messages from neighbors except j */
      for (Vind2 = V[n] - 1; Vind2 < V[n + 1] - 1; Vind2++) {
        e2 = E[Vind2] - 1;
        if (e != e2) {
          if (n == edgeEnds[e2 + nEdges] - 1) {
            for (s = 0; s < nStates[n]; s++) {
              tmp[s] = tmp[s] + msg[s + maxState * e2];
            }
          } else {
            for (s = 0; s < nStates[n]; s++) {
              tmp[s] = tmp[s] + msg[s + maxState * (e2 + nEdges)];
            }
          }
        }
      }

      /* Multiply (Add) in edge potential and Marginalize (Maximize) */
      /* TODO: BLAS call to speed this bit up. */
      double this_preMessage, this_edgePot, inner_terms;
      unsigned int idxMsgMat;
      if (n == n2) {
        for (s1 = 0; s1 < nStates[n1]; s1++) {
          msg[s1 + maxState * (e + nEdges)] = -INFINITY;
            
          /* compute message */
          for (s2 = 0; s2 < nStates[n2]; s2++) {
            this_preMessage = tmp[s2];
            this_edgePot = edgePot[s1 + maxState * (s2 + maxState * e)];
            inner_terms = this_preMessage + this_edgePot;
            msg[s1 + maxState * (e + nEdges)] = fmax(msg[s1 + maxState * (e + nEdges)], inner_terms);
          }

	    
        }

        /* Normalize & Damp */
        z = maxarray(&msg[maxState * (e + nEdges)], nStates[n1], nStates[n1], 1);
        for (s = 0; s < nStates[n1]; s++) {
          msg[s + maxState * (e + nEdges)] = msg[s + maxState * (e + nEdges)] - z;
          msg[s + maxState * (e + nEdges)] = damp * msg[s + maxState * (e + nEdges)]
            + (1-damp) * oldMsgs[s + maxState * (e + nEdges)];
        }

        /* else, opposite end, do same stuff w/ different indexing */
      } else {
        for (s2 = 0; s2 < nStates[n2]; s2++) {
          msg[s2 + maxState * e] = -INFINITY;
            
          for (s1 = 0; s1 < nStates[n1]; s1++) {
            this_preMessage = tmp[s1];
            this_edgePot = edgePot[s1 + maxState * (s2 + maxState * e)];
            inner_terms = this_preMessage + this_edgePot;
            msg[s2 + maxState * e] = fmax(msg[s2 + maxState * e], inner_terms);
          }
        }

        /* Normalize & Damp */
        z = maxarray(&msg[maxState * e], nStates[n2], nStates[n2], 1);
        for (s = 0; s < nStates[n2]; s++) {
          msg[s + maxState * e] = msg[s + maxState * e] - z;
          msg[s + maxState * e] = damp * msg[s + maxState * e] +
            (1-damp) * oldMsgs[s + maxState * e];
        }	  
      }
    }

    /* check convergence */
    msg_diff = -INFINITY;
    double this_diff;
    for (s = 0; s < maxState; s++) {
      for (e = 0; e < nEdges * 2; e++) {
        this_diff = absDif(exp(msg[s + maxState * e]), exp(oldMsgs[s + maxState * e]));
        msg_diff = fmax(msg_diff, this_diff);
        oldMsgs[s + maxState * e] = msg[s + maxState * e];
      }      
    }    
    if (msg_diff < convTol)
      break;

    /* compute log-probability */
    getLogBeliefs(nodeBel, msg, nodePot, edgePot, nNodes, nEdges, maxState,
                  nStates, V, E, edgeEnds);
    getMAPlabel(mapState, nodeBel, nNodes, nStates);
    logP[*iter] = getLabelProb(mapState, nodePot, edgePot, edgeEnds, nNodes,
                               maxState, nEdges);
  }
  
  /* Free memory */
  mxFree(prodMsgs);
  mxFree(oldMsgs);
  mxFree(tmp);
  mxFree(nodeBel);
  
  /* mexPrintf("msgDiff: %e\n", msg_diff); */
}
