代码随想录学习 54day 图论 A star算法

avatar
作者
猴君
阅读量:1

A * 算法精讲 (A star算法)

卡码网:126. 骑士的攻击  题目描述  在象棋中,马和象的移动规则分别是“马走日”和“象走田”。现给定骑士的起始坐标和目标坐标,要求根据骑士的移动规则,计算从起点到达目标点所需的最短步数。  棋盘大小 1000 x 1000(棋盘的 x 和 y 坐标均在 [1, 1000] 区间内,包含边界)  输入描述  第一行包含一个整数 n,表示测试用例的数量。  接下来的 n 行,每行包含四个整数 a1, a2, b1, b2,分别表示骑士的起始位置 (a1, a2) 和目标位置 (b1, b2)。  输出描述  输出共 n 行,每行输出一个整数,表示骑士从起点到目标点的最短路径长度。  输入示例  6 5 2 5 4 1 1 2 2 1 1 8 8 1 1 8 7 2 1 3 3 4 6 4 6 输出示例  2 4 6 5 1 0 

思路

我们看到这道题目的第一个想法就是广搜,这也是最经典的广搜类型题目。  这里我直接给出广搜的C++代码: 

code c++ 超时

#include<iostream> #include<queue> #include<string.h> using namespace std; int moves[1001][1001]; int dir[8][2]={-2,-1,-2,1,-1,2,1,2,2,1,2,-1,1,-2,-1,-2}; void bfs(int a1,int a2, int b1, int b2) { 	queue<int> q; 	q.push(a1); 	q.push(a2); 	while(!q.empty()) 	{ 		int m=q.front(); q.pop(); 		int n=q.front(); q.pop(); 		if(m == b1 && n == b2) 		break; 		for(int i=0;i<8;i++) 		{ 			int mm=m + dir[i][0]; 			int nn=n + dir[i][1]; 			if(mm < 1 || mm > 1000 || nn < 1 || nn > 1000) 			continue; 			if(!moves[mm][nn]) 			{ 				moves[mm][nn]=moves[m][n]+1; 				q.push(mm); 				q.push(nn); 			} 		} 	} }  int main() {     int n, a1, a2, b1, b2;     cin >> n;     while (n--) {         cin >> a1 >> a2 >> b1 >> b2;         memset(moves,0,sizeof(moves)); 		bfs(a1, a2, b1, b2); 		cout << moves[b1][b2] << endl; 	} 	return 0; } 

code python 1

from collections import deque def bfs(points):     start = points[:2]     end = points[-2:]     if start == end:return 0     grid = [[0  for _ in range(1001)] for _ in range(1001)]     visited = [[False for _ in range(1001)] for _ in range(1001)]     que = deque([start])     visited[start[0]][start[1]] = True      dir = [[-1, -2],[ -2, -1], [-2, 1], [-1, 2], [1, 2], [2, 1],[ 2, -1], [1, -2]]     counts = 0     while que:         counts += 1         lens = len(que)         for _ in range(lens):             point = que.popleft()             # print('\n', point)             for i in range(len(dir)):                 nextx = point[0] + dir[i][0]                 nexty = point[1] + dir[i][1]                 if visited[nextx][nexty] or nextx < 1 or nexty < 1 or nextx > 1000 or nexty > 1000:continue                 else:                     visited[nextx][nexty] = True                     grid[nextx][nexty] = counts                     # print(nextx, nexty, counts)                 if nextx == end[0] and nexty == end[1]:                     return counts                 else:                     que.append([nextx, nexty])    from collections import deque print('xxxxxxxx') memsets = [[5, 2, 5, 4], [1, 1, 2, 2], [1, 1, 8, 8], [1, 1, 8, 7], [2, 1, 3, 3], [4, 6, 4, 6]]  for points in memsets:     res = bfs(points)     print(points, res) 

code python 2

