Home Resources From PyTorch to Android: Creating a Quantized TensorFlow Lite Model

From PyTorch to Android: Creating a Quantized TensorFlow Lite Model

This step-by-step guide details how to convert a pretrained computer vision model from PyTorch to quantized TensorFlow Lite model. We demonstrate this process using the timm MobileNetV3 classifier, exporting it to ONNX and then converting it to an INT8 quantized TensorFlow Lite model. Finally, we validate the converted model on an Android device using a sample application.

An official solution to convert PyTorch to TFLite is already available. However, this approach does not include the quantization process.

Edge AI is advancing rapidly, enabling real-time, efficient processing without relying on cloud computing. This shift enhances privacy, reduces latency, and lowers operational costs, making AI ideal for various applications:

  • Face Unlock on Smartphones – AI processes facial recognition locally for instant unlocking.
  • Voice Assistants – Devices like Google Nest execute commands locally for faster response times.
  • Healthcare Wearables – Smartwatches detect heart irregularities instantly, independent of an internet connection.

As edge AI continues to evolve, it is reshaping industries by making devices more intelligent, responsive, and independent.

Problem Statement

Deploying AI models on edge devices presents challenges in reducing resource demands while maintaining model performance. Limitations include restricted RAM, flash storage, CPU/GPU instruction support, and execution speed constraints.

While ONNX models can run on Android, converting them to TensorFlow Lite is highly advantageous due to its smaller size, faster inference, and superior hardware acceleration via NNAPI, GPU, and EdgeTPU. Although ONNX Runtime is cross-platform, TensorFlow Lite is preferred on Android due to its mobile-first optimizations, performance benefits, and deep integration within the Android ecosystem. Newer Android devices use the LiteRT library. More information about LiteRT can be found in Google’s blog post.


Prerequisites and Dependencies

  • A device running Android 14 or later
  • Permissions to deploy applications (sideloading an application) and access to device resources (READ_EXTERNAL_STORAGE permissions)
  • Android Studio 2024.2.2 or later
  • An environment capable of running Docker images

Proposed Solution

Setting Up the Environment

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
pip install timm==1.0.14
pip install onnx==1.17.0
pip install onnxsim==0.4.36
pip install timm==1.0.14 pip install onnx==1.17.0 pip install onnxsim==0.4.36
pip install timm==1.0.14
pip install onnx==1.17.0
pip install onnxsim==0.4.36

Downloading and Converting MobileNetV3 to ONNX

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import onnx
from onnxsim import simplify
from pathlib import Path
import timm
import torch
# Load MobileNetV3 model (large variant)
model = timm.create_model("mobilenetv3_large_100", pretrained=True)
model.eval() # Set to evaluation mode
# Create a dummy input tensor (batch size=1, 3 color channels, 224x224 image)
dummy_input = torch.randn(1, 3, 224, 224)
model_file_path = Path("mobilenetv3.onnx")
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
str(model_file_path),
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
opset_version=11
)
onnx_model = onnx.load(model_file_path)
# ONNX simplification step
simplified_model, check = simplify(onnx_model)
if not check:
print("ONNX Model simplification failed, the model could not be validated.")
else:
onnx.save(simplified_model, model_file_path)
if model_file_path.exists():
print(f"File: {model_file_path.name}")
print(f"Size: {model_file_path.stat().st_size} bytes")
print(f"Last Modified: {model_file_path.stat().st_mtime}") # Unix timestamp
print(f"Absolute Path: {model_file_path.resolve()}")
else:
print(f"File {model_file_path} not found!")
import onnx from onnxsim import simplify from pathlib import Path import timm import torch # Load MobileNetV3 model (large variant) model = timm.create_model("mobilenetv3_large_100", pretrained=True) model.eval() # Set to evaluation mode # Create a dummy input tensor (batch size=1, 3 color channels, 224x224 image) dummy_input = torch.randn(1, 3, 224, 224) model_file_path = Path("mobilenetv3.onnx") # Export to ONNX torch.onnx.export( model, dummy_input, str(model_file_path), input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, opset_version=11 ) onnx_model = onnx.load(model_file_path) # ONNX simplification step simplified_model, check = simplify(onnx_model) if not check: print("ONNX Model simplification failed, the model could not be validated.") else: onnx.save(simplified_model, model_file_path) if model_file_path.exists(): print(f"File: {model_file_path.name}") print(f"Size: {model_file_path.stat().st_size} bytes") print(f"Last Modified: {model_file_path.stat().st_mtime}") # Unix timestamp print(f"Absolute Path: {model_file_path.resolve()}") else: print(f"File {model_file_path} not found!")
import onnx
from onnxsim import simplify
from pathlib import Path
import timm
import torch

