26 December 2022

Introduction to Union-Find Algorithm

Union-find, also known as a disjoint-set data structure, is an algorithm that is used to keep track of groups of items. It provides two main operations: Find and Union.

Find:

The Find operation determines whether two items belong to the same group. For example, you might use the Find operation to determine whether two items are connected in a graph, or whether two items have the same root in a tree.

f = list(range(n))
def find(x):
    if f[x] != x: f[x] = find(f[x])
    return f[x]

f[x] : parent of x
find(x) : highest parent of x
Union:

The Union operation is used to merge two groups of items. For example, you might use the Union operation to merge two connected components in a graph, or to merge two subtrees in a tree.

def union(x, y):
    f[find(x)] = find(y)

union(x, y): connect x to y

It's important to note that the Find and Union operations are closely related, and they are often used together in algorithms that involve connectivity. Below, I summarized a list of union-find algorithm questions and answers from Leetcode to help you better understand the topic. Happy coding!😄

https://leetcode.com/problems/redundant-connection/

class Solution:
  def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
      n = len(edges)
      f = list(range(n + 1))
      
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)
      
      for u, v in edges:
          if find(u) == find(v): return [u, v]    
          union(u, v)

https://leetcode.com/problems/graph-valid-tree/

class Solution:
  def validTree(self, n: int, edges: List[List[int]]) -> bool:
      if len(edges) < n - 1: return False
      
      f = list(range(n))
      
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
          
      def union(x, y):
          f[find(x)] = find(y)
      
      for u, v in edges:
          if find(u) == find(v): return False
          union(u, v)
      return True

https://leetcode.com/problems/count-unreachable-pairs-of-nodes-in-an-undirected-graph/

class Solution:
  def countPairs(self, n: int, edges: List[List[int]]) -> int:
      f = list(range(n))
      
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)
          
      for u, v in edges:
          union(u, v)
      
      mp = defaultdict(int)
      for i in range(n):
          mp[find(i)] += 1
      
      ans, sm = 0, 0
      for x in mp.values():
          ans += sm * x
          sm += x
          
      return ans

https://leetcode.com/problems/connecting-cities-with-minimum-cost/

class Solution:
  def minimumCost(self, n: int, connections: List[List[int]]) -> int:
      f = list(range(n))
      
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)
      
      connections.sort(key = lambda x: x[2])
      
      ans = 0
      for u, v, w in connections:
          if find(u - 1) != find(v - 1):
              union(u - 1, v - 1)
              ans += w
              
      return ans if sum(f[i] == i for i in range(n)) == 1 else -1

https://leetcode.com/problems/most-stones-removed-with-same-row-or-column/

class Solution:
  def removeStones(self, stones: List[List[int]]) -> int:
      n = len(stones)
      f = {}
      
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)

      for u, v in stones:
          f[u] = u
          f[v + 10005] = v + 10005
      
      for u, v in stones:
          union(u, v + 10005)

      return n - sum(k == v for k, v in f.items())

https://leetcode.com/problems/minimize-hamming-distance-after-swap-operations/

class Solution:
  def minimumHammingDistance(self, source: List[int], target: List[int], allowedSwaps: List[List[int]]) -> int:
      n = len(source)
      f = list(range(n))

      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]

      def union(x, y):
          f[find(x)] = find(y)

      for u, v in allowedSwaps:
          union(u, v)

      g = defaultdict(list)
      for i in range(n):
          g[find(i)].append(i)
          
      ans = 0
      for v in g.values():
          freq = defaultdict(int)
          for idx in v:
              freq[target[idx]] += 1
              freq[source[idx]] -= 1
          ans += sum(abs(x) for x in freq.values()) // 2

      return ans

https://leetcode.com/problems/satisfiability-of-equality-equations/

class Solution:
  def equationsPossible(self, equations: List[str]) -> bool:
      f = {}

      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)

      for e in equations:
          f[e[0]] = e[0]
          f[e[3]] = e[3]
      
      for e in equations:
          if e[1] == "=":
              union(e[0], e[3])
      
      for e in equations:
          if e[1] == "!" and find(e[0]) == find(e[3]): return False
      return True

https://leetcode.com/problems/sentence-similarity-ii/

