Problem
You are given an integer n
. There is an undirected graph with n
nodes, numbered from 0
to n - 1
. You are given a 2D integer array edges
where edges[i] = [ai, bi]
denotes that there exists an undirected edge connecting nodes ai and bi
.
Return the number of pairs of different nodes that are unreachable from each other.
Solution
There are multiple solutions to solve this problem, such as utilizing DFS or BFS. By using DFS or BFS algorithms, sets of connected nodes could be obtained and thus the pairs off unreachable nodes. However, there is another way of solving the problem by using a data structure called Disjoint set.
If disjoint set is not a foreign concept to you, GOTO
Approach.
Disjoint Set
Disjoint set, or Union find is a data structure that has the two following operations:
find(i)
: Find the representing node in the set of nodei
.union(i, j)
: Union (join) two sets which contains nodei
andj
. (BTW, if you forgot, union is a keyword in C/C++, so find another function name when implementing the data structure)
Below is a graphical representation of a disjoint set. Each set has a representing node (root node) which is colored blue. (graph stolen from Leetcode)
Each node will know its parent. Therefore, to find(i)
the representing node, node i
will need to traverse upwards to the root node. On the other hand, what union(i, j)
does is it will make a set’s representing node’s parent to another set’s representing node. Therefore, if union(8, 6)
is called, the following will happen:
To build the disjoint set, first initialize the nodes’ parent to itself (which means that each node is in its own set). Then start feeding it edges and union
the sets.
See the implementation of disjoint set here.
Approach
Here is my first attempt of solving this problem. Yeah … counting the paths of different nodes that are unreachable with nested for loops destroyed the time complexity. Without a doubt, this solution timed out.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public:
long long countPairs(int n, vector<vector<int>>& edges) {
DisjointSet set(n);
for (const vector<int>& edge: edges) {
set.setUnion(edge[0], edge[1]);
}
long long unreachableCount = 0;
for (int i = 0; i < n-1; i++) {
for (int j = i+1; j < n; j++) {
if (set.find(i) != set.find(j)) unreachableCount++;
}
}
return unreachableCount;
}
};
Implementation
On second thought, the pairs of unreachable node could be calculated in one go. Simply by multiplying the set size and the remaining nodes (excluding the nodes in set), you could get the number of pairs of unreachable nodes.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class DisjointSet {
public:
DisjointSet(int n): parent(n) {
for (int i = 0; i < n; i++) parent[i] = i;
}
int find(int i) const {
if (parent[i] != i) {
return find(parent[i]);
}
return i;
}
void setUnion(int i, int j) {
parent[find(j)] = parent[find(i)];
}
private:
vector<int> parent;
};
class Solution {
public:
long long countPairs(int n, vector<vector<int>>& edges) {
DisjointSet set(n);
for (const vector<int>& edge: edges) {
set.setUnion(edge[0], edge[1]);
}
long long unreachableCount = 0;
long long nodesLeft = n;
unordered_map<int, int> componentSize;;
for (int i = 0; i < n; i++) componentSize[set.find(i)]++;
for (auto component: componentSize) {
nodesLeft -= component.second;
unreachableCount += component.second * nodesLeft;
}
return unreachableCount;
}
};