Implementing the A* Pathfinding Algorithm in XNA

The vast majority of videogames use pathfinding to aid in the movement of on screen avatars. Anytime you tell a character to move to some distant spot, and that character must avoid obstacles in between, an algorithm is executed to determine the most efficient path from point A to point B. If you've ever played a Real-Time Strategy like StarCraft, a 4X game like Civilization, or a Strategy RPG like Disgaea, you've seen pathfinding in action.

For my upcoming SRPG, Armadillo, I needed to implement a pathfinding algorithm for the battle scenes. When the player tells a character to move from point A to point B, that character needs to dodge walls and enemies in order to get there. Having my characters walk through walls would not look very good at all.

The A* Search Algorithm

One very common pathfinding algorithm is known as the A* search algorithm (pronounced A-Star). This algorithm dates back to 1968 and is still commonly used by game developers to search for the best path between two points. The A* algorithm is well suited for video games for a few reasons:

  • The A* algorithm operates on a generic node system. Rather than being concretely tied to a square grid, such as a chess board, or a hexagonal grid like seen in Chinese checkers, the A* algorithm is only concerned about nodes, their neighbors, and how far away each neighbor is. This means that this one algorithm can be used in many games, including one with highly irregular node networks, such as Crusader Kings II.
  • A* is fast. It uses a best-first search pattern. This means that it tries to predict the distance remaining and prioritizes routes that seem like they'll work well. This is better than a breadth-first search or a depth-first search. Breadth-first searches scan everything to a set distance, and gradually expand out the set distance. This wastes a lot of time searching in completely wrong directions. Depth-first searches seek down a path until a dead end is found. This wastes time by going down rabbit holes that should be easily discounted once it is known that they are going away from the destination.
  • A* is complete, meaning it will always find a solution if one exists.

Setting Up the Node Tree

In order to use the A* search algorithm, the first step is creating a node tree to be analyzed. For a basic A* algorithm, I will just use a square grid. I'm going to use the same Grid object that I discussed in my last post on sandbagging:

public struct Grid
{
    public Rectangle Size;
    
    public byte[,] Weight;

    public Grid(int x, int y, byte defaultValue = 0)
    {
        Size = new Rectangle(0, 0, x, y);
        Weight = new byte[x, y];

        for(var i = 0; i < x; i++)
        {
            for (var j = 0; j < y; j++)
            {
                Weight[i, j] = defaultValue;
            }
        }
    }
}

This is a lightweight object that contains a rectangular grid of square cells. Each cell has a weight, indicating whether or not it is accessible. A value of 0 indicates that a cell is inaccessible, a non-zero value indicates that a cell is accessible. For my upcoming Strategy RPG, codenamed Armadillo, I generate a grid based on the current battleboard, taking into account walls, obstacles, and enemies. This grid represents where a specified character is able to move.

Implementing A*

Now that I have a node tree to analyze, I can implement an A* search algorithm upon the node tree. This algorithm will be a method within the Grid object that takes two parameters - start and end. This keeps it extremely generic, only concerning itself with where it starts, where it can end, and where it moves. 

The A* search algorithm is very straightforward. It operates on a best-first approach, looking at the most likely path first, one node at a time. It looks at the neighbors of the current node, calculating how far the neighbor is from itself, as well as the start point, and estimates how far from that neighbor to the end. This heuristic estimate is what makes A* really shine for games, since this is how the best-first approach works. By estimating the remaining distance, it is able to prioritize which path is the best. Once it has fully processed a node, it moves onto the next node, choosing the node that looks to be optimal.

Once it has a fully established path, the function will trim down the unused nodes to return a list of Point objects that contains the best path from start to finish.

Here is the full code for the A* algorithm in XNA:

public struct Grid
{
    public Rectangle Size;
    
    public byte[,] Weight;
 
    public Grid(int x, int y, byte defaultValue = 0)
    {
        Size = new Rectangle(0, 0, x, y);
        Weight = new byte[x, y];
 
        for(var i = 0; i < x; i++)
        {
            for (var j = 0; j < y; j++)
            {
                Weight[i, j] = defaultValue;
            }
        }
    }

