gravatar
kpeter (Peter Kovacs)
kpeter@inf.elte.hu
Minor improvements in CostScaling (#417)
0 1 0
default
1 file changed with 49 insertions and 37 deletions:
↑ Collapse diff ↑
Ignore white space 48 line context
... ...
@@ -554,60 +554,67 @@
554 554
    ///
555 555
    /// \return <tt>(*this)</tt>
556 556
    ///
557 557
    /// \see reset(), run()
558 558
    CostScaling& resetParams() {
559 559
      for (int i = 0; i != _res_node_num; ++i) {
560 560
        _supply[i] = 0;
561 561
      }
562 562
      int limit = _first_out[_root];
563 563
      for (int j = 0; j != limit; ++j) {
564 564
        _lower[j] = 0;
565 565
        _upper[j] = INF;
566 566
        _scost[j] = _forward[j] ? 1 : -1;
567 567
      }
568 568
      for (int j = limit; j != _res_arc_num; ++j) {
569 569
        _lower[j] = 0;
570 570
        _upper[j] = INF;
571 571
        _scost[j] = 0;
572 572
        _scost[_reverse[j]] = 0;
573 573
      }
574 574
      _have_lower = false;
575 575
      return *this;
576 576
    }
577 577

	
578
    /// \brief Reset all the parameters that have been given before.
578
    /// \brief Reset the internal data structures and all the parameters
579
    /// that have been given before.
579 580
    ///
580
    /// This function resets all the paramaters that have been given
581
    /// before using functions \ref lowerMap(), \ref upperMap(),
582
    /// \ref costMap(), \ref supplyMap(), \ref stSupply().
581
    /// This function resets the internal data structures and all the
582
    /// paramaters that have been given before using functions \ref lowerMap(),
583
    /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply().
583 584
    ///
584
    /// It is useful for multiple run() calls. If this function is not
585
    /// used, all the parameters given before are kept for the next
586
    /// \ref run() call.
587
    /// However, the underlying digraph must not be modified after this
588
    /// class have been constructed, since it copies and extends the graph.
585
    /// It is useful for multiple \ref run() calls. By default, all the given
586
    /// parameters are kept for the next \ref run() call, unless
587
    /// \ref resetParams() or \ref reset() is used.
588
    /// If the underlying digraph was also modified after the construction
589
    /// of the class or the last \ref reset() call, then the \ref reset()
590
    /// function must be used, otherwise \ref resetParams() is sufficient.
591
    ///
592
    /// See \ref resetParams() for examples.
593
    ///
589 594
    /// \return <tt>(*this)</tt>
595
    ///
596
    /// \see resetParams(), run()
590 597
    CostScaling& reset() {
591 598
      // Resize vectors
592 599
      _node_num = countNodes(_graph);
593 600
      _arc_num = countArcs(_graph);
594 601
      _res_node_num = _node_num + 1;
595 602
      _res_arc_num = 2 * (_arc_num + _node_num);
596 603
      _root = _node_num;
597 604

	
598 605
      _first_out.resize(_res_node_num + 1);
599 606
      _forward.resize(_res_arc_num);
600 607
      _source.resize(_res_arc_num);
601 608
      _target.resize(_res_arc_num);
602 609
      _reverse.resize(_res_arc_num);
603 610

	
604 611
      _lower.resize(_res_arc_num);
605 612
      _upper.resize(_res_arc_num);
606 613
      _scost.resize(_res_arc_num);
607 614
      _supply.resize(_res_node_num);
608 615

	
609 616
      _res_cap.resize(_res_arc_num);
610 617
      _cost.resize(_res_arc_num);
611 618
      _pi.resize(_res_node_num);
612 619
      _excess.resize(_res_node_num);
613 620
      _next_out.resize(_res_node_num);
... ...
@@ -869,220 +876,225 @@
869 876
        for (int a = _first_out[_root]; a != _res_arc_num; ++a) {
870 877
          int u = _target[a];
871 878
          int ra = _reverse[a];
872 879
          _res_cap[a] = -_sum_supply + 1;
873 880
          _res_cap[ra] = -_excess[u];
874 881
          _cost[a] = 0;
875 882
          _cost[ra] = 0;
876 883
          _excess[u] = 0;
877 884
        }
878 885
      } else {
879 886
        for (ArcIt a(_graph); a != INVALID; ++a) {
880 887
          Value fa = flow[a];
881 888
          _res_cap[_arc_idf[a]] = cap[a] - fa;
882 889
          _res_cap[_arc_idb[a]] = fa;
883 890
        }
884 891
        for (int a = _first_out[_root]; a != _res_arc_num; ++a) {
885 892
          int ra = _reverse[a];
886 893
          _res_cap[a] = 0;
887 894
          _res_cap[ra] = 0;
888 895
          _cost[a] = 0;
889 896
          _cost[ra] = 0;
890 897
        }
891 898
      }
892 899

	
893
      return OPTIMAL;
894
    }
895

	
896
    // Execute the algorithm and transform the results
