The Task Library
BertQuestionAnswerer
API loads a Bert model and answers
questions based on the content of a given passage. For more information, see the
documentation for the Question-Answer model
here
.
Key features of the BertQuestionAnswerer API
Supported BertQuestionAnswerer models
The following models are compatible with the
BertNLClassifier
API.
Run inference in Java
Step 1: Import Gradle dependency and other settings
Copy the
.tflite
model file to the assets directory of the Android module
where the model will be run. Specify that the file should not be compressed, and
add the TensorFlow Lite library to the module’s
build.gradle
file:
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import the Task Text Library dependency (NNAPI is included)
implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}
Step 2: Run inference using the API
// Initialization
BertQuestionAnswererOptions options =
BertQuestionAnswererOptions.builder()
.setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
.build();
BertQuestionAnswerer answerer =
BertQuestionAnswerer.createFromFileAndOptions(
androidContext, modelFile, options);
// Run inference
List<QaAnswer> answers = answerer.answer(contextOfTheQuestion, questionToAsk);
See the
source code
for more details.
Run inference in Swift
Step 1: Import CocoaPods
Add the TensorFlowLiteTaskText pod in Podfile
target 'MySwiftAppWithTaskAPI' do
use_frameworks!
pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end
Step 2: Run inference using the API
// Initialization
let mobileBertAnswerer = TFLBertQuestionAnswerer.questionAnswerer(
modelPath: mobileBertModelPath)
// Run inference
let answers = mobileBertAnswerer.answer(
context: context, question: question)
See the
source code
for more details.
Run inference in C++
// Initialization
BertQuestionAnswererOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertQuestionAnswerer> answerer = BertQuestionAnswerer::CreateFromOptions(options).value();
// Run inference with your inputs, `context_of_question` and `question_to_ask`.
std::vector<QaAnswer> positive_results = answerer->Answer(context_of_question, question_to_ask);
See the
source code
for more details.
Run inference in Python
Step 1: Install the pip package
pip install tflite-support
Step 2: Using the model
# Imports
from tflite_support.task import text
# Initialization
answerer = text.BertQuestionAnswerer.create_from_file(model_path)
# Run inference
bert_qa_result = answerer.answer(context, question)
See the
source code
for more options to configure
BertQuestionAnswerer
.
Example results
Here is an example of the answer results of
ALBERT model
.
Context: "The?Amazon rainforest, alternatively, the?Amazon Jungle, also known in
English as?Amazonia, is a?moist broadleaf?tropical?rainforest?in the?Amazon
biome?that covers most of the?Amazon basin?of South America. This basin
encompasses 7,000,000?km2?(2,700,000?sq?mi), of which
5,500,000?km2?(2,100,000?sq?mi) are covered by the rainforest. This region
includes territory belonging to nine nations."
Question: "Where is Amazon rainforest?"
Answers:
answer[0]: 'South America.'
logit: 1.84847, start_index: 39, end_index: 40
answer[1]: 'most of the Amazon basin of South America.'
logit: 1.2921, start_index: 34, end_index: 40
answer[2]: 'the Amazon basin of South America.'
logit: -0.0959535, start_index: 36, end_index: 40
answer[3]: 'the Amazon biome that covers most of the Amazon basin of South America.'
logit: -0.498558, start_index: 28, end_index: 40
answer[4]: 'Amazon basin of South America.'
logit: -0.774266, start_index: 37, end_index: 40
Try out the simple
CLI demo tool for BertQuestionAnswerer
with your own model and test data.
Model compatibility requirements
The
BertQuestionAnswerer
API expects a TFLite model with mandatory
TFLite Model Metadata
.
The Metadata should meet the following requirements:
input_process_units
for Wordpiece/Sentencepiece Tokenizer
3 input tensors with names "ids", "mask" and "segment_ids" for the output of
the tokenizer
2 output tensors with names "end_logits" and "start_logits" to indicate the
answer's relative position in the context