function [msg, iters, logP, msg_argmax] = DPMP_MaxSum(nodePot,edgePot,edgeStruct,stepsize,sched)
% DPMP_INFER_MAXSUM - Max-Sum Inference
%
%   J. Pacheco 2014
%   Based on UGM implementation by M. Schmidt
%

  % Default schedule
  if nargin<5
    sched = DPMP_makeFwdBwdSched(edgeStruct);
  end   

  % Init. stuff
  [nNodes,maxState] = size(nodePot);
  nEdges = size(edgePot,3);
  edgeEnds = edgeStruct.edgeEnds;
  nStates = double(edgeStruct.nStates);
  
  % Init. messages
  msg_old = zeros(maxState,nEdges*2);
  msg = zeros(maxState,nEdges*2);
  msg_argmax = zeros(maxState,nEdges*2);
  for e = 1:nEdges
    n1 = edgeEnds(e,1);
    n2 = edgeEnds(e,2);
    msg(1:nStates(n2),e) = -log(nStates(n2)); % Message from n1 => n2
    msg(1:nStates(n1),e+nEdges) = -log(nStates(n1)); % Message from n2 => n1
  end
  logP = [];
  
  % Check there are edges
  if nEdges == 0
    iters = 0;
    return;
  end
  
  % Update Messages
for iters = 1:edgeStruct.maxIter
  if iters==1, damp = 1.0; else damp = stepsize; end
  for eDirIdx = sched
    
    % get undirected edge
    if eDirIdx <= nEdges
      e = eDirIdx;
      n = edgeEnds(e,1);
    else
      e = eDirIdx - nEdges;
      n = edgeEnds(e,2);
    end
    edges = UGM_getEdges(n,edgeStruct);
    
    % get end nodes
    n1 = edgeEnds(e,1);
    n2 = edgeEnds(e,2); 

    % get edge potential
    if n == edgeEnds(e,2)
      pot_ij = edgePot(1:nStates(n1),1:nStates(n2),e);
    else
      pot_ij = edgePot(1:nStates(n1),1:nStates(n2),e)';
    end

    % Compute temp = product of all incoming msgs except j
    temp = nodePot(n,1:nStates(n))';
    for e2 = edges
      if e ~= e2
        if n == edgeEnds(e2,2)
          temp = temp + msg(1:nStates(n),e2);
        else
          temp = temp + msg(1:nStates(n),e2+nEdges);
        end
      end
    end

    % message update  
    msgMat = bsxfun(@plus, pot_ij, temp'); 
    [newm, newm_argmax] = max(msgMat, [], 2);

    % Normalize & damp
    newm = newm - max(newm);            
    if n==n2
      msg(1:nStates(n1),eDirIdx) = damp * newm + (1-damp) * msg_old(1:nStates(n1),eDirIdx);
      msg_argmax(1:nStates(n1),eDirIdx) = newm_argmax;
    else
      msg(1:nStates(n2),eDirIdx) = damp * newm + (1-damp) * msg_old(1:nStates(n2),eDirIdx);
      msg_argmax(1:nStates(n2),eDirIdx) = newm_argmax;
    end
  end    
    
  % compute log probability
  if nargout>2
    nodeBel = DPMP_getLogBeliefs(msg, nodePot, edgePot, edgeStruct);      
    [~,mapState] = max(nodeBel,[],2);      
    this_logP = DPMP_getLabelProb(mapState, nodePot, edgePot, edgeStruct );
    logP = cat(1,logP,this_logP);
  end

  % Check convergence  
  msgDiff = max( abs( exp(msg(:)) - exp(msg_old(:)) ) );  
  if iters>1 && msgDiff < edgeStruct.convTol, break; end

  % save old message
  msg_old = msg;
end  
