Tensorflow Body Segmentation API

ยท 667 words ยท 4 minute read

demo image

My experience with AI/Machine Learning is mostly as a user and application developer. This eye-tracking project was one of my first personal development ideas that I brought to life. In undergrad, I took an intro to AI course and learned the basics: motivation, history, basic models, and applications. Computer Vision continues to motivate exciting use cases for me, and I’ve found a cool tool for my next project!

BlazePose ๐Ÿ”—

Google Research published the original BlazePose paper in 2020 demonstrating an accurate body pose tracking model which could run on generally accessible hardware (30 frames/second on a Pixel 2!).

At a high-level, the technology uses heatmaps and neural networks with some optimization.

We […] use an encoder-decoder network architecture to predict heatmaps for all joints, followed by another encoder that regresses directly to the coordinates of all joints.

Especially important to me is the ability to run this on commodity devices that everyone carries in their pockets. I was extremely impressed to see the performance comparisons.

At the same time, BlazePose performs 25โ€“75 times faster on a single mid-tier phone CPU compared to OpenPose on a 20 core desktop CPU[5] depending on the requested quality

Creating a Demo App ๐Ÿ”—

Despite javascript being the first language I fell in love with and a huge part of my path to becoming the engineer I am today, I absolutely hate it. The popularity and general applicability of javascript has been its downfall with a complex and intimidating ecosystem.

I tried to recreate the demo from the repo using a single html file with the include scripts for the library and an inline script to run the logic. Silly me, that’s not how es6, or cjs, or whatever type modules work. I fiddled with this for a while before abandoning it and just pulling the example code out of the repo. After digging through the example, it makes sense why my initial attempt wasn’t fruitful. This is a complex little app; I don’t know why I expected my minimal-effort solution to work!

By pulling away the parts of the example I didn’t care about, we are down to under 200 lines of really interesting code. This is the code that deals with the camera feed, processing, and drawing predictions.

This is what my dependencies in my package.json look like

"dependencies": {
    "@mediapipe/pose": "~0.4.0",
    "@tensorflow-models/pose-detection": "^2.0.0",
    "@tensorflow/tfjs-backend-webgl": "^4.0.0",
    "@tensorflow/tfjs-converter": "^4.0.0",
    "@tensorflow/tfjs-core": "^4.0.0"
  }

I import them like so in my index.js. This snippet shows initializing a detector as well as gathering data from it in the form of poses.

import '@tensorflow/tfjs-backend-webgl';
import * as posedetection from '@tensorflow-models/pose-detection';
...
let detector = await posedetection.createDetector(
    posedetection.SupportedModels.BlazePose, 
    {
        runtime: 'tfjs',
        modelType: STATE.modelConfig.type});
...
try {
    poses = await detector.estimatePoses(
        camera.video,
        {maxPoses: STATE.modelConfig.maxPoses, flipHorizontal: false});
} catch (error) {
    detector.dispose();
    detector = null;
}

The camera does a lot of heavy lifting and I recommend looking at that file. Here are the highlights: the camera uses the getUserMedia api to ingest video, posedetection outputs keypoints, we can access the score for each keypoint.

const stream = await navigator.mediaDevices.getUserMedia(videoConfig);
const camera = new Camera();
camera.video.srcObject = stream;
...
const keypointInd = posedetection.util.getKeypointIndexBySide(params.STATE.model);
...
if (pose.keypoints != null) {
      this.drawKeypoints(pose.keypoints);
      this.drawSkeleton(pose.keypoints, pose.id);
    }
...
posedetection.util.getAdjacentPairs(params.STATE.model).forEach(([i, j]) => {
      const kp1 = keypoints[i];
      const kp2 = keypoints[j];

      // If score is null, just show the keypoint.
      const score1 = kp1.score != null ? kp1.score : 1;
      const score2 = kp2.score != null ? kp2.score : 1;
      const scoreThreshold = params.STATE.modelConfig.scoreThreshold || 0;

      if (score1 >= scoreThreshold && score2 >= scoreThreshold) {
        this.ctx.beginPath();
        this.ctx.moveTo(kp1.x, kp1.y);
        this.ctx.lineTo(kp2.x, kp2.y);
        this.ctx.stroke();
      }
    });

Exciting! ๐Ÿ”—

This gives us a skeleton drawn over the video feed. The skeleton is a set of lines connecting key points that estimate the video feed subject’s joints and focal points. With this, we can understand the body position of the subject in the video feed, and with that there are some really exciting use cases!

Resources ๐Ÿ”—