    public List<Point> Pathfind(Point start, Point end)
    {
        // nodes that have already been analyzed and have a path from the start to them
        var closedSet = new List<Point>();
        // nodes that have been identified as a neighbor of an analyzed node, but have 
        // yet to be fully analyzed
        var openSet = new List<Point> { start };
        // a dictionary identifying the optimal origin point to each node. this is used 
        // to back-track from the end to find the optimal path
        var cameFrom = new Dictionary<Point, Point>();
        // a dictionary indicating how far each analyzed node is from the start
        var currentDistance = new Dictionary<Point, int>();
        // a dictionary indicating how far it is expected to reach the end, if the path 
        // travels through the specified node. 
        var predictedDistance = new Dictionary<Point, float>();
    
        // initialize the start node as having a distance of 0, and an estmated distance 
        // of y-distance + x-distance, which is the optimal path in a square grid that 
        // doesn't allow for diagonal movement
        currentDistance.Add(start, 0);
        predictedDistance.Add(
            start, 
            0 + +Math.Abs(start.X - end.X) + Math.Abs(start.Y - end.Y)
        );
    
        // if there are any unanalyzed nodes, process them
        while (openSet.Count > 0)
        {
            // get the node with the lowest estimated cost to finish
            var current = (
                from p in openSet orderby predictedDistance[p] ascending select p
            ).First();
    
            // if it is the finish, return the path
            if (current.X == end.X && current.Y == end.Y)
            {
                // generate the found path
                return ReconstructPath(cameFrom, end);
            }
    
            // move current node from open to closed
            openSet.Remove(current);
            closedSet.Add(current);
    
            // process each valid node around the current node
            foreach (var neighbor in GetNeighborNodes(current))
            {
                var tempCurrentDistance = currentDistance[current] + 1;
    
                // if we already know a faster way to this neighbor, use that route and 
                // ignore this one
                if (closedSet.Contains(neighbor) 
                    && tempCurrentDistance >= currentDistance[neighbor])
                {
                    continue;
                }
    
                // if we don't know a route to this neighbor, or if this is faster, 
                // store this route
                if (!closedSet.Contains(neighbor) 
                    || tempCurrentDistance < currentDistance[neighbor])
                {
                    if (cameFrom.Keys.Contains(neighbor))
                    {
                        cameFrom[neighbor] = current;
                    }
                    else
                    {
                        cameFrom.Add(neighbor, current);
                    }
    
                    currentDistance[neighbor] = tempCurrentDistance;
                    predictedDistance[neighbor] = 
                        currentDistance[neighbor] 
                        + Math.Abs(neighbor.X - end.X) 
                        + Math.Abs(neighbor.Y - end.Y);
    
                    // if this is a new node, add it to processing
                    if (!openSet.Contains(neighbor))
                    {
                        openSet.Add(neighbor);
                    }
                }
            }
        }
    
        // unable to figure out a path, abort.
        throw new Exception(
            string.Format(
                "unable to find a path between {0},{1} and {2},{3}", 
                start.X, start.Y, 
                end.X, end.Y
            )
        );
    }
    
    /// <summary>
    /// Return a list of accessible nodes neighboring a specified node
    /// </summary>
    /// <param name="node">The center node to be analyzed.</param>
    /// <returns>A list of nodes neighboring the node that are accessible.</returns>
    private IEnumerable<Point> GetNeighborNodes(Point node)
    {
        var nodes = new List<Point>();
    
        // up
        if (Weight[node.X, node.Y - 1] > 0)
        {
            nodes.Add(new Point(node.X, node.Y - 1));
        }
    
        // right
        if (Weight[node.X + 1, node.Y] > 0)
        {
            nodes.Add(new Point(node.X + 1, node.Y));
        }
    
        // down
        if (Weight[node.X, node.Y + 1] > 0)
        {
            nodes.Add(new Point(node.X, node.Y + 1));
        }
    
        // left
        if (Weight[node.X - 1, node.Y] > 0)
        {
            nodes.Add(new Point(node.X - 1, node.Y));
        }
    
        return nodes;
    }
    
    /// <summary>
    /// Process a list of valid paths generated by the Pathfind function and return 
    /// a coherent path to current.
    /// </summary>
    /// <param name="cameFrom">A list of nodes and the origin to that node.</param>
    /// <param name="current">The destination node being sought out.</param>
    /// <returns>The shortest path from the start to the destination node.</returns>
    private List<Point> ReconstructPath(Dictionary<Point, Point> cameFrom, Point current)
    {
        if (!cameFrom.Keys.Contains(current))
        {
            return new List<Point> { current };
        }
    
        var path = ReconstructPath(cameFrom, cameFrom[current]);
        path.Add(current);
        return path;
    }
}