Trying Android's NNAPI ML accelerator with Object Detection on a Pixel 4 XL

Testing and comparing Android's ML acceleration with a MobileDet model.

Trying Android's NNAPI ML accelerator with Object Detection on a Pixel 4 XL

As the requirements for more private and fast, low-latency machine learning increases, so does the need for more accessible and on-device solutions capable of performing well on the so-called "edge." Two of these solutions are the Pixel Neural Core (PNC) hardware and its Edge TPU architecture currently available on the Google Pixel 4 mobile phone, and The Android Neural Networks API (NNAPI), an API designed for executing machine learning operations on Android devices.

In this article, I will show how I modified the TensorFlow Lite Object Detection demo for Android to use an Edge TPU optimized model running under the NNAPI on a Pixel 4 XL. Additionally, I want to present the changes I did to log the prediction latencies and compare those done using the default TensorFlow Lite API and the NNAPI. But before that, let me give a brief overview of the terms I've introduced so far.

Pixel Neural Core, Edge TPU and NNAPI

The Pixel Neural Core, the successor of the previous Pixel Visual Core, is a domain-specific chip that's part of the Pixel 4 hardware. It's architecture, follows that of the Edge TPU (tensor processing unit), Google's machine learning accelerator for edge computing devices. Being a chip designed for "the edge" means that it is smaller and more energy-efficient (it can perform 4 trillion operations per second while consuming just 2W) than its big counterparts you will find in Google's cloud platform.

The Edge TPU, however, is not an overall accelerator for all kinds of machine learning. The hardware is designed to improve forward-pass operations, meaning that it excels as an inference engine and not as a tool for training. That's why you will mostly find applications where the model used on the device, was trained somewhere else.

On the software side of things, we have the NNAPI. This Android API, written in C, provides acceleration for TensorFlow Lite models on devices that employ hardware accelerators such as the Pixel Visual Core and GPUs. The TensorFlow Lite framework for Android includes an NNAPI delegate, so don't worry, we won't write any C code.

Figure 1. System architecture for Android Neural Networks API. Source: https://developer.android.com/ndk/guides/neuralnetworks

The model

The model we will use for this project is the float32 version of the MobileDet object detection model optimized for the Edge TPU and trained on the COCO dataset (link). Let me quickly explain what these terms mean. MobileDet (Xiong et al.) is a very recent state-of-the-art family of lightweight object detection models for low computational power devices like mobile phones. This float32 variant means that it is not a quantized model, a model that has been transformed to reduce its size at the cost of model accuracy. On the other hand, a fully quantized model uses small weights based on 8 bits integers (source). Then, we have the COCO dataset, short for "Common Objects in Context" (Lin et al.). This collection of images has over 200k labeled images separated across 90 classes that include "bird, "cat," "person," and "car."

Now, after that bit of theory, let's take a look at the app.

The app

The app I used is based on the object detection example app for Android provided in the TensorFlow repository. However, I altered it to use the NNAPI and log to a file the inference times, data I used to compare the NNAPI and default TFLITE API's prediction time. Below is the DetectorActivity.java file, responsible for producing the detections—the complete source code is on my GitHub; I'm just showing this file since it has the most changes. In this file, I changed the name of the model (after adding the MobileDet model to the assets directory), changed the variable TF_OD_API_INPUT_SIZE to reflect the input size of MobileDet and set TF_OD_API_IS_QUANTIZED to false since the model is not quantized. Besides this, I added two lists to collect the inference times of the predictions (one list per API), and an override onStop method to dump the lists to a file once the use closes the app. Other small changes I had to do was changing NUM_DETECTIONS from TFLiteObjectDetectionAPIModel.java from 10 to 100 and adding the WRITE_EXTERNAL_STORAGE permission to the Android manifest so that the app could write the files to the Documents directory.

