https://github.com/hybridgroup/gocv/tree/release/cmd/yolo-detection
文件分析及关键函数注释
文件概述
该文件实现了使用YOLOv8深度神经网络进行对象检测的功能。它从指定的视频源(摄像头或视频文件)读取帧,并使用预训练的YOLOv8模型进行对象检测,最后在窗口中显示带有检测框和标签的图像。
关键部分及注释
package main
import (
"fmt"
"image"
"image/color"
"os"
"gocv.io/x/gocv"
)
func main() {
// 检查命令行参数数量是否足够
if len(os.Args) < 3 {
fmt.Println("How to run:\nyolo-detection [videosource] [modelfile] ([backend] [device])")
return
}
// 解析命令行参数
deviceID := os.Args[1]
model := os.Args[2]
backend := gocv.NetBackendDefault
if len(os.Args) > 3 {
backend = gocv.ParseNetBackend(os.Args[3])
}
target := gocv.NetTargetCPU
if len(os.Args) > 4 {
target = gocv.ParseNetTarget(os.Args[4])
}
// 打开视频捕获设备(摄像头或视频文件)
webcam, err := gocv.OpenVideoCapture(deviceID)
if err != nil {
fmt.Printf("Error opening video capture device: %v\n", deviceID)
return
}
defer webcam.Close()
// 创建一个窗口用于显示检测结果
window := gocv.NewWindow("YOLO Detection")
defer window.Close()
// 创建一个Mat对象用于存储每一帧图像
img := gocv.NewMat()
defer img.Close()
// 加载YOLOv8 ONNX模型
net := gocv.ReadNetFromONNX(model)
if net.Empty() {
fmt.Printf("Error reading network model from : %v\n", model)
return
}
defer net.Close()
net.SetPreferableBackend(gocv.NetBackendType(backend))
net.SetPreferableTarget(gocv.NetTargetType(target))
// 获取输出层名称
outputNames := getOutputNames(&net)
if len(outputNames) == 0 {
fmt.Println("Error reading output layer names")
return
}
fmt.Printf("Start reading device: %v\n", deviceID)
// 主循环:读取帧、进行检测并显示结果
for {
if ok := webcam.Read(&img); !ok {
fmt.Printf("Device closed: %v\n", deviceID)
return
}
if img.Empty() {
continue
}
detect(&net, &img, outputNames)
window.IMShow(img)
if window.WaitKey(1) >= 0 {
break
}
}
}
关键函数及其注释
-
getOutputNames
:func getOutputNames(net *gocv.Net) []string { var outputLayers []string for _, i := range net.GetUnconnectedOutLayers() { layer := net.GetLayer(i) layerName := layer.GetName() if layerName != "_input" { outputLayers = append(outputLayers, layerName) } } return outputLayers }
- 功能: 获取YOLOv8模型的输出层名称。
- 输入:
net
- YOLOv8模型的网络对象。 - 输出: 包含所有输出层名称的字符串切片。
- 说明: 遍历未连接的输出层,获取每个层的名称,并过滤掉输入层。
-
performDetection
:func performDetection(outs []gocv.Mat) ([]image.Rectangle, []float32, []int) { var classIds []int var confidences []float32 var boxes []image.Rectangle // 转置输出矩阵以适应YOLOv8的格式要求 gocv.TransposeND(outs[0], []int{0, 2, 1}, &outs[0]) for _, out := range outs { out = out.Reshape(1, out.Size()[1]) for i := 0; i < out.Rows(); i++ { cols := out.Cols() scoresCol := out.RowRange(i, i+1) scores := scoresCol.ColRange(4, cols) _, confidence, _, classIDPoint := gocv.MinMaxLoc(scores) if confidence > 0.5 { centerX := out.GetFloatAt(i, cols-4) centerY := out.GetFloatAt(i, cols-3) width := out.GetFloatAt(i, cols-2) height := out.GetFloatAt(i, cols-1) left := centerX - width/2 top := centerY - height/2 right := centerX + width/2 bottom := centerY + height/2 classIds = append(classIds, classIDPoint.X) confidences = append(confidences, float32(confidence)) boxes = append(boxes, image.Rect(int(left), int(top), int(right), int(bottom))) } } } return boxes, confidences, classIds }
- 功能: 从网络输出中提取边界框、置信度和类别ID。
- 输入:
outs
- 网络的输出层数据。 - 输出:
boxes
- 检测到的对象边界框。confidences
- 对应的置信度。classIds
- 对应的类别ID。
- 说明: 遍历每个输出层的数据,提取每个检测结果的中心坐标、宽度、高度等信息,并计算边界框。
-
drawRects
:func drawRects(img *gocv.Mat, boxes []image.Rectangle, classes []string, classIds []int, indices []int) []string { var detectClass []string for _, idx := range indices { if idx == 0 { continue } gocv.Rectangle(img, image.Rect(boxes[idx].Min.X, boxes[idx].Min.Y, boxes[idx].Max.X, boxes[idx].Max.Y), color.RGBA{0, 255, 0, 0}, 2) gocv.PutText(img, classes[classIds[idx]], image.Point{boxes[idx].Min.X, boxes[idx].Min.Y - 10}, gocv.FontHersheyPlain, 0.6, color.RGBA{0, 255, 0, 0}, 1) detectClass = append(detectClass, classes[classIds[idx]]) } return detectClass }
- 功能: 在图像上绘制检测到的对象的边界框和标签。
- 输入:
img
- 原始图像。boxes
- 检测到的对象边界框。classes
- 类别名称列表。classIds
- 类别ID列表。indices
- 经过非极大值抑制后的索引列表。
- 输出: 检测到的类别名称列表。
- 说明: 根据传入的边界框和类别信息,在图像上绘制矩形框和文本标签。
-
detect
:func detect(net *gocv.Net, src *gocv.Mat, outputNames []string) { params := gocv.NewImageToBlobParams(ratio, image.Pt(640, 640), mean, swapRGB, gocv.MatTypeCV32F, gocv.DataLayoutNCHW, gocv.PaddingModeLetterbox, padValue) blob := gocv.BlobFromImageWithParams(*src, params) defer blob.Close() // 将blob输入到网络中 net.SetInput(blob, "") // 进行前向传播,获取输出 probs := net.ForwardLayers(outputNames) defer func() { for _, prob := range probs { prob.Close() } }() // 提取检测结果 boxes, confidences, classIds := performDetection(probs) if len(boxes) == 0 { fmt.Println("No classes detected") return } // 将边界框从blob坐标转换回图像坐标 iboxes := params.BlobRectsToImageRects(boxes, image.Pt(src.Cols(), src.Rows())) // 应用非极大值抑制 indices := gocv.NMSBoxes(iboxes, confidences, scoreThreshold, nmsThreshold) // 绘制检测结果 drawRects(src, iboxes, classes, classIds, indices) }
- 功能: 对给定的图像进行对象检测。
- 输入:
net
- YOLOv8模型的网络对象。src
- 输入图像。outputNames
- 输出层名称列表。
- 说明:
- 将图像转换为blob格式并输入到网络中。
- 进行前向传播,获取检测结果。
- 提取边界框、置信度和类别ID。
- 应用非极大值抑制(NMS)去除冗余检测框。
- 在图像上绘制检测结果。
总结
该文件通过调用YOLOv8模型对视频流中的每一帧进行对象检测,并将检测结果实时显示在一个窗口中。代码结构清晰,主要由几个关键函数组成,每个函数负责不同的处理步骤,如获取输出层名称、执行检测、绘制结果等。