#
luxiaotao1123
2024-12-24 6542142bebf5246e9097fafc75e69227e9fdadbf
#
1个文件已添加
242 ■■■■■ 已修改文件
zy-acs-manager/src/main/resources/agv1.py 242 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
zy-acs-manager/src/main/resources/agv1.py
New file
@@ -0,0 +1,242 @@
import ast
import sys
import numpy as np
import json
import time
import redis
from collections import deque, defaultdict
##########################
# 工具函数
##########################
def convert_to_float_array(str_array):
    """
    将字符串或可迭代对象转换为浮点型数组。
    """
    if str_array == "-":
        return np.array([], dtype=float)
    if isinstance(str_array, str):
        return np.array(ast.literal_eval(str_array), dtype=float)
    elif isinstance(str_array, list) or isinstance(str_array, np.ndarray):
        return np.array(str_array, dtype=float)
    return np.array([], dtype=float)
def initWaveMatrix(codeMatrix):
    """
    根据 codeMatrix 初始化 waveMatrix。
    若格子为 'NONE',则 waveMatrix[x][y] = "-"
    否则为 "[]"
    """
    waveMatrix = np.empty_like(codeMatrix, dtype=object)
    for x in range(codeMatrix.shape[0]):
        for y in range(codeMatrix.shape[1]):
            if codeMatrix[x][y] == 'NONE':
                waveMatrix[x][y] = "-"
            else:
                waveMatrix[x][y] = "[]"
    return waveMatrix
def mergeWave(originWave, vehicle):
    """
    originWave 是 waveMatrix[x][y] 中存储的字符串(内部是一个 JSON 数组),
    将新车辆 vehicle 合并进去(用 set 去重),
    并返回更新后的 JSON 字符串。
    """
    try:
        set_data = set(ast.literal_eval(originWave))
    except (ValueError, SyntaxError):
        set_data = set()
    set_data.add(vehicle)
    return json.dumps(list(set_data))
def convert_to_structured_array(dynamicMatrix):
    """
    将 dynamicMatrix(嵌套列表,里边是 dict)转换为 numpy 结构化数组,方便之后进行掩码筛选。
    """
    # 定义结构化数组的 dtype
    dtype = [('serial', int), ('vehicle', 'U2'), ('time', int)]
    structured_list = []
    for row in dynamicMatrix:
        for d in row:
            serial = d.get('serial', 0)
            vehicle = d.get('vehicle', '0')
            time_val = d.get('time', 0)
            structured_list.append((serial, vehicle, time_val))
    # 转换为结构化数组,并重塑为原始二维形状
    structured_array = np.array(structured_list, dtype=dtype)
    return structured_array.reshape(len(dynamicMatrix), -1)
