function [msg, iters, msg_argmax, logP, logPbound] = DPMP_TRW(nodePot,edgePot,edgeStruct,stepsize,sched,mu)
% DPMP_TRW - Parallel message passing tree-reweighted max-product.
%
%   J. Pacheco 2014
%   Based on UGM implementation by M. Schmidt
%

% Init. stuff
[nNodes,maxState] = size(nodePot);
edgeEnds = edgeStruct.edgeEnds;
nStates = edgeStruct.nStates;
nEdges = size(edgePot,3);

% Init. messages
msg_old = zeros(maxState,nEdges*2);
msg = zeros(maxState,nEdges*2);
msg_argmax = zeros(maxState,nEdges*2);
logP = []; 
logPbound = [];

%% DEBUG:
nonconvMsg = [];

% Update Messages
for iters = 1:edgeStruct.maxIter
  if iters==1, damp = 1.0; else damp = stepsize; end
  for eDirIdx = sched
    
    % get undirected edge
    if eDirIdx <= nEdges
      e = eDirIdx;
      n = edgeEnds(e,1);
    else
      e = eDirIdx - nEdges;
      n = edgeEnds(e,2);
    end
    edges = UGM_getEdges(n,edgeStruct);
        
    % get end nodes
    n1 = edgeEnds(e,1);
    n2 = edgeEnds(e,2);      
    
    % get edge potential
    if n == edgeEnds(e,2)
      pot_ij = edgePot(1:nStates(n1),1:nStates(n2),e);
    else
      pot_ij = edgePot(1:nStates(n1),1:nStates(n2),e)';
    end

    % Adjust edge potential by edge appearnce probability
    pot_ij = (1/mu(e)) * pot_ij;

    % Compute temp = product of all incoming msgs except j
    %   to the power of the edge appearance probability,
    %   divided by msg from j to the (1 - edge appearnce prob)
    temp = nodePot(n,1:nStates(n))';
    for e2 = edges
      if n == edgeEnds(e2,2)
        incoming = msg(1:nStates(n),e2);
      else
        incoming = msg(1:nStates(n),e2+nEdges);
      end
      if e ~= e2
        temp = temp + mu(e2)*incoming;
      else
        temp = temp - (1-mu(e2))*incoming;
      end
    end

    % Max-Sum  
    msgMat = bsxfun(@plus, pot_ij, temp'); 
    [newm, newm_argmax] = max(msgMat, [], 2);

    % Normalize & damp
%     newm = newm - max(newm);
    if n==n2
      msg(1:nStates(n1),eDirIdx) = damp * newm + (1-damp) * msg_old(1:nStates(n1),eDirIdx);
      msg(1:nStates(n1),eDirIdx) = msg(1:nStates(n1),eDirIdx) - max(msg(1:nStates(n1),eDirIdx));
      msg_argmax(1:nStates(n1),eDirIdx) = newm_argmax;
    else
      msg(1:nStates(n2),eDirIdx) = damp * newm + (1-damp) * msg_old(1:nStates(n2),eDirIdx);
      msg(1:nStates(n2),eDirIdx)= msg(1:nStates(n2),eDirIdx) - max(msg(1:nStates(n2),eDirIdx));
      msg_argmax(1:nStates(n2),eDirIdx) = newm_argmax;
    end
  end 
  
  % compute log probability
  if nargout>3
    if nargout>4
      [nodeBel, edgeBel] = DPMP_getTRWLogBeliefs(msg, nodePot, edgePot, edgeStruct, mu, false);      
      thisLogPbound = DPMP_computeTRWbound( nodeBel, edgeBel, edgeStruct, mu );
      logPbound = cat(1,logPbound,thisLogPbound);
    else
      nodeBel = DPMP_getTRWLogBeliefs(msg, nodePot, edgePot, edgeStruct, mu);      
    end
    [~,mapState] = max(nodeBel,[],2);      
    this_logP = DPMP_getLabelProb(mapState, nodePot, edgePot, edgeStruct );    
    logP = cat(1,logP,this_logP);
  end

  % Check convergence
  msgDiff = max( abs( exp(msg(:) ) - exp(msg_old(:) ) ) ); 
  if iters>1 && msgDiff < edgeStruct.convTol, break; end
  
  % save old message
  msg_old = msg;
end
fprintf('msgDiff: %e\n', msgDiff);