from collections import deque def bfs(points):     start = points[:2]     end = points[-2:]     if start == end:return 0     grid = [[-1  for _ in range(1001)] for _ in range(1001)]     que = deque([start])     grid[start[0]][start[1]] = 0      dir = [[-1, -2],[ -2, -1], [-2, 1], [-1, 2], [1, 2], [2, 1],[ 2, -1], [1, -2]]     counts = 0     while que:         counts += 1         lens = len(que)         for _ in range(lens):             point = que.popleft()             # print('\n', point)             for i in range(len(dir)):                 nextx = point[0] + dir[i][0]                 nexty = point[1] + dir[i][1]                 if grid[nextx][nexty]!=-1 or nextx < 1 or nexty < 1 or nextx > 1000 or nexty > 1000:continue                 else:                     grid[nextx][nexty] = counts                     # print(nextx, nexty, counts)                 if nextx == end[0] and nexty == end[1]:                     return counts                 else:                     que.append([nextx, nexty])    from collections import deque print('xxxxxxxx') memsets = [[5, 2, 5, 4], [1, 1, 2, 2], [1, 1, 8, 8], [1, 1, 8, 7], [2, 1, 3, 3], [4, 6, 4, 6]]  for points in memsets:     res = bfs(points)     print(points, res) 

code python 3

## 20240717 code from collections import deque  def bfs(points, grid):     start = points[:2]     end = points[-2:]     # grid = [[-1 for _ in range(1001)] for _ in range(1001)]     grid[start[0]][start[1]] = 0     offsets = [[-2, -1], [-2, 1], [2, -1], [2, 1], [1, -2], [1, 2], [-1, -2], [-1, 2]]  # 8个可以移动的offset     ## 初始化, 进入队列     que = deque([start])     counts = 0  # 统计走了几步到达目标点     while que:         node = que.popleft()  # 取出元素         counts += 1         if node[0] == end[0]  and node[1] == end[1]:             print(f'counts"{counts}')             break  # 找到终点了  因为是先赋值, 后判断         for i in range(8):    # 8个可以移动的offset             nextx = node[0] + offsets[i][0]             nexty = node[1] + offsets[i][1]              # 索引的合法性进行判断             if nextx < 1 or nexty < 1 or nextx > 1000 or nexty > 1000: continue             # 如果grid没有被访问过             if grid[nextx][nexty] == -1:                 # 数值在node 的 基础上 + 1                 grid[nextx][nexty] = grid[node[0]][node[1]] + 1                 # 该位置放入 que  广度搜索                 que.append([nextx, nexty])   memsets = [[5, 2, 5, 4], [1, 1, 2, 2], [1, 1, 8, 8], [1, 1, 8, 7], [2, 1, 3, 3], [4, 6, 4, 6]] for points in memsets:     grid = [[-1 for _ in range(1001)] for _ in range(1001)]     bfs(points, grid)     print(points, grid[points[-2]][points[-1]]) # 终点的值  # 但是这个代码的计算量很大,  """ counts"12 [5, 2, 5, 4] 2 counts"35 [1, 1, 2, 2] 4 counts"93 [1, 1, 8, 8] 6 counts"78 [1, 1, 8, 7] 5 counts"3 [2, 1, 3, 3] 1 counts"1 [4, 6, 4, 6] 0 """ 
提交后,大家会发现,超时了。  因为本题地图足够大,且 n 也有可能很大,导致有非常多的查询。  我们来看一下广搜的搜索过程,如图,红色是起点,绿色是终点,黄色是要遍历的点,最后从 起点 找到 达到终点的最短路径是棕色。  可以看出 广搜中,做了很多无用的遍历, 黄色的格子是广搜遍历到的点。  这里我们能不能让遍历方向,向这终点的方向去遍历呢?  这样我们就可以避免很多无用遍历。 

Astar

