import ast import sys import numpy as np import json import time import redis from collections import deque, defaultdict ########################## # 工具函数 ########################## def convert_to_float_array(str_array): """ 将字符串或可迭代对象转换为浮点型数组。 """ if str_array == "-": return np.array([], dtype=float) if isinstance(str_array, str): return np.array(ast.literal_eval(str_array), dtype=float) elif isinstance(str_array, list) or isinstance(str_array, np.ndarray): return np.array(str_array, dtype=float) return np.array([], dtype=float) def initWaveMatrix(codeMatrix): """ 根据 codeMatrix 初始化 waveMatrix。 若格子为 'NONE',则 waveMatrix[x][y] = "-" 否则为 "[]" """ waveMatrix = np.empty_like(codeMatrix, dtype=object) for x in range(codeMatrix.shape[0]): for y in range(codeMatrix.shape[1]): if codeMatrix[x][y] == 'NONE': waveMatrix[x][y] = "-" else: waveMatrix[x][y] = "[]" return waveMatrix def mergeWave(originWave, vehicle): """ originWave 是 waveMatrix[x][y] 中存储的字符串(内部是一个 JSON 数组), 将新车辆 vehicle 合并进去(用 set 去重), 并返回更新后的 JSON 字符串。 """ try: set_data = set(ast.literal_eval(originWave)) except (ValueError, SyntaxError): set_data = set() set_data.add(vehicle) return json.dumps(list(set_data)) def convert_to_structured_array(dynamicMatrix): """ 将 dynamicMatrix(嵌套列表,里边是 dict)转换为 numpy 结构化数组,方便之后进行掩码筛选。 """ # 定义结构化数组的 dtype dtype = [('serial', int), ('vehicle', 'U2'), ('time', int)] structured_list = [] for row in dynamicMatrix: for d in row: serial = d.get('serial', 0) vehicle = d.get('vehicle', '0') time_val = d.get('time', 0) structured_list.append((serial, vehicle, time_val)) # 转换为结构化数组,并重塑为原始二维形状 structured_array = np.array(structured_list, dtype=dtype) return structured_array.reshape(len(dynamicMatrix), -1) ########################## # 多源 BFS 核心函数 ########################## def process_dynamic_matrix_multi_source_bfs(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen): """ 使用多源 BFS,一次性将所有含有车辆的格子作为起点同时进入队列。 - 如果格子为 'NONE',则沿同一方向继续直线扩散; - 若遇到非 'NONE' 的格子,则进行欧几里得距离判定 <= radiusLen 时可到达,并且从此处再次朝四方向扩散(方向重置为 None)。 """ # 1. 转换 dynamicMatrix 为结构化数组 dynamicMatrix = convert_to_structured_array(dynamicMatrix) rows, cols = dynamicMatrix.shape # 2. 初始化 waveMatrix(最终要存储回 Redis 的矩阵) waveMatrix = initWaveMatrix(codeMatrix) # 3. 建立一个与 waveMatrix 等大小的 2D 数组 waveSets,用于保存车辆集合(用 set 存储) waveSets = np.empty((rows, cols), dtype=object) for i in range(rows): for j in range(cols): waveSets[i, j] = set() # 4. 准备 BFS 队列,队列元素格式: (x, y, direction, vehicle) queue = deque() # 5. 找到所有有车辆的格子,统一入队列 mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1') x_indices, y_indices = np.where(mask) for x, y in zip(x_indices, y_indices): v = dynamicMatrix[x][y]['vehicle'] waveSets[x, y].add(v) queue.append((x, y, None, v)) # 起始时 direction=None # 6. 建立 visited 来避免重复访问 # 因为一个位置 (x, y) 可能被多个车辆用多种方向访问,需要记录 (x, y, direction, vehicle) visited = set() # 7. BFS 主循环 while queue: x, y, direction, vehicle = queue.popleft() if (x, y, direction, vehicle) in visited: continue visited.add((x, y, direction, vehicle)) # 判断下一个要扩展的方向 if direction is None: # 无方向时,四个方向同时扩散 neighbors_info = [ (x + 1, y, 'right'), (x - 1, y, 'left'), (x, y + 1, 'down'), (x, y - 1, 'up') ] else: # 有方向时,只往同一个方向扩展 if direction == 'right': neighbors_info = [(x + 1, y, 'right')] elif direction == 'left': neighbors_info = [(x - 1, y, 'left')] elif direction == 'down': neighbors_info = [(x, y + 1, 'down')] elif direction == 'up': neighbors_info = [(x, y - 1, 'up')] else: neighbors_info = [] # 遍历邻居 for nx, ny, ndir in neighbors_info: # 边界检查 if nx < 0 or nx >= rows or ny < 0 or ny >= cols: continue neighbor_code = codeMatrix[nx, ny] if neighbor_code == 'NONE': # 如果是 'NONE',则沿当前方向继续扩散 if vehicle not in waveSets[nx, ny]: waveSets[nx, ny].add(vehicle) queue.append((nx, ny, ndir, vehicle)) else: # 非 'NONE',需要用欧几里得距离判定 c1 = convert_to_float_array(cdaMatrix[x, y]) # 当前坐标 c2 = convert_to_float_array(cdaMatrix[nx, ny]) # 邻居坐标 if c1.size < 2 or c2.size < 2: continue dist_sqr = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2 if dist_sqr <= radiusLen**2: # 可以到达,则把车辆加入 waveSets if vehicle not in waveSets[nx, ny]: waveSets[nx, ny].add(vehicle) # 方向重置为 None,表示四向扩散 queue.append((nx, ny, None, vehicle)) # 8. BFS 完成后,将 waveSets 的信息写入 waveMatrix(字符串形式) for i in range(rows): for j in range(cols): if waveSets[i, j]: origin_str = waveMatrix[i][j] for v in waveSets[i, j]: origin_str = mergeWave(origin_str, v) waveMatrix[i][j] = origin_str return waveMatrix ########################## # 主函数入口 ########################## def main(): global radiusLen, codeMatrix, cdaMatrix, waveMatrix # 声明为全局变量 if len(sys.argv) != 6: print("Usage: python script.py ") sys.exit(1) radiusLenStr = sys.argv[1] try: radiusLen = float(radiusLenStr) except ValueError: print("Error: radiusLen must be a float.") sys.exit(1) redisHost = sys.argv[2] redisPwd = sys.argv[3] redisPort = sys.argv[4] redisIdx = sys.argv[5] startTime = time.perf_counter() try: # 1) 连接 Redis pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx)) r = redis.Redis(connection_pool=pool) # 2) 获取并加载 codeMatrix codeMatrixStr = r.get('KV.AGV_MAP_ASTAR_CODE_FLAG.1') if codeMatrixStr is None: print("Error: 'KV.AGV_MAP_ASTAR_CODE_FLAG.1' not found in Redis.") sys.exit(1) codeMatrix = np.array(json.loads(codeMatrixStr.decode('utf-8')), dtype=str) # 3) 获取并加载 cdaMatrix cdaMatrixStr = r.get('KV.AGV_MAP_ASTAR_CDA_FLAG.1') if cdaMatrixStr is None: print("Error: 'KV.AGV_MAP_ASTAR_CDA_FLAG.1' not found in Redis.") sys.exit(1) cdaMatrix = np.array(json.loads(cdaMatrixStr.decode('utf-8')), dtype=object) # 4) 获取并加载 dynamicMatrix dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1') if dynamicMatrixStr is None: print("Error: 'KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1' not found in Redis.") sys.exit(1) dynamicMatrix = np.array(json.loads(dynamicMatrixStr.decode('utf-8')), dtype=object) # 5) 使用多源 BFS 计算 waveMatrix waveMatrix = process_dynamic_matrix_multi_source_bfs(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen) # 6) 将 waveMatrix 转为 JSON 并写回 Redis waveMatrixList = waveMatrix.tolist() waveMatrixJsonStr = json.dumps(waveMatrixList) r.set("KV.AGV_MAP_ASTAR_WAVE_FLAG.1", waveMatrixJsonStr) end = time.perf_counter() print(f"程序运行时间为: {end - startTime} Seconds") print("1") except Exception as e: print(f"An error occurred: {e}") sys.exit(1) if __name__ == "__main__": main()