/*
 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.tensorflow.lite.examples.detection;

import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.Environment;
import android.os.SystemClock;
import android.util.Size;
import android.util.TypedValue;
import android.widget.Toast;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.io.Writer;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

import org.json.JSONObject;
import org.tensorflow.lite.examples.detection.customview.OverlayView;
import org.tensorflow.lite.examples.detection.customview.OverlayView.DrawCallback;
import org.tensorflow.lite.examples.detection.env.BorderedText;
import org.tensorflow.lite.examples.detection.env.ImageUtils;
import org.tensorflow.lite.examples.detection.env.Logger;
import org.tensorflow.lite.examples.detection.tflite.Classifier;
import org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel;
import org.tensorflow.lite.examples.detection.tracking.MultiBoxTracker;

/**
 * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
 * objects.
 */
public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
  private static final Logger LOGGER = new Logger();

  // Configuration values for the prepackaged SSD model.

  private static final int TF_OD_API_INPUT_SIZE = 320;
  private static final boolean TF_OD_API_IS_QUANTIZED = false;


  private static final String TF_OD_API_MODEL_FILE = "md_non_quant.tflite";


  private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/labelmap.txt";
  private static final DetectorMode MODE = DetectorMode.TF_OD_API;
  // Minimum detection confidence to track a detection.
  private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.5f;
  private static final boolean MAINTAIN_ASPECT = false;
  private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
  private static final boolean SAVE_PREVIEW_BITMAP = false;
  private static final float TEXT_SIZE_DIP = 10;
  OverlayView trackingOverlay;
  private Integer sensorOrientation;

  private Classifier detector;

  private long lastProcessingTimeMs;
  private Bitmap rgbFrameBitmap = null;
  private Bitmap croppedBitmap = null;
  private Bitmap cropCopyBitmap = null;

  private boolean computingDetection = false;

  private long timestamp = 0;

  private Matrix frameToCropTransform;
  private Matrix cropToFrameTransform;

  private MultiBoxTracker tracker;

  private BorderedText borderedText;

  private boolean isUsingNNAPI = false;
  // Log detections
  private Map<String, Integer> DETECTIONS_OUTPUT_MAP = new HashMap<>();
  private ArrayList<Long> nnapiTimes = new ArrayList<Long>();
  private ArrayList<Long> nonNNAPITimes = new ArrayList<Long>();

  @Override
  public void onPreviewSizeChosen(final Size size, final int rotation) {
    final float textSizePx =
        TypedValue.applyDimension(
            TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
    borderedText = new BorderedText(textSizePx);
    borderedText.setTypeface(Typeface.MONOSPACE);

    tracker = new MultiBoxTracker(this);

    int cropSize = TF_OD_API_INPUT_SIZE;

    try {
      detector =
          TFLiteObjectDetectionAPIModel.create(
              getAssets(),
              TF_OD_API_MODEL_FILE,
              TF_OD_API_LABELS_FILE,
              TF_OD_API_INPUT_SIZE,
              TF_OD_API_IS_QUANTIZED);
      cropSize = TF_OD_API_INPUT_SIZE;
    } catch (final IOException e) {
      e.printStackTrace();
      LOGGER.e(e, "Exception initializing classifier!");
      Toast toast =
          Toast.makeText(
              getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
      toast.show();
      finish();
    }

    previewWidth = size.getWidth();
    previewHeight = size.getHeight();

    sensorOrientation = rotation - getScreenOrientation();
    LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);

    LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
    rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
    croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);

    frameToCropTransform =
        ImageUtils.getTransformationMatrix(
            previewWidth, previewHeight,
            cropSize, cropSize,
            sensorOrientation, MAINTAIN_ASPECT);

    cropToFrameTransform = new Matrix();
    frameToCropTransform.invert(cropToFrameTransform);

    trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
    trackingOverlay.addCallback(
        new DrawCallback() {
          @Override
          public void drawCallback(final Canvas canvas) {
            tracker.draw(canvas);
            if (isDebug()) {
              tracker.drawDebug(canvas);
            }
          }
        });

    tracker.setFrameConfiguration(previewWidth, previewHeight, sensorOrientation);
  }

  @Override
  protected void processImage() {
    ++timestamp;
    final long currTimestamp = timestamp;
    trackingOverlay.postInvalidate();

    // No mutex needed as this method is not reentrant.
    if (computingDetection) {
      readyForNextImage();
      return;
    }
    computingDetection = true;

    rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);

    readyForNextImage();

    final Canvas canvas = new Canvas(croppedBitmap);
    canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
    // For examining the actual TF input.
    if (SAVE_PREVIEW_BITMAP) {
      ImageUtils.saveBitmap(croppedBitmap);
    }

    runInBackground(
        new Runnable() {
          @Override
          public void run() {
            //LOGGER.i("Running detection on image " + currTimestamp);
            final long startTime = SystemClock.uptimeMillis();
            final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
            lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;

            cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
            final Canvas canvas = new Canvas(cropCopyBitmap);
            final Paint paint = new Paint();
            paint.setColor(Color.RED);
            paint.setStyle(Style.STROKE);
            paint.setStrokeWidth(2.0f);

            float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
            switch (MODE) {
              case TF_OD_API:
                minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
                break;
            }

            final List<Classifier.Recognition> mappedRecognitions =
                new LinkedList<Classifier.Recognition>();

            for (final Classifier.Recognition result : results) {
              final RectF location = result.getLocation();
              if (location != null && result.getConfidence() >= minimumConfidence) {
                canvas.drawRect(location, paint);

                cropToFrameTransform.mapRect(location);

                result.setLocation(location);
                mappedRecognitions.add(result);

                if (!DETECTIONS_OUTPUT_MAP.containsKey(result.getTitle())) {
                  DETECTIONS_OUTPUT_MAP.put(result.getTitle(), 0);
                }

                DETECTIONS_OUTPUT_MAP.put(result.getTitle(), DETECTIONS_OUTPUT_MAP.get(result.getTitle()) + 1);
              }
            }

            tracker.trackResults(mappedRecognitions, currTimestamp);
            trackingOverlay.postInvalidate();

            computingDetection = false;
            if (isUsingNNAPI) {
              nnapiTimes.add(lastProcessingTimeMs);
            } else {
              nonNNAPITimes.add(lastProcessingTimeMs);
            }

            runOnUiThread(
                new Runnable() {
                  @Override
                  public void run() {
                    showFrameInfo(previewWidth + "x" + previewHeight);
                    showCropInfo(cropCopyBitmap.getWidth() + "x" + cropCopyBitmap.getHeight());
                    showInference(lastProcessingTimeMs + "ms");
                  }
                });
          }
        });
  }

  @Override
  protected int getLayoutId() {
    return R.layout.tfe_od_camera_connection_fragment_tracking;
  }

  @Override
  protected Size getDesiredPreviewFrameSize() {
    return DESIRED_PREVIEW_SIZE;
  }

  // Which detection model to use: by default uses Tensorflow Object Detection API frozen
  // checkpoints.
  private enum DetectorMode {
    TF_OD_API;
  }

  @Override
  protected void setUseNNAPI(final boolean isChecked) {
    isUsingNNAPI = isChecked;
    runInBackground(() -> detector.setUseNNAPI(isChecked));
  }

  @Override
  protected void setNumThreads(final int numThreads) {
    runInBackground(() -> detector.setNumThreads(numThreads));
  }

  @Override
  public synchronized void onStop() {
    JSONObject obj =new JSONObject(DETECTIONS_OUTPUT_MAP);
    try {
      Calendar cal = Calendar.getInstance();
      Date date = cal.getTime();
      SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd_HH:mm:ss");
      String formattedDate = dateFormat.format(date);
      String filename = String.format("%s_%s.json", "detections", formattedDate);

      File f = new File(
              Environment.getExternalStoragePublicDirectory(
                      Environment.DIRECTORY_DOCUMENTS), "/TensorFlowLiteDetections");

      if (!f.exists()) {
        f.mkdirs();
      }

      // Write detections logs
      File file = new File(f, filename);
      Writer output = null;
      output = new BufferedWriter(new FileWriter(file));
      output.write(obj.toString());
      output.close();

      // Write NNAPI times
      filename = String.format("%s_%s.txt", "nnapi_times", formattedDate);
      file = new File(f, filename);
      output = new BufferedWriter(new FileWriter(file));
      for(Long l: nnapiTimes) {
        output.write(l + System.lineSeparator());
      }
      output.close();


      // Write non-NNAPI times
      filename = String.format("%s_%s.txt", "non_nnapi_times", formattedDate);
      file = new File(f, filename);
      output = new BufferedWriter(new FileWriter(file));
      for(Long l: nonNNAPITimes) {
        output.write(l + System.lineSeparator());
      }
      output.close();
    } catch (Exception e) {
      LOGGER.d("DetectorActivity", "Couldn't write file." + e);
    }
    LOGGER.d("Data written to directory");

    super.onStop();
  }
}

