import ast
|
import sys
|
import numpy as np
|
import json
|
import time
|
import redis
|
from collections import deque, defaultdict
|
|
##########################
|
# 工具函数
|
##########################
|
def convert_to_float_array(str_array):
|
"""
|
将字符串或可迭代对象转换为浮点型数组。
|
"""
|
if str_array == "-":
|
return np.array([], dtype=float)
|
if isinstance(str_array, str):
|
return np.array(ast.literal_eval(str_array), dtype=float)
|
elif isinstance(str_array, list) or isinstance(str_array, np.ndarray):
|
return np.array(str_array, dtype=float)
|
return np.array([], dtype=float)
|
|
def initWaveMatrix(codeMatrix):
|
"""
|
根据 codeMatrix 初始化 waveMatrix。
|
若格子为 'NONE',则 waveMatrix[x][y] = "-"
|
否则为 "[]"
|
"""
|
waveMatrix = np.empty_like(codeMatrix, dtype=object)
|
for x in range(codeMatrix.shape[0]):
|
for y in range(codeMatrix.shape[1]):
|
if codeMatrix[x][y] == 'NONE':
|
waveMatrix[x][y] = "-"
|
else:
|
waveMatrix[x][y] = "[]"
|
return waveMatrix
|
|
def mergeWave(originWave, vehicle):
|
"""
|
originWave 是 waveMatrix[x][y] 中存储的字符串(内部是一个 JSON 数组),
|
将新车辆 vehicle 合并进去(用 set 去重),
|
并返回更新后的 JSON 字符串。
|
"""
|
try:
|
set_data = set(ast.literal_eval(originWave))
|
except (ValueError, SyntaxError):
|
set_data = set()
|
set_data.add(vehicle)
|
return json.dumps(list(set_data))
|
|
def convert_to_structured_array(dynamicMatrix):
|
"""
|
将 dynamicMatrix(嵌套列表,里边是 dict)转换为 numpy 结构化数组,方便之后进行掩码筛选。
|
"""
|
# 定义结构化数组的 dtype
|
dtype = [('serial', int), ('vehicle', 'U2'), ('time', int)]
|
structured_list = []
|
for row in dynamicMatrix:
|
for d in row:
|
serial = d.get('serial', 0)
|
vehicle = d.get('vehicle', '0')
|
time_val = d.get('time', 0)
|
structured_list.append((serial, vehicle, time_val))
|
# 转换为结构化数组,并重塑为原始二维形状
|
structured_array = np.array(structured_list, dtype=dtype)
|
return structured_array.reshape(len(dynamicMatrix), -1)
|
|
|
##########################
|
# 多源 BFS 核心函数
|
##########################
|
def process_dynamic_matrix_multi_source_bfs(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
|
"""
|
使用多源 BFS,一次性将所有含有车辆的格子作为起点同时进入队列。
|
- 如果格子为 'NONE',则沿同一方向继续直线扩散;
|
- 若遇到非 'NONE' 的格子,则进行欧几里得距离判定 <= radiusLen 时可到达,并且从此处再次朝四方向扩散(方向重置为 None)。
|
"""
|
# 1. 转换 dynamicMatrix 为结构化数组
|
dynamicMatrix = convert_to_structured_array(dynamicMatrix)
|
rows, cols = dynamicMatrix.shape
|
|
# 2. 初始化 waveMatrix(最终要存储回 Redis 的矩阵)
|
waveMatrix = initWaveMatrix(codeMatrix)
|
|
# 3. 建立一个与 waveMatrix 等大小的 2D 数组 waveSets,用于保存车辆集合(用 set 存储)
|
waveSets = np.empty((rows, cols), dtype=object)
|
for i in range(rows):
|
for j in range(cols):
|
waveSets[i, j] = set()
|
|
# 4. 准备 BFS 队列,队列元素格式: (x, y, direction, vehicle)
|
queue = deque()
|
|
# 5. 找到所有有车辆的格子,统一入队列
|
mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1')
|
x_indices, y_indices = np.where(mask)
|
for x, y in zip(x_indices, y_indices):
|
v = dynamicMatrix[x][y]['vehicle']
|
waveSets[x, y].add(v)
|
queue.append((x, y, None, v)) # 起始时 direction=None
|
|
# 6. 建立 visited 来避免重复访问
|
# 因为一个位置 (x, y) 可能被多个车辆用多种方向访问,需要记录 (x, y, direction, vehicle)
|
visited = set()
|
|
# 7. BFS 主循环
|
while queue:
|
x, y, direction, vehicle = queue.popleft()
|
|
if (x, y, direction, vehicle) in visited:
|
continue
|
visited.add((x, y, direction, vehicle))
|
|
# 判断下一个要扩展的方向
|
if direction is None:
|
# 无方向时,四个方向同时扩散
|
neighbors_info = [
|
(x + 1, y, 'right'),
|
(x - 1, y, 'left'),
|
(x, y + 1, 'down'),
|
(x, y - 1, 'up')
|
]
|
else:
|
# 有方向时,只往同一个方向扩展
|
if direction == 'right':
|
neighbors_info = [(x + 1, y, 'right')]
|
elif direction == 'left':
|
neighbors_info = [(x - 1, y, 'left')]
|
elif direction == 'down':
|
neighbors_info = [(x, y + 1, 'down')]
|
elif direction == 'up':
|
neighbors_info = [(x, y - 1, 'up')]
|
else:
|
neighbors_info = []
|
|
# 遍历邻居
|
for nx, ny, ndir in neighbors_info:
|
# 边界检查
|
if nx < 0 or nx >= rows or ny < 0 or ny >= cols:
|
continue
|
|
neighbor_code = codeMatrix[nx, ny]
|
|
if neighbor_code == 'NONE':
|
# 如果是 'NONE',则沿当前方向继续扩散
|
if vehicle not in waveSets[nx, ny]:
|
waveSets[nx, ny].add(vehicle)
|
queue.append((nx, ny, ndir, vehicle))
|
else:
|
# 非 'NONE',需要用欧几里得距离判定
|
c1 = convert_to_float_array(cdaMatrix[x, y]) # 当前坐标
|
c2 = convert_to_float_array(cdaMatrix[nx, ny]) # 邻居坐标
|
if c1.size < 2 or c2.size < 2:
|
continue
|
dist_sqr = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2
|
if dist_sqr <= radiusLen**2:
|
# 可以到达,则把车辆加入 waveSets
|
if vehicle not in waveSets[nx, ny]:
|
waveSets[nx, ny].add(vehicle)
|
# 方向重置为 None,表示四向扩散
|
queue.append((nx, ny, None, vehicle))
|
|
# 8. BFS 完成后,将 waveSets 的信息写入 waveMatrix(字符串形式)
|
for i in range(rows):
|
for j in range(cols):
|
if waveSets[i, j]:
|
origin_str = waveMatrix[i][j]
|
for v in waveSets[i, j]:
|
origin_str = mergeWave(origin_str, v)
|
waveMatrix[i][j] = origin_str
|
|
return waveMatrix
|
|
|
##########################
|
# 主函数入口
|
##########################
|
def main():
|
global radiusLen, codeMatrix, cdaMatrix, waveMatrix # 声明为全局变量
|
|
if len(sys.argv) != 6:
|
print("Usage: python script.py <radiusLen> <redisHost> <redisPwd> <redisPort> <redisIdx>")
|
sys.exit(1)
|
|
radiusLenStr = sys.argv[1]
|
try:
|
radiusLen = float(radiusLenStr)
|
except ValueError:
|
print("Error: radiusLen must be a float.")
|
sys.exit(1)
|
|
redisHost = sys.argv[2]
|
redisPwd = sys.argv[3]
|
redisPort = sys.argv[4]
|
redisIdx = sys.argv[5]
|
|
startTime = time.perf_counter()
|
|
try:
|
# 1) 连接 Redis
|
pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx))
|
r = redis.Redis(connection_pool=pool)
|
|
# 2) 获取并加载 codeMatrix
|
codeMatrixStr = r.get('KV.AGV_MAP_ASTAR_CODE_FLAG.1')
|
if codeMatrixStr is None:
|
print("Error: 'KV.AGV_MAP_ASTAR_CODE_FLAG.1' not found in Redis.")
|
sys.exit(1)
|
codeMatrix = np.array(json.loads(codeMatrixStr.decode('utf-8')), dtype=str)
|
|
# 3) 获取并加载 cdaMatrix
|
cdaMatrixStr = r.get('KV.AGV_MAP_ASTAR_CDA_FLAG.1')
|
if cdaMatrixStr is None:
|
print("Error: 'KV.AGV_MAP_ASTAR_CDA_FLAG.1' not found in Redis.")
|
sys.exit(1)
|
cdaMatrix = np.array(json.loads(cdaMatrixStr.decode('utf-8')), dtype=object)
|
|
# 4) 获取并加载 dynamicMatrix
|
dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1')
|
if dynamicMatrixStr is None:
|
print("Error: 'KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1' not found in Redis.")
|
sys.exit(1)
|
dynamicMatrix = np.array(json.loads(dynamicMatrixStr.decode('utf-8')), dtype=object)
|
|
# 5) 使用多源 BFS 计算 waveMatrix
|
waveMatrix = process_dynamic_matrix_multi_source_bfs(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen)
|
|
# 6) 将 waveMatrix 转为 JSON 并写回 Redis
|
waveMatrixList = waveMatrix.tolist()
|
waveMatrixJsonStr = json.dumps(waveMatrixList)
|
r.set("KV.AGV_MAP_ASTAR_WAVE_FLAG.1", waveMatrixJsonStr)
|
|
end = time.perf_counter()
|
print(f"程序运行时间为: {end - startTime} Seconds")
|
print("1")
|
except Exception as e:
|
print(f"An error occurred: {e}")
|
sys.exit(1)
|
|
if __name__ == "__main__":
|
main()
|