【vue】动作捕捉tensorflow-models/pose-

2022-12-04  本文已影响0人  西叶web

前端利用tensorflow实现动作捕捉

tensorflow介绍

官方模型https://www.tensorflow.org/js/models

里面可以看到姿势检测,这就是我们要实现动作捕捉关键的库

官方提供三种模型选项

https://github.com/tensorflow/tfjs-models/tree/master/pose-detection

可以自己点demo进行观看:

MoveNet Demo 提供17个点位 帧率可达50帧以上

BlazePose Demo 除了17个点位,还提供了面部手脚额外的点位,总共33个

PoseNet Demo 可检测多个姿势,每个提供17个点位

官方也提供了的相应的源码

源码中会使用stats和dat.gui,看起来会很乱


image.png

下面我抽离主要功能进行说明,移除stats性能监控、dat.gui操作这些逻辑

以PoseNet 为例,其他两个也是一样的,使用的差别就是在于用哪个模型而已

文字说明

主要逻辑

1 获取摄像头数据,也可以使用本地视频或者远端视频,总之就是一个video标签
2 需要一个画布canvas,把摄像头数据和点位画上去,一帧画一个,组合起来就是一个视频
3 连接模型,或者说是以模型创建探测器,因为他api叫createDetector
4 探测器创建完后,接着就是探测的对象
--因为是一个画面一个画面的探测的,所以需要先弄一个画布
--这个画布去画视频每一帧
-- 我们只需要把这画布canvas传进去就行
5 调用estimatePoses对画布进行分析,api字面意思是估计姿势,他会返回17个点位
6 拿到点位就可以进行你的业务逻辑了,现在我的业务就是把那些点按一定规律连起来,连起来的线术语叫骨骼
7 最后,一些变量要释放,防止内存泄漏

代码说明

引入依赖包

import { PoseDetector } from '@tensorflow-models/pose-detection';
import * as poseDetection from '@tensorflow-models/pose-detection';
import '@tensorflow/tfjs-backend-webgl';

虽然只使用了两个,但你需要安装这些依赖

"@mediapipe/pose": "^0.5.1635988162",
"@tensorflow-models/pose-detection": "^2.0.0",
"@tensorflow/tfjs-backend-webgl": "^4.1.0",
"@tensorflow/tfjs-converter": "^4.1.0",
"@tensorflow/tfjs-core": "^4.1.0",

html

<div>
    <canvas id="output"></canvas>
    <video id="video" playsinline autoplay width="360" height="270"></video>
</div>

顶级变量

let videoEl: HTMLVideoElement;
let canvasEl: HTMLCanvasElement;
let canvasCtx: CanvasRenderingContext2D;
let detector: PoseDetector;
let model = poseDetection.SupportedModels.PoseNet;

const DEFAULT_LINE_WIDTH = 2;
const DEFAULT_RADIUS = 4;
const SCORE_THRESHOLD = 0.5;

let requestID: any; // requestAnimationFrame

启动函数

const init = async () => {
  // 获取dom
  canvasEl = document.getElementById('output') as HTMLCanvasElement;
  videoEl = document.getElementById('video') as HTMLVideoElement;
  // 获取画布
  canvasCtx = canvasEl.getContext('2d')!;

  // 设置视频源,这里使用摄像头
  const stream = await navigator.mediaDevices.getUserMedia({
    audio: false,
    video: true,
  });
  // 设置流
  videoEl.srcObject = stream;
  // 视频加载后执行
  videoEl.onloadeddata = async function () {
    // 下一步这里开始
  };
};

video onload之后

videoEl.onloadeddata = async function () {
    const { width, height } = videoEl.getBoundingClientRect();
    canvasEl.width = width;
    canvasEl.height = height;
    // 加载模型,model 在顶级变量里已经设置为poseDetection.SupportedModels.PoseNet
    detector = await poseDetection.createDetector(model, {
      quantBytes: 4,
      architecture: 'MobileNetV1',
      outputStride: 16,
      inputResolution: { width, height },
      multiplier: 0.75,
    });
    // 开始检测
    startDetect();
  };

探测函数

// 开始检测
async function startDetect() {
  const video = document.getElementById('video') as HTMLVideoElement;
  
  // 检测画布动作
  const poses = await detector.estimatePoses(canvasEl, {
    flipHorizontal: false, // 是否水平翻转
    maxPoses: 1, // 最大检测人数
    // scoreThreshold: 0.5, // 置信度
    // nmsRadius: 20, // 非极大值抑制
  });
  // 绘制视频
  canvasCtx.drawImage(video, 0, 0, canvasEl.width, canvasEl.height);
  // 画第一个人的姿势 poses[0]
  // 画点
  drawKeypoints(canvasCtx, poses[0].keypoints);
  // 画骨骼
  drawSkeleton(canvasCtx, poses[0].keypoints, poses.id);
  // 一帧执行一次  可替换为setTimeout方案: setTimeout(()=>startDetect(),1000/16)
  requestID = requestAnimationFrame(() => startDetect());
}

画点画线的函数drawKeypoints drawSkeleton

// 画点
function drawKeypoints(ctx: CanvasRenderingContext2D, keypoints) {
  // keypointInd 主要按left middle right  返回索引,left是单数索引,right是双数索引,打印一下你就知道了
  const keypointInd = poseDetection.util.getKeypointIndexBySide(model);
  ctx.strokeStyle = 'White';
  ctx.lineWidth = DEFAULT_LINE_WIDTH;

  ctx.fillStyle = 'Red';
  for (const i of keypointInd.middle) {
    drawKeypoint(keypoints[i]);
  }

  ctx.fillStyle = 'Green';
  for (const i of keypointInd.left) {
    drawKeypoint(keypoints[i]);
  }

  ctx.fillStyle = 'Orange';
  for (const i of keypointInd.right) {
    drawKeypoint(keypoints[i]);
  }
}
function drawKeypoint(ctx: CanvasRenderingContext2D, keypoint) {
  // If score is null, just show the keypoint.
  const score = keypoint.score != null ? keypoint.score : 1;

  if (score >= SCORE_THRESHOLD) {
    const circle = new Path2D();
    circle.arc(keypoint.x, keypoint.y, DEFAULT_RADIUS, 0, 2 * Math.PI);
    ctx.fill(circle);
    ctx.stroke(circle);
  }
}
// 画骨架
function drawSkeleton(ctx: CanvasRenderingContext2D, keypoints: any, poseId?: any) {
  // Each poseId is mapped to a color in the color palette.
  const color = 'White';

  ctx.fillStyle = color;
  ctx.strokeStyle = color;
  ctx.lineWidth = DEFAULT_LINE_WIDTH;

  poseDetection.util.getAdjacentPairs(model).forEach(([i, j]) => {
    const kp1 = keypoints[i];
    const kp2 = keypoints[j];

    // If score is null, just show the keypoint.
    const score1 = kp1.score != null ? kp1.score : 1;
    const score2 = kp2.score != null ? kp2.score : 1;

    if (score1 >= SCORE_THRESHOLD && score2 >= SCORE_THRESHOLD) {
      ctx.beginPath();
      ctx.moveTo(kp1.x, kp1.y);
      ctx.lineTo(kp2.x, kp2.y);
      ctx.stroke();
    }
  });
}

准备完成,开始执行

onMounted(() => {
  init();
});

页面离开前记得释放变量

onUnmounted(() => {
  detector.dispose();
  detector = null;
  cancelAnimationFrame(requestID);
});

效果展示

gege.gif
上一篇下一篇

猜你喜欢

热点阅读