Another important change I did was commenting out—and thus, enabling— the toggle button that allows us to run the app using the NNAPI. By default, that part of code is commented, so it is not possible to activate the API from within the app; I might be wrong, though (please correct me if you find otherwise).

You can find the complete source code behind the app in my GitHub repo at https://github.com/juandes/mobiledet-tflite-nnapi. To run it, open Android Studio,  select "open an existing Android Studio project," and select the project's root directory. Once the project is opened, click the small green hammer icon to build it. Then, click the play icon to run it either on a virtual device or on an actual device (if one is connected). If possible, use a real device.

Measuring the latency

So, how fast is the app at detecting objects? To measure the latency, I added a small functionality that writes to files the prediction time (in milliseconds) of the inference made with the normal TFLITE and NNAPI. After that, I took the files and performed a little analysis in R to get the insights from the data. Below are the results.

Figure 2: Inference times of predictions done with the TFLITE API
Figure 3: Inference times of predictions done with the TFLITE API

Figures 2 and 3 are histograms of the inference times under each API. The first one (n=909), corresponding to the default TFLITE API, has a peak around the 100 (ms.) mark and several extreme outliers at the higher end of the visualization. Figure 3 (n=1169), corresponding to predictions done using the NNAPI, has its peak around 50 ms. However, those extreme outlier values shift the mean value and the distribution towards the right. So, to better visualize the times, I removed these values and drew the same visualizations without them.  Now, they look as follows:

