25
From DFS to Topological Sort
Recently I revisited this leetcode question https://leetcode.com/problems/alien-dictionary/solution/. I first learned about it a year and a half ago when I was a data science fellow at Insight, in New York. It's a fun problem, though a terrible interview question. I think it would be much better if you could at least assume the alien dictionary has no inconsistencies (ie that an order can always be extracted).
Depth-first search is a standard method to explore a graph. For example, let's say we want to print the value of all the nodes in a graph. Here is a(n incorrect) recursive implementation.
def dfs(node):
print(node.val)
for child in node.children:
dfs(child)
There are two big issues with this code. The first is that it might get stuck in a loop. For the graph 1->2->3->1
, calling dfs(node_1)
will loop forever. Let's fix this by keeping track of which nodes have already been visited by the DFS.
def dfs(node):
if node in visited:
return None
print(node.val)
visited.add(node)
for child in node.children:
dfs(child)
visited = set()
The second issue is that it has no awareness of the whole graph. For the graph 1 -> 2 -> 3
, calling dfs(2)
will miss node 1
. So let's assume we have a dictionary edges
, where edges[node]
lists all the children of node
.
def print_nodes(edges):
def dfs(node):
if node in visited:
return None
print(node.val)
visited.add(node)
for child in edges[node]:
dfs(child)
visited = set()
for node in edges.keys():
dfs(node)
For the graph
1
/
2
/ \
3 4
(where I assume the edges are all pointing down) the function above will print the nodes in the order 1,2,3,4
(assuming 3 comes before 4 in edges[node_2]
). This would be a "pre-order" way of printing the graph. The "post-order" way would be
def print_nodes_post(edges):
def dfs(node):
if node in visited:
return None
visited.add(node)
for child in edges[node]:
dfs(child)
print(node)
visited = set()
for node in edges.keys():
dfs(node)
yielding 3,4,2,1
.
Obviously printing out the nodes in a graph is a silly task (we could just directly use edges.keys()
). The point is that DFS allows us to explore a whole graph, and we can perform tasks as we move along.
def dfs(node):
visited.add(node)
do_some_work_pre_children(node)
for child in edges[node]:
dfs(child)
do_some_work_post_children(node)
Given a graph, let's write a function to detect if it has any cycles. Here is an (incorrect) approach.
def is_acyclic(edges):
def dfs(node):
if node in visited:
# cycle detected?
return False
visited.add(node)
for child in edges[node]:
if dfs(child) is False:
return False
return True
visited = set()
for node in edges.keys():
if node not in visited:
if dfs(node) is False:
return False
return True
The main thought process here is that if upon calling dfs
the node is already in visited
then a cycle must be present. However this is deeply flawed. For example with the graph
1
/ \
2 3
\ /
4
(where again all the edges are pointing downwards) if we start DFS from 0 everything is fine. But if we instead started from node 2 (or 3, or 4) the function will return False
, ie that the graph contains a cycle.
To fix this, we add an additional visited
set, which is specific to the recursion.
def is_acyclic(edges):
def dfs(node):
if node in visited:
return True
visited.add(node)
visited_dfs.add(node)
for child in edges[node]:
if child in visited_dfs:
return False
if dfs(child) is False:
return False
# after we explored node and all its children
# we can remove the temporary marker
# if we don't we'll run into the same issue
# as in the naive approach
visited_dfs.remove(node)
return True
visited = set()
visited_dfs = set()
for node in edges.keys():
if dfs(node) is False:
return False
return True
Let's say you have a graph which represents certain tasks to be executed (eg ETL tasks for an orchestrator like airflow or prefect, or a computation graph for a neural network). Edges in this graph represent a dependency between two tasks. Because of these dependencies, we can't just execute these tasks in a random order. It would be nice if we could extract a total order of the tasks, compatible with the edges, ie if you can go from node a to node b, we should have a < b
in the order.
We can't do this for all graphs. Whatever order we choose for the nodes of 1->2->3->1
it will be incompatible with the graph structure. However if we have a DAG (directed acyclic graph), ie a graph without cycles, then we can. One way to come up with such an order is via DFS. For the graph
0
|
1
/ \
2 3
\ /
4
5
[0,1,2,3,4,5]
, [0,1,3,2,4,5]
, [0,5,1,2,3,4]
are all acceptable orders. The idea is to explore the graph using DFS and to add nodes to the list in post-order. For example let's start by exploring the graph above from the node 2. This will yield the list [4,2]
. Continuing from node 1, we will have [4,2,3,1]
. Continue with node 5 to get [4,2,3,1,5]
, and finally add 0 to get [4,2,3,1,5,0]
. Reversing this list we get [0,5,1,3,2,4]
, which is a valid total order.
def topological_sort(edges):
"""Assumes graph is acyclic."""
def dfs(node):
if node in visited:
return None
visited.add(node)
for child in edges[node]:
dfs(child)
order.append(node)
order = []
visited = set()
for node in edges.keys():
dfs(node)
order.reverse()
return order
Of course, if your graph is not guaranteed to be acyclic, we can combine the two operations at once.
def topological_sort_if_acyclic(edges):
"""Return an order if acyclic, else return []"""
def dfs(node):
if node in visited:
return True
visited.add(node)
visited_dfs.add(node)
for child in edges[node]:
if child in visited_dfs:
return False
if dfs(child) is False:
return False
visited_dfs.remove(node)
order.append(node)
return True
order = []
visited = set()
visited_dfs = set()
for node in edges.keys():
if dfs(node) is False:
return []
order.reverse()
return order
25