树莓派: tensorflow-lite 目标检测

2020-12-25  本文已影响0人  洗洗睡吧i

0. 安装 tflite-runtime

ref: https://tensorflow.google.cn/lite/guide/python

pip3 install https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp37-cp37m-linux_armv7l.whl

1. tensorflow官方示例

tensorflow 提供了一个示例, 基于picamera的。

ref: https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/

# 1. Clone
git clone https://github.com/tensorflow/examples --depth 1

# 2. 进入文件夹
cd examples/lite/examples/object_detection/raspberry_pi

# 文件夹里总共5个文件
# README.md              #  
# annotation.py          # 用于绘制方框、标签  
# detect_picamera.py     # 主程序
# download.sh            # 下载 python 依赖包、已训练的模型 
# requirements.txt       #

# 3. 下载已训练好的模型
bash download.sh /tmp
# - 下载 python 依赖包: numpy  picamera  Pillow
# - 下载 coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip, 里面有两个文件:detect.tflite, labelmap.txt,这个label文件有乱码
# - 下载正确的label文件: https://dl.google.com/coral/canned_models/coco_labels.txt

# 4. 运行程序
python3 detect_picamera.py  --model /tmp/detect.tflite --labels /tmp/coco_labels.txt

2. 使用 opencv 调用 usb camera

我这里没有 picamera,只有一个老的 usb 接口的摄像头。 但 picamera 的 API 不支持 USB 摄像头。

下面改一下代码 使用 opencv 来调用 usb camera.

Example using TF Lite to detect objects with the Raspberry USB camera.

- Pi 3b+
- usb camera

- python 3.7.3
- tflite runtime 2.1
- opencv

- coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip

import re
import time

import numpy as np
import cv2

from tflite_runtime.interpreter import Interpreter

args_camera_width = 640
args_camera_height = 480
args_model = 'detect.tflite'
args_labels = 'coco_labels.txt'
args_threshold = 0.4

def load_labels(path):
    """Loads the labels file. Supports files with or without index numbers."""
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        labels = {}
        for row_number, content in enumerate(lines):
            pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
            if len(pair) == 2 and pair[0].strip().isdigit():
                labels[int(pair[0])] = pair[1].strip()
                labels[row_number] = pair[0].strip()

    return labels

def detect_objects(interpreter, image, threshold):
    # 识别:张量填充,运行推理
    interpreter.set_tensor(input_details[0]['index'], input_image)

    # 结果输出
    boxes = interpreter.get_tensor(output_details[0]['index'])
    classes = interpreter.get_tensor(output_details[1]['index'])
    scores = interpreter.get_tensor(output_details[2]['index'])
    boxes = np.squeeze(boxes)
    classes = np.squeeze(classes).astype(np.int32)
    scores = np.squeeze(scores)
    # print('boxes:', boxes)
    # print('classes:', classes)
    # print('scores:', classes)

    # 设置识别阈值,剔除不好的结果
    results = []
    for i, score in enumerate(scores):
        if score >= threshold:
            result = {
                'box': boxes[i],
                'class_id': classes[i],
                'score': scores[i]
    return results 

def annotate_objects(image, results):
  for rst in results:
    ymin, xmin, ymax, xmax = rst['box']
    class_id = rst['class_id']
    name = labels_dict[class_id]
    score = rst['score']

    xmin = int(xmin * args_camera_width)
    xmax = int(xmax * args_camera_width)
    ymin = int(ymin * args_camera_height)
    ymax = int(ymax * args_camera_height)
    cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0))

    txt = f'{name} {score:.2%}'
    cv2.putText(image, txt, (xmin, ymin), 0, 1, (255, 255, 255), 2)

# 1. 读取 labels
labels_dict = load_labels(args_labels)
print('labels_dict: \n ', labels_dict)

# 2. 加载模型文件
interpreter = Interpreter(args_model)

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# print('input_details:\n ', input_details)
# print('output_details:\n ', output_details)

# 3. 打开摄像头
camera = cv2.VideoCapture(0)
camera.set(3, args_camera_height)
camera.set(4, args_camera_width)

frame_rate_calc = 1.0
freq = cv2.getTickFrequency()

# 4. 目标识别
while (True):
    # 4.1 计算FPS, 开始计时
    t1 = cv2.getTickCount()

    # 4.2 从摄像头读取图片, 缩放为 300x300
    ret, frame = camera.read()
    input_image = cv2.resize(frame, (300, 300))
    input_image = np.expand_dims(input_image, axis=0)
    input_image = np.uint8(np.float32(input_image))

    # 4.3 识别:张量填充,运行推理   
    results = detect_objects(interpreter, input_image, args_threshold)
    print(f'--- {time.strftime("%Y-%m-%d %H:%M:%S")} ---')
    for rst in results:
        box = rst['box']
        class_id = rst['class_id']
        name = labels_dict[class_id]
        score = rst['score']        
        print(f'* {name} : {score:.2%}  @ {box}')
    # 4.4 将识别结果绘制在原图上
    annotate_objects(frame, results)

    # 4.5 将 FPS 绘制在原图上
    txt = f'FPS: {frame_rate_calc:.2f}'
    cv2.putText(frame, txt, (20, 30), 0, 1, (0, 255, 255), 2)

    # 4.6 显示图片
    cv2.imshow('Object detect', frame)

    # 4.7 更新计算 FPS 
    t2 = cv2.getTickCount()
    frame_rate_calc = freq / (t2 - t1)



