zhang
6 天以前 9483baffba9a24a2a36fc8739fc65b59317d9142
zy-acs-manager/src/main/resources/agv.py
@@ -1,70 +1,100 @@
import ast
import sys
# -*- coding: utf-8 -*-
import ast
import multiprocessing
import sys
import numpy as np
import json
import time
import redis
from collections import deque, defaultdict
radiusLen = None
#with open("./codeMatrix.txt", "r") as file:
#    codeMatrix = np.array(json.loads(file.read()))
#with open("./cdaMatrix.txt", "r") as file:
#    data = json.loads(file.read())
#    cdaMatrix = np.array(data)
# 将字符串转换为浮点型数组
def convert_to_float_array(str_array):
    if isinstance(str_array, str):
        return np.array(ast.literal_eval(str_array), dtype=float)
    return str_array
    elif isinstance(str_array, list) or isinstance(str_array, np.ndarray):
        return np.array(str_array, dtype=float)
    return np.array([], dtype=float)
def getWaveScopeByCode(x, y):
    code = codeMatrix[x, y]
def getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen):
    """
    使用广度优先搜索(BFS)并跟踪扩展方向,以避免递归深度过大和不必要的资源浪费。
    当遇到 'NONE' 节点时,仅在当前方向上继续扩展。
    """
    includeList = []
    existNodes = set()
    spreadWaveNode({"x": x, "y": y}, {"x": x, "y": y}, existNodes, includeList)
    return includeList
    queue = deque()
def spreadWaveNode(originNode, currNode, existNodes, includeList):
    x, y = currNode['x'], currNode['y']
    neighbors = [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)]
    for neighbor in neighbors:
        extendNeighborNodes(originNode, {"x": neighbor[0], "y": neighbor[1]}, existNodes, includeList)
def extendNeighborNodes(originNode, nextNode, existNodes, includeList):
    x, y = nextNode['x'], nextNode['y']
    if (x < 0 or x >= codeMatrix.shape[0] or y < 0 or y >= codeMatrix.shape[1]):
        return
    if (x, y) in existNodes:
        return
    # 初始节点,没有方向
    originNode = {"x": x, "y": y, "dir": None}
    queue.append(originNode)
    existNodes.add((x, y))
    nextNodeCodeData = codeMatrix[x, y]
    while queue:
        node = queue.popleft()
        node_x, node_y, current_dir = node['x'], node['y'], node['dir']
    if nextNodeCodeData == 'NONE':
        spreadWaveNode(originNode, nextNode, existNodes, includeList)
    else:
        o1Cda = convert_to_float_array(cdaMatrix[originNode['x'], originNode['y']])
        o2Cda = convert_to_float_array(cdaMatrix[x, y])
        # 根据当前方向决定扩展的方向
        if current_dir is None:
            # 如果没有方向,向四个方向扩展
            neighbors = [
                (node_x + 1, node_y, 'right'),
                (node_x - 1, node_y, 'left'),
                (node_x, node_y + 1, 'down'),
                (node_x, node_y - 1, 'up')
            ]
        else:
            # 如果有方向,仅在该方向上扩展
            if current_dir == 'right':
                neighbors = [(node_x + 1, node_y, 'right')]
            elif current_dir == 'left':
                neighbors = [(node_x - 1, node_y, 'left')]
            elif current_dir == 'down':
                neighbors = [(node_x, node_y + 1, 'down')]
            elif current_dir == 'up':
                neighbors = [(node_x, node_y - 1, 'up')]
            else:
                neighbors = []
        num1 = (o1Cda[0] - o2Cda[0]) ** 2
        num2 = (o1Cda[1] - o2Cda[1]) ** 2
        if num1 + num2 <= radiusLen ** 2:
            includeList.append({"x": int(x), "y": int(y), "code": str(codeMatrix[x, y])})
            spreadWaveNode(originNode, nextNode, existNodes, includeList)
        for nx, ny, direction in neighbors:
            # 检查边界条件
            if (nx < 0 or nx >= codeMatrix.shape[0] or ny < 0 or ny >= codeMatrix.shape[1]):
                continue
            if (nx, ny) in existNodes:
                continue
# 找到某个值对应的 x, y 下标
def find_value_in_matrix(value):
            existNodes.add((nx, ny))
            neighbor_code = codeMatrix[nx, ny]
            if neighbor_code == 'NONE':
                # 遇到 'NONE' 节点,继续在当前方向上扩展
                queue.append({"x": nx, "y": ny, "dir": direction})
            else:
                # 检查距离条件
                o1Cda = convert_to_float_array(cdaMatrix[x, y])
                o2Cda = convert_to_float_array(cdaMatrix[nx, ny])
                num1 = (o1Cda[0] - o2Cda[0]) ** 2
                num2 = (o1Cda[1] - o2Cda[1]) ** 2
                if num1 + num2 <= radiusLen ** 2:
                    includeList.append({
                        "x": int(nx),
                        "y": int(ny),
                        "code": str(codeMatrix[nx, ny])
                    })
                    # 非 'NONE' 节点,重置方向
                    queue.append({"x": nx, "y": ny, "dir": None})
    return includeList