Astar 是一种 广搜的改良版。 有的是 Astar 是 dijkstra 的改良版。  其实只是场景不同而已 我们在搜索最短路的时候, 如果是无权图(边的权值都是1) 那就用广搜,代码简洁,时间效率和 dijkstra 差不多 (具体要取决于图的稠密)  如果是有权图(边有不同的权值),优先考虑 dijkstra。  而 Astar 关键在于 启发式函数, 也就是 影响 广搜或者 dijkstra 从 容器(队列)里取元素的优先顺序。  以下,我用BFS版本的A * 来进行讲解。  在BFS中,我们想搜索,从起点到终点的最短路径,要一层一层去遍历。  如果使用 A* 的话,其搜索过程是这样的,如图,图中着色的都是我们要遍历的点。  (上面两图中 最短路长度都是8,只是走的方式不同而已)  大家可以发现 BFS 是没有目的性的 一圈一圈去搜索, 而 A * 是有方向性的去搜索。  看出 A * 可以节省很多没有必要的遍历步骤。  为了让大家可以明显看到区别,我将 BFS 和 A * 制作成可视化动图,大家可以自己看看动图,效果更好。  地址:https://kamacoder.com/tools/knight.html  那么 A * 为什么可以有方向性的去搜索,它是如何知道方向呢?  其关键在于 启发式函数。  那么 启发式函数 落实到代码处,如果指引搜索的方向?  在本篇开篇中给出了BFS代码,指引 搜索的方向的关键代码在这里:  int m=q.front();q.pop();  int n=q.front();q.pop();  从队列里取出什么元素,接下来就是从哪里开始搜索。  所以 启发式函数 要影响的就是队列里元素的排序!  这是影响BFS搜索方向的关键。  对队列里节点进行排序,就需要给每一个节点权值,如何计算权值呢?  每个节点的权值为F,给出公式为:F = G + H  G:起点达到目前遍历节点的距离  F:目前遍历的节点到达终点的距离  起点达到目前遍历节点的距离 + 目前遍历的节点到达终点的距离 就是起点到达终点的距离。  本题的图是无权网格状,在计算两点距离通常有如下三种计算方式:  曼哈顿距离,计算方式: d = abs(x1-x2)+abs(y1-y2)  欧氏距离(欧拉距离) ,计算方式:d = sqrt( (x1-x2)^2 + (y1-y2)^2 )  切比雪夫距离,计算方式:d = max(abs(x1 - x2), abs(y1 - y2))  x1, x2 为起点坐标,y1, y2 为终点坐标 ,abs 为求绝对值,sqrt 为求开根号,  选择哪一种距离计算方式 也会导致 A * 算法的结果不同。  本题,采用欧拉距离才能最大程度体现 点与点之间的距离。  所以 使用欧拉距离计算 和 广搜搜出来的最短路的节点数是一样的。 (路径可能不同,但路径上的节点数是相同的)  我在制作动画演示的过程中,分别给出了曼哈顿、欧拉以及契比雪夫 三种计算方式下,A * 算法的寻路过程,大家可以自己看看看其区别。  动画地址:https://kamacoder.com/tools/knight.html  计算出来 F 之后,按照 F 的 大小,来选去出队列的节点。  可以使用 优先级队列 帮我们排好序,每次出队列,就是F最小的节点。  实现代码如下:(启发式函数 采用 欧拉距离计算方式) 

code c++

#include<iostream> #include<queue> #include<string.h> using namespace std; int moves[1001][1001]; int dir[8][2]={-2,-1,-2,1,-1,2,1,2,2,1,2,-1,1,-2,-1,-2}; int b1, b2; // F = G + H // G = 从起点到该节点路径消耗 // H = 该节点到终点的预估消耗  struct Knight{     int x,y;     int g,h,f;     bool operator < (const Knight & k) const{  // 重载运算符, 从小到大排序      return k.f < f;     } };  priority_queue<Knight> que;  int Heuristic(const Knight& k) { // 欧拉距离     return (k.x - b1) * (k.x - b1) + (k.y - b2) * (k.y - b2); // 统一不开根号,这样可以提高精度 } void astar(const Knight& k) {     Knight cur, next; 	que.push(k); 	while(!que.empty()) 	{ 		cur=que.top(); que.pop(); 		if(cur.x == b1 && cur.y == b2) 		break; 		for(int i = 0; i < 8; i++) 		{ 			next.x = cur.x + dir[i][0]; 			next.y = cur.y + dir[i][1]; 			if(next.x < 1 || next.x > 1000 || next.y < 1 || next.y > 1000) 			continue; 			if(!moves[next.x][next.y]) 			{ 				moves[next.x][next.y] = moves[cur.x][cur.y] + 1;                  // 开始计算F 				next.g = cur.g + 5; // 统一不开根号,这样可以提高精度,马走日,1 * 1 + 2 * 2 = 5                 next.h = Heuristic(next);                 next.f = next.g + next.h;                 que.push(next); 			} 		} 	} }  int main() {     int n, a1, a2;     cin >> n;     while (n--) {         cin >> a1 >> a2 >> b1 >> b2;         memset(moves,0,sizeof(moves));         Knight start;         start.x = a1;         start.y = a2;         start.g = 0;         start.h = Heuristic(start);         start.f = start.g + start.h; 		astar(start);         while(!que.empty()) que.pop(); // 队列清空 		cout << moves[b1][b2] << endl; 	} 	return 0; } 

