From DFS to Topological Sort

Recently I revisited this leetcode question https://leetcode.com/problems/alien-dictionary/solution/. I first learned about it a year and a half ago when I was a data science fellow at Insight, in New York. It's a fun problem, though a terrible interview question. I think it would be much better if you could at least assume the alien dictionary has no inconsistencies (ie that an order can always be extracted).

DFS

Depth-first search is a standard method to explore a graph. For example, let's say we want to print the value of all the nodes in a graph. Here is a(n incorrect) recursive implementation.

def dfs(node):
    print(node.val)
    for child in node.children:
        dfs(child)

There are two big issues with this code. The first is that it might get stuck in a loop. For the graph 1->2->3->1, calling dfs(node_1) will loop forever. Let's fix this by keeping track of which nodes have already been visited by the DFS.

def dfs(node):
    if node in visited:
        return None
    print(node.val)
    visited.add(node)
    for child in node.children:
        dfs(child)

visited = set()

The second issue is that it has no awareness of the whole graph. For the graph 1 -> 2 -> 3, calling dfs(2) will miss node 1. So let's assume we have a dictionary edges, where edges[node] lists all the children of node.

def print_nodes(edges):
    def dfs(node):
        if node in visited:
            return None
        print(node.val)
        visited.add(node)
        for child in edges[node]:
            dfs(child)

    visited = set()
    for node in edges.keys():
        dfs(node)

For the graph

1
   /
  2
 / \
3   4

(where I assume the edges are all pointing down) the function above will print the nodes in the order 1,2,3,4 (assuming 3 comes before 4 in edges[node_2]). This would be a "pre-order" way of printing the graph. The "post-order" way would be

def print_nodes_post(edges):
    def dfs(node):
        if node in visited:
            return None
        visited.add(node)
        for child in edges[node]:
            dfs(child)
        print(node)

    visited = set()
    for node in edges.keys():
        dfs(node)

yielding 3,4,2,1.

Obviously printing out the nodes in a graph is a silly task (we could just directly use edges.keys()). The point is that DFS allows us to explore a whole graph, and we can perform tasks as we move along.

def dfs(node):
    visited.add(node)
    do_some_work_pre_children(node)
    for child in edges[node]:
        dfs(child)
    do_some_work_post_children(node)

Detecting cycles

Given a graph, let's write a function to detect if it has any cycles. Here is an (incorrect) approach.

def is_acyclic(edges):
    def dfs(node):
        if node in visited:
            # cycle detected?
            return False

        visited.add(node)
        for child in edges[node]:
            if dfs(child) is False:
                return False
        return True

    visited = set()
    for node in edges.keys():
        if node not in visited:
            if dfs(node) is False:
                return False
    return True

The main thought process here is that if upon calling dfs the node is already in visited then a cycle must be present. However this is deeply flawed. For example with the graph

1
 / \
2   3
 \ /
  4

(where again all the edges are pointing downwards) if we start DFS from 0 everything is fine. But if we instead started from node 2 (or 3, or 4) the function will return False, ie that the graph contains a cycle.

To fix this, we add an additional visited set, which is specific to the recursion.

def is_acyclic(edges):
    def dfs(node):
        if node in visited:
            return True

        visited.add(node)
        visited_dfs.add(node)
        for child in edges[node]:
            if child in visited_dfs:
                return False
            if dfs(child) is False:
                return False

        # after we explored node and all its children
        # we can remove the temporary marker
        # if we don't we'll run into the same issue
        # as in the naive approach
        visited_dfs.remove(node)
        return True

    visited = set()
    visited_dfs = set()
    for node in edges.keys():
        if dfs(node) is False:
            return False
    return True

Topological Sort

Let's say you have a graph which represents certain tasks to be executed (eg ETL tasks for an orchestrator like airflow or prefect, or a computation graph for a neural network). Edges in this graph represent a dependency between two tasks. Because of these dependencies, we can't just execute these tasks in a random order. It would be nice if we could extract a total order of the tasks, compatible with the edges, ie if you can go from node a to node b, we should have a < b in the order.

We can't do this for all graphs. Whatever order we choose for the nodes of 1->2->3->1 it will be incompatible with the graph structure. However if we have a DAG (directed acyclic graph), ie a graph without cycles, then we can. One way to come up with such an order is via DFS. For the graph

0
  |
  1
 / \
2   3
 \ /
  4

5

[0,1,2,3,4,5], [0,1,3,2,4,5], [0,5,1,2,3,4] are all acceptable orders. The idea is to explore the graph using DFS and to add nodes to the list in post-order. For example let's start by exploring the graph above from the node 2. This will yield the list [4,2]. Continuing from node 1, we will have [4,2,3,1]. Continue with node 5 to get [4,2,3,1,5], and finally add 0 to get [4,2,3,1,5,0]. Reversing this list we get [0,5,1,3,2,4], which is a valid total order.

def topological_sort(edges):
    """Assumes graph is acyclic."""
    def dfs(node):
        if node in visited:
            return None
        visited.add(node)
        for child in edges[node]:
                dfs(child)
        order.append(node)

    order = []
    visited = set()
    for node in edges.keys():
        dfs(node)

    order.reverse()
    return order

Combining the two

Of course, if your graph is not guaranteed to be acyclic, we can combine the two operations at once.

def topological_sort_if_acyclic(edges):
    """Return an order if acyclic, else return []"""
    def dfs(node):
        if node in visited:
            return True

        visited.add(node)
        visited_dfs.add(node)

        for child in edges[node]:
            if child in visited_dfs:
                return False
            if dfs(child) is False:
                return False
        visited_dfs.remove(node)
        order.append(node)
        return True

    order = []
    visited = set()
    visited_dfs = set()
    for node in edges.keys():
        if dfs(node) is False:
            return []

    order.reverse()
    return order

26