def find_value_in_matrix(value, codeMatrix):
    indices = np.where(codeMatrix == value)
    return list(zip(indices[0], indices[1]))
def initWaveMatrix():
    lev = 1
def initWaveMatrix(codeMatrix):
    waveMatrix = np.empty_like(codeMatrix, dtype=object)
    for x in range(codeMatrix.shape[0]):
@@ -79,7 +109,10 @@
# 优化版本:使用集合来提高性能
def mergeWave(originWave, vehicle):
    # 将字符串解析为集合
    set_data = set(ast.literal_eval(originWave))
    try:
        set_data = set(ast.literal_eval(originWave))
    except (ValueError, SyntaxError):
        set_data = set()
    # 如果 vehicle 不在集合中,则添加
    set_data.add(vehicle)
    # 返回序列化后的字符串
@@ -88,14 +121,25 @@
# 将 dynamicMatrix 转换为 numpy 结构化数组
def convert_to_structured_array(dynamicMatrix):
    # 定义结构化数组的 dtype
    dtype = [('serial', int), ('vehicle', 'U2')]
    dtype = [('serial', int), ('vehicle', 'U2'), ('time', int)]
    # 确保每个字典包含所有字段
    structured_list = []
    for row in dynamicMatrix:
        for d in row:
            # 提取字段,确保 'time' 存在,否则设置为默认值(例如 0)
            serial = d.get('serial', 0)
            vehicle = d.get('vehicle', '0')
            time_val = d.get('time', 0)
            structured_list.append((serial, vehicle, time_val))
    # 将嵌套的列表转换为结构化数组
    structured_array = np.array([tuple(d.values()) for row in dynamicMatrix for d in row], dtype=dtype)
    structured_array = np.array(structured_list, dtype=dtype)
    # 重塑为原始的二维形状
    return structured_array.reshape(len(dynamicMatrix), -1)
# 使用 numpy 加速的代码
def process_dynamic_matrix(dynamicMatrix, codeMatrix):
def process_dynamic_matrix(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen, waveMatrix):
    # 将 dynamicMatrix 转换为结构化数组
    dynamicMatrix = convert_to_structured_array(dynamicMatrix)
@@ -110,60 +154,131 @@
    # 遍历满足条件的坐标
    for x, y in zip(x_indices, y_indices):
        # print(code)
        data = dynamicMatrix[x][y]
        vehicle = data['vehicle']
        includeList = getWaveScopeByCode(x,y)
        vehicle = dynamicMatrix[x][y]['vehicle']
        includeList = getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen)
        for include in includeList:
            originWave = waveMatrix[include['x']][include['y']]
            waveMatrix[include['x']][include['y']] = mergeWave(originWave, vehicle)
radiusLenStr = sys.argv[1]
radiusLen = float(radiusLenStr)
    return waveMatrix
codeMatrixPath = sys.argv[2]
cdaMatrixPath = sys.argv[3]
def process_chunk(chunk, dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
    """处理数据块的函数,返回需要合并的 (x, y, vehicle) 列表"""
    local_wave_updates = []
    for data in chunk:
        x, y = data  # 假设每个数据项包含x和y坐标
        vehicle = dynamicMatrix[x][y]['vehicle']
        includeList = getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen)
        for include in includeList:
            local_wave_updates.append((include['x'], include['y'], vehicle))
    return local_wave_updates
redisHost = sys.argv[4]
redisPwd = sys.argv[5]
redisPort = sys.argv[6]
redisIdx = sys.argv[7]
def process_dynamic_matrix_parallel(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
    # 将 dynamicMatrix 转换为结构化数组
    dynamicMatrix = convert_to_structured_array(dynamicMatrix)
with open(codeMatrixPath, "r") as file:
    codeMatrix = np.array(json.loads(file.read()))
    # 创建布尔掩码
    mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1')
    x_indices, y_indices = np.where(mask)
with open(cdaMatrixPath, "r") as file:
    data = json.loads(file.read())
    cdaMatrix = np.array(data)
    # 将满足条件的坐标组合成任务,每个任务包含一个数据块
    tasks = list(zip(x_indices, y_indices))
#     num_processes = multiprocessing.cpu_count()
    num_processes = 5
    chunk_size = max(1, len(tasks) // num_processes)
    chunks = [tasks[i:i + chunk_size] for i in range(0, len(tasks), chunk_size)]
startTime = time.perf_counter()
    all_results = []  # 存储所有进程的结果
waveMatrix = initWaveMatrix()
    # 设置进程池
    with multiprocessing.Pool(processes=num_processes) as pool:
        # 使用map方法并行处理任务
        results = pool.starmap(process_chunk, [(chunk, dynamicMatrix, codeMatrix, cdaMatrix, radiusLen) for chunk in chunks])
        for result in results:
            all_results.extend(result)
# 创建一个连接池
pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx))
r = redis.Redis(connection_pool=pool)
dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1')
    # 使用 defaultdict 来收集每个 (x, y) 对应的所有 vehicle
    wave_updates = defaultdict(set)
    for x, y, vehicle in all_results:
        wave_updates[(x, y)].add(vehicle)
