Data Structures
Union Find
Overview
Union Find (also called Disjoint Set Union or DSU) is a data structure that tracks elements partitioned into disjoint sets. It efficiently supports two operations:
- Find: Determine which set an element belongs to
- Union: Merge two sets into one
Use Cases: Detecting cycles in graphs, finding connected components, Kruskal's MST algorithm, dynamic connectivity problems.
Complexity
| Operation | Without Rank | With Rank + Path Compression |
|---|---|---|
| Find | worst | amortized |
| Union | worst | amortized |
Where is the inverse Ackermann function (effectively constant, for practical values).
Optimizations
- Path Compression: During
find(), make nodes point directly to root - Union by Rank: Attach smaller tree under root of larger tree
Implementation
class UnionFind:
def __init__(self, size: int) -> None:
self.root = list(range(size))
def find(self, x: int) -> int:
rx = self.root[x]
if rx != x:
rx = self.find(rx)
self.root[x] = rx
return rx
def union(self, x: int, y: int) -> bool:
rx = self.find(x)
ry = self.find(y)
if rx != ry:
self.root[rx] = ry
return True
return Falseclass UnionFind:
def __init__(self, size: int) -> None:
self.root = list(range(size))
self.rank = [1] * size
def find(self, x: int) -> int:
rx = self.root[x]
if rx != x:
rx = self.find(rx)
self.root[x] = rx
return rx
def union(self, x: int, y: int) -> bool:
rx = self.find(x)
ry = self.find(y)
if rx != ry:
if self.rank[rx] < self.rank[ry]:
self.root[rx] = ry
elif self.rank[rx] > self.rank[ry]:
self.root[ry] = rx
else:
self.root[ry] = rx
self.rank[rx] += 1
return True
return FalseExample Usage
# Initialize with 5 elements (0-4)
uf = UnionFind(5)
# Union operations
uf.union(0, 1) # Connect 0 and 1
uf.union(1, 2) # Connect 1 and 2
uf.union(3, 4) # Connect 3 and 4
# Find operations
uf.find(0) # Returns root of set containing 0
uf.find(2) # Returns same root as find(0)
uf.find(3) # Returns different root
# Check if connected
uf.find(0) == uf.find(2) # True (in same set)
uf.find(0) == uf.find(3) # False (different sets)Common Patterns
Cycle Detection
def has_cycle(n: int, edges: list[list[int]]) -> bool:
uf = UnionFind(n)
for u, v in edges:
if not uf.union(u, v): # Already connected
return True
return FalseCount Components
def count_components(n: int, edges: list[list[int]]) -> int:
uf = UnionFind(n)
for u, v in edges:
uf.union(u, v)
return len(set(uf.find(i) for i in range(n)))Key Points
- Essential for graph connectivity problems
- Nearly constant time operations with both optimizations
union()returnsFalseif elements already connected (useful for cycle detection)- Path compression happens during
find(), notunion()
Best Practice
Always use Union Find with rank and path compression in competitive programming for optimal amortized complexity.