Figure 4: Inference times (without outliers) of predictions done with the TFLITE API
Figure 5: Inference times (without outliers) of predictions done with the NNAPI API

Better, right? The black vertical lines on both plots indicate the mean value. For the TFLITE graph, the mean inference time is 103 ms., the median is 100 ms. On the NNAPI side, the average prediction takes 55 ms. with a median of 54. Almost twice as fast.

The following video shows the app in action. Here, I'm simply pointing the phone at my computer to detect objects from a video:

Recap and conclusion

The advances of machine learning and, generally AI, are truly fascinating. First, they took over our computers and cloud, and now they are on their way to our mobile devices. Yet, there's a big difference between the latter and the platforms we traditionally use to deploy machine learning platforms. Smaller processors and battery dependencies are some of these. As a result, the development of frameworks and hardware specialized for this sort of device is increasing rapidly.

Several of these tools are the Pixel Neural Core, Edge TPU, and NNAPI. This combination of hardware and software aims to bring highly efficient and accurate AI to our mobile devices. In this article, I presented an overview of these. Then, I showed how to update the TensorFlow Lite object detection example for Android to able the NNAPI and write to file the inference times. Using these files, I did a small analysis using R to visualize them and discovered that the predictions done with the NNAPI took around half the time that those done with the default API.

Thanks for reading :)