# Load MobileNetV3 model (large variant)
model = timm.create_model("mobilenetv3_large_100", pretrained=True)
model.eval()  # Set to evaluation mode

# Create a dummy input tensor (batch size=1, 3 color channels, 224x224 image)
dummy_input = torch.randn(1, 3, 224, 224)

model_file_path = Path("mobilenetv3.onnx")
# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    str(model_file_path),
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=11
)

onnx_model = onnx.load(model_file_path)

# ONNX simplification step
simplified_model, check = simplify(onnx_model)
if not check:
    print("ONNX Model simplification failed, the model could not be validated.")
else:
    onnx.save(simplified_model, model_file_path)

if model_file_path.exists():
    print(f"File: {model_file_path.name}")
    print(f"Size: {model_file_path.stat().st_size} bytes")
    print(f"Last Modified: {model_file_path.stat().st_mtime}")  # Unix timestamp
    print(f"Absolute Path: {model_file_path.resolve()}")
else:
    print(f"File {model_file_path} not found!")

As an initial step in the quantization process, it is beneficial to simplify the model. This involves removing redundant nodes from the ONNX graph, which makes the subsequent conversion process smoother and more efficient. For more information about ONNX simplification, refer to the ONNX Simplifier documentation.

Pulling the Docker Image

To ensure the stable environment for the conversion from ONNX to TensorFlow, use the official Docker container

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
docker pull docker.io/pinto0309/onnx2tf:1.26.7
docker pull docker.io/pinto0309/onnx2tf:1.26.7
docker pull docker.io/pinto0309/onnx2tf:1.26.7

Model Conversion

The key steps in converting an ONNX model to a TensorFlow model are as follows:

ONNX to TensorFlow Conversion:
The model is first converted from ONNX format to TensorFlow using a Docker container. This ensures compatibility with TensorFlow-based workflows and tools.

Post-Training Quantization:
The TensorFlow model undergoes quantization to reduce its size and enhance performance. It is then converted to TensorFlow Lite (TFLite) format, with quantization applied using a dynamic quantization. This optimization improves the model’s efficiency for deployment on edge devices. Depending on hardware compatibility.

Model conversion from ONNX format to TensorFlow models

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) pinto0309/onnx2tf:1.26.7 onnx2tf -i /workspace/mobilenetv3.onnx -o /workspace/model_tf
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) pinto0309/onnx2tf:1.26.7 onnx2tf -i /workspace/mobilenetv3.onnx -o /workspace/model_tf
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g)  pinto0309/onnx2tf:1.26.7 onnx2tf -i /workspace/mobilenetv3.onnx -o /workspace/model_tf

Converting TensorFlow Model to Quantized TensorFlow Lite

This code should be saved as my_conversion_script.py, which allows it to be run from the Docker container.

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import tensorflow as tf
TF_MODEL_FOLDER_LOCAL="/workspace/model_tf"
converter = tf.lite.TFLiteConverter.from_saved_model(TF_MODEL_FOLDER_LOCAL)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Convert the model
tflite_quant_model = converter.convert()
# Save the quantized model
with open("/workspace/model_quant.tflite", "wb") as f:
f.write(tflite_quant_model)
print("✅ Quantized model saved as 'model_quant.tflite'")
import tensorflow as tf TF_MODEL_FOLDER_LOCAL="/workspace/model_tf" converter = tf.lite.TFLiteConverter.from_saved_model(TF_MODEL_FOLDER_LOCAL) converter.optimizations = [tf.lite.Optimize.DEFAULT] # Convert the model tflite_quant_model = converter.convert() # Save the quantized model with open("/workspace/model_quant.tflite", "wb") as f: f.write(tflite_quant_model) print("✅ Quantized model saved as 'model_quant.tflite'")
import tensorflow as tf

