A* (read: A Star) algorithm is a pathfinding algorithm to find a shortest path from location A to location B where different steps can have different costs.

It consists of a couple of parts that are mostly reusable as-is or with small adjustments.

Here are notes for how to implement it for a 2D grid search in Python.

The main algorithm:

Coordinate = namedtuple('Coordinate', ['x', 'y'])
 
def find_shortest_path(
    start: Coordinate, target: Coordinate, walls: List[Coordinate]
) -> Tuple[List[Coordinate], int] | None:
    """A* path finding to find the shortest path from start to target."""
    open_set = set()
    open_set.add(start)
    came_from: dict[Coordinate, Coordinate] = {}
 
    g_score: dict[Coordinate, int] = defaultdict(constant_factory(math.inf))
    g_score[start] = 0
 
    f_score: dict[Coordinate, int] = defaultdict(constant_factory(math.inf))
    f_score[start] = distance_to_target(start, target)
 
    while open_set:
        current = cheapest_by_f_score(f_score, open_set)
        if current == target:
            score = g_score[current]
            return (
                reconstruct_path(came_from, current),
                score,
            )
        open_set.remove(current)
 
        for neighbour in get_valid_neighbours(current, walls):
            tentative_g_score = g_score[current] + 1
 
            if tentative_g_score < g_score[neighbour]:
                came_from[neighbour] = current
                g_score[neighbour] = tentative_g_score
                f_score[neighbour] = tentative_g_score + distance_to_target(
                    neighbour, target
                )
                open_set.add(neighbour)
 
    return None

The algorithm receives the start location, target location and some sort of list or dictionary that represents the grid. In this example, it’s a list of coordinates that cannot be entered.

Open Set

In the beginning, an open_set is initialised as a set with the starting location.

This open set keeps track of all the coordinates we have seen as neighbours but haven’t reached the goal or a dead end yet.

During the execution, we take a look the currently cheapest one in open_set and add all of its accessible neighbouring locations into the set.

open_set = set()
open_set.add(start)

Namedtuples and sets

One easy bug with sets and tuples is to initialise the set with one tuple:

open_set = set(start) 

What will happen there is that start which we consider to be Coordinate is actually a tuple of (x,y) and when you initialise a set with an iterable, it will create a set of {x, y} instead of a set of {(x,y)}.

I have been bitten by this so many times.

Came from

To be able to call back a path we’ve taken, we keep track of the position and its previous position in a dictionary called came_from.

came_from: dict[Coordinate, Coordinate] = {}

Every time we take a look at a neighbouring tile, we record where we entered it from and once we find the target location, we reconstruct the path with a helper function:

def reconstruct_path(came_from: dict, position: Coordinate) -> List[Coordinate]:
    path = [position]
    while position in came_from:
        position = came_from[position]
        path.append(position)
    return list(reversed(path))

G & F Scores

g_score and f_score are two dictionaries that keep track of how much it has cost us to reach current location (g_score) and how much we estimate the cost to reach the target to be (f_score).

We initialise both of them as defaultdictionaries where the default value is math.inf. The reason for this is to make sure any real value we calculate will always be shorter than this.

g_score: dict[Coordinate, int] = defaultdict(constant_factory(math.inf))
g_score[start] = 0
 
f_score: dict[Coordinate, int] = defaultdict(constant_factory(math.inf))
f_score[start] = distance_to_target(start, target)

To initialise a defaultdict with a value that is not a default value of type (like 0 for int), we can use a constant factory:

def constant_factory(x):
    return lambda: x

It’s a function that returns a function that when called, will always return the original value that was used when it was initialised.

Main logic

The main loop has two parts.

First, we find the position we estimate to be the cheapest to reach the target from out of all the positions we haven’t checked yet:

while open_set:
	current = cheapest_by_f_score(f_score, open_set)

We find the cheapest one by going through all the ones in open_set and checking if they are cheaper than the currently cheapest:

def cheapest_by_f_score(
    f_score: dict[Coordinate, int],
    open_set: set[Coordinate],
) -> Coordinate:
    first = open_set.pop()
    cheapest = f_score[first]
    cheapest_pos = first
    for pos in open_set:
        if (score := f_score[pos]) < cheapest:
            cheapest_pos = pos
            cheapest = score
    open_set.add(first)
 
    return cheapest_pos

Sidenote on priority queues

An alternative is to use a priority queue which is always sorted by the cost:

import heapq
 
# Initialize
open_set = []
 
# Get the current value
_, current = heapq.heappop(open_set)
 
# Add a neighbour into the queue
heapq.heappush(open_set, (tentative_g_score, neighbour))

With this approach, we don’t need to manually search the open_set for the smallest one, saving few calculations. It’s important to note that compared to a set, a list can have duplicate items so we need to keep track of it so we don’t push duplicates into it.

The rest of the example is implemented with open_set as a set

Once we find the current position to look at, we check if we have reached the target. If we have, we reconstruct the path and return it alongside the score (the cost of reaching there which is the g_score of that position):

	if current == target:
		score = g_score[current]
		return (
			reconstruct_path(came_from, current),
			score,
		)

If we’re not yet at the target, we mark the current one as processed (by removing it from open_set) and continue by looking at all of the possible neighbours we can move into.

For each neighbour, we calculate its score. In this example, every step costs 1 so the cost is the cost to the get the current + 1. If your system has different costs depending on where you’re moving (like turning costing extra or each location having a different moving cost), you’d take it into account here.

open_set.remove(current)
 
for neighbour in get_valid_neighbours(current, walls):
	tentative_g_score = g_score[current] + 1
 
	if tentative_g_score < g_score[neighbour]:
		came_from[neighbour] = current
		g_score[neighbour] = tentative_g_score
		f_score[neighbour] = tentative_g_score + distance_to_target(
			neighbour, target
		)
		open_set.add(neighbour)

If this score is cheaper than what we have previously calculated for it – either because it’s the first time we run into it and default is infinity or because we have earlier reached it with more expensive route – we mark this path to be the winning one, update the cost, update the f_score which is the estimate to reach the end and add this neighbour to the collection of locations to check next.

How you define the valid neighbours for a location is up to the specific task at hand. Here, in a 2D grid with walls and cardinal movement, it looks like this:

def get_valid_neighbours(
    current: Coordinate, walls: List[Coordinate]
) -> List[Coordinate]:
    neighbours = [
        Coordinate(current.x, current.y - 1),
        Coordinate(current.x, current.y + 1),
        Coordinate(current.x - 1, current.y),
        Coordinate(current.x + 1, current.y),
    ]
 
    return [
        neighbour
        for neighbour in neighbours
        if neighbour not in walls
        and neighbour.x >= MIN_X
        and neighbour.x <= MAX_X
        and neighbour.y >= MIN_Y
        and neighbour.y <= MAX_Y
    ]

We first create a list of all the 4 cardinal directions and then filter out all the walls and everything that would go beyond our grid. Sometimes, the grid may be surrounded by walls in which case the MIN/MAX comparisons can be skipped.

To calculate the f_score, we need a heuristic function that estimates the cost to reach the goal. In our case of 2D grid where we can move in 4 cardinal directions, this cost can be estimate as a manhattan/taxi cab distance:

def distance_to_target(start: Coordinate, target: Coordinate) -> int:
    return abs(start.x - target.x) + abs(start.y - target.y)

As long as this estimate never overestimates the cost (it’s called being admissible), A* is guaranteed to find the cheapest route to the target. Sometimes this heuristic function is easy to come up with, sometimes it can be tough to figure out a good heuristic.

If we processed all the locations and didn’t find a route to the end, we return None to tell the caller that no path was discovered.