Two Pointer Technique — Deep Dive

Mathematical foundation

The two pointer technique works on problems where the search space can be pruned by a monotonic relationship. For converging pointers on a sorted array:

Given sorted array a[0] ≤ a[1] ≤ ... ≤ a[n-1] and target T:

  • If a[L] + a[R] < T, then a[L] + a[j] < T for all j ≤ R, so pair (L, j) can be eliminated → move L right.
  • If a[L] + a[R] > T, then a[i] + a[R] > T for all i ≥ L, so pair (i, R) can be eliminated → move R left.

This eliminates O(n) candidates per step, reducing O(n²) to O(n).

Dutch National Flag problem

Partition an array into three groups using three pointers. This is the foundation of quicksort’s three-way partition.

def sort_colors(nums):
    """Sort array of 0s, 1s, and 2s in-place."""
    low = 0          # boundary for 0s
    mid = 0          # current element
    high = len(nums) - 1  # boundary for 2s

    while mid <= high:
        if nums[mid] == 0:
            nums[low], nums[mid] = nums[mid], nums[low]
            low += 1
            mid += 1
        elif nums[mid] == 1:
            mid += 1
        else:  # nums[mid] == 2
            nums[mid], nums[high] = nums[high], nums[mid]
            high -= 1
            # Note: mid does NOT advance — the swapped element needs checking

    return nums

print(sort_colors([2, 0, 2, 1, 1, 0]))  # [0, 0, 1, 1, 2, 2]

Invariants maintained:

  • nums[0:low] contains only 0s
  • nums[low:mid] contains only 1s
  • nums[high+1:] contains only 2s
  • nums[mid:high+1] is unprocessed

Time: O(n) single pass. Space: O(1).

Why mid does not advance when swapping with high: The element swapped from high to mid has not been examined yet — it could be 0, 1, or 2. The element swapped from low to mid is guaranteed to be 1 (it was previously processed), so mid safely advances.

Trapping Rain Water

Two approaches: two pointers and stack. The two pointer solution is O(n) time and O(1) space.

def trap(height):
    if not height:
        return 0

    left, right = 0, len(height) - 1
    left_max, right_max = height[left], height[right]
    water = 0

    while left < right:
        if left_max < right_max:
            left += 1
            left_max = max(left_max, height[left])
            water += left_max - height[left]
        else:
            right -= 1
            right_max = max(right_max, height[right])
            water += right_max - height[right]

    return water

print(trap([0, 1, 0, 2, 1, 0, 1, 3, 2, 1, 2, 1]))  # 6

Why it works: Water at position i = min(max_left, max_right) - height[i]. We process from the side with the smaller maximum because we know the water level at that side is determined by that maximum (the other side is at least as tall). When left_max < right_max, we know the water at left is bounded by left_max regardless of what is between the pointers.

4Sum

Generalize 3Sum with an additional loop:

def four_sum(nums, target):
    nums.sort()
    n = len(nums)
    result = []

    for i in range(n - 3):
        if i > 0 and nums[i] == nums[i-1]:
            continue
        # Early termination: if smallest 4 sum exceeds target
        if nums[i] + nums[i+1] + nums[i+2] + nums[i+3] > target:
            break
        # Skip: if largest possible sum with nums[i] is too small
        if nums[i] + nums[n-1] + nums[n-2] + nums[n-3] < target:
            continue

        for j in range(i + 1, n - 2):
            if j > i + 1 and nums[j] == nums[j-1]:
                continue
            if nums[i] + nums[j] + nums[j+1] + nums[j+2] > target:
                break
            if nums[i] + nums[j] + nums[n-1] + nums[n-2] < target:
                continue

            left, right = j + 1, n - 1
            while left < right:
                total = nums[i] + nums[j] + nums[left] + nums[right]
                if total == target:
                    result.append([nums[i], nums[j], nums[left], nums[right]])
                    while left < right and nums[left] == nums[left + 1]:
                        left += 1
                    while left < right and nums[right] == nums[right - 1]:
                        right -= 1
                    left += 1
                    right -= 1
                elif total < target:
                    left += 1
                else:
                    right -= 1

    return result

print(four_sum([1, 0, -1, 0, -2, 2], 0))
# [[-2, -1, 1, 2], [-2, 0, 0, 2], [-1, 0, 0, 1]]

Time: O(n³). The early termination pruning significantly reduces actual runtime on most inputs.

