chiachan
chiachan
Published on 2025-02-07 / 30 Visits
0

gocv的yolo-detection代码分析

https://github.com/hybridgroup/gocv/tree/release/cmd/yolo-detection

文件分析及关键函数注释

文件概述

该文件实现了使用YOLOv8深度神经网络进行对象检测的功能。它从指定的视频源(摄像头或视频文件)读取帧,并使用预训练的YOLOv8模型进行对象检测,最后在窗口中显示带有检测框和标签的图像。

关键部分及注释

package main

import (
	"fmt"
	"image"
	"image/color"
	"os"

	"gocv.io/x/gocv"
)

func main() {
	// 检查命令行参数数量是否足够
	if len(os.Args) < 3 {
		fmt.Println("How to run:\nyolo-detection [videosource] [modelfile] ([backend] [device])")
		return
	}

	// 解析命令行参数
	deviceID := os.Args[1]
	model := os.Args[2]
	backend := gocv.NetBackendDefault
	if len(os.Args) > 3 {
		backend = gocv.ParseNetBackend(os.Args[3])
	}

	target := gocv.NetTargetCPU
	if len(os.Args) > 4 {
		target = gocv.ParseNetTarget(os.Args[4])
	}

	// 打开视频捕获设备(摄像头或视频文件)
	webcam, err := gocv.OpenVideoCapture(deviceID)
	if err != nil {
		fmt.Printf("Error opening video capture device: %v\n", deviceID)
		return
	}
	defer webcam.Close()

	// 创建一个窗口用于显示检测结果
	window := gocv.NewWindow("YOLO Detection")
	defer window.Close()

	// 创建一个Mat对象用于存储每一帧图像
	img := gocv.NewMat()
	defer img.Close()

	// 加载YOLOv8 ONNX模型
	net := gocv.ReadNetFromONNX(model)
	if net.Empty() {
		fmt.Printf("Error reading network model from : %v\n", model)
		return
	}
	defer net.Close()
	net.SetPreferableBackend(gocv.NetBackendType(backend))
	net.SetPreferableTarget(gocv.NetTargetType(target))

	// 获取输出层名称
	outputNames := getOutputNames(&net)
	if len(outputNames) == 0 {
		fmt.Println("Error reading output layer names")
		return
	}

	fmt.Printf("Start reading device: %v\n", deviceID)

	// 主循环:读取帧、进行检测并显示结果
	for {
		if ok := webcam.Read(&img); !ok {
			fmt.Printf("Device closed: %v\n", deviceID)
			return
		}
		if img.Empty() {
			continue
		}

		detect(&net, &img, outputNames)

		window.IMShow(img)
		if window.WaitKey(1) >= 0 {
			break
		}
	}
}

