#include <iostream>
#include <set>
#include <vector>
#include <iterator>

#include <cmath>
#include <cstdlib>

#include <lemon/smart_graph.h>
#include <lemon/min_cost_arborescence.h>

#include <lemon/graph_utils.h>
#include <lemon/time_measure.h>

#include <lemon/tolerance.h>

#include "test_tools.h"

using namespace lemon;
using namespace std;

const int n = 10;
const int e = 22;

int sourceNode = 0;

int sources[e] = {
  1, 0, 2, 4, 4, 3, 9, 8, 9, 8,
  4, 2, 0, 6, 4, 1, 7, 2, 8, 6,
  1, 0
};

int targets[e] = {
  8, 3, 1, 1, 4, 9, 8, 1, 8, 0,
  3, 2, 1, 3, 1, 1, 2, 6, 3, 9,
  1, 3
};

double costs[e] = {
  107.444, 70.3069, 46.0496, 28.3962, 91.4325,
  76.9443, 61.986, 39.3754, 74.9575, 39.3153,
  45.7094, 34.6184, 100.156, 95.726, 22.3429,
  31.587, 51.6972, 29.6773, 115.038, 32.4137,
  60.0038, 40.1237
};



int main() {
  srand(time(0));
  typedef SmartGraph Graph;
  GRAPH_TYPEDEFS(Graph);

  typedef Graph::EdgeMap<double> CostMap;

  Graph graph;
  CostMap cost(graph);
  vector<Node> nodes;
  
  for (int i = 0; i < n; ++i) {
    nodes.push_back(graph.addNode());
  }

  for (int i = 0; i < e; ++i) {
    Edge edge = graph.addEdge(nodes[sources[i]], nodes[targets[i]]);
    cost[edge] = costs[i];
  }

  Node source = nodes[sourceNode];

  MinCostArborescence<Graph, CostMap> mca(graph, cost);
  mca.run(source);

  vector<pair<double, set<Node> > > dualSolution(mca.dualSize());

  for (int i = 0; i < mca.dualSize(); ++i) {
    dualSolution[i].first = mca.dualValue(i);
    for (MinCostArborescence<Graph, CostMap>::DualIt it(mca, i); 
         it != INVALID; ++it) {
      dualSolution[i].second.insert(it);
    }
  }

  Tolerance<double> tolerance;

  for (EdgeIt it(graph); it != INVALID; ++it) {
    if (mca.reached(graph.source(it))) {
      double sum = 0.0;
      for (int i = 0; i < (int)dualSolution.size(); ++i) {
        if (dualSolution[i].second.find(graph.target(it)) 
            != dualSolution[i].second.end() &&
            dualSolution[i].second.find(graph.source(it)) 
            == dualSolution[i].second.end()) {
          sum += dualSolution[i].first;
        }
      }
      if (mca.arborescence(it)) {
        check(!tolerance.less(sum, cost[it]), "INVALID DUAL");
      }
      check(!tolerance.less(cost[it], sum), "INVALID DUAL");
    }
  }


  check(!tolerance.different(mca.dualValue(), mca.arborescenceValue()),
               "INVALID DUAL");


  check(mca.reached(source), "INVALID ARBORESCENCE");
  for (EdgeIt it(graph); it != INVALID; ++it) {
    check(!mca.reached(graph.source(it)) || 
                 mca.reached(graph.target(it)), "INVALID ARBORESCENCE");
  }

  for (NodeIt it(graph); it != INVALID; ++it) {
    if (!mca.reached(it)) continue;
    int cnt = 0;
    for (InEdgeIt jt(graph, it); jt != INVALID; ++jt) {
      if (mca.arborescence(jt)) {
        ++cnt;
      }
    }
    check((it == source ? cnt == 0 : cnt == 1), "INVALID ARBORESCENCE");
  }
  
  return 0;
}