897
    void start(Method method) {
898
      // Maximum path length for partial augment
899
      const int MAX_PATH_LENGTH = 4;
900

	
901 900
      // Initialize data structures for buckets
902 901
      _max_rank = _alpha * _res_node_num;
903 902
      _buckets.resize(_max_rank);
904 903
      _bucket_next.resize(_res_node_num + 1);
905 904
      _bucket_prev.resize(_res_node_num + 1);
906 905
      _rank.resize(_res_node_num + 1);
907 906

	
908
      // Execute the algorithm
907
      return OPTIMAL;
908
    }
909

	
910
    // Execute the algorithm and transform the results
911
    void start(Method method) {
912
      const int MAX_PARTIAL_PATH_LENGTH = 4;
913

	
909 914
      switch (method) {
910 915
        case PUSH:
911 916
          startPush();
912 917
          break;
913 918
        case AUGMENT:
914 919
          startAugment(_res_node_num - 1);
915 920
          break;
916 921
        case PARTIAL_AUGMENT:
917
          startAugment(MAX_PATH_LENGTH);
922
          startAugment(MAX_PARTIAL_PATH_LENGTH);
918 923
          break;
919 924
      }
920 925

	
921 926
      // Compute node potentials for the original costs
922 927
      _arc_vec.clear();
923 928
      _cost_vec.clear();
924 929
      for (int j = 0; j != _res_arc_num; ++j) {
925 930
        if (_res_cap[j] > 0) {
926 931
          _arc_vec.push_back(IntPair(_source[j], _target[j]));
927 932
          _cost_vec.push_back(_scost[j]);
928 933
        }
929 934
      }
930 935
      _sgr.build(_res_node_num, _arc_vec.begin(), _arc_vec.end());
931 936

	
932 937
      typename BellmanFord<StaticDigraph, LargeCostArcMap>
933 938
        ::template SetDistMap<LargeCostNodeMap>::Create bf(_sgr, _cost_map);
934 939
      bf.distMap(_pi_map);
935 940
      bf.init(0);
936 941
      bf.start();
937 942

	
938 943
      // Handle non-zero lower bounds
939 944
      if (_have_lower) {
940 945
        int limit = _first_out[_root];
941 946
        for (int j = 0; j != limit; ++j) {
942 947
          if (!_forward[j]) _res_cap[j] += _lower[j];
943 948
        }
944 949
      }
945 950
    }
946 951

	
947 952
    // Initialize a cost scaling phase