关键函数及其注释

  1. getOutputNames:

    func getOutputNames(net *gocv.Net) []string {
        var outputLayers []string
        for _, i := range net.GetUnconnectedOutLayers() {
            layer := net.GetLayer(i)
            layerName := layer.GetName()
            if layerName != "_input" {
                outputLayers = append(outputLayers, layerName)
            }
        }
        return outputLayers
    }
    
    • 功能: 获取YOLOv8模型的输出层名称。
    • 输入: net - YOLOv8模型的网络对象。
    • 输出: 包含所有输出层名称的字符串切片。
    • 说明: 遍历未连接的输出层,获取每个层的名称,并过滤掉输入层。
  2. performDetection:

    func performDetection(outs []gocv.Mat) ([]image.Rectangle, []float32, []int) {
        var classIds []int
        var confidences []float32
        var boxes []image.Rectangle
    
        // 转置输出矩阵以适应YOLOv8的格式要求
        gocv.TransposeND(outs[0], []int{0, 2, 1}, &outs[0])
    
        for _, out := range outs {
            out = out.Reshape(1, out.Size()[1])
    
            for i := 0; i < out.Rows(); i++ {
                cols := out.Cols()
                scoresCol := out.RowRange(i, i+1)
    
                scores := scoresCol.ColRange(4, cols)
                _, confidence, _, classIDPoint := gocv.MinMaxLoc(scores)
    
                if confidence > 0.5 {
                    centerX := out.GetFloatAt(i, cols-4)
                    centerY := out.GetFloatAt(i, cols-3)
                    width := out.GetFloatAt(i, cols-2)
                    height := out.GetFloatAt(i, cols-1)
    
                    left := centerX - width/2
                    top := centerY - height/2
                    right := centerX + width/2
                    bottom := centerY + height/2
                    classIds = append(classIds, classIDPoint.X)
                    confidences = append(confidences, float32(confidence))
    
                    boxes = append(boxes, image.Rect(int(left), int(top), int(right), int(bottom)))
                }
            }
        }
    
        return boxes, confidences, classIds
    }
    
    • 功能: 从网络输出中提取边界框、置信度和类别ID。
    • 输入: outs - 网络的输出层数据。
    • 输出:
      • boxes - 检测到的对象边界框。
      • confidences - 对应的置信度。
      • classIds - 对应的类别ID。
    • 说明: 遍历每个输出层的数据,提取每个检测结果的中心坐标、宽度、高度等信息,并计算边界框。
  3. drawRects:

    func drawRects(img *gocv.Mat, boxes []image.Rectangle, classes []string, classIds []int, indices []int) []string {
        var detectClass []string
        for _, idx := range indices {
            if idx == 0 {
                continue
            }
            gocv.Rectangle(img, image.Rect(boxes[idx].Min.X, boxes[idx].Min.Y, boxes[idx].Max.X, boxes[idx].Max.Y), color.RGBA{0, 255, 0, 0}, 2)
            gocv.PutText(img, classes[classIds[idx]], image.Point{boxes[idx].Min.X, boxes[idx].Min.Y - 10}, gocv.FontHersheyPlain, 0.6, color.RGBA{0, 255, 0, 0}, 1)
            detectClass = append(detectClass, classes[classIds[idx]])
        }
    
        return detectClass
    }
    
    • 功能: 在图像上绘制检测到的对象的边界框和标签。
    • 输入:
      • img - 原始图像。
      • boxes - 检测到的对象边界框。
      • classes - 类别名称列表。
      • classIds - 类别ID列表。
      • indices - 经过非极大值抑制后的索引列表。
    • 输出: 检测到的类别名称列表。
    • 说明: 根据传入的边界框和类别信息,在图像上绘制矩形框和文本标签。
  4. detect:

    func detect(net *gocv.Net, src *gocv.Mat, outputNames []string) {
        params := gocv.NewImageToBlobParams(ratio, image.Pt(640, 640), mean, swapRGB, gocv.MatTypeCV32F, gocv.DataLayoutNCHW, gocv.PaddingModeLetterbox, padValue)
        blob := gocv.BlobFromImageWithParams(*src, params)
        defer blob.Close()
    
        // 将blob输入到网络中
        net.SetInput(blob, "")
    
        // 进行前向传播,获取输出
        probs := net.ForwardLayers(outputNames)
        defer func() {
            for _, prob := range probs {
                prob.Close()
            }
        }()
    
        // 提取检测结果
        boxes, confidences, classIds := performDetection(probs)
        if len(boxes) == 0 {
            fmt.Println("No classes detected")
            return
        }
    
        // 将边界框从blob坐标转换回图像坐标
        iboxes := params.BlobRectsToImageRects(boxes, image.Pt(src.Cols(), src.Rows()))
        // 应用非极大值抑制
        indices := gocv.NMSBoxes(iboxes, confidences, scoreThreshold, nmsThreshold)
        // 绘制检测结果
        drawRects(src, iboxes, classes, classIds, indices)
    }
    
    • 功能: 对给定的图像进行对象检测。
    • 输入:
      • net - YOLOv8模型的网络对象。
      • src - 输入图像。
      • outputNames - 输出层名称列表。
    • 说明:
      • 将图像转换为blob格式并输入到网络中。
      • 进行前向传播,获取检测结果。
      • 提取边界框、置信度和类别ID。
      • 应用非极大值抑制(NMS)去除冗余检测框。
      • 在图像上绘制检测结果。

总结

该文件通过调用YOLOv8模型对视频流中的每一帧进行对象检测,并将检测结果实时显示在一个窗口中。代码结构清晰,主要由几个关键函数组成,每个函数负责不同的处理步骤,如获取输出层名称、执行检测、绘制结果等。