##########################
# 多源 BFS 核心函数
##########################
def process_dynamic_matrix_multi_source_bfs(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
    """
    使用多源 BFS,一次性将所有含有车辆的格子作为起点同时进入队列。
    - 如果格子为 'NONE',则沿同一方向继续直线扩散;
    - 若遇到非 'NONE' 的格子,则进行欧几里得距离判定 <= radiusLen 时可到达,并且从此处再次朝四方向扩散(方向重置为 None)。
    """
    # 1. 转换 dynamicMatrix 为结构化数组
    dynamicMatrix = convert_to_structured_array(dynamicMatrix)
    rows, cols = dynamicMatrix.shape
    # 2. 初始化 waveMatrix(最终要存储回 Redis 的矩阵)
    waveMatrix = initWaveMatrix(codeMatrix)
    # 3. 建立一个与 waveMatrix 等大小的 2D 数组 waveSets,用于保存车辆集合(用 set 存储)
    waveSets = np.empty((rows, cols), dtype=object)
    for i in range(rows):
        for j in range(cols):
            waveSets[i, j] = set()
    # 4. 准备 BFS 队列,队列元素格式: (x, y, direction, vehicle)
    queue = deque()
    # 5. 找到所有有车辆的格子,统一入队列
    mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1')
    x_indices, y_indices = np.where(mask)
    for x, y in zip(x_indices, y_indices):
        v = dynamicMatrix[x][y]['vehicle']
        waveSets[x, y].add(v)
        queue.append((x, y, None, v))  # 起始时 direction=None
    # 6. 建立 visited 来避免重复访问
    #    因为一个位置 (x, y) 可能被多个车辆用多种方向访问,需要记录 (x, y, direction, vehicle)
    visited = set()
    # 7. BFS 主循环
    while queue:
        x, y, direction, vehicle = queue.popleft()
        if (x, y, direction, vehicle) in visited:
            continue
        visited.add((x, y, direction, vehicle))
        # 判断下一个要扩展的方向
        if direction is None:
            # 无方向时,四个方向同时扩散
            neighbors_info = [
                (x + 1, y, 'right'),
                (x - 1, y, 'left'),
                (x, y + 1, 'down'),
                (x, y - 1, 'up')
            ]
        else:
            # 有方向时,只往同一个方向扩展
            if direction == 'right':
                neighbors_info = [(x + 1, y, 'right')]
            elif direction == 'left':
                neighbors_info = [(x - 1, y, 'left')]
            elif direction == 'down':
                neighbors_info = [(x, y + 1, 'down')]
            elif direction == 'up':
                neighbors_info = [(x, y - 1, 'up')]
            else:
                neighbors_info = []
        # 遍历邻居
        for nx, ny, ndir in neighbors_info:
            # 边界检查
            if nx < 0 or nx >= rows or ny < 0 or ny >= cols:
                continue
            neighbor_code = codeMatrix[nx, ny]
            if neighbor_code == 'NONE':
                # 如果是 'NONE',则沿当前方向继续扩散
                if vehicle not in waveSets[nx, ny]:
                    waveSets[nx, ny].add(vehicle)
                    queue.append((nx, ny, ndir, vehicle))
            else:
                # 非 'NONE',需要用欧几里得距离判定
                c1 = convert_to_float_array(cdaMatrix[x, y])   # 当前坐标
                c2 = convert_to_float_array(cdaMatrix[nx, ny]) # 邻居坐标
                if c1.size < 2 or c2.size < 2:
                    continue
                dist_sqr = (c1[0] - c2[0])**2 + (c1[1] - c2[1])**2
                if dist_sqr <= radiusLen**2:
                    # 可以到达,则把车辆加入 waveSets
                    if vehicle not in waveSets[nx, ny]:
                        waveSets[nx, ny].add(vehicle)
                        # 方向重置为 None,表示四向扩散
                        queue.append((nx, ny, None, vehicle))
    # 8. BFS 完成后,将 waveSets 的信息写入 waveMatrix(字符串形式)
    for i in range(rows):
        for j in range(cols):
            if waveSets[i, j]:
                origin_str = waveMatrix[i][j]
                for v in waveSets[i, j]:
                    origin_str = mergeWave(origin_str, v)
                waveMatrix[i][j] = origin_str
    return waveMatrix
##########################
# 主函数入口
##########################
def main():
    global radiusLen, codeMatrix, cdaMatrix, waveMatrix  # 声明为全局变量
    if len(sys.argv) != 6:
        print("Usage: python script.py <radiusLen> <redisHost> <redisPwd> <redisPort> <redisIdx>")
        sys.exit(1)
    radiusLenStr = sys.argv[1]
    try:
        radiusLen = float(radiusLenStr)
    except ValueError:
        print("Error: radiusLen must be a float.")
        sys.exit(1)
    redisHost = sys.argv[2]
    redisPwd = sys.argv[3]
    redisPort = sys.argv[4]
    redisIdx = sys.argv[5]
    startTime = time.perf_counter()
    try:
        # 1) 连接 Redis
        pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx))
        r = redis.Redis(connection_pool=pool)
        # 2) 获取并加载 codeMatrix
        codeMatrixStr = r.get('KV.AGV_MAP_ASTAR_CODE_FLAG.1')
        if codeMatrixStr is None:
            print("Error: 'KV.AGV_MAP_ASTAR_CODE_FLAG.1' not found in Redis.")
            sys.exit(1)
        codeMatrix = np.array(json.loads(codeMatrixStr.decode('utf-8')), dtype=str)
        # 3) 获取并加载 cdaMatrix
        cdaMatrixStr = r.get('KV.AGV_MAP_ASTAR_CDA_FLAG.1')
        if cdaMatrixStr is None:
            print("Error: 'KV.AGV_MAP_ASTAR_CDA_FLAG.1' not found in Redis.")
            sys.exit(1)
        cdaMatrix = np.array(json.loads(cdaMatrixStr.decode('utf-8')), dtype=object)
        # 4) 获取并加载 dynamicMatrix
        dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1')
        if dynamicMatrixStr is None:
            print("Error: 'KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1' not found in Redis.")
            sys.exit(1)
        dynamicMatrix = np.array(json.loads(dynamicMatrixStr.decode('utf-8')), dtype=object)
        # 5) 使用多源 BFS 计算 waveMatrix
        waveMatrix = process_dynamic_matrix_multi_source_bfs(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen)
        # 6) 将 waveMatrix 转为 JSON 并写回 Redis
        waveMatrixList = waveMatrix.tolist()
        waveMatrixJsonStr = json.dumps(waveMatrixList)
        r.set("KV.AGV_MAP_ASTAR_WAVE_FLAG.1", waveMatrixJsonStr)
        end = time.perf_counter()
        print(f"程序运行时间为: {end - startTime} Seconds")
        print("1")
    except Exception as e:
        print(f"An error occurred: {e}")
        sys.exit(1)
if __name__ == "__main__":
    main()