import os
import argparse
import glob
import cv2
import numpy as np
import onnxruntime
import tqdm
import pymysql
import time
import jsonos.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用 GPU 0def get_connection():"""创建并返回一个新的数据库连接。"""# 数据库连接信息host = 'localhost'user = 'root'password = '123456'database = 'video_streaming_database'return pymysql.connect(host=host, user=user, password=password, database=database)def ensure_connection(connection):"""确保连接有效。如果连接无效,则重新建立连接。"""if connection is None or not connection.open:print("Connection is invalid or closed. Reconnecting...")return get_connection()return connectiondef npy_files_list(folder_path):"""遍历文件夹中的所有 .npy 文件,将它们的内容合并并保存到一个新的 .npy 文件中。:param folder_path: 包含 .npy 文件的文件夹路径:param output_file: 合并后保存的 .npy 文件路径"""# 存储所有 .npy 文件内容的列表all_data = []all_name = []# 遍历文件夹中的所有文件for file_name in os.listdir(folder_path):if file_name.endswith('.npy'):file_path = os.path.join(folder_path, file_name)# 加载 .npy 文件内容并添加到列表中data = np.load(file_path)all_data.append(data)all_name.append(file_name.split(".")[0])return all_data, all_namedef get_parser():parser = argparse.ArgumentParser(description="onnx model inference")parser.add_argument("--model-path",default=R"/home/hitsz/yk_workspace/Yolov5_track/weights/sbs_r50_0206_export_params_True.onnx",help="onnx model path")parser.add_argument("--input",default="/home/hitsz/yk_workspace/Yolov5_track/test_4S_videos/test_yk1_det3/save_crops/test_yk1/person/1/*jpg",nargs="+",help="A list of space separated input images; ""or a single glob pattern such as 'directory/*.jpg'",)parser.add_argument("--output",default='/home/hitsz/yk_workspace/Yolov5_track/02_output_det/onnx_output',help='path to save the output features')parser.add_argument("--height",type=int,default=384,help="height of image")parser.add_argument("--width",type=int,default=128,help="width of image")return parserdef preprocess(image_path, image_height, image_width):original_image = cv2.imread(image_path)norm_mean = np.array([0.485, 0.456, 0.406])norm_std = np.array([0.229, 0.224, 0.225])normalized_img = (original_image / 255.0 - norm_mean) / norm_stdoriginal_image = normalized_img[:, :, ::-1]img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)return imgdef normalize(nparray, order=2, axis=-1):"""Normalize a N-D numpy array along the specified axis."""norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)return nparray / (norm + np.finfo(np.float32).eps)if __name__ == "__main__":args = get_parser().parse_args()# 配置数据库连接db_config = {'host': 'localhost','user': 'root','password': '123456','database': 'video_streaming_database','charset': 'utf8mb4',}# 定义批处理大小batch_size = 500pre_end_id = 1# 连接到数据库connection = pymysql.connect(**db_config)# 指定使用 GPU(CUDA Execution Provider)ort_sess = onnxruntime.InferenceSession(args.model_path, providers=['CUDAExecutionProvider'])# 使用示例folder_path = "/home/hitsz/yk_workspace/Yolov5_track/db_tools/human_feature/deature_npy" # 替换为包含 .npy 文件的文件夹路径all_features, all_names = npy_files_list(folder_path)print("Providers being used:", ort_sess.get_providers())input_name = ort_sess.get_inputs()[0].namewhile True:connection = ensure_connection(connection)try:with connection.cursor() as cursor:# 获取自增ID的最大值start_time = time.time()cursor.execute("SELECT MAX(id) FROM new_detection_tracking_results_1")max_id = cursor.fetchone()[0]result = cursor.fetchone() # 获取查询结果的第一行# 获取ID前面100条数据if max_id is not None:start_id = max(pre_end_id, max_id - batch_size)end_id = max(pre_end_id, max_id)cursor.execute(f"SELECT crop_image_path FROM new_detection_tracking_results_1 WHERE id <= {end_id} AND id > {start_id}")batch = cursor.fetchall() connection.commit()data = []for result in batch:# start_time_1 = time.time()crop_image_path = result[0] # 假设crop_image_path是第一个字段name_json_output_path = os.path.join(crop_image_path.split("/")[0], crop_image_path.split("/")[-1].split("_")[0] + "_" + crop_image_path.split("/")[-1].split("_")[1]+ "_track_name.json")# if os.path.exists(name_json_output_path):# continueif not os.path.exists(crop_image_path[:-4] + ".jpg"):continueimage = preprocess(crop_image_path[:-4] + ".jpg", args.height, args.width)# torch.load(os.path.join(crop_image_path.split("/")[0], crop_image_path.split("/")[-1].split("_")[0] + "_" + crop_image_path.split("/")[-1].split("_")[1]+ "_track_tensor.pth"))feat = ort_sess.run(None, {input_name: image})[0]all_feature_distances = []for feature in all_features:# 假设 query_feats 和 gallery_feats 是 2048 维的 NumPy 数组query_feats = feature # m x 2048gallery_feats = feat # n x 2048# 计算两个特征矩阵的形状m, n = query_feats.shape[0], gallery_feats.shape[0]# 计算距离矩阵distmat = np.sum(np.square(query_feats), axis=1, keepdims=True).repeat(n, axis=1) + \np.sum(np.square(gallery_feats), axis=1, keepdims=True).repeat(m, axis=1).Tdistmat = distmat - 2 * np.dot(query_feats, gallery_feats.T)# 取最小距离distance = np.min(distmat)all_feature_distances.append(distance)# 找到最小值min_value = min(all_feature_distances)# 获取最小值对应的索引min_index = all_feature_distances.index(min_value)if min_value < 175.0:# print(crop_image_path)name = all_names[min_index]id = crop_image_path.split("_")[-6]# print(name_json_output_path)data_dict = {id:str(name) + "_" + str(int(min_value))}# Initialize or read existing dataif os.path.exists(name_json_output_path):with open(name_json_output_path, 'r') as json_file:try:existing_data = json.load(json_file)if not isinstance(existing_data, dict):existing_data = {}except json.JSONDecodeError:existing_data = {}else:existing_data = {}# Add new datadata_dict = {id: str(name) + "_" + str(int(min_value))}# Update the existing data with new dataexisting_data.update(data_dict)# Save the updated data as JSONwith open(name_json_output_path, 'w') as json_file:json.dump(existing_data, json_file, indent=4)else:name = ""# # feat = normalize(feat, axis=1)# np.save(crop_image_path[:-4] + '.npy', feat)# 记录结束时间 end_time = time.time() runtime = end_time - start_timepre_end_id = end_idprint(f"起始ID:{start_id} 结束帧:{end_id} 处理ID数:{end_id - start_id} 程序运行时间:{runtime}秒")if end_time - start_time > 10:end_time = time.time() print(f"总时间{end_time - start_time}秒")print("开始下一轮访问数据库\n")continueelse:print("休眠: " + str(min(10, 10 - end_time + start_time)) + "秒\n")time.sleep(min(10, 10 - end_time + start_time))# 计算并打印运行时间 # 记录结束时间 end_time = time.time() print(f"总时间{end_time - start_time}秒")print("开始下一轮访问数据库\n")else:connection.commit()time.sleep(1)except:print("ERROR...")# finally:# connection.close()
Ubuntu环境
onnxruntime-gpu版本可以说是一个非常简单易用的框架。
通常在安装onnxruntime时,需要将其版本与pytorch版本和CUDA版本进行对应,其中ONNXRuntime与CUDA版本对应关系表如下表所示。
onnxruntime-gpu, cuda, cudnn 版本对应关系
cudnn下载地址
cuda下载地址
pytorch地址