Kruskal算法需要用到并查集的知识,如果不了解并查集,可以先看我的另一篇博客:并查集 | Ephemeral-fever

Kruskal算法从边的角度求带权图的最小生成树,时间复杂度为O(eloge)。和Prim算法恰恰相反,Kruskal算法更适合于求边稀疏的带权图的最小生成树。

先来看几个概念:

  • 带权图:边赋以权值的图称为网或带权图,带权图的生成树也是带权的,生成树T各边的权值总和称为该树的权。
  • 最小生成树(MST):权值最小的生成树。
  • 最小生成树的性质:假设G=(V,E)是一个连通网,U是顶点集V的一个非空子集。若(u,v)是一条具有最小权值的边,其中u∈U,v∈V-U,则必存在一棵包含边(u,v)的最小生成树。

构造最小生成树必须解决下面两个问题:

  1. 尽可能选取权值小的边,但不能构成回路;
  2. 选取n-1条恰当的边以连通n个顶点;

于是,对于任意一个连通带权图的最小生成树来说,在要求总的权值最小的情况下,最直接的想法就是将图中的所有边按照权值大小进行升序排序,然后从小到大依次选择。

由于最小生成树本身是一棵生成树,所以需要时刻满足以下两点:

  • 生成树中任意顶点之间有且仅有一条通路,也就是说,生成树中不能存在回路
  • 对于具有 n 个顶点的连通图,其生成树中只能有 n-1 条边,这 n-1 条边连通着 n 个顶点

连接 n 个顶点在不产生回路的情况下,只需要 n-1 条边。

所以Kruskal算法的具体思路是:将所有边按照权值的大小进行升序排序,然后从小到大一一判断,条件为:如果这个边不会与之前选择的所有边组成回路,就可以作为最小生成树的一部分;反之,舍去。直到具有 n 个顶点的连通带权图筛选出来 n-1 条边为止。筛选出来的边和所有的顶点构成此连通图的最小生成树。

判断是否会产生回路的方法为:在初始状态下给每个顶点赋予不同的标记,这个标记其实就是该顶点所在树的根节点,初始状态每个顶点自成一树。对于遍历过程的每条边,其都有两个顶点,判断这两个顶点的标记是否一致,如果一致,说明它们本身就处在一棵树中,如果继续连接就会产生回路;如果不一致,说明它们之间还没有任何关系,可以连接。

具体实现如下:

#include <iostream>
#include <algorithm>

using namespace std;
/*
输入:
6 10
1 2 5
1 6 1
1 5 5
2 6 5
5 6 5
2 3 3
3 6 6
3 4 6
4 5 2
4 6 4
输出:
15
*/

const int maxn = 10001;
int n, m, total = 0, used = 0; //n个节点,m条边,total:最小生成树的权值,used:已经使用的边的数量
int pre[maxn]; //保存每个节点的前驱

struct node
{
int from, to, w; //分别为边的起点,终点,权值
} edge[maxn];

int find(int x) //并查集的查找操作
{
if (pre[x] == x)
return x;
return pre[x] = find(pre[x]);
}

void join(int x, int y) //并查集的合并操作
{
int fx = find(x), fy = find(y);
if (fx != fy)
pre[fx] = fy;
}

bool cmp(node x, node y)
{
return x.w < y.w;
}

int main()
{
cin >> n >> m;
//输入边的信息
for (int i = 0; i < m; i++)
cin >> edge[i].from >> edge[i].to >> edge[i].w;
//并查集的初始化
for (int i = 0; i <= n; i++)
pre[i] = i;
sort(edge, edge + m, cmp); //按边的权值排序
for (int i = 0; i < m; i++) //从小到大遍历每条边
{
if (used == n - 1) //n个节点的树有n-1条边
break;
if (find(edge[i].from) != find(edge[i].to))
{
join(edge[i].from, edge[i].to);
total += edge[i].w;
used++;
}
}
cout << total << endl; //输出生成树的最小权值
return 0;
}

python实现:

n, m = map(int, input().split())
edge = []
used = 0
total = 0
pre = [i for i in range(n+1)]


def myfind(x):
if pre[x] == x:
return x
pre[x] = myfind(pre[x])
return pre[x]


def myjoin(x, y):
fx, fy = myfind(x), myfind(y)
if fx != fy:
pre[fx] = fy


for i in range(m):
dic = {}
dic['from'], dic['to'], dic['w'] = map(int, input().split())
edge.append(dic)

edge.sort(key=lambda x: x['w'])

for elem in edge:
if used == n-1:
break
if myfind(elem['from']) != myfind(elem['to']):
myjoin(elem['from'], elem['to'])
total += elem['w']
used += 1
print(total)
'''
for i in edge:
print(i)
'''