Training machine learning models is one thing, using them is another. When models are trained for real-time inference, you need a runtime that can deliver results with milliseconds accuracy. If you trained your model on a PyTorch pre-trained model, your inference time will not be optimal if you simply use the torch state dictionary during inference. Converting your Pytorch model into a faster runtime like ONNX is a faster alternative. In this blog post, you will learn how to convert a Pytorch state-dictionary model into ONNX format for faster inference.
Prerequisites
All you need for this tutorial is a Python execution environment that contains torch, torchvision, and onnx. You can use Google Colab to avoid setup.
Direct Inference Using the PyTorch State Dictionary
The most important aspect of this is defining the model, and configuring its features via the fully-connected layer. You also need to define the device that handles the inference. This is either the CPU or GPU. The last sub-task is putting the model in eval mode.
resnet = models.resnet50()
resnet.fc = torch.nn.Linear(in_features=2048, out_features=102)
resnet.load_state_dict(torch.load("model.pth", weights_only=True, map_location=torch.device('cpu')))
resnet.eval()
With all that defined, you must pass a tensor input to the tuned model and preprocess the output.
tensor_image = load_and_prepare_image(args.image_path)
with torch.no_grad():
output = resnet(tensor_image)
class_distribution = torch.nn.functional.softmax(output, dim=1).squeeze() # Apply softmax for probability distribution
predicted_class = torch.argmax(class_distribution).item()
Converting to ONNX
ONNX or Open Neural Network Exchange is a format that allows for interoperability of machine learning models. So no matter the training engine or the target runtime environment, onnx allows the model to work regardless. ONNX supports JavaScript (Web), Python, Go, Java runtimes. This makes machine learning models accessible to developers who want to integrate it into their applications. ONNX is built into PyTorch and this means you can easily export your pytorch model into an ONNX format.
from torch.autograd import Variable
import torch.onnx
from torchvision import models
import torch
# single image batch size, 3 color channels, height, and width
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
resnet = models.resnet50()
resnet.fc = torch.nn.Linear(in_features=2048, out_features=102)
resnet.load_state_dict(torch.load("model.pth", weights_only=True, map_location=torch.device('cpu')))
torch.onnx.export(resnet, dummy_input, "model.onnx")
The snippet above contains similarities to the first snippet on running the model using the state dictionary alone. You still need to maintain the in_feature and out_features.
Running the ONNX model
To run the ONNX model, you need to install ONNX runtime python library. You can easily install it via pip. With the library now available to you locally, you can then carry out inference with it. Below is an example of this process:
import onnxruntime
import numpy as np
from PIL import Image
from torchvision import transforms
def load_and_prepare_image(image_path):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
img_tensor = preprocess(image)
return img_tensor.numpy()[np.newaxis, :]
def main():
session = onnxruntime.InferenceSession("model.onnx", providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
input_data = load_and_prepare_image(args.image_path)
outputs = session.run(None, {input_name: input_data})
scores = outputs[0][0]
exp_scores = np.exp(scores - np.max(scores))
probabilities = exp_scores / exp_scores.sum()
predicted_class = np.argmax(probabilities)
Wrap Up
In this article, you learnt how to run a PyTorch model directly from the exported state dictionary and also via the Onnx runtime. You can take the optimized ONNX model and run it in a web browser, a mobile app, an embedded app via C++, etc. The whole idea of ONNX is to provide open, interoperable machine learning models across platforms.