Union-find, also known as a disjoint-set data structure, is an algorithm that is used to keep track of groups of items. It provides two main operations: Find and Union.
Find:
The Find operation determines whether two items belong to the same group. For example, you might use the Find operation to determine whether two items are connected in a graph, or whether two items have the same root in a tree.
f = list(range(n))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
f[x] : parent of x
find(x) : highest parent of x
Union:
The Union operation is used to merge two groups of items. For example, you might use the Union operation to merge two connected components in a graph, or to merge two subtrees in a tree.
def union(x, y):
f[find(x)] = find(y)
union(x, y): connect x to y
It's important to note that the Find and Union operations are closely related, and they are often used together in algorithms that involve connectivity. Below, I summarized a list of union-find algorithm questions and answers from Leetcode to help you better understand the topic. Happy coding!😄
https://leetcode.com/problems/redundant-connection/
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
n = len(edges)
f = list(range(n + 1))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in edges:
if find(u) == find(v): return [u, v]
union(u, v)
https://leetcode.com/problems/graph-valid-tree/
class Solution:
def validTree(self, n: int, edges: List[List[int]]) -> bool:
if len(edges) < n - 1: return False
f = list(range(n))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in edges:
if find(u) == find(v): return False
union(u, v)
return True
https://leetcode.com/problems/count-unreachable-pairs-of-nodes-in-an-undirected-graph/
class Solution:
def countPairs(self, n: int, edges: List[List[int]]) -> int:
f = list(range(n))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in edges:
union(u, v)
mp = defaultdict(int)
for i in range(n):
mp[find(i)] += 1
ans, sm = 0, 0
for x in mp.values():
ans += sm * x
sm += x
return ans
https://leetcode.com/problems/connecting-cities-with-minimum-cost/
class Solution:
def minimumCost(self, n: int, connections: List[List[int]]) -> int:
f = list(range(n))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
connections.sort(key = lambda x: x[2])
ans = 0
for u, v, w in connections:
if find(u - 1) != find(v - 1):
union(u - 1, v - 1)
ans += w
return ans if sum(f[i] == i for i in range(n)) == 1 else -1
https://leetcode.com/problems/most-stones-removed-with-same-row-or-column/
class Solution:
def removeStones(self, stones: List[List[int]]) -> int:
n = len(stones)
f = {}
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in stones:
f[u] = u
f[v + 10005] = v + 10005
for u, v in stones:
union(u, v + 10005)
return n - sum(k == v for k, v in f.items())
https://leetcode.com/problems/minimize-hamming-distance-after-swap-operations/
class Solution:
def minimumHammingDistance(self, source: List[int], target: List[int], allowedSwaps: List[List[int]]) -> int:
n = len(source)
f = list(range(n))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in allowedSwaps:
union(u, v)
g = defaultdict(list)
for i in range(n):
g[find(i)].append(i)
ans = 0
for v in g.values():
freq = defaultdict(int)
for idx in v:
freq[target[idx]] += 1
freq[source[idx]] -= 1
ans += sum(abs(x) for x in freq.values()) // 2
return ans
https://leetcode.com/problems/satisfiability-of-equality-equations/
class Solution:
def equationsPossible(self, equations: List[str]) -> bool:
f = {}
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for e in equations:
f[e[0]] = e[0]
f[e[3]] = e[3]
for e in equations:
if e[1] == "=":
union(e[0], e[3])
for e in equations:
if e[1] == "!" and find(e[0]) == find(e[3]): return False
return True
https://leetcode.com/problems/sentence-similarity-ii/
class Solution:
def areSentencesSimilarTwo(self, sentence1: List[str], sentence2: List[str], similarPairs: List[List[str]]) -> bool:
f = {}
if len(sentence1) != len(sentence2): return False
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in similarPairs:
f[u] = u
f[v] = v
for u, v in similarPairs:
union(u, v)
for i, j in zip(sentence1, sentence2):
if i != j:
if i not in f or j not in f: return False
if find(i) != find(j): return False
return True
https://leetcode.com/problems/min-cost-to-connect-all-points/
class Solution:
def minCostConnectPoints(self, points: List[List[int]]) -> int:
d = []
for u, v in combinations(points, 2):
d.append([abs(u[0] - v[0]) + abs(u[1] - v[1]), str(u), str(v)])
d.sort()
f = {}
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for w, u, v in d:
f[u] = u
f[v] = v
ans = 0
for w, u, v in d:
if find(u) != find(v):
union(u, v)
ans += w
return ans
https://leetcode.com/problems/similar-string-groups/
class Solution:
def numSimilarGroups(self, strs: List[str]) -> int:
f = {}
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(u, v):
f[find(u)] = find(v)
def isSimilar(x, y):
return sum(u != v for u, v in zip(x, y)) <= 2
for s in strs:
f[s] = s
for u, v in combinations(strs, 2):
if isSimilar(u, v): union(u, v)
return len(set(find(s) for s in strs))
https://leetcode.com/problems/accounts-merge/
class Solution:
def accountsMerge(self, accounts: List[List[str]]) -> List[List[str]]:
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
f, emailToName = {}, {}
for account in accounts:
name = account[0]
for email in account[1:]:
if email not in f: f[email] = email
union(email, account[1])
emailToName[email] = name
ans = defaultdict(list)
for email in f: ans[find(email)].append(email)
return [[emailToName[root]] + sorted(email) for root, email in ans.items()]
https://leetcode.com/problems/lexicographically-smallest-equivalent-string/
class Solution:
def smallestEquivalentString(self, s1: str, s2: str, baseStr: str) -> str:
f = {}
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in zip(s1, s2):
if u not in f: f[u] = u
if v not in f: f[v] = v
union(u, v)
mp = defaultdict(lambda: 'z')
for s in f: mp[find(s)] = min(mp[find(s)], s)
base = list(baseStr)
for i in range(len(base)):
if base[i] in f: base[i] = mp[find(base[i])]
return ''.join(base)
https://leetcode.com/problems/smallest-string-with-swaps/
class Solution:
def smallestStringWithSwaps(self, s: str, pairs: List[List[int]]) -> str:
n = len(s)
f = list(range(n))
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in pairs:
union(u, v)
mp = defaultdict(list)
for i in range(n): mp[find(i)].append(i)
ans = [''] * n
for idx in mp.values():
t = sorted([s[i] for i in idx])
for i in range(len(idx)):
ans[idx[i]] = t[i]
return ''.join(ans)
https://leetcode.com/problems/synonymous-sentences/
class Solution:
def generateSentences(self, synonyms: List[List[str]], text: str) -> List[str]:
f = {}
def find(x):
if f[x] != x: f[x] = find(f[x])
return f[x]
def union(x, y):
f[find(x)] = find(y)
for u, v in synonyms:
if u not in f: f[u] = u
if v not in f: f[v] = v
union(u, v)
mp = defaultdict(list)
for x in f: mp[find(x)].append(x)
s = text.split()
ans = [s[::]]
for i in range(len(s)):
if s[i] in f:
nxt = []
for w in ans:
for v in mp[find(s[i])]:
w[i] = v
nxt.append(w[::])
ans = nxt
return sorted(' '.join(x) for x in ans)