function [nodeBel, edgeBel] = DPMP_Infer_JunctionMax(nodePot,edgePot,edgeStruct,ordering)
% DPMP_INFER_JUNCTIONMAX - MaxSum variant of exact Junction Tree inference.
%
% J. Pacheco, 2014
% Based on code by M. Schmidt
%

  nNodes = double( edgeStruct.nNodes );
  nEdges = double( edgeStruct.nEdges ); 
  edgeEnds = edgeStruct.edgeEnds;
  nStates = edgeStruct.nStates;

  if nargin < 4
    ordering = 1:nNodes;
  end
  
  %% Build Junction Tree
  [cliques, cliqueEdges, cliquePot, separators] = ...
    DPMP_makeJunctionTree( nodePot, edgePot, edgeStruct, ordering );
  nCliques = length(cliques);
  nCliqueEdges = size(cliqueEdges,1);
  [V,E] = UGM_makeEdgeVE(cliqueEdges,nCliques); 

  % Count number of neighbors
  nNeighbors = zeros(nCliques,1);
  for e = 1:nCliqueEdges
    nNeighbors(cliqueEdges(e,1)) = nNeighbors(cliqueEdges(e,1))+1;
    nNeighbors(cliqueEdges(e,2)) = nNeighbors(cliqueEdges(e,2))+1;
  end

  % Add all leafs to initial queue
  Q = find(nNeighbors == 1);

  sent = zeros(nCliqueEdges*2,1);
  waiting = ones(nCliqueEdges*2,1);
  messages = cell(nCliqueEdges*2,1);
  while ~isempty(Q)
    c = Q(1);
    Q = Q(2:end);

    wait = waiting(V(c):V(c+1)-1);
    sending = sent(V(c):V(c+1)-1);

    nWaiting = sum(wait==1);

    if nWaiting == 0
      % Send final messages
      for sendEdge = [double(V(c))+find(sending==0)-1]'
        sent(sendEdge) = 1;
        [messages,waiting,nei] = send(c,sendEdge,cliques,cliquePot,messages,waiting,nStates,cliqueEdges,V,E,separators);
        if nNeighbors(nei) == 1 || nNeighbors(nei) == 0
          Q = [Q;nei];
        end
      end
    else
      remainingEdge = V(c)+find(wait==1)-1;
      sent(remainingEdge) = 1;
      [messages,waiting,nei] = send(c,remainingEdge,cliques,cliquePot,messages,waiting,nStates,cliqueEdges,V,E,separators);
      nNeighbors(nei) = nNeighbors(nei)-1;
      if nNeighbors(nei) == 1 || nNeighbors(nei) == 0
        Q = [Q;nei];
      end
    end
  end
  %messages{:}

  %% Compute cliqueBel
  cliqueBel = cell(nCliques,1);
  for c = 1:nCliques
    nodes = cliques{c};
    ind = cell(length(nodes),1);
    for nodeInd = 1:length(nodes)
      ind{nodeInd} = 1:nStates(nodes(nodeInd));
    end

    % Multiply cliquePot by all incoming messages
    cb = cliquePot{c};
    edges = E(V(c):V(c+1)-1);
    for e = edges(:)'
      if c == cliqueEdges(e,2)
        msg = messages{e};
      else
        msg = messages{e+nCliqueEdges};
      end


      ind_sub = ind;
      sepLength = length(separators{e});
      sep = zeros(sepLength,1);
      s = cell(length(sep),1);
      for n = 1:sepLength
        s{n,1} = 1;
        sep(n) = find(nodes==separators{e}(n));
      end
      while 1
        for nodeInd = 1:length(sep)
          ind_sub{sep(nodeInd)} = s{nodeInd};
        end
        cb(ind_sub{:}) = cb(ind_sub{:}) + msg(s{:});

        for nodeInd = 1:sepLength
          s{nodeInd} = s{nodeInd} + 1;
          if s{nodeInd} <= nStates(separators{e}(nodeInd))
            break;
          else
            s{nodeInd} = 1;
          end
        end
        if nodeInd == length(sep) && s{end} == 1
          break;
        end
      end
    end
    cb = cb - max( cb(:) );
    cliqueBel{c} = cb;
  end
  % cliqueBel{:}
  % pause

  %% Compute nodeBel
  nodeBel = zeros(size(nodePot));
  nodeBelMissing = ones(nNodes,1);
  for c = 1:nCliques
    cb = cliqueBel{c};

    nodes = cliques{c};
    ind = cell(length(nodes),1);
    for nodeInd = 1:length(nodes)
      ind{nodeInd} = 1:nStates(nodes(nodeInd));
    end

    for nodeInd = 1:length(nodes)
      n = nodes(nodeInd);
      if nodeBelMissing(n)
        nodeBelMissing(n) = 0;

        ind_sub = ind;
        for s = 1:nStates(n)
          ind_sub{nodeInd} = s;
          slice = cb(ind_sub{:});
          nodeBel(n,s) = max(slice(:));
        end
      end
    end
  end

  %% Compute edgeBel
  if nargout > 1
    edgeBel = zeros(size(edgePot));
    edgeBelMissing = ones(nEdges,1);
    for c = 1:nCliques
      cb = cliqueBel{c};

      nodes = cliques{c};
      ind = cell(length(nodes),1);
      for nodeInd = 1:length(nodes)
        ind{nodeInd} = 1:nStates(nodes(nodeInd));
      end

      for e = 1:nEdges
        n1 = edgeEnds(e,1);
        n2 = edgeEnds(e,2);
        n1Ind = find(n1==nodes);
        n2Ind = find(n2==nodes);
        if edgeBelMissing(e) && ~isempty(n1Ind) && ~isempty(n2Ind)
          edgeBelMissing(e) = 0;

          ind_sub = ind;
          for s1 = 1:nStates(n1)
            for s2 = 1:nStates(n2)
              ind_sub{n1Ind} = s1;
              ind_sub{n2Ind} = s2;
              slice = cb(ind_sub{:});
              edgeBel(s1,s2,e) = max(slice(:));
            end
          end
        end
      end
    end
  end