复杂度分析

A * 算法的时间复杂度 其实是不好去量化的,因为他取决于 启发式函数怎么写。  最坏情况下,A * 退化成广搜,算法的时间复杂度 是 O(n * 2),n 为节点数量。  最佳情况,是从起点直接到终点,时间复杂度为 O(dlogd),d 为起点到终点的深度。  因为在搜索的过程中也需要堆排序,所以是 O(dlogd)。  实际上 A * 的时间复杂度是介于 最优 和最坏 情况之间, 可以 非常粗略的认为 A * 算法的时间复杂度是 O(nlogn) ,n 为节点数量。  A * 算法的空间复杂度 O(b ^ d) ,d 为起点到终点的深度,b 是 图中节点间的连接数量,本题因为是无权网格图,所以 节点间连接数量为 4

拓展

如果本题大家使用 曼哈顿距离 或者 切比雪夫距离 计算的话,可以提交试一试,有的最短路结果是并不是最短的。  原因也是 曼哈顿 和 切比雪夫这两种计算方式在 本题的网格地图中,都没有体现出点到点的真正距离!  可能有些录友找到类似的题目,例如 poj 2243 (opens new window),使用 曼哈顿距离 提交也过了, 那是因为题目中的地图太小了,仅仅是一张 8 * 8的地图,根本看不出来 不同启发式函数写法的区别。  A * 算法 并不是一个明确的最短路算法,A * 算法搜的路径如何,完全取决于 启发式函数怎么写。  A * 算法并不能保证一定是最短路,因为在设计 启发式函数的时候,要考虑 时间效率与准确度之间的一个权衡。  虽然本题中,A * 算法得到是最短路,也是因为本题 启发式函数 和 地图结构都是最简单的。  例如在游戏中,在地图很大、不同路径权值不同、有障碍 且多个游戏单位在地图中寻路的情况,如果要计算准确最短路,耗时很大,会给玩家一种卡顿的感觉。  而真实玩家在玩游戏的时候,并不要求一定是最短路,次短路也是可以的 (玩家不一定能感受出来,及时感受出来也不是很在意),只要奔着目标走过去 大体就可以接受。  所以 在游戏开发设计中,保证运行效率的情况下,A * 算法中的启发式函数 设计往往不是最短路,而是接近最短路的 次短路设计。  大家如果玩 LOL,或者 王者荣耀 可以回忆一下:如果 从很远的地方点击 让英雄直接跑过去 是 跑的路径是不靠谱的,所以玩家们才会在 距离英雄尽可能近的位置去点击 让英雄跑过去。 

A * 的缺点

大家看上述 A * 代码的时候,可以看到 我们想 队列里添加了很多节点,但真正从队列里取出来的 仅仅是 靠启发式函数判断 距离终点最近的节点。  相对了 普通BFS,A * 算法只从 队列里取出 距离终点最近的节点。  那么问题来了,A * 在一次路径搜索中,大量不需要访问的节点都在队列里,会造成空间的过度消耗。  IDA * 算法 对这一空间增长问题进行了优化,关于 IDA * 算法,本篇不再做讲解,感兴趣的录友可以自行找资料学习。  另外还有一种场景 是 A * 解决不了的。  如果题目中,给出 多个可能的目标,然后在这多个目标中 选择最近的目标,这种 A * 就不擅长了, A * 只擅长给出明确的目标 然后找到最短路径。  如果是多个目标找最近目标(特别是潜在目标数量很多的时候),可以考虑 Dijkstra ,BFS 或者 Floyd。  