dynamicMatrix = np.array(json.loads(dynamicMatrixStr))
# # 使用 numpy 加速的代码
process_dynamic_matrix(dynamicMatrix, codeMatrix)
    return wave_updates
# for x in range(dynamicMatrix.shape[0]):
#     for y in range(dynamicMatrix.shape[1]):
#         data = dynamicMatrix[x, y]
#         vehicle = data['vehicle']
#         if vehicle != '0' and vehicle != '-1':
#             getWaveScopeByCode(x, y)
def main():
    global radiusLen, codeMatrix, cdaMatrix, waveMatrix  # 声明为全局变量
# 将 numpy.ndarray 转换为嵌套列表
waveMatrixList = waveMatrix.tolist()
# 将嵌套列表转换为 JSON 字符串
waveMatrixJsonStr = json.dumps(waveMatrixList)
    if len(sys.argv) != 6:
        print("Usage: python script.py <radiusLen> <redisHost> <redisPwd> <redisPort> <redisIdx>")
        sys.exit(1)
r.set("KV.AGV_MAP_ASTAR_WAVE_FLAG.1",waveMatrixJsonStr)
    radiusLenStr = sys.argv[1]
    try:
        radiusLen = float(radiusLenStr)
    except ValueError:
        print("Error: radiusLen must be a float.")
        sys.exit(1)
end = time.perf_counter()
# print('程序运行时间为: %s Seconds' % (end - startTime))
print("1")
    redisHost = sys.argv[2]
    redisPwd = sys.argv[3]
    redisPort = sys.argv[4]
    redisIdx = sys.argv[5]
    startTime = time.perf_counter()
    try:
        # 创建一个连接池
        pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx))
        r = redis.Redis(connection_pool=pool)
        # 获取并加载 codeMatrix
        codeMatrixStr = r.get('KV.AGV_MAP_ASTAR_CODE_FLAG.1')
        if codeMatrixStr is None:
            print("Error: 'KV.AGV_MAP_ASTAR_CODE_FLAG.1' not found in Redis.")
            sys.exit(1)
        codeMatrix = np.array(json.loads(codeMatrixStr.decode('utf-8')), dtype=str)
        # 获取并加载 cdaMatrix
        cdaMatrixStr = r.get('KV.AGV_MAP_ASTAR_CDA_FLAG.1')
        if cdaMatrixStr is None:
            print("Error: 'KV.AGV_MAP_ASTAR_CDA_FLAG.1' not found in Redis.")
            sys.exit(1)
        cdaMatrix = np.array(json.loads(cdaMatrixStr.decode('utf-8')), dtype=object)
        # 获取并加载 dynamicMatrix
        dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1')
        if dynamicMatrixStr is None:
            print("Error: 'KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1' not found in Redis.")
            sys.exit(1)
        dynamicMatrix = np.array(json.loads(dynamicMatrixStr.decode('utf-8')), dtype=object)
        # 初始化 waveMatrix
        waveMatrix = initWaveMatrix(codeMatrix)
        # 调用并行处理的函数
        wave_updates = process_dynamic_matrix_parallel(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen)
        # 应用所有更新到 waveMatrix
        for (x, y), vehicles in wave_updates.items():
            originWave = waveMatrix[x][y]
            for vehicle in vehicles:
                originWave = mergeWave(originWave, vehicle)
            waveMatrix[x][y] = originWave
        # 将 numpy.ndarray 转换为嵌套列表
        waveMatrixList = waveMatrix.tolist()
        # 将嵌套列表转换为 JSON 字符串
        waveMatrixJsonStr = json.dumps(waveMatrixList)
        # 将结果保存回 Redis
        r.set("KV.AGV_MAP_ASTAR_WAVE_FLAG.1", waveMatrixJsonStr)
        end = time.perf_counter()
        # 打印程序运行时间
#         print(f"程序运行时间为: {end - startTime} Seconds")
        print("1")
    except Exception as e:
        print(f"An error occurred: {e}")
        sys.exit(1)
if __name__ == "__main__":
    main()