天池街景字符识别:基于YOLOv5的端到端实战方案解析

📅 2026/6/20 20:39:28
天池街景字符识别:基于YOLOv5的端到端实战方案解析
1. 从零开始YOLOv5与天池街景字符识别的完美邂逅第一次接触阿里天池街景字符识别比赛时我完全没想到YOLOv5这个目标检测模型能在OCR任务上表现得如此出色。这个比赛的任务是识别街景图片中的多位数字符编码传统OCR方案往往需要复杂的预处理和后处理而YOLOv5居然可以直接端到端搞定实测下来准确率能达到0.924这让我这个常年折腾OCR的老手都感到惊喜。为什么选择YOLOv5三个字快、准、省。相比传统OCR方案需要单独处理文字检测和识别两个阶段YOLOv5能一次性输出字符位置和类别部署起来特别方便。而且它的预训练模型在小样本场景下表现惊人我用的还是GTX1050这种入门级显卡训练100个epoch就能达到不错效果。如果你手头有更好的硬件效果还能更上一层楼。2. 环境搭建避开那些新手必踩的坑2.1 模型获取与依赖安装直接从YOLOv5的GitHub仓库克隆最新代码是第一步但这里有个小技巧建议使用git clone而不是下载zip包因为后续更新会更方便。安装依赖时最常见的坑就是pycocotools特别是在Windows环境下。我试过最稳的解决方案是先用conda安装Cython和Visual Studio Build Tools再用pip安装pycocotools。git clone https://github.com/ultralytics/yolov5 cd yolov5 pip install -r requirements.txt如果遇到权限问题可以加上--user参数。记得检查CUDA和cuDNN版本是否匹配这是影响训练速度的关键因素。我遇到过torch版本不兼容导致GPU无法使用的情况这时候需要先卸载原有torch再安装指定版本pip uninstall torch torchvision pip install torch1.8.1cu111 torchvision0.9.1cu111 -f https://download.pytorch.org/whl/torch_stable.html2.2 数据集准备与解析天池比赛提供的训练集包含3万张街景图片每张图片都带有JSON格式的标注。YOLOv5需要的是txt格式的标注文件每个文件对应一张图片格式为类别 x_center y_center width height。我参考了论坛大佬的解析代码但做了些优化def convert_annotation(json_path, img_dir, output_dir): os.makedirs(output_dir, exist_okTrue) with open(json_path) as f: data json.load(f) for img_name in tqdm(data): img cv2.imread(f{img_dir}/{img_name}) h, w img.shape[:2] labels data[img_name][label] boxes zip(data[img_name][left], data[img_name][top], data[img_name][width], data[img_name][height]) with open(f{output_dir}/{img_name.replace(.png, .txt)}, w) as f_txt: for label, (l, t, bw, bh) in zip(labels, boxes): x_center (l bw/2) / w y_center (t bh/2) / h norm_w bw / w norm_h bh / h f_txt.write(f{label} {x_center:.6f} {y_center:.6f} {norm_w:.6f} {norm_h:.6f}\n)这个版本增加了进度条显示和浮点数精度控制处理3万张图片大约需要15分钟。记得检查生成的txt文件是否与图片一一对应这是后续训练能否成功的关键。3. 模型配置让YOLOv5适应字符识别任务3.1 修改配置文件YOLOv5默认配置是针对COCO数据集的80类目标检测我们需要调整两个关键文件模型配置文件yolov5s.yaml# 修改类别数 nc: 10 # 数字0-9 depth_multiple: 0.33 width_multiple: 0.50数据配置文件street_yolo.yamltrain: ../tianchi/images/train val: ../tianchi/images/val nc: 10 names: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]这里有个细节容易忽略YOLOv5默认使用相对路径所以要注意文件路径的层级关系。我建议在yolov5目录下创建专门的tianchi文件夹存放所有比赛相关文件结构如下yolov5/ ├── tianchi/ │ ├── images/ │ │ ├── train/ │ │ └── val/ │ ├── labels/ │ │ ├── train/ │ │ └── val/ │ ├── street_yolo.yaml │ └── street_yolov5s.yaml3.2 数据增强策略街景字符的特点是尺寸变化大、角度多样适当的数据增强能显著提升模型鲁棒性。在data.yaml中可以配置# 在street_yolo.yaml中添加 augment: True hsv_h: 0.015 # 色相增强 hsv_s: 0.7 # 饱和度增强 hsv_v: 0.4 # 明度增强 degrees: 10 # 旋转角度 translate: 0.1 # 平移比例 scale: 0.5 # 缩放比例 shear: 2 # 剪切幅度但要注意过度增强反而会降低效果。比如字符识别任务不适合大角度旋转因为数字6和9旋转后容易混淆。我测试发现degrees设置在5-15度之间效果最佳。4. 训练技巧从入门到精通的实战经验4.1 基础训练命令启动训练最简单的命令是python train.py --data tianchi/street_yolo.yaml --cfg tianchi/street_yolov5s.yaml --epochs 100 --batch-size 16但这样直接跑效果可能不理想有几个关键参数需要调整--img-size默认640x640但街景字符通常较小可以尝试缩小到416x416--batch-size根据显存调整GTX1050建议设为8-16--weights加载预训练权重能加速收敛使用yolov5s.pt我的最佳实践组合是python train.py --data tianchi/street_yolo.yaml --cfg tianchi/street_yolov5s.yaml \ --weights yolov5s.pt --epochs 150 --batch-size 16 --img-size 416 \ --hyp data/hyps/hyp.scratch-low.yaml --optimizer AdamW --patience 104.2 训练过程监控YOLOv5内置了TensorBoard支持启动命令tensorboard --logdir runs/train重点关注三个指标mAP0.5验证集上的平均精度达到0.9以上说明模型不错obj_loss目标检测损失应该稳步下降cls_loss分类损失反映数字识别的准确度如果发现过拟合训练集指标持续提升但验证集停滞可以增加--patience参数提前停止使用--freeze参数冻结部分层增大--dropout概率5. 预测与提交完整Pipeline实现5.1 批量预测命令预测测试集并保存结果python detect.py --weights runs/train/exp/weights/best.pt \ --source tianchi/images/test/ --save-txt --save-conf \ --imgsz 416 --conf-thres 0.5 --iou-thres 0.45关键参数说明--conf-thres置信度阈值过滤低质量预测--iou-thresNMS的IoU阈值防止重复检测--save-txt保存预测结果为YOLO格式txt5.2 结果格式转换比赛要求提交CSV文件需要将YOLO格式转换为比赛格式。我优化后的转换脚本如下def yolo_to_submission(label_dir, sample_csv, output_csv): df pd.read_csv(sample_csv) pred_dict {} for txt_path in glob.glob(f{label_dir}/*.txt): with open(txt_path) as f: lines sorted(f.readlines(), keylambda x: float(x.split()[1])) pred .join([line.split()[0] for line in lines]) img_name os.path.basename(txt_path).replace(.txt, .png) pred_dict[img_name] pred df[file_code] df[file_name].map(pred_dict).fillna() df.to_csv(output_csv, indexFalse)这个版本增加了对空预测的处理fillna()避免出现NaN值导致提交失败。转换后的CSR文件可以直接在天池平台提交。6. 进阶优化突破0.95分的技巧6.1 模型集成策略单个模型达到瓶颈后可以尝试多模型投票训练yolov5s、yolov5m、yolov5l三个模型对预测结果投票TTA测试时增强使用--augment参数启用翻转增强多尺度预测组合不同--imgsz的预测结果集成示例命令for model in yolov5s yolov5m yolov5l; do python detect.py --weights runs/train/${model}/weights/best.pt \ --source tianchi/images/test/ --save-txt --nosave \ --imgsz 416 640 --augment done6.2 后处理优化原始预测可能存在以下问题重复检测同一字符漏检部分字符字符顺序错乱改进的后处理方案def refine_predictions(txt_dir, conf_thresh0.6, iou_thresh0.3): for txt_path in glob.glob(f{txt_dir}/*.txt): with open(txt_path) as f: boxes [] for line in f: cls, x, y, w, h, conf map(float, line.split()) if conf conf_thresh: boxes.append([x, cls, conf]) # 按x坐标排序确保字符顺序正确 boxes.sort(keylambda b: b[0]) # NMS去除重叠检测 boxes nms(boxes, iou_thresh) # 保存优化后的结果 with open(txt_path, w) as f: for x, cls, conf in boxes: f.write(f{int(cls)} {x} 0.5 0.5 0.5 0.5 {conf}\n)这个后处理能提升约1-2%的准确率特别是在处理长数字串时效果明显。实际项目中我通过这套方案最终将准确率提升到了0.952进入了当时比赛的前10%。