python 双端队列 例子1

class Knight:     def __init__(self,k, point):         self.k = k         self.point = point     def __lt__(self, other):         return self.k < other.k  from queue import PriorityQueue c1 = Knight(5, (2,4)) c2 = Knight(10, (3,7)) c3 = Knight(2, (2,1))  que = PriorityQueue() que.put(c1) que.put(c2) que.put(c3) while que:     c = que.get()     print(c.k, c.point) 

python 双端队列 例子2

## 20240716 class Knight:     def __init__(self,point, k):         self.k = k         self.point = point     def __lt__(self, other):         return self.k < other.k  from queue import PriorityQueue c1 = Knight(5, (2,4)) c2 = Knight(10, (3,7)) c3 = Knight(2, (2,1))  que = PriorityQueue() que.put(c1) que.put(c2) que.put(c3) while que:     c = que.get()     print(c.k, c.point)  # 两个例子一样的效果 

python 双端队列 例子2

## 20240717  from queue import PriorityQueue  class Knight:     def __init__(self, x, y, g, h, f):         self.x = x         self.y = y         self.g = g         self.h = h         self.f = f      def __lt__(self, other):         return self.f < other.f  # 重载运算符, 从小到大排序  def Heuristic(mid, end):    # 欧拉距离     k = Knight(mid[0], mid[1], 0,0,0)     b1, b2 = end     return (k.x - b1) * (k.x-b1) + (k.y - b2) * (k.y - b2)  # 统一不开根号,这样可以提高精度  c1 = Knight(1, 2, 3, 4, 5) c2 = Knight(6, 7, 83, 2, 30) c3 = Knight(4, 2, 3, 5, 20) c4 = Knight(5, 2, 3, 6, 5) c5 = Knight(6, 2, 3, 7, 10) c6 = Knight(7, 2, 3, 8, 8) que = PriorityQueue()  for v in [f'c{i}' for i in range(1,7)]:     # print(eval(v).f)     que.put(eval(v))  while not que.empty():     node = que.get()     print(node.x, node.y, node.g, node.h, node.f)     mid = [node.x, node.y]     end = [7, 7]     print('欧式距离',Heuristic(mid, end))  

Python code 优先级队列

## 20240717  from queue import PriorityQueue  que = PriorityQueue()   class Knight:     def __init__(self, x=0, y=0, g=0, h=0, f=0):         self.x = x         self.y = y         self.g = g         self.h = h         self.f = f      def __lt__(self, other):         return self.f < other.f  # 重载运算符, 从小到大排序   def Heuristic(k, end):  # 欧拉距离  这个公式是计算 mid 到达终点的距离, 不适合 起点 到达 mid.     return (k.x - end.x) * (k.x - end.x) + (k.y - end.y) * (k.y - end.y)  # 统一不开根号,这样可以提高精度   def Astartbfs(grid, start, end):     """     point: Knight的结构     grid: 统计走过的点需要的步数     """     offsets = [[-2, -1], [-2, 1], [2, -1], [2, 1], [1, -2], [1, 2], [-1, -2], [-1, 2]]  # 8个可以移动的offset     que.put(start)     counts = 0     while not que.empty():         counts += 1         cur = que.get()         # print(cur.x, cur.y, cur.g, cur.h, cur.f, grid[cur.x][cur.y], 'cur=========------------------------------------========')         if cur.x == end.x and cur.y == end.y:  # 找到了终点位置              print(f'到达终点, counts: {counts}')             break         for i in range(8):             next = Knight()             next.x = cur.x + offsets[i][0]             next.y = cur.y + offsets[i][1]             # 边界             if next.x < 1 or next.x > 1000 or next.y < 1 or next.y > 1000: continue             if grid[next.x][next.y] == -1:                 grid[next.x][next.y] = grid[cur.x][cur.y] + 1                  # 计算距离                 next.g = cur.g + 5   # 每走一步的距离一定与上一步 + 5, 而不是和原点  保证一个距离在同一个水平, 走一步之后都是在5的基础上 + (next, end)的距离                 # print(start.x,start.y, next.x,next.y,'----------------------------------------')                 next.h = Heuristic(next, end)                 next.f = next.g + next.h                 # print(next.x, next.y, next.g, next.h, next.f, grid[next.x][next.y], 'next=================')                 que.put(next)   def main():     memsets = [[5, 2, 5, 4], [1, 1, 2, 2], [1, 1, 8, 8], [1, 1, 8, 7], [2, 1, 3, 3], [4, 6, 4, 6]]     for points in memsets[:]:         grid = [[-1 for _ in range(1001)] for _ in range(1001)]         grid[points[0]][points[1]] = 0  # 起点的初始化          start = Knight(points[0], points[1])  # 初始化, x, y, g, h, f         end = Knight(points[-2], points[-1])         start.g = 0         start.h = Heuristic(start, end)         start.f = start.g + start.h         Astartbfs(grid, start, end)         print('输出的结果', grid[points[-2]][points[-1]])         print('\n')         while not que.empty():             que.get()  # 如果优先级队列非空, 则清空该队列 main() 