948 953
    void initPhase() {
949 954
      // Saturate arcs not satisfying the optimality condition
950 955
      for (int u = 0; u != _res_node_num; ++u) {
951 956
        int last_out = _first_out[u+1];
952 957
        LargeCost pi_u = _pi[u];
953 958
        for (int a = _first_out[u]; a != last_out; ++a) {
954
          int v = _target[a];
955
          if (_res_cap[a] > 0 && _cost[a] + pi_u - _pi[v] < 0) {
956
            Value delta = _res_cap[a];
957
            _excess[u] -= delta;
958
            _excess[v] += delta;
959
            _res_cap[a] = 0;
960
            _res_cap[_reverse[a]] += delta;
959
          Value delta = _res_cap[a];
960
          if (delta > 0) {
961
            int v = _target[a];
962
            if (_cost[a] + pi_u - _pi[v] < 0) {
963
              _excess[u] -= delta;
964
              _excess[v] += delta;
965
              _res_cap[a] = 0;
966
              _res_cap[_reverse[a]] += delta;
967
            }
961 968
          }
962 969
        }
963 970
      }
964 971

	
965 972
      // Find active nodes (i.e. nodes with positive excess)
966 973
      for (int u = 0; u != _res_node_num; ++u) {
967 974
        if (_excess[u] > 0) _active_nodes.push_back(u);
968 975
      }
969 976

	
970 977
      // Initialize the next arcs
971 978
      for (int u = 0; u != _res_node_num; ++u) {
972 979
        _next_out[u] = _first_out[u];
973 980
      }
974 981
    }
975 982

	
976 983
    // Early termination heuristic
977 984
    bool earlyTermination() {
978 985
      const double EARLY_TERM_FACTOR = 3.0;
979 986

	
980 987
      // Build a static residual graph
981 988
      _arc_vec.clear();
982 989
      _cost_vec.clear();
983 990
      for (int j = 0; j != _res_arc_num; ++j) {
984 991
        if (_res_cap[j] > 0) {
985 992
          _arc_vec.push_back(IntPair(_source[j], _target[j]));
986 993
          _cost_vec.push_back(_cost[j] + 1);
987 994
        }
988 995
      }
989 996
      _sgr.build(_res_node_num, _arc_vec.begin(), _arc_vec.end());
990 997

	
991 998
      // Run Bellman-Ford algorithm to check if the current flow is optimal
992 999
      BellmanFord<StaticDigraph, LargeCostArcMap> bf(_sgr, _cost_map);
993 1000
      bf.init(0);
994 1001
      bool done = false;
995 1002
      int K = int(EARLY_TERM_FACTOR * std::sqrt(double(_res_node_num)));
996 1003
      for (int i = 0; i < K && !done; ++i) {
997 1004
        done = bf.processNextWeakRound();
998 1005
      }
999 1006
      return done;
1000 1007
    }
1001 1008

	
1002 1009
    // Global potential update heuristic
1003 1010
    void globalUpdate() {
1004
      int bucket_end = _root + 1;
1011
      const int bucket_end = _root + 1;
1005 1012

	
1006 1013
      // Initialize buckets
1007 1014
      for (int r = 0; r != _max_rank; ++r) {
1008 1015
        _buckets[r] = bucket_end;
1009 1016
      }
1010 1017
      Value total_excess = 0;
1018
      int b0 = bucket_end;
1011 1019
      for (int i = 0; i != _res_node_num; ++i) {
1012 1020
        if (_excess[i] < 0) {
1013 1021
          _rank[i] = 0;
1014
          _bucket_next[i] = _buckets[0];
1015
          _bucket_prev[_buckets[0]] = i;
1016
          _buckets[0] = i;
1022
          _bucket_next[i] = b0;
1023
          _bucket_prev[b0] = i;
1024
          b0 = i;
1017 1025
        } else {
1018 1026
          total_excess += _excess[i];
1019 1027
          _rank[i] = _max_rank;
1020 1028
        }
1021 1029
      }
1022 1030
      if (total_excess == 0) return;
1031
      _buckets[0] = b0;
1023 1032

	
1024 1033
      // Search the buckets
1025 1034
      int r = 0;
1026 1035
      for ( ; r != _max_rank; ++r) {
1027 1036
        while (_buckets[r] != bucket_end) {
1028 1037
          // Remove the first node from the current bucket
1029 1038
          int u = _buckets[r];
1030 1039
          _buckets[r] = _bucket_next[u];
1031 1040

	
1032 1041
          // Search the incomming arcs of u
1033 1042
          LargeCost pi_u = _pi[u];
1034 1043
          int last_out = _first_out[u+1];
1035 1044
          for (int a = _first_out[u]; a != last_out; ++a) {
1036 1045
            int ra = _reverse[a];
1037 1046
            if (_res_cap[ra] > 0) {
1038 1047
              int v = _source[ra];
1039 1048
              int old_rank_v = _rank[v];
1040 1049
              if (r < old_rank_v) {
1041 1050
                // Compute the new rank of v
1042 1051
                LargeCost nrc = (_cost[ra] + _pi[v] - pi_u) / _epsilon;
1043 1052
                int new_rank_v = old_rank_v;
1044
                if (nrc < LargeCost(_max_rank))
1045
                  new_rank_v = r + 1 + int(nrc);
1053
                if (nrc < LargeCost(_max_rank)) {
1054
                  new_rank_v = r + 1 + static_cast<int>(nrc);
1055
                }
1046 1056

	
1047 1057
                // Change the rank of v
1048 1058
                if (new_rank_v < old_rank_v) {
1049 1059
                  _rank[v] = new_rank_v;
1050 1060
                  _next_out[v] = _first_out[v];
1051 1061

	
1052 1062
                  // Remove v from its old bucket
1053 1063
                  if (old_rank_v < _max_rank) {
1054 1064
                    if (_buckets[old_rank_v] == v) {
1055 1065
                      _buckets[old_rank_v] = _bucket_next[v];
1056 1066
                    } else {
1057
                      _bucket_next[_bucket_prev[v]] = _bucket_next[v];
1058
                      _bucket_prev[_bucket_next[v]] = _bucket_prev[v];
1067
                      int pv = _bucket_prev[v], nv = _bucket_next[v];
1068
                      _bucket_next[pv] = nv;
1069
                      _bucket_prev[nv] = pv;
1059 1070
                    }
1060 1071
                  }
1061 1072

	
1062
                  // Insert v to its new bucket
1063
                  _bucket_next[v] = _buckets[new_rank_v];
1064
                  _bucket_prev[_buckets[new_rank_v]] = v;
1073
                  // Insert v into its new bucket
1074
                  int nv = _buckets[new_rank_v];
1075
                  _bucket_next[v] = nv;
1076
                  _bucket_prev[nv] = v;
1065 1077
                  _buckets[new_rank_v] = v;
1066 1078
                }
1067 1079
              }
1068 1080
            }
1069 1081
          }
1070 1082

	
1071 1083
          // Finish search if there are no more active nodes
1072 1084
          if (_excess[u] > 0) {
1073 1085
            total_excess -= _excess[u];
1074 1086
            if (total_excess <= 0) break;
1075 1087
          }
1076 1088
        }
1077 1089
        if (total_excess <= 0) break;
1078 1090
      }
1079 1091

	
1080 1092
      // Relabel nodes
1081 1093
      for (int u = 0; u != _res_node_num; ++u) {
1082 1094
        int k = std::min(_rank[u], r);
1083 1095
        if (k > 0) {
1084 1096
          _pi[u] -= _epsilon * k;
1085 1097
          _next_out[u] = _first_out[u];
1086 1098
        }
1087 1099
      }
1088 1100
    }
0 comments (0 inline)