Generalization to K-Sum: For K-Sum, use K-2 nested loops + two pointers, giving O(n^(K-1)) time.

Linked list intersection

Find the node where two singly linked lists intersect.

def get_intersection_node(headA, headB):
    if not headA or not headB:
        return None

    a, b = headA, headB

    while a is not b:
        a = a.next if a else headB
        b = b.next if b else headA

    return a  # either intersection node or None

Why it works: Let list A have length m + c and list B have length n + c, where c is the shared tail. Pointer a traverses A then B: m + c + n steps to reach intersection. Pointer b traverses B then A: n + c + m steps. Both traverse the same total distance, so they meet at the intersection.

If there is no intersection, both pointers reach None after m + n + 2c steps (since both traverse both full lists).

Cycle detection: finding the cycle start

Floyd’s algorithm finds not just whether a cycle exists, but where it starts:

def detect_cycle(head):
    slow = fast = head

    # Phase 1: Detect cycle
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next
        if slow == fast:
            break
    else:
        return None  # no cycle

    # Phase 2: Find cycle start
    slow = head
    while slow != fast:
        slow = slow.next
        fast = fast.next  # now both move at same speed

    return slow  # cycle start node

Mathematical proof: Let the distance from head to cycle start be a, cycle length be L, and the distance from cycle start to meeting point be b. At the meeting point:

  • Slow has traveled: a + b
  • Fast has traveled: a + b + kL (for some integer k ≥ 1)
  • Fast travels 2× slow: 2(a + b) = a + b + kLa + b = kLa = kL - b

Starting a pointer at head (distance a from cycle start) and another at the meeting point (distance L - b from cycle start): since a = kL - b = (k-1)L + (L - b), both pointers reach the cycle start after a steps. The meeting-point pointer does (k-1) full laps plus L-b additional steps, arriving at the cycle start.

Reverse pairs counting

Count pairs (i, j) where i < j and nums[i] > 2 × nums[j]. This combines merge sort with two pointers:

def reverse_pairs(nums):
    def merge_count(arr, left, right):
        if left >= right:
            return 0

        mid = (left + right) // 2
        count = merge_count(arr, left, mid) + merge_count(arr, mid + 1, right)

        # Count reverse pairs with two pointers
        j = mid + 1
        for i in range(left, mid + 1):
            while j <= right and arr[i] > 2 * arr[j]:
                j += 1
            count += j - (mid + 1)

        # Standard merge
        temp = []
        i, j = left, mid + 1
        while i <= mid and j <= right:
            if arr[i] <= arr[j]:
                temp.append(arr[i])
                i += 1
            else:
                temp.append(arr[j])
                j += 1
        temp.extend(arr[i:mid+1])
        temp.extend(arr[j:right+1])
        arr[left:right+1] = temp

        return count

    return merge_count(nums, 0, len(nums) - 1)

print(reverse_pairs([1, 3, 2, 3, 1]))  # 2

Time: O(n log n). The two-pointer counting step is O(n) per merge level, and there are O(log n) levels.

Two pointers on matrices

Search a sorted matrix

Each row and column is sorted. Find a target value:

def search_matrix(matrix, target):
    if not matrix:
        return False

    row, col = 0, len(matrix[0]) - 1  # start top-right

    while row < len(matrix) and col >= 0:
        if matrix[row][col] == target:
            return True
        elif matrix[row][col] < target:
            row += 1     # eliminate this row
        else:
            col -= 1     # eliminate this column

    return False

Time: O(m + n) where m = rows, n = columns. Each step eliminates a row or column.

Performance patterns

ProblemBrute forceTwo pointersSpeedup
Two Sum (sorted)O(n²)O(n)
3SumO(n³)O(n²)
Remove duplicatesO(n) extra spaceO(1) spaceSpace savings
Cycle detectionO(n) space (hash set)O(1) spaceSpace savings
Trapping rain waterO(n) space (prefix arrays)O(1) spaceSpace savings
Matrix searchO(mn) or O(m log n)O(m+n)Significant

Two pointers often do not change the time complexity class but eliminate a factor of n or reduce space from O(n) to O(1). The technique is about efficiency at the constant level, not just asymptotic improvement.

One thing to remember: Two pointers is fundamentally about pruning the search space. Each pointer movement eliminates candidates that provably cannot be part of the answer. The key to applying it correctly is identifying the invariant that makes this pruning safe — usually sorted order, a monotonic relationship, or a structural property of the data.

pythonalgorithmstwo-pointersinterviews

See Also