14
LeetCode 124. Binary Tree Maximum Path Sum
I normally do not enjoy the leetcode grind, but I had fun understanding the solution to this problem, as it had an enlightening gotcha.
Here is the link to it: https://leetcode.com/problems/binary-tree-maximum-path-sum
In short, the problem asks the following. In a binary tree, a path is a (non-repeating) sequence of adjacent nodes. For each path, take the sum of the values of each node. What's the highest value you can get?
For the tree
1
/ \
2 -3
the following are all the paths: [1]
, [2]
, [-3]
, [1,2]
, [1,-3]
, but also [2,1,-3]
. (I am considering the paths [2,1]
and [1,2]
as being the same) So two things to notice: the paths do not need to contain the root, paths can travel upwards.
I thought about how to solve it for bit, and then looked up the solution. The idea is pretty simple, once you realize that in doing DFS by recursion you should be looking at the tree from the bottom up.
Taking the tree above again, let's start from node 2
. It has no children, so the best thing we can do is the path [2]
for a max of 2. Same goes for -3. Once DFS moves up a level, we land in node 1
. Here we have several options.
- stand still, ie the path
[1]
for a max of 1 - consider 1 as part of a path going left, ie
[1,2]
with a max of 3 - consider 1 as part of a path going right, ie
[1,-3]
with a max of 2 - consider 1 as part of a path going from left to right, ie
[2,1,-3]
for max of 0
At this point I thought I had understood the solution, so I went and wrote this up.
def max_path(root: Node) -> int:
# initialize the best value to be the lowest number
# (the tree could contain all negative values,
# so we can't just initialize as 0)
# I'm using an array as a cheap way to bind a global variable
# perhaps the better way is to use python's global / nonlocal
best_global = [float('-inf')]
def helper(node: Node) -> int:
# node is None, return 0
if not node:
return 0
# compute the possible scenarios
stand = node.val
go_left = stand + helper(node.left)
go_right = stand + helper(node.right)
bridge = go_left + go_right - stand
# the value of bridge (ie a path that goes left and right) is
# helper(node.left) + stand + helper(node.right)
# = go_left - stand + stand + go_right - stand
# = go_left + go_right - stand
local_best = max(stand, go_left, go_right, bridge)
best = best_global.pop()
best = max(best, local_best)
best_global.append(best)
# return local_best to continue the recursion
return local_best
helper(root)
best = best_global.pop()
return best
I was confident about this code, as it's pretty intuitive. However, it contains a serious error. Can you spot it?
Here is a tree where it breaks down.
5
/
4
/ \
1 2
What's the best path sum we can do here? Well there's [5,4,1]
for a total of 10, and there's [5,4,2]
for a total of 11. However the code above will return 12. Why?
The mistake is in the recursion step, viz in the return value of the helper function. At node 4, local_best
is 7, which is the value of bridge
, corresponding to the path [1,4,2]
. Hence, at node 5 the helper function will happily return 5 + 7 = 12. But this would correspond to the "path" [5,4,1,2]
, which is not allowed!
Thankfully the fix is simple.
def max_path(root: Node) -> int:
best_global = [float('-inf')]
def helper(node: Node) -> int:
if not node:
return 0
stand = node.val
go_left = stand + helper(node.left)
go_right = stand + helper(node.right)
bridge = go_left + go_right - stand
# for recursion to be correct it should
# exclude bridge from the returned max value
recursion_best = max(stand, go_left, go_right)
local_best = max(recursion_best, bridge)
best = best_global.pop()
best = max(best, local_best)
best_global.append(best)
# return recursion_best to continue the recursion
return recursion_best
helper(root)
best = best_global.pop()
return best
So, in a nutshell, the latter is the solution to the "max_path" problem, while the former is the solution to the "max_subtree" problem.
All right, back to the grind.
14