code 的步数统计

到达终点, counts: 6 输出的结果 2 到达终点, counts: 18 输出的结果 4 到达终点, counts: 16 输出的结果 6 到达终点, counts: 10 输出的结果 5 到达终点, counts: 2 输出的结果 1 到达终点, counts: 1 输出的结果 0  明显小于未优化之前的步数  

python code3

# 把昨天写的小垃圾修正一下 昨天主要的错误是: k1 = grid[nextx][nexty]  * 5 k2 = Heuristic([nextx, nexty], end) que.put(Knight(k1 + k2, [nextx, nexty]))  昨天写成了 k1 = Heuristic(start, [nextx, nexty]) + 5 k2 = Heuristic([nextx, nexty], end) que.put(Knight(k1 + k2, [nextx, nexty]))  没有理解 k1 的作用,k1 最大的作用是 走了 1步则k1=0+5, 走了2步, k1 = 0 + 5 + 5, 每多走一步, 一定是在上一步的k1的基础上递增。而不是计算start和当前步的欧式距离 ,欧式距离只适用于: 当前步到达终点位置的计算。  
from queue import PriorityQueue class Knight:     def __init__(self,k, point):         self.k = k         self.point = point     def __lt__(self, other):         return self.k < other.k  def Heuristic(point1, point2):     return (point1[0] - point2[0]) * (point1[0] - point2[0]) + (point1[1] - point2[1]) * (point1[1] - point2[1])  def bfs(points):     start = points[:2]     end = points[-2:]     grid = [[-1  for _ in range(1001)] for _ in range(1001)]     que = PriorityQueue()     k1 = Heuristic(start, start) + 5     k2 = Heuristic(start, end)     que.put(Knight(k1 + k2, start))     grid[start[0]][start[1]] = 0      dir = [[-1, -2],[ -2, -1], [-2, 1], [-1, 2], [1, 2], [2, 1],[ 2, -1], [1, -2]]     while not que.empty():         point = que.get().point         # print('\n', point)         if point[0] == end[0] and point[1] == end[1]:return grid[point[0]][point[1]]         for i in range(8):             nextx = point[0] + dir[i][0]             nexty = point[1] + dir[i][1]             if nextx < 1 or nexty < 1 or nextx > 1000 or nexty > 1000:continue             if grid[nextx][nexty] == -1:                 grid[nextx][nexty] = grid[point[0]][point[1]] + 1                 # print(nextx, nexty, counts)                 k1 = grid[nextx][nexty]  * 5                 k2 = Heuristic([nextx, nexty], end)                 que.put(Knight(k1 + k2, [nextx, nexty]))  from collections import deque print('xxxxxxxx') memsets = [[5, 2, 5, 4], [1, 1, 2, 2], [1, 1, 8, 8], [1, 1, 8, 7], [2, 1, 3, 3], [4, 6, 4, 6]]  for points in memsets:     res = bfs(points)     print(points, res) 

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!