end

%% Message passing function
function [messages,waiting,nei] = send(c,e,cliques,cliquePot,messages,waiting,nStates,edgeEnds,V,E,separators)
  nEdges = size(edgeEnds,1);
  edge = E(e);
  if c == edgeEnds(edge,1)
    nei = edgeEnds(edge,2);
  else
    nei = edgeEnds(edge,1);
  end
%   fprintf('Sending from %d to %d\n',c,nei);

  % Opposite edge is no longer waiting
  for tmp = V(nei):V(nei+1)-1
    if tmp ~= e && E(tmp) == E(e)
      waiting(tmp) = 0;
    end
  end

  e = edge;

  nodes = cliques{c};
  for nodeInd = 1:length(nodes)
    ind{nodeInd} = 1:nStates(nodes(nodeInd));
  end

  % Compute Product of clique potential with all incoming messages except
  % along e
  temp = cliquePot{c};
  neighbors = E(V(c):V(c+1)-1);
  for e2 = neighbors(:)'
    if e ~= e2
      ind_sub = ind;
      sepLength = length(separators{e2});
      sep = zeros(sepLength,1);
      s = cell(length(sep),1);
      for n = 1:sepLength
        s{n,1} = 1;
        sep(n) = find(nodes==separators{e2}(n));
      end
      while 1
        for nodeInd = 1:length(sep)
          ind_sub{sep(nodeInd)} = s{nodeInd};
        end
        if c == edgeEnds(e2,2)
          temp(ind_sub{:}) = temp(ind_sub{:}) + messages{e2}(s{:});
        else
          temp(ind_sub{:}) = temp(ind_sub{:}) + messages{e2+nEdges}(s{:});
        end

        for nodeInd = 1:length(sep)
          s{nodeInd} = s{nodeInd} + 1;
          if s{nodeInd} <= nStates(separators{e2}(nodeInd))
            break;
          else
            s{nodeInd} = 1;
          end
        end
        if nodeInd == length(sep) && s{end} == 1
          break;
        end
      end
    end
  end

  % Maximize over all variables except separator set
  sepLength = length(separators{e});
  sep = zeros(sepLength,1);
  s = cell(length(sep),1);
  for n = 1:sepLength
    s{n,1} = 1;
    sep(n) = find(nodes==separators{e}(n));
  end
  newm = zeros([nStates(separators{e})' 1]);
  ind_sub = ind;
  while 1
    for nodeInd = 1:length(sep)
      ind_sub{sep(nodeInd)} = s{nodeInd};
    end
    slice = temp(ind_sub{:});
    newm(s{:}) = max(slice(:));

    for nodeInd = 1:length(sep)
      s{nodeInd} = s{nodeInd} + 1;
      if s{nodeInd} <= nStates(separators{e}(nodeInd))
        break;
      else
        s{nodeInd} = 1;
      end
    end
    if nodeInd == length(sep) && s{end} == 1
      break;
    end
  end
  newm = newm - max( newm(:) );

  if c == edgeEnds(e,2)
    messages{e+nEdges} = newm;
  else
    messages{e} = newm;
  end
end