Union-find, also known as a disjoint-set data structure, is an algorithm that is used to keep track of groups of items. It provides two main operations: Find and Union.
Find:
The Find operation determines whether two items belong to the same group. For example, you might use the Find operation to determine whether two items are connected in a graph, or whether two items have the same root in a tree.
f = list(range(n)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] f[x] : parent of x find(x) : highest parent of x
Union:
The Union operation is used to merge two groups of items. For example, you might use the Union operation to merge two connected components in a graph, or to merge two subtrees in a tree.
def union(x, y): f[find(x)] = find(y) union(x, y): connect x to y
It's important to note that the Find and Union operations are closely related, and they are often used together in algorithms that involve connectivity. Below, I summarized a list of union-find algorithm questions and answers from Leetcode to help you better understand the topic. Happy coding!😄
https://leetcode.com/problems/redundant-connection/
class Solution: def findRedundantConnection(self, edges: List[List[int]]) -> List[int]: n = len(edges) f = list(range(n + 1)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in edges: if find(u) == find(v): return [u, v] union(u, v)
https://leetcode.com/problems/graph-valid-tree/
class Solution: def validTree(self, n: int, edges: List[List[int]]) -> bool: if len(edges) < n - 1: return False f = list(range(n)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in edges: if find(u) == find(v): return False union(u, v) return True
https://leetcode.com/problems/count-unreachable-pairs-of-nodes-in-an-undirected-graph/
class Solution: def countPairs(self, n: int, edges: List[List[int]]) -> int: f = list(range(n)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in edges: union(u, v) mp = defaultdict(int) for i in range(n): mp[find(i)] += 1 ans, sm = 0, 0 for x in mp.values(): ans += sm * x sm += x return ans
https://leetcode.com/problems/connecting-cities-with-minimum-cost/
class Solution: def minimumCost(self, n: int, connections: List[List[int]]) -> int: f = list(range(n)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) connections.sort(key = lambda x: x[2]) ans = 0 for u, v, w in connections: if find(u - 1) != find(v - 1): union(u - 1, v - 1) ans += w return ans if sum(f[i] == i for i in range(n)) == 1 else -1
https://leetcode.com/problems/most-stones-removed-with-same-row-or-column/
class Solution: def removeStones(self, stones: List[List[int]]) -> int: n = len(stones) f = {} def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in stones: f[u] = u f[v + 10005] = v + 10005 for u, v in stones: union(u, v + 10005) return n - sum(k == v for k, v in f.items())
https://leetcode.com/problems/minimize-hamming-distance-after-swap-operations/
class Solution: def minimumHammingDistance(self, source: List[int], target: List[int], allowedSwaps: List[List[int]]) -> int: n = len(source) f = list(range(n)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in allowedSwaps: union(u, v) g = defaultdict(list) for i in range(n): g[find(i)].append(i) ans = 0 for v in g.values(): freq = defaultdict(int) for idx in v: freq[target[idx]] += 1 freq[source[idx]] -= 1 ans += sum(abs(x) for x in freq.values()) // 2 return ans
https://leetcode.com/problems/satisfiability-of-equality-equations/
class Solution: def equationsPossible(self, equations: List[str]) -> bool: f = {} def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for e in equations: f[e[0]] = e[0] f[e[3]] = e[3] for e in equations: if e[1] == "=": union(e[0], e[3]) for e in equations: if e[1] == "!" and find(e[0]) == find(e[3]): return False return True
https://leetcode.com/problems/sentence-similarity-ii/
class Solution: def areSentencesSimilarTwo(self, sentence1: List[str], sentence2: List[str], similarPairs: List[List[str]]) -> bool: f = {} if len(sentence1) != len(sentence2): return False def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in similarPairs: f[u] = u f[v] = v for u, v in similarPairs: union(u, v) for i, j in zip(sentence1, sentence2): if i != j: if i not in f or j not in f: return False if find(i) != find(j): return False return True
https://leetcode.com/problems/min-cost-to-connect-all-points/
class Solution: def minCostConnectPoints(self, points: List[List[int]]) -> int: d = [] for u, v in combinations(points, 2): d.append([abs(u[0] - v[0]) + abs(u[1] - v[1]), str(u), str(v)]) d.sort() f = {} def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for w, u, v in d: f[u] = u f[v] = v ans = 0 for w, u, v in d: if find(u) != find(v): union(u, v) ans += w return ans
https://leetcode.com/problems/similar-string-groups/
class Solution: def numSimilarGroups(self, strs: List[str]) -> int: f = {} def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(u, v): f[find(u)] = find(v) def isSimilar(x, y): return sum(u != v for u, v in zip(x, y)) <= 2 for s in strs: f[s] = s for u, v in combinations(strs, 2): if isSimilar(u, v): union(u, v) return len(set(find(s) for s in strs))
https://leetcode.com/problems/accounts-merge/
class Solution: def accountsMerge(self, accounts: List[List[str]]) -> List[List[str]]: def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) f, emailToName = {}, {} for account in accounts: name = account[0] for email in account[1:]: if email not in f: f[email] = email union(email, account[1]) emailToName[email] = name ans = defaultdict(list) for email in f: ans[find(email)].append(email) return [[emailToName[root]] + sorted(email) for root, email in ans.items()]
https://leetcode.com/problems/lexicographically-smallest-equivalent-string/
class Solution: def smallestEquivalentString(self, s1: str, s2: str, baseStr: str) -> str: f = {} def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in zip(s1, s2): if u not in f: f[u] = u if v not in f: f[v] = v union(u, v) mp = defaultdict(lambda: 'z') for s in f: mp[find(s)] = min(mp[find(s)], s) base = list(baseStr) for i in range(len(base)): if base[i] in f: base[i] = mp[find(base[i])] return ''.join(base)
https://leetcode.com/problems/smallest-string-with-swaps/
class Solution: def smallestStringWithSwaps(self, s: str, pairs: List[List[int]]) -> str: n = len(s) f = list(range(n)) def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in pairs: union(u, v) mp = defaultdict(list) for i in range(n): mp[find(i)].append(i) ans = [''] * n for idx in mp.values(): t = sorted([s[i] for i in idx]) for i in range(len(idx)): ans[idx[i]] = t[i] return ''.join(ans)
https://leetcode.com/problems/synonymous-sentences/
class Solution: def generateSentences(self, synonyms: List[List[str]], text: str) -> List[str]: f = {} def find(x): if f[x] != x: f[x] = find(f[x]) return f[x] def union(x, y): f[find(x)] = find(y) for u, v in synonyms: if u not in f: f[u] = u if v not in f: f[v] = v union(u, v) mp = defaultdict(list) for x in f: mp[find(x)].append(x) s = text.split() ans = [s[::]] for i in range(len(s)): if s[i] in f: nxt = [] for w in ans: for v in mp[find(s[i])]: w[i] = v nxt.append(w[::]) ans = nxt return sorted(' '.join(x) for x in ans)