function  [nodeBel, edgeBel, logZ, messages] = UGM_Infer_Tree(nodePot, edgePot, edgeStruct, maximize, logspace)
% INPUT
% nodePot(node,class)
% edgePot(class,class,edge) where e is referenced by V,E (must be the same
% between feature engine and inference engine)
%
% OUTPUT
% nodeBel(node,class) - marginal beliefs
% edgeBel(class,class,e) - pairwise beliefs
% logZ - negative of free energy

[nNodes,maxState] = size(nodePot);
nEdges = size(edgePot,3);
edgeEnds = edgeStruct.edgeEnds;
nStates = edgeStruct.nStates;
V = edgeStruct.V;
E = edgeStruct.E;

% accumulation operators
if logspace
  opFun = @plus;     
  opFunInv = @minus;
  opNorm = @logsumexp;
else     
  opFun = @times;
  opFunInv = @rdivide;
  opNorm = @sum;
end

% Compute Messages
messages = UGM_TreeBP(nodePot,edgePot,edgeStruct,maximize,logspace);

% Compute nodeBel
for n = 1:nNodes
  nodeBel(n,1:nStates(n)) = nodePot(n,1:nStates(n));

  edges = E(V(n):V(n+1)-1);
  for e = edges(:)'
    if n == edgeEnds(e,2)
      nodeBel(n,1:nStates(n)) = opFun(nodeBel(n,1:nStates(n)),messages(1:nStates(n),e)');
    else
      nodeBel(n,1:nStates(n)) = opFun(nodeBel(n,1:nStates(n)),messages(1:nStates(n),e+nEdges)');
    end
  end

  % normalize
  norm_const = opNorm( nodeBel(n,1:nStates(n)) );
  nodeBel(n,1:nStates(n)) = opFunInv(nodeBel(n,1:nStates(n)), norm_const);
end

% Compute edge beliefs
if nargout > 1   
   messages(messages==0) = inf; % Do the right thing for divide by zero case
   edgeBel = zeros(maxState,maxState,nEdges);
   
   for e = 1:nEdges
      n1 = edgeEnds(e,1);
      n2 = edgeEnds(e,2);
      
      % construct product(sum) of incoming messages by removing 
      % messages between n1 & n2 from node beliefs
      belN1 = opFunInv(nodeBel(n1,1:nStates(n1))',messages(1:nStates(n1),e+nEdges));
      belN2 = opFunInv(nodeBel(n2,1:nStates(n2))',messages(1:nStates(n2),e));
      
      % combine aggregate messages with edge potential
      b_unary = bsxfun(opFun, belN1, belN2');
      eb = opFun(...
        b_unary, edgePot(1:nStates(n1),1:nStates(n2),e) ...
      );
      
      % normalize
      norm_const = opNorm(eb(:));
      edgeBel(1:nStates(n1),1:nStates(n2),e) = opFunInv(eb, norm_const);
   end
end

if nargout > 2
   % Compute Bethe free energy 
   % (Z could also be computed as normalizing constant for any node in the tree
   %    if unnormalized messages are used)
   Energy1 = 0; Energy2 = 0; Entropy1 = 0; Entropy2 = 0;
   
   % log-beliefs?
   if logspace
      nodeBelTmp = exp( nodeBel ) + eps;
      edgeBelTmp = exp( edgeBel ) + eps;
      nodePotTmp = exp( nodePot );
      edgePotTmp = exp( edgePot );
   else     
      nodeBelTmp = nodeBel+eps;
      edgeBelTmp = edgeBel+eps;
      nodePotTmp = nodePot;
      edgePotTmp = edgePot;
   end
   
   % compute unary terms
   for n = 1:nNodes
      edges = E(V(n):V(n+1)-1);
      nNbrs = length(edges);

      % Node Entropy (can get divide by zero if beliefs at 0)
      Entropy1 = Entropy1 + (nNbrs-1)*sum(nodeBelTmp(n,1:nStates(n)).*log(nodeBelTmp(n,1:nStates(n))));      

      % Node Energy
      Energy1 = Energy1 - sum(nodeBelTmp(n,1:nStates(n)).*log(nodePotTmp(n,1:nStates(n))));
   end

   % compute pairwise terms
   for e = 1:nEdges
      n1 = edgeEnds(e,1);
      n2 = edgeEnds(e,2);

      % Pairwise Entropy (can get divide by zero if beliefs at 0)
      eb = edgeBelTmp(1:nStates(n1),1:nStates(n2),e); % this is slow...
      Entropy2 = Entropy2 - sum(eb(:).*log(eb(:)));

      % Pairwise Energy
      ep = edgePotTmp(1:nStates(n1),1:nStates(n2),e);  
      Energy2 = Energy2 - sum(eb(:).*log(ep(:)));
   end
   F = (Energy1+Energy2) - (Entropy1+Entropy2);
   logZ = -F;
end

end
