function [nodeBel, x, stats, edgeBel] = DPMP_Infer_GPMP( x0, edgeStruct, funHandle, opt, sched, rho )
% DPMP_INFER_GPMP - Greedy Particle Max-Product inference.
%
% REFERENCES:
% Peng, J. and Hazan, T. and McAllester, D. and Urtasun, R.,
%   "Convex max-product algorithms for continuous {MRF}s with applications to
%   protein folding", ICML 2011
%
% J. Pacheco 2014
%

  % unpack stuff
  nParticles = double(edgeStruct.nStates);
  [ maxIters, maxItersLBP, verbose, display_iter, stepsize, doTieRes ] = deal( ...
    opt.maxIters, opt.maxItersLBP, opt.verbose, opt.display_iter, ...
    opt.stepsize, opt.doTieRes );
  x = x0;    
      
  % init stats
  stats.logPall = []; 
  stats.logP = [];
  if strcmpi(opt.msgPass,'trw'), stats.logPbound = [];  end
  stats.iters = [];  
  stats.unique = [];
  stats.tWall = [];
      
  % Init display
  if display_iter
    h = figure('InvertHardcopy','off','Color',[1 1 1]);
  end
  
  % Main Loop
  for gpmp_iters=1:maxIters
    if verbose, fprintf('Iter %d: \n', gpmp_iters); end
    t_start_pmp = tic();
    
    % resample particles
    if gpmp_iters>1
      x_best = DPMP_getLabelParticles( xMAP, x );
      x = zeros(size(x));
      x(:, 1, :) = x_best;
      x(:,2:end,:) = funHandle.proposal(x(:, 1, :), nParticles-1, edgeStruct);   
    end
    [ nodePot, edgePot ] = funHandle.funEvalModel( x, edgeStruct );

    % run Max-Sum
    t_start = tic;
    switch lower(opt.msgPass)
      case 'lbp'
        if verbose, fprintf('\tRunning LBP...'); end        
        if edgeStruct.useMex
          [msg, lbp_iters, logPall] = DPMP_MaxSumC(...
            nodePot,edgePot,int32(edgeStruct.edgeEnds),int32(edgeStruct.nStates),...
            int32(edgeStruct.V),int32(edgeStruct.E),int32(edgeStruct.maxIter),...
            edgeStruct.convTol,stepsize,int32(sched));    
          logPall = logPall(1:lbp_iters);
        else          
          [msg, lbp_iters, logPall] = ...
            DPMP_MaxSum(nodePot, edgePot, edgeStruct, stepsize, sched);  
        end
      case 'trw'        
        if verbose, fprintf('\tRunning TRW...'); end
        if edgeStruct.useMex
          [msg, lbp_iters, logPall, logPbound] = DPMP_TRW_C(...
            nodePot,edgePot,int32(edgeStruct.edgeEnds),int32(edgeStruct.nStates),...
            int32(edgeStruct.V),int32(edgeStruct.E),int32(edgeStruct.maxIter),...
            edgeStruct.convTol,stepsize,int32(sched),rho);    
          logPall = logPall(1:lbp_iters);
          logPbound = logPbound(1:lbp_iters);
        else
          [msg, lbp_iters, ~, logPall, logPbound] = ...
            DPMP_TRW(nodePot, edgePot, edgeStruct, stepsize, sched, rho);          
        end
        stats.logPbound = cat(1,stats.logPbound,logPbound);  
      otherwise
        error('Unrecognized message passing method %s.', opt.msgPass);
    end
    t_stop = toc( t_start );
    stats.tWall = cat(1,stats.tWall,t_stop);
    stats.iters = cat(1,stats.iters,lbp_iters);
    stats.logPall = cat(1,stats.logPall,logPall);
            
    % output stats
    if verbose
      if lbp_iters==maxItersLBP, fprintf('done.  LBP did not converge!');
      else fprintf('done %d iterations.', lbp_iters);
      end
      fprintf(' (%0.1fs)\n', t_stop);
    end
    
    % get MAP labeling
    switch lower(opt.msgPass)
      case 'lbp'
        nodeBel = DPMP_getLogBeliefs(msg, nodePot, edgePot, edgeStruct);      
      case 'trw'
        nodeBel = DPMP_getTRWLogBeliefs(msg, nodePot, edgePot, edgeStruct, rho);      
    end
    [xMAP, nTies] = DPMP_getMAPLabel(doTieRes, nodeBel, nodePot, edgePot, edgeStruct);
    logPall(end) = DPMP_getLabelProb(xMAP, nodePot, edgePot, edgeStruct ); 
    stats.logP = cat(1,stats.logP,logPall(end));
    stats.unique = cat(1,stats.unique,numel(xMAP)-nTies);
    if verbose
      fprintf('\t%d ties of %d nodes.\n', nTies, numel(xMAP)); 
    end
    
    % end iteration
    t_stop_pmp = toc( t_start_pmp );
    if verbose 
      fprintf('\tdone (%0.3fs)\n', t_stop_pmp);
    end
    
  end
  
  % compute beliefs :
  % MRF might have many edges so avoid computing edge beliefs if not asked
  switch lower(opt.msgPass)
    case 'lbp'
      if nargout>3
        [nodeBel, edgeBel] = DPMP_getLogBeliefs(msg, nodePot, edgePot, edgeStruct);
      else
        nodeBel = DPMP_getLogBeliefs(msg, nodePot, edgePot, edgeStruct);
      end        
    case 'trw'
      if nargout>3
        [nodeBel, edgeBel] = DPMP_getTRWLogBeliefs(msg, nodePot, edgePot, edgeStruct, rho);
      else
        nodeBel = DPMP_getTRWLogBeliefs(msg, nodePot, edgePot, edgeStruct, rho);
      end
  end  
end

