LeetCode 124. Binary Tree Maximum Path Sum

I normally do not enjoy the leetcode grind, but I had fun understanding the solution to this problem, as it had an enlightening gotcha.
Here is the link to it: https://leetcode.com/problems/binary-tree-maximum-path-sum

In short, the problem asks the following. In a binary tree, a path is a (non-repeating) sequence of adjacent nodes. For each path, take the sum of the values of each node. What's the highest value you can get?

For the tree

1
 / \
2   -3

the following are all the paths: [1], [2], [-3], [1,2], [1,-3], but also [2,1,-3]. (I am considering the paths [2,1] and [1,2] as being the same) So two things to notice: the paths do not need to contain the root, paths can travel upwards.

I thought about how to solve it for bit, and then looked up the solution. The idea is pretty simple, once you realize that in doing DFS by recursion you should be looking at the tree from the bottom up.

Taking the tree above again, let's start from node 2. It has no children, so the best thing we can do is the path [2] for a max of 2. Same goes for -3. Once DFS moves up a level, we land in node 1. Here we have several options.

  • stand still, ie the path [1] for a max of 1
  • consider 1 as part of a path going left, ie [1,2] with a max of 3
  • consider 1 as part of a path going right, ie [1,-3] with a max of 2
  • consider 1 as part of a path going from left to right, ie [2,1,-3] for max of 0

At this point I thought I had understood the solution, so I went and wrote this up.

def max_path(root: Node) -> int:
    # initialize the best value to be the lowest number
    # (the tree could contain all negative values,
    # so we can't just initialize as 0)
    # I'm using an array as a cheap way to bind a global variable
    # perhaps the better way is to use python's global / nonlocal
    best_global = [float('-inf')]

    def helper(node: Node) -> int:
        # node is None, return 0
        if not node:
            return 0

        # compute the possible scenarios
        stand = node.val
        go_left = stand + helper(node.left)
        go_right = stand + helper(node.right)
        bridge = go_left + go_right - stand
        # the value of bridge (ie a path that goes left and right) is
        # helper(node.left) + stand + helper(node.right)
        # = go_left - stand + stand + go_right - stand
        # = go_left + go_right - stand

        local_best = max(stand, go_left, go_right, bridge)
        best = best_global.pop()
        best = max(best, local_best)
        best_global.append(best)

        # return local_best to continue the recursion
        return local_best

    helper(root)
    best = best_global.pop()
    return best

I was confident about this code, as it's pretty intuitive. However, it contains a serious error. Can you spot it?

Here is a tree where it breaks down.

5
   /
  4
 / \
1   2

What's the best path sum we can do here? Well there's [5,4,1] for a total of 10, and there's [5,4,2] for a total of 11. However the code above will return 12. Why?
The mistake is in the recursion step, viz in the return value of the helper function. At node 4, local_best is 7, which is the value of bridge, corresponding to the path [1,4,2]. Hence, at node 5 the helper function will happily return 5 + 7 = 12. But this would correspond to the "path" [5,4,1,2], which is not allowed!

Thankfully the fix is simple.

def max_path(root: Node) -> int:
    best_global = [float('-inf')]

    def helper(node: Node) -> int:
        if not node:
            return 0

        stand = node.val
        go_left = stand + helper(node.left)
        go_right = stand + helper(node.right)
        bridge = go_left + go_right - stand

        # for recursion to be correct it should
        # exclude bridge from the returned max value
        recursion_best = max(stand, go_left, go_right)
        local_best = max(recursion_best, bridge)
        best = best_global.pop()
        best = max(best, local_best)
        best_global.append(best)

        # return recursion_best to continue the recursion
        return recursion_best

    helper(root)
    best = best_global.pop()
    return best

So, in a nutshell, the latter is the solution to the "max_path" problem, while the former is the solution to the "max_subtree" problem.

All right, back to the grind.

10