#
luxiaotao1123
2024-12-21 418c07bf5bb7e27d124cac00cf867039fad9060e
#
1个文件已修改
92 ■■■■ 已修改文件
zy-acs-manager/src/main/resources/agv.py 92 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
zy-acs-manager/src/main/resources/agv.py
@@ -1,10 +1,11 @@
import ast
import multiprocessing
import sys
import numpy as np
import json
import time
import redis
from collections import deque
from collections import deque, defaultdict
radiusLen = None
@@ -12,9 +13,11 @@
def convert_to_float_array(str_array):
    if isinstance(str_array, str):
        return np.array(ast.literal_eval(str_array), dtype=float)
    return str_array
    elif isinstance(str_array, list) or isinstance(str_array, np.ndarray):
        return np.array(str_array, dtype=float)
    return np.array([], dtype=float)
def getWaveScopeByCode_iterative(x, y):
def getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen):
    """
    使用广度优先搜索(BFS)来代替递归,以避免递归深度过大的问题。
    """
@@ -66,11 +69,11 @@
    return includeList
def find_value_in_matrix(value):
def find_value_in_matrix(value, codeMatrix):
    indices = np.where(codeMatrix == value)
    return list(zip(indices[0], indices[1]))
def initWaveMatrix():
def initWaveMatrix(codeMatrix):
    waveMatrix = np.empty_like(codeMatrix, dtype=object)
    for x in range(codeMatrix.shape[0]):
@@ -103,7 +106,7 @@
    structured_list = []
    for row in dynamicMatrix:
       for d in row:
           # 提取字段,确保 'time' 存在,否则设置为默认值(例如 0.0)
            # 提取字段,确保 'time' 存在,否则设置为默认值(例如 0)
           serial = d.get('serial', 0)
           vehicle = d.get('vehicle', '0')
           time_val = d.get('time', 0)
@@ -115,9 +118,7 @@
    return structured_array.reshape(len(dynamicMatrix), -1)
# 使用 numpy 加速的代码
def process_dynamic_matrix(dynamicMatrix, codeMatrix):
    global waveMatrix  # 确保 waveMatrix 是全局变量
def process_dynamic_matrix(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen, waveMatrix):
    # 将 dynamicMatrix 转换为结构化数组
    dynamicMatrix = convert_to_structured_array(dynamicMatrix)
@@ -133,10 +134,54 @@
    # 遍历满足条件的坐标
    for x, y in zip(x_indices, y_indices):
        vehicle = dynamicMatrix[x][y]['vehicle']
        includeList = getWaveScopeByCode_iterative(x, y)
        includeList = getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen)
        for include in includeList:
            originWave = waveMatrix[include['x']][include['y']]
            waveMatrix[include['x']][include['y']] = mergeWave(originWave, vehicle)
    return waveMatrix
def process_chunk(chunk, dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
    """处理数据块的函数,返回需要合并的 (x, y, vehicle) 列表"""
    local_wave_updates = []
    for data in chunk:
        x, y = data  # 假设每个数据项包含x和y坐标
        vehicle = dynamicMatrix[x][y]['vehicle']
        includeList = getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen)
        for include in includeList:
            local_wave_updates.append((include['x'], include['y'], vehicle))
    return local_wave_updates
def process_dynamic_matrix_parallel(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
    # 将 dynamicMatrix 转换为结构化数组
    dynamicMatrix = convert_to_structured_array(dynamicMatrix)
    # 创建布尔掩码
    mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1')
    x_indices, y_indices = np.where(mask)
    # 将满足条件的坐标组合成任务,每个任务包含一个数据块
    tasks = list(zip(x_indices, y_indices))
#     num_processes = multiprocessing.cpu_count()
    num_processes = 5
    chunk_size = max(1, len(tasks) // num_processes)
    chunks = [tasks[i:i + chunk_size] for i in range(0, len(tasks), chunk_size)]
    all_results = []  # 存储所有进程的结果
    # 设置进程池
    with multiprocessing.Pool(processes=num_processes) as pool:
        # 使用map方法并行处理任务
        results = pool.starmap(process_chunk, [(chunk, dynamicMatrix, codeMatrix, cdaMatrix, radiusLen) for chunk in chunks])
        for result in results:
            all_results.extend(result)
    # 使用 defaultdict 来收集每个 (x, y) 对应的所有 vehicle
    wave_updates = defaultdict(set)
    for x, y, vehicle in all_results:
        wave_updates[(x, y)].add(vehicle)
    return wave_updates
def main():
    global radiusLen, codeMatrix, cdaMatrix, waveMatrix  # 声明为全局变量
@@ -146,7 +191,11 @@
        sys.exit(1)
    radiusLenStr = sys.argv[1]
    try:
    radiusLen = float(radiusLenStr)
    except ValueError:
        print("Error: radiusLen must be a float.")
        sys.exit(1)
    redisHost = sys.argv[2]
    redisPwd = sys.argv[3]
@@ -155,6 +204,7 @@
    startTime = time.perf_counter()
    try:
    # 创建一个连接池
    pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx))
    r = redis.Redis(connection_pool=pool)
@@ -164,27 +214,34 @@
    if codeMatrixStr is None:
        print("Error: 'KV.AGV_MAP_ASTAR_CODE_FLAG.1' not found in Redis.")
        sys.exit(1)
    codeMatrix = np.array(json.loads(codeMatrixStr))
        codeMatrix = np.array(json.loads(codeMatrixStr.decode('utf-8')), dtype=str)
    # 获取并加载 cdaMatrix
    cdaMatrixStr = r.get('KV.AGV_MAP_ASTAR_CDA_FLAG.1')
    if cdaMatrixStr is None:
        print("Error: 'KV.AGV_MAP_ASTAR_CDA_FLAG.1' not found in Redis.")
        sys.exit(1)
    cdaMatrix = np.array(json.loads(cdaMatrixStr))
        cdaMatrix = np.array(json.loads(cdaMatrixStr.decode('utf-8')), dtype=object)
    # 获取并加载 dynamicMatrix
    dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1')
    if dynamicMatrixStr is None:
        print("Error: 'KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1' not found in Redis.")
        sys.exit(1)
    dynamicMatrix = np.array(json.loads(dynamicMatrixStr))
        dynamicMatrix = np.array(json.loads(dynamicMatrixStr.decode('utf-8')), dtype=object)
    # 初始化 waveMatrix
    waveMatrix = initWaveMatrix()
        waveMatrix = initWaveMatrix(codeMatrix)
    # 处理 dynamicMatrix
    process_dynamic_matrix(dynamicMatrix, codeMatrix)
        # 调用并行处理的函数
        wave_updates = process_dynamic_matrix_parallel(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen)
        # 应用所有更新到 waveMatrix
        for (x, y), vehicles in wave_updates.items():
            originWave = waveMatrix[x][y]
            for vehicle in vehicles:
                originWave = mergeWave(originWave, vehicle)
            waveMatrix[x][y] = originWave
    # 将 numpy.ndarray 转换为嵌套列表
    waveMatrixList = waveMatrix.tolist()
@@ -198,6 +255,9 @@
    # 打印程序运行时间
#     print(f"程序运行时间为: {end - startTime} Seconds")
    print("1")
    except Exception as e:
        print(f"An error occurred: {e}")
        sys.exit(1)
if __name__ == "__main__":
    main()