浏览器用TensorFlow PoseNet进行姿态估计,支持本
![](https://img.haomeiwen.com/i33776/5dfe198bfcc04045.png)
TensorFlow 在18年的一篇文章里面有报道:
英文版:Real-time Human Pose Estimation in the Browser with TensorFlow.js
翻译版:在浏览器中用TensorFlow.js进行实时人体姿态估计
简单介绍下,就是利用TensorFlow.js版本的PoseNet模型来识别,现在官网中也能看到了。
![](https://img.haomeiwen.com/i33776/2257457d32a253e3.png)
点击TensorFlow的姿态估计直接跳到了GitHub,里面有详细的介绍,教你如何配置,并且还贴心的给出了一个在线的demo,但是需要梯子才能运行。
使用介绍
大概看了下文档,使用起来还是比较简单的,demo里面其实给了两个例子,一个是静态图片的姿态识别,一个是摄像头的实时姿态识别。
下面就先用静态图片识别做个介绍,分解一下主要就是下面4个步骤:
- 首先需要加载 TensorFlow.js 和 Posenet
- 执行posenet.load方法拿到模型对象
- 调用模型对象的estimateSinglePose方法去识别图片
- 上一步会返回17个关键点的坐标,将坐标点绘制到网页即可
看起来还挺简单,但是理想很丰满,现实很骨感呀,下面说说可能遇到的坑:
TensorFlow.js无法加载
demo给出的TensorFlow.js
和 posenet
地址无法访问,得用梯子才能拿到,所以最好是保存到本地,然后每次从自己服务器加载。
模型load不了
执行posenet.load半天没反应,打开浏览器的控制台,发现一堆bin
文件的网络请求。
![](https://img.haomeiwen.com/i33776/ac317bfb56ee8a51.png)
原来是模型文件都在google服务器存着的,每次load都是根据你的配置信息在线下载对应的模型文件,所以运行的时候还得用梯子,好在文档中说load方法有个modelUrl
参数,可以指定模型文件的位置。modelUrl是设置.json
文件的存放位置,框架会根据这个.json
自动去下载对应的.bin
文件,所以只需要把.json
文件和.bin
文件手动下载下来一起放到自己服务器上就行。然后调用如下:
posenet.load({ modelUrl: '/pose_models/model-stride16.json' })
.then(net => {
console.log(net);
})
.catch(err => {
console.log(err);
})
模型文件一般都是几十兆甚至上百兆,如果自己服务器带宽比较小,可以放到七牛或者阿里云的oss上。
辅助函数
识别之后拿到了坐标,还需要自己去绘制出来,所以得自己去写绘制方法。
下面是源码
<!DOCTYPE html>
<html>
<head>
<!-- 加载 TensorFlow.js,建议换成自己的服务器地址 -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- 加载 Posenet,建议换成自己的服务器地址 -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/posenet"></script>
<style type="text/css">
.wrap {
display: flex;
justify-content: center;
align-items: center;
flex-direction: row;
}
#myImg {
width: 300px;
}
</style>
</head>
<body>
<div class="wrap">
<img id="myImg" crossOrigin="anonymous" src="https://img.haomeiwen.com/i33776/129d3871bccfaa22.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240"/>
<canvas id="output"></canvas>
</div>
</body>
<script>
const color = "#ff0000";
const minConfidence = 0.2;
const lineWidth = 1;
// ######################################### 工具函数
// 坐标转换
function toTuple({ y, x }) {
return [y, x];
}
// 将图片绘制到canvas
function renderImageToCanvas(image, size, canvas) {
canvas.width = size[0];
canvas.height = size[1];
const ctx = canvas.getContext('2d');
ctx.drawImage(image, 0, 0, size[0], size[1]);
}
// 画关键点
function drawKeypoints(keypoints, ctx, scale = 1) {
for (let i = 0; i < keypoints.length; i++) {
const keypoint = keypoints[i];
if (keypoint.score < minConfidence) {
continue;
}
const { y, x } = keypoint.position;
drawPoint(ctx, y * scale, x * scale, 3, color);
}
}
// canvas画点
function drawPoint(ctx, y, x, r, color) {
ctx.beginPath();
ctx.arc(x, y, r, 0, 2 * Math.PI);
ctx.fillStyle = color;
ctx.fill();
}
// 关键点连线
function drawSkeleton(keypoints, ctx, scale = 1) {
const adjacentKeyPoints =
posenet.getAdjacentKeyPoints(keypoints, minConfidence);
adjacentKeyPoints.forEach((keypoints) => {
drawSegment(
toTuple(keypoints[0].position), toTuple(keypoints[1].position), color, scale, ctx);
});
}
// canvas画线
function drawSegment([ay, ax], [by, bx], color, scale, ctx) {
ctx.beginPath();
ctx.moveTo(ax * scale, ay * scale);
ctx.lineTo(bx * scale, by * scale);
ctx.lineWidth = lineWidth;
ctx.strokeStyle = color;
ctx.stroke();
}
// ######################################### 识别图片
function detectImg() {
let imageElement = document.getElementById('myImg');
let canvas = document.getElementById('output');
// 设置加载模型走自己服务器,不设置则走google的服务器
posenet.load({ modelUrl: '/pose_models/model-stride16.json' }).then(net => {
// posenet.load({}).then(net => {
return net.estimateSinglePose(imageElement, {
flipHorizontal: false
});
}).then(pose => {
console.log(pose);
renderImageToCanvas(imageElement, [imageElement.width, imageElement.height], canvas);
let ctx = canvas.getContext('2d');
drawKeypoints(pose.keypoints, ctx);
drawSkeleton(pose.keypoints, ctx);
});
}
// 识别
detectImg();
</script>
</html>
结果如下,这年头文章不放两张妹子图都没人看呀:
![](https://img.haomeiwen.com/i33776/77f0816d802029b8.png)
代码优化
上面的代码还有个问题就是,每次加载页面都需要去加载模型文件,虽然说是在自己的服务器上,但是毕竟模型文件还是有几十兆呀。好在浏览器能缓存模型文件,但是还是有缓存失效的问题,要是其他页面要用也要再加载一遍,太浪费资源了。
不久前,看到有些文章说 TensorFlow.js 微信小程序插件开始支持模型缓存了,然后去翻官方文档,发现确实可以将模型缓存到本地的,并且支持多种方式缓存,文档地址:https://www.tensorflow.org/js/guide/save_load
浏览器上建议缓存到indexeddb
,一来没有文件大小限制,二来同域名下均可读取,完美。
但是posenet这个库有点坑,居然没有开放出这个方法,没办法,只能翻了源码之后自己写一个:
const MOBILENET_V1_CONFIG = {
architecture: 'MobileNetV1',
outputStride: 16,
multiplier: 0.75,
inputResolution: 257,
}
async function loadMobileNet(config=MOBILENET_V1_CONFIG) {
const outputStride = config.outputStride;
const quantBytes = config.quantBytes;
const multiplier = config.multiplier;
var graphModel = null;
try {
graphModel = await tf.loadGraphModel('indexeddb://my-model');
console.log('从缓存中加载模型');
} catch(e) {
console.log(e);
graphModel = await tf.loadGraphModel('/pose_models/model-stride16.json');
graphModel.save('indexeddb://my-model');
console.log('从网络中加载模型');
}
const mobilenet = new posenet.MobileNet(graphModel, outputStride);
const validInputResolution = [config.inputResolution, config.inputResolution];
return new posenet.PoseNet(mobilenet, validInputResolution);
}
ok,现在把以前的posenet.load
方法替换成loadMobileNet
就行了,这样第一次会从网络加载,后面就走本地缓存了,而且同域名下其他页面也能直接从缓存加载:
loadMobileNet().then(net => {
return net.estimateSinglePose(imageElement, {
flipHorizontal: false
});
}).then(pose => {
console.log(pose);
});
实时视频识别
这个得先拉起用户摄像头,然后一帧一帧的去绘制,其实和识别图片原理是一样的。官方的demo里面有具体代码可以自行查看。
![](https://img.haomeiwen.com/i33776/5fac84addea1ee1a.gif)