193 lines
7.9 KiB
C++
193 lines
7.9 KiB
C++
|
#ifndef HUNGARIAN_METHOD_HPP
|
||
|
#define HUNGARIAN_METHOD_HPP
|
||
|
//#include "Common.hpp"
|
||
|
#include <cstdlib>
|
||
|
#include <cstdio>
|
||
|
#include <cstring>
|
||
|
#include <limits>
|
||
|
|
||
|
/// A function object which calculates the maximum-weighted bipartite matching between
|
||
|
/// two sets via the hungarian method.
|
||
|
template <int N=20>
|
||
|
class HungarianMethod {
|
||
|
public :
|
||
|
static const int MAX_SIZE = N;
|
||
|
|
||
|
private:
|
||
|
int n, max_match; //n workers and n jobs
|
||
|
double lx[N], ly[N]; //labels of X and Y parts
|
||
|
int xy[N]; //xy[x] - vertex that is matched with x,
|
||
|
int yx[N]; //yx[y] - vertex that is matched with y
|
||
|
bool S[N], T[N]; //sets S and T in algorithm
|
||
|
double slack[N]; //as in the algorithm description
|
||
|
double slackx[N]; //slackx[y] such a vertex, that
|
||
|
// l(slackx[y]) + l(y) - w(slackx[y],y) = slack[y]
|
||
|
int prev[N]; //array for memorizing alternating paths
|
||
|
|
||
|
void init_labels(const double cost[N][N])
|
||
|
{
|
||
|
memset(lx, 0, sizeof(lx));
|
||
|
memset(ly, 0, sizeof(ly));
|
||
|
for (int x = 0; x < n; x++)
|
||
|
for (int y = 0; y < n; y++)
|
||
|
lx[x] = std::max(lx[x], cost[x][y]);
|
||
|
}
|
||
|
|
||
|
void augment(const double cost[N][N]) //main function of the algorithm
|
||
|
{
|
||
|
if (max_match == n) return; //check wether matching is already perfect
|
||
|
int x, y, root; //just counters and root vertex
|
||
|
int q[N], wr = 0, rd = 0; //q - queue for bfs, wr,rd - write and read
|
||
|
//pos in queue
|
||
|
memset(S, false, sizeof(S)); //init set S
|
||
|
memset(T, false, sizeof(T)); //init set T
|
||
|
memset(prev, -1, sizeof(prev)); //init set prev - for the alternating tree
|
||
|
for (x = 0; x < n; x++) //finding root of the tree
|
||
|
if (xy[x] == -1)
|
||
|
{
|
||
|
q[wr++] = root = x;
|
||
|
prev[x] = -2;
|
||
|
S[x] = true;
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
for (y = 0; y < n; y++) //initializing slack array
|
||
|
{
|
||
|
slack[y] = lx[root] + ly[y] - cost[root][y];
|
||
|
slackx[y] = root;
|
||
|
}
|
||
|
while (true) //main cycle
|
||
|
{
|
||
|
while (rd < wr) //building tree with bfs cycle
|
||
|
{
|
||
|
x = q[rd++]; //current vertex from X part
|
||
|
for (y = 0; y < n; y++) //iterate through all edges in equality graph
|
||
|
if (cost[x][y] == lx[x] + ly[y] && !T[y])
|
||
|
{
|
||
|
if (yx[y] == -1) break; //an exposed vertex in Y found, so
|
||
|
//augmenting path exists!
|
||
|
T[y] = true; //else just add y to T,
|
||
|
q[wr++] = yx[y]; //add vertex yx[y], which is matched
|
||
|
//with y, to the queue
|
||
|
add_to_tree(yx[y], x, cost); //add edges (x,y) and (y,yx[y]) to the tree
|
||
|
}
|
||
|
if (y < n) break; //augmenting path found!
|
||
|
}
|
||
|
if (y < n) break; //augmenting path found!
|
||
|
|
||
|
update_labels(); //augmenting path not found, so improve labeling
|
||
|
wr = rd = 0;
|
||
|
for (y = 0; y < n; y++)
|
||
|
//in this cycle we add edges that were added to the equality graph as a
|
||
|
//result of improving the labeling, we add edge (slackx[y], y) to the tree if
|
||
|
//and only if !T[y] && slack[y] == 0, also with this edge we add another one
|
||
|
//(y, yx[y]) or augment the matching, if y was exposed
|
||
|
if (!T[y] && slack[y] == 0)
|
||
|
{
|
||
|
if (yx[y] == -1) //exposed vertex in Y found - augmenting path exists!
|
||
|
{
|
||
|
x = slackx[y];
|
||
|
break;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
T[y] = true; //else just add y to T,
|
||
|
if (!S[yx[y]])
|
||
|
{
|
||
|
q[wr++] = yx[y]; //add vertex yx[y], which is matched with
|
||
|
//y, to the queue
|
||
|
add_to_tree(yx[y], slackx[y],cost); //and add edges (x,y) and (y,
|
||
|
//yx[y]) to the tree
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if (y < n) break; //augmenting path found!
|
||
|
}
|
||
|
|
||
|
if (y < n) //we found augmenting path!
|
||
|
{
|
||
|
max_match++; //increment matching
|
||
|
//in this cycle we inverse edges along augmenting path
|
||
|
for (int cx = x, cy = y, ty; cx != -2; cx = prev[cx], cy = ty)
|
||
|
{
|
||
|
ty = xy[cx];
|
||
|
yx[cy] = cx;
|
||
|
xy[cx] = cy;
|
||
|
}
|
||
|
augment(cost); //recall function, go to step 1 of the algorithm
|
||
|
}
|
||
|
}//end of augment() function
|
||
|
|
||
|
void update_labels()
|
||
|
{
|
||
|
int x, y;
|
||
|
double delta = std::numeric_limits<double>::max();
|
||
|
for (y = 0; y < n; y++) //calculate delta using slack
|
||
|
if (!T[y])
|
||
|
delta = std::min(delta, slack[y]);
|
||
|
for (x = 0; x < n; x++) //update X labels
|
||
|
if (S[x]) lx[x] -= delta;
|
||
|
for (y = 0; y < n; y++) //update Y labels
|
||
|
if (T[y]) ly[y] += delta;
|
||
|
for (y = 0; y < n; y++) //update slack array
|
||
|
if (!T[y])
|
||
|
slack[y] -= delta;
|
||
|
}
|
||
|
|
||
|
void add_to_tree(int x, int prevx, const double cost[N][N])
|
||
|
//x - current vertex,prevx - vertex from X before x in the alternating path,
|
||
|
//so we add edges (prevx, xy[x]), (xy[x], x)
|
||
|
{
|
||
|
S[x] = true; //add x to S
|
||
|
prev[x] = prevx; //we need this when augmenting
|
||
|
for (int y = 0; y < n; y++) //update slacks, because we add new vertex to S
|
||
|
if (lx[x] + ly[y] - cost[x][y] < slack[y])
|
||
|
{
|
||
|
slack[y] = lx[x] + ly[y] - cost[x][y];
|
||
|
slackx[y] = x;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
public:
|
||
|
/// Computes the best matching of two sets given its cost matrix.
|
||
|
/// See the matching() method to get the computed match result.
|
||
|
/// \param cost a matrix of two sets I,J where cost[i][j] is the weight of edge i->j
|
||
|
/// \param logicalSize the number of elements in both I and J
|
||
|
/// \returns the total cost of the best matching
|
||
|
inline double operator()(const double cost[N][N], int logicalSize)
|
||
|
{
|
||
|
|
||
|
n = logicalSize;
|
||
|
assert(n <= N);
|
||
|
double ret = 0; //weight of the optimal matching
|
||
|
max_match = 0; //number of vertices in current matching
|
||
|
memset(xy, -1, sizeof(xy));
|
||
|
memset(yx, -1, sizeof(yx));
|
||
|
init_labels(cost); //step 0
|
||
|
augment(cost); //steps 1-3
|
||
|
for (int x = 0; x < n; x++) //forming answer there
|
||
|
ret += cost[x][xy[x]];
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
/// Gets the matching element in 2nd set of the ith element in the first set
|
||
|
/// \param i the index of the ith element in the first set (passed in operator())
|
||
|
/// \returns an index j, denoting the matched jth element of the 2nd set
|
||
|
inline int matching(int i) const {
|
||
|
return xy[i];
|
||
|
}
|
||
|
|
||
|
|
||
|
/// Gets the matching element in 1st set of the jth element in the 2nd set
|
||
|
/// \param j the index of the jth element in the 2nd set (passed in operator())
|
||
|
/// \returns an index i, denoting the matched ith element of the 1st set
|
||
|
/// \note inverseMatching(matching(i)) == i
|
||
|
inline int inverseMatching(int j) const {
|
||
|
return yx[j];
|
||
|
}
|
||
|
|
||
|
};
|
||
|
|
||
|
|
||
|
#endif
|