class Solution:
  def areSentencesSimilarTwo(self, sentence1: List[str], sentence2: List[str], similarPairs: List[List[str]]) -> bool:
      f = {}
      if len(sentence1) != len(sentence2): return False
      
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]

      def union(x, y):
          f[find(x)] = find(y)

      for u, v in similarPairs:
          f[u] = u
          f[v] = v

      for u, v in similarPairs:
          union(u, v)

      for i, j in zip(sentence1, sentence2):
          if i != j:
              if i not in f or j not in f: return False
              if find(i) != find(j): return False
      return True

https://leetcode.com/problems/min-cost-to-connect-all-points/

class Solution:
  def minCostConnectPoints(self, points: List[List[int]]) -> int:
      d = []
      for u, v in combinations(points, 2):
          d.append([abs(u[0] - v[0]) + abs(u[1] - v[1]), str(u), str(v)])

      d.sort()

      f = {}
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]

      def union(x, y):
          f[find(x)] = find(y)

      for w, u, v in d:
          f[u] = u
          f[v] = v
          
      ans = 0
      for w, u, v in d:
          if find(u) != find(v):
              union(u, v)
              ans += w

      return ans

https://leetcode.com/problems/similar-string-groups/

class Solution:
  def numSimilarGroups(self, strs: List[str]) -> int:
      f = {}
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(u, v):
          f[find(u)] = find(v)

      def isSimilar(x, y):
          return sum(u != v for u, v in zip(x, y)) <= 2
      
      for s in strs:
          f[s] = s

      for u, v in combinations(strs, 2):
          if isSimilar(u, v): union(u, v)
      
      return len(set(find(s) for s in strs))

https://leetcode.com/problems/accounts-merge/

class Solution:
  def accountsMerge(self, accounts: List[List[str]]) -> List[List[str]]:
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]

      def union(x, y):
          f[find(x)] = find(y)

      f, emailToName = {}, {}
      for account in accounts:
          name = account[0]
          for email in account[1:]:
              if email not in f: f[email] = email
              union(email, account[1])
              emailToName[email] = name

      ans = defaultdict(list)
      for email in f: ans[find(email)].append(email)

      return [[emailToName[root]] + sorted(email) for root, email in ans.items()]

https://leetcode.com/problems/lexicographically-smallest-equivalent-string/

class Solution:
  def smallestEquivalentString(self, s1: str, s2: str, baseStr: str) -> str:
      f = {}
      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)

      for u, v in zip(s1, s2):
          if u not in f: f[u] = u
          if v not in f: f[v] = v
          union(u, v)

      mp = defaultdict(lambda: 'z')
      for s in f: mp[find(s)] = min(mp[find(s)], s)

      base = list(baseStr)
      for i in range(len(base)):
          if base[i] in f: base[i] = mp[find(base[i])]
      
      return ''.join(base)

https://leetcode.com/problems/smallest-string-with-swaps/

class Solution:
  def smallestStringWithSwaps(self, s: str, pairs: List[List[int]]) -> str:
      n = len(s)
      f = list(range(n))

      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]
      
      def union(x, y):
          f[find(x)] = find(y)

      for u, v in pairs:
          union(u, v)

      mp = defaultdict(list)
      for i in range(n): mp[find(i)].append(i)

      ans = [''] * n
      for idx in mp.values():
          t = sorted([s[i] for i in idx])
          for i in range(len(idx)):
              ans[idx[i]] = t[i]

      return ''.join(ans)

https://leetcode.com/problems/synonymous-sentences/

class Solution:
  def generateSentences(self, synonyms: List[List[str]], text: str) -> List[str]:
      f = {}

      def find(x):
          if f[x] != x: f[x] = find(f[x])
          return f[x]

      def union(x, y):
          f[find(x)] = find(y)

      for u, v in synonyms:
          if u not in f: f[u] = u
          if v not in f: f[v] = v
          union(u, v)

      mp = defaultdict(list)
      for x in f: mp[find(x)].append(x)
      s = text.split()
      ans = [s[::]]
      
      for i in range(len(s)):
          if s[i] in f:
              nxt = []
              for w in ans:
                  for v in mp[find(s[i])]:
                      w[i] = v
                      nxt.append(w[::])
              ans = nxt
      return sorted(' '.join(x) for x in ans)