import ast
|
import multiprocessing
|
import sys
|
import numpy as np
|
import json
|
import time
|
import redis
|
from collections import deque, defaultdict
|
|
radiusLen = None
|
|
# 将字符串转换为浮点型数组
|
def convert_to_float_array(str_array):
|
if isinstance(str_array, str):
|
return np.array(ast.literal_eval(str_array), dtype=float)
|
elif isinstance(str_array, list) or isinstance(str_array, np.ndarray):
|
return np.array(str_array, dtype=float)
|
return np.array([], dtype=float)
|
|
def getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen):
|
"""
|
使用广度优先搜索(BFS)来代替递归,以避免递归深度过大的问题。
|
"""
|
includeList = []
|
existNodes = set()
|
queue = deque()
|
|
originNode = {"x": x, "y": y}
|
currNode = {"x": x, "y": y}
|
queue.append(currNode)
|
existNodes.add((x, y))
|
|
while queue:
|
node = queue.popleft()
|
node_x, node_y = node['x'], node['y']
|
neighbors = [
|
(node_x + 1, node_y),
|
(node_x - 1, node_y),
|
(node_x, node_y + 1),
|
(node_x, node_y - 1)
|
]
|
|
for neighbor in neighbors:
|
nx, ny = neighbor
|
# 检查边界条件
|
if (nx < 0 or nx >= codeMatrix.shape[0] or ny < 0 or ny >= codeMatrix.shape[1]):
|
continue
|
if (nx, ny) in existNodes:
|
continue
|
|
existNodes.add((nx, ny))
|
neighbor_code = codeMatrix[nx, ny]
|
|
if neighbor_code == 'NONE':
|
queue.append({"x": nx, "y": ny})
|
else:
|
o1Cda = convert_to_float_array(cdaMatrix[x, y])
|
o2Cda = convert_to_float_array(cdaMatrix[nx, ny])
|
|
num1 = (o1Cda[0] - o2Cda[0]) ** 2
|
num2 = (o1Cda[1] - o2Cda[1]) ** 2
|
if num1 + num2 <= radiusLen ** 2:
|
includeList.append({
|
"x": int(nx),
|
"y": int(ny),
|
"code": str(codeMatrix[nx, ny])
|
})
|
queue.append({"x": nx, "y": ny})
|
|
return includeList
|
|
def find_value_in_matrix(value, codeMatrix):
|
indices = np.where(codeMatrix == value)
|
return list(zip(indices[0], indices[1]))
|
|
def initWaveMatrix(codeMatrix):
|
waveMatrix = np.empty_like(codeMatrix, dtype=object)
|
|
for x in range(codeMatrix.shape[0]):
|
for y in range(codeMatrix.shape[1]):
|
if codeMatrix[x][y] == 'NONE':
|
waveMatrix[x][y] = "-"
|
else:
|
waveMatrix[x][y] = '[]'
|
|
return waveMatrix
|
|
# 优化版本:使用集合来提高性能
|
def mergeWave(originWave, vehicle):
|
# 将字符串解析为集合
|
try:
|
set_data = set(ast.literal_eval(originWave))
|
except (ValueError, SyntaxError):
|
set_data = set()
|
# 如果 vehicle 不在集合中,则添加
|
set_data.add(vehicle)
|
# 返回序列化后的字符串
|
return json.dumps(list(set_data))
|
|
# 将 dynamicMatrix 转换为 numpy 结构化数组
|
def convert_to_structured_array(dynamicMatrix):
|
# 定义结构化数组的 dtype
|
dtype = [('serial', int), ('vehicle', 'U2'), ('time', int)]
|
|
# 确保每个字典包含所有字段
|
structured_list = []
|
for row in dynamicMatrix:
|
for d in row:
|
# 提取字段,确保 'time' 存在,否则设置为默认值(例如 0)
|
serial = d.get('serial', 0)
|
vehicle = d.get('vehicle', '0')
|
time_val = d.get('time', 0)
|
structured_list.append((serial, vehicle, time_val))
|
|
# 将嵌套的列表转换为结构化数组
|
structured_array = np.array(structured_list, dtype=dtype)
|
# 重塑为原始的二维形状
|
return structured_array.reshape(len(dynamicMatrix), -1)
|
|
# 使用 numpy 加速的代码
|
def process_dynamic_matrix(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen, waveMatrix):
|
# 将 dynamicMatrix 转换为结构化数组
|
dynamicMatrix = convert_to_structured_array(dynamicMatrix)
|
|
# 获取 dynamicMatrix 的形状
|
rows, cols = dynamicMatrix.shape
|
|
# 创建一个布尔掩码,用于筛选出 vehicle 不为 '0' 和 '-1' 的元素
|
mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1')
|
|
# 获取满足条件的 x 和 y 坐标
|
x_indices, y_indices = np.where(mask)
|
|
# 遍历满足条件的坐标
|
for x, y in zip(x_indices, y_indices):
|
vehicle = dynamicMatrix[x][y]['vehicle']
|
includeList = getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen)
|
for include in includeList:
|
originWave = waveMatrix[include['x']][include['y']]
|
waveMatrix[include['x']][include['y']] = mergeWave(originWave, vehicle)
|
|
return waveMatrix
|
|
def process_chunk(chunk, dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
|
"""处理数据块的函数,返回需要合并的 (x, y, vehicle) 列表"""
|
local_wave_updates = []
|
for data in chunk:
|
x, y = data # 假设每个数据项包含x和y坐标
|
vehicle = dynamicMatrix[x][y]['vehicle']
|
includeList = getWaveScopeByCode_iterative(x, y, codeMatrix, cdaMatrix, radiusLen)
|
for include in includeList:
|
local_wave_updates.append((include['x'], include['y'], vehicle))
|
return local_wave_updates
|
|
def process_dynamic_matrix_parallel(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen):
|
# 将 dynamicMatrix 转换为结构化数组
|
dynamicMatrix = convert_to_structured_array(dynamicMatrix)
|
|
# 创建布尔掩码
|
mask = (dynamicMatrix['vehicle'] != '0') & (dynamicMatrix['vehicle'] != '-1')
|
x_indices, y_indices = np.where(mask)
|
|
# 将满足条件的坐标组合成任务,每个任务包含一个数据块
|
tasks = list(zip(x_indices, y_indices))
|
# num_processes = multiprocessing.cpu_count()
|
num_processes = 5
|
chunk_size = max(1, len(tasks) // num_processes)
|
chunks = [tasks[i:i + chunk_size] for i in range(0, len(tasks), chunk_size)]
|
|
all_results = [] # 存储所有进程的结果
|
|
# 设置进程池
|
with multiprocessing.Pool(processes=num_processes) as pool:
|
# 使用map方法并行处理任务
|
results = pool.starmap(process_chunk, [(chunk, dynamicMatrix, codeMatrix, cdaMatrix, radiusLen) for chunk in chunks])
|
for result in results:
|
all_results.extend(result)
|
|
# 使用 defaultdict 来收集每个 (x, y) 对应的所有 vehicle
|
wave_updates = defaultdict(set)
|
for x, y, vehicle in all_results:
|
wave_updates[(x, y)].add(vehicle)
|
|
return wave_updates
|
|
def main():
|
global radiusLen, codeMatrix, cdaMatrix, waveMatrix # 声明为全局变量
|
|
if len(sys.argv) != 6:
|
print("Usage: python script.py <radiusLen> <redisHost> <redisPwd> <redisPort> <redisIdx>")
|
sys.exit(1)
|
|
radiusLenStr = sys.argv[1]
|
try:
|
radiusLen = float(radiusLenStr)
|
except ValueError:
|
print("Error: radiusLen must be a float.")
|
sys.exit(1)
|
|
redisHost = sys.argv[2]
|
redisPwd = sys.argv[3]
|
redisPort = sys.argv[4]
|
redisIdx = sys.argv[5]
|
|
startTime = time.perf_counter()
|
|
try:
|
# 创建一个连接池
|
pool = redis.ConnectionPool(host=redisHost, port=int(redisPort), password=redisPwd, db=int(redisIdx))
|
r = redis.Redis(connection_pool=pool)
|
|
# 获取并加载 codeMatrix
|
codeMatrixStr = r.get('KV.AGV_MAP_ASTAR_CODE_FLAG.1')
|
if codeMatrixStr is None:
|
print("Error: 'KV.AGV_MAP_ASTAR_CODE_FLAG.1' not found in Redis.")
|
sys.exit(1)
|
codeMatrix = np.array(json.loads(codeMatrixStr.decode('utf-8')), dtype=str)
|
|
# 获取并加载 cdaMatrix
|
cdaMatrixStr = r.get('KV.AGV_MAP_ASTAR_CDA_FLAG.1')
|
if cdaMatrixStr is None:
|
print("Error: 'KV.AGV_MAP_ASTAR_CDA_FLAG.1' not found in Redis.")
|
sys.exit(1)
|
cdaMatrix = np.array(json.loads(cdaMatrixStr.decode('utf-8')), dtype=object)
|
|
# 获取并加载 dynamicMatrix
|
dynamicMatrixStr = r.get('KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1')
|
if dynamicMatrixStr is None:
|
print("Error: 'KV.AGV_MAP_ASTAR_DYNAMIC_FLAG.1' not found in Redis.")
|
sys.exit(1)
|
dynamicMatrix = np.array(json.loads(dynamicMatrixStr.decode('utf-8')), dtype=object)
|
|
# 初始化 waveMatrix
|
waveMatrix = initWaveMatrix(codeMatrix)
|
|
# 调用并行处理的函数
|
wave_updates = process_dynamic_matrix_parallel(dynamicMatrix, codeMatrix, cdaMatrix, radiusLen)
|
|
# 应用所有更新到 waveMatrix
|
for (x, y), vehicles in wave_updates.items():
|
originWave = waveMatrix[x][y]
|
for vehicle in vehicles:
|
originWave = mergeWave(originWave, vehicle)
|
waveMatrix[x][y] = originWave
|
|
# 将 numpy.ndarray 转换为嵌套列表
|
waveMatrixList = waveMatrix.tolist()
|
# 将嵌套列表转换为 JSON 字符串
|
waveMatrixJsonStr = json.dumps(waveMatrixList)
|
|
# 将结果保存回 Redis
|
r.set("KV.AGV_MAP_ASTAR_WAVE_FLAG.1", waveMatrixJsonStr)
|
|
end = time.perf_counter()
|
# 打印程序运行时间
|
# print(f"程序运行时间为: {end - startTime} Seconds")
|
print("1")
|
except Exception as e:
|
print(f"An error occurred: {e}")
|
sys.exit(1)
|
|
if __name__ == "__main__":
|
main()
|