TF_MODEL_FOLDER_LOCAL="/workspace/model_tf"

converter = tf.lite.TFLiteConverter.from_saved_model(TF_MODEL_FOLDER_LOCAL)
converter.optimizations = [tf.lite.Optimize.DEFAULT]


# Convert the model
tflite_quant_model = converter.convert()

# Save the quantized model
with open("/workspace/model_quant.tflite", "wb") as f:
    f.write(tflite_quant_model)

print("✅ Quantized model saved as 'model_quant.tflite'")

However, dynamically quantized models may not be compatible with specialized INT-only hardware, which performs computations using only integer (INT) arithmetic, typically 8-bit integers (INT8) instead of floating-point numbers. This is because activations remain in floating-point during execution. The dynamic quantization is applied post-training that means it does not need for a dataset.

In this cookbook, we focus on dynamic quantization. Below, the full quantization procedure is shown. This type of quantization quantizes both weights and activations. It fully leverages hardware acceleration (e.g., Edge TPUs, DSPs, NNAPI).

To fully quantize the model, a calibration dataset is required. The calibration dataset should be a representative sample of the data that the model will process at runtime. Conversion with calibration is time-consuming and may not work for some models, but it helps align the outputs with the inputs

Create calibration generator for full quantization

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
CALIBRATION_DATASET = Path("/workspace/calibration-dataset")
def load_images_from_folder(image_size=(224, 224)):
normalization_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
normalization_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
for image in tqdm(list(CALIBRATION_DATASET.rglob('*.[pj][np][ge]*')), desc="Calibrating"):
img = Image.open(str(image)).convert("RGB")
img = img.resize(image_size) # Resize to expected model input size
# data normalization
raw_image = np.array(img, dtype=np.float32) / 255.0
calibration_image = (raw_image - normalization_mean) / normalization_std
# Ensure correct shape: (1, height, width, channels)
calibration_image = np.expand_dims(calibration_image, axis=0)
yield [calibration_image] # Must return float32
CALIBRATION_DATASET = Path("/workspace/calibration-dataset") def load_images_from_folder(image_size=(224, 224)): normalization_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) normalization_std = np.array([0.229, 0.224, 0.225], dtype=np.float32) for image in tqdm(list(CALIBRATION_DATASET.rglob('*.[pj][np][ge]*')), desc="Calibrating"): img = Image.open(str(image)).convert("RGB") img = img.resize(image_size) # Resize to expected model input size # data normalization raw_image = np.array(img, dtype=np.float32) / 255.0 calibration_image = (raw_image - normalization_mean) / normalization_std # Ensure correct shape: (1, height, width, channels) calibration_image = np.expand_dims(calibration_image, axis=0) yield [calibration_image] # Must return float32
CALIBRATION_DATASET = Path("/workspace/calibration-dataset")

def load_images_from_folder(image_size=(224, 224)):
    normalization_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) 
    normalization_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    for image in tqdm(list(CALIBRATION_DATASET.rglob('*.[pj][np][ge]*')), desc="Calibrating"):
        img = Image.open(str(image)).convert("RGB")
        img = img.resize(image_size)  # Resize to expected model input size

        # data normalization
        raw_image = np.array(img, dtype=np.float32) / 255.0
        calibration_image = (raw_image - normalization_mean) / normalization_std

        # Ensure correct shape: (1, height, width, channels)
        calibration_image = np.expand_dims(calibration_image, axis=0)
        yield [calibration_image]  # Must return float32

The next step is to set the properties of TFLiteConverter object.

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = load_images_from_folder
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.representative_dataset = load_images_from_folder
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = load_images_from_folder

Execute conversion with Post Training Quantization

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) -e PYTHONUSERBASE=/workspace/.local pinto0309/onnx2tf:1.26.7 pip install tqdm pillow
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) -e PYTHONUSERBASE=/workspace/.local pinto0309/onnx2tf:1.26.7 python /workspace/my_conversion_script.py
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) -e PYTHONUSERBASE=/workspace/.local pinto0309/onnx2tf:1.26.7 pip install tqdm pillow docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) -e PYTHONUSERBASE=/workspace/.local pinto0309/onnx2tf:1.26.7 python /workspace/my_conversion_script.py
docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) -e PYTHONUSERBASE=/workspace/.local pinto0309/onnx2tf:1.26.7 pip install tqdm pillow

