After you
train your own model using AutoML Vision Edge
,
you can use it in your app to detect objects in images.
There are two ways to integrate models trained from AutoML Vision Edge: You can
bundle the model by putting it inside your app’s asset folder, or you can
dynamically download it from Firebase.
Model bundling options
|
Bundled in your app
|
- The model is part of your app's APK
- The model is available immediately, even when the Android device is offline
- No need for a Firebase project
|
Hosted with Firebase
|
|
Before you begin
If you want to download a model
, make sure you
add Firebase to your Android project
,
if you have not already done so. This is not required when you bundle the model.
Add the dependencies for the TensorFlow Lite Task library to your module's
app-level gradle file, which is usually
app/build.gradle
:
For bundling a model with your app:
dependencies {
// ...
// Object detection with a bundled Auto ML model
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly-SNAPSHOT'
}
For dynamically downloading a model from Firebase, also add the Firebase ML
dependency:
dependencies {
// ...
// Object detection with an Auto ML model deployed to Firebase
implementation platform('com.google.firebase:firebase-bom:26.1.1')
implementation 'com.google.firebase:firebase-ml-model-interpreter'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
}
1. Load the model
To bundle the model with your app:
- Extract the model from the zip archive you downloaded from the
Google Cloud console.
- Include your model in your app package:
- If you don't have an assets folder in your project, create one by
right-clicking the
app/
folder, then clicking
New > Folder > Assets Folder
.
- Copy your
tflite
model file with embedded metadata to the assets
folder.
Add the following to your app's
build.gradle
file to ensure
Gradle doesn’t compress the model file when building the app:
android {
// ...
aaptOptions {
noCompress "tflite"
}
}
The model file will be included in the app package and available
as a raw asset.
To use the remotely-hosted model, create a
RemoteModel
object,
specifying the name you assigned the model when you published it:
Java
// Specify the name you assigned when you deployed the model.
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
Kotlin
// Specify the name you assigned when you deployed the model.
val remoteModel =
FirebaseCustomRemoteModel.Builder("your_model_name").build()
Then, start the model download task, specifying the conditions under which
you want to allow downloading. If the model isn't on the device, or if a newer
version of the model is available, the task will asynchronously download the
model from Firebase:
Java
DownloadConditions downloadConditions = new DownloadConditions.Builder()
.requireWifi()
.build();
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
.addOnSuccessListener(new OnSuccessListener<Void>() {
@Override
public void onSuccess(@NonNull Task<Void> task) {
// Success.
}
});
Kotlin
val downloadConditions = DownloadConditions.Builder()
.requireWifi()
.build()
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
.addOnSuccessListener {
// Success.
}
Many apps start the download task in their initialization code, but you
can do so at any point before you need to use the model.
Create an object detector from your model
After you configure your model sources, create a
ObjectDetector
object from one
of them.
If you only have a locally-bundled model, just create an object detector from your
model file and configure the confidence score
threshold you want to require (see
Evaluate your model
):
Java
// Initialization
ObjectDetectorOptions options = ObjectDetectorOptions.builder()
.setScoreThreshold(0) // Evaluate your model in the Google Cloud console
// to determine an appropriate value.
.build();
ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);
Kotlin
// Initialization
val options = ObjectDetectorOptions.builder()
.setScoreThreshold(0) // Evaluate your model in the Google Cloud console
// to determine an appropriate value.
.build()
val objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options)
If you have a remotely-hosted model, you will have to check that it has been
downloaded before you run it. You can check the status of the model download
task using the model manager's
isModelDownloaded()
method.
Although you only have to confirm this before running the object detector, if you
have both a remotely-hosted model and a locally-bundled model, it might make
sense to perform this check when instantiating the object detector: create an
object detector from the remote model if it's been downloaded, and from the local
model otherwise.
Java
FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
.addOnSuccessListener(new OnSuccessListener<Boolean>() {
@Override
public void onSuccess(Boolean isDownloaded) {
}
});
Kotlin
FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
.addOnSuccessListener { success ->
}
If you only have a remotely-hosted model, you should disable model-related
functionality—for example, grey-out or hide part of your UI—until
you confirm the model has been downloaded. You can do so by attaching a listener
to the model manager's
download()
method.
Once you know your model has been downloaded, create an object detector from the
model file:
Java
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
ObjectDetectorOptions options = ObjectDetectorOptions.builder()
.setScoreThreshold(0)
.build();
objectDetector = ObjectDetector.createFromFileAndOptions(
getApplicationContext(), modelFile.getPath(), options);
}
}
});
Kotlin
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnSuccessListener { modelFile ->
val options = ObjectDetectorOptions.builder()
.setScoreThreshold(0f)
.build()
objectDetector = ObjectDetector.createFromFileAndOptions(
applicationContext, modelFile.path, options)
}
Then, for each image you want to label, create a
TensorImage
object from your
image. You can create a
TensorImage
object from a
Bitmap
using the
fromBitmap
method:
Java
TensorImage image = TensorImage.fromBitmap(bitmap);
Kotlin
val image = TensorImage.fromBitmap(bitmap)
If your image data isn't in a
Bitmap
, you can load a pixel array as shown in
the
TensorFlow Lite docs
.
3. Run the object detector
To detect objects in an image, pass the
TensorImage
object to the
ObjectDetector
's
detect()
method.
Java
List<Detection> results = objectDetector.detect(image);
Kotlin
val results = objectDetector.detect(image)
If the object detection operation succeeds, it returns a list of
Detection
objects. Each
Detection
object represents something that was detected in the
image. You can get each object's bounding box and its labels.
For example:
Java
for (Detection result : results) {
RectF bounds = result.getBoundingBox();
List<Category> labels = result.getCategories();
}
Kotlin
for (result in results) {
val bounds = result.getBoundingBox()
val labels = result.getCategories()
}
If you want to label images in a real-time application, follow these
guidelines to achieve the best framerates:
- Throttle calls to the image labeler. If a new video frame becomes
available while the image labeler is running, drop the frame. See the
VisionProcessorBase
class in the quickstart sample app for an example.
- If you are using the output of the image labeler to overlay graphics on
the input image, first get the result, then render the image
and overlay in a single step. By doing so, you render to the display surface
only once for each input frame. See the
CameraSourcePreview
and
GraphicOverlay
classes in the quickstart sample app for an
example.
-
If you use the Camera2 API, capture images in
ImageFormat.YUV_420_888
format.
If you use the older Camera API, capture images in
ImageFormat.NV21
format.