docker run --rm -v $(pwd):/workspace --user $(id -u):$(id -g) -e PYTHONUSERBASE=/workspace/.local pinto0309/onnx2tf:1.26.7 python /workspace/my_conversion_script.py

The output is a model_quant.tflite file, which is ready for deployment and inference on an edge device. The model size has been reduced by a factor of 4 compared to the original model, enabling more efficient storage and faster execution.

After quantization, the model size is:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
lukaszkarlowski@ST60573:~/projects/cookbooks/notebooks$ du -sh *.tflite
5.6M model_quant.tflite
lukaszkarlowski@ST60573:~/projects/cookbooks/notebooks$ du -sh *.tflite 5.6M model_quant.tflite
lukaszkarlowski@ST60573:~/projects/cookbooks/notebooks$ du -sh *.tflite
5.6M    model_quant.tflite

Deploying the Model to Android

To integrate the TensorFlow Lite model into an Android app, add the required dependency to build.gradle.kts:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// build.gradle.kts
...
dependencies {
implementation("org.tensorflow:tensorflow-lite:2.11.0")
}
...
// build.gradle.kts ... dependencies { implementation("org.tensorflow:tensorflow-lite:2.11.0") } ...
// build.gradle.kts
...
dependencies {
    implementation("org.tensorflow:tensorflow-lite:2.11.0")
}
...

Initialize and run the model:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
Interpreter.Options options = new Interpreter.Options();
interpreter = new Interpreter(loadModelFile(context, modelPath), options);
interpreter.run(inputBuffer, output);
Interpreter.Options options = new Interpreter.Options(); interpreter = new Interpreter(loadModelFile(context, modelPath), options); interpreter.run(inputBuffer, output);
Interpreter.Options options = new Interpreter.Options();
interpreter = new Interpreter(loadModelFile(context, modelPath), options);

interpreter.run(inputBuffer, output);

Demonstration Application

The demonstration application can be accessed at the following link:
Demo Application.

This application allows you to:

  • Load an image and classify it using either a non-quantized or quantized model (depending on the selected option).
  • Crop the selected image to a square.
  • Resize the image to 224×224 pixels.
  • Pass the processed image as normalized input to the model.

The classification results are displayed in a list view, showing the top 3 possible classifications along with their respective probabilities. These probabilities are calculated by applying softmax function the model’s output.

Application with Model Inferences on Galaxy S24

This video demonstrates how the AI model performs inferences on the Samsung Galaxy S24.

Timings

To demonstrate the timing benefits, the conversion process described in this cookbook has also been applied to another timm-pretrained model: Adversarial Inception V3. It is another image classification model, but it is larger and more time-consuming while providing more reliable output.

The inference times were measured on an Android application. The table presents the average single inference time.

ModelNon-Quantized (ms)Quantized (ms)
MobileNetV39.75.6
Adversarial Inception v310.72.1

Model Evaluation

For MobileNetV3 the quantized version shows an improvement in performance, reducing the inference time by 4 ms.
For Adversarial Inception v3 the quantized version significantly improves performance.

The difference of time reduction is due to the size of the Adversarial Inception v3 model which is larger than MobileNetV3.

Evaluation was performed using ImageNet1K from Kaggle, running on a local PC with CPU-based execution.

Model TypeTop-1 AccuracyTop-5 Accuracy
ONNX Model67.13%87.85%
TensorFlow Lite (Non-Quantized, Weights in Float32)67.13%87.85%
TensorFlow Lite (Quantized, Weights in Int8)67.11%87.82%

The accuracy of ONNX and TensorFlow float32 models remains identical because the model architecture, numerical precision, and computational graph are preserved. However, quantizing the TFLite model (e.g., converting it to INT8) may result in accuracy degradation due to numerical approximations while providing benefits such as reduced size, lower latency, and faster execution.

Next Steps

Model quantization development:

  • present how quantization-aware training influences model performance.

Android application development:

  • modify the application to use the TensorFlow Lite GPU delegate for running model inference on the device’s GPU,
  • implement real-time image classification using the device’s camera.

Authors

Posted in