Export to ONNX and inference using TensorRT
Note:
Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.
Transformer Engine (TE) is a library designed primarily for training DL models in low precision. It is not specifically optimized for inference tasks, so other dedicated solutions should be used. NVIDIA provides several inference tools that enhance the entire inference pipeline. Two prominent NVIDIA inference SDKs are TensorRT and TensorRT-LLM.
This tutorial illustrates how one can export a PyTorch model to ONNX format and subsequently perform inference with TensorRT. This approach is particularly beneficial if model integrates Transformer Engine layers within more complex architectures. It’s important to highlight that for Transformer-based large language models (LLMs), TensorRT-LLM could provide a more optimized inference experience. However, the ONNX-to-TensorRT approach described here may be more suitable for other models, such as diffusion-based architectures or vision transformers.
Creating models with TE
Let’s begin by defining a simple model composed of layers both from Transformer Engine and standard PyTorch:
[1]:
import torch
import torch.nn as nn
import transformer_engine as te
import warnings
warnings.filterwarnings("ignore")
# batch size, sequence length, hidden dimension
B, S, H = 256, 512, 256
class Model(torch.nn.Module):
def __init__(self, hidden_dim=H, num_non_te_layers=16, num_te_layers=4, num_te_heads=4):
super(Model, self).__init__()
self.non_te_part = nn.Sequential(
*[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(num_non_te_layers)]
)
self.te_part = nn.Sequential(
*[te.pytorch.TransformerLayer(hidden_dim, hidden_dim, num_te_heads) for _ in range(num_te_layers)]
)
def forward(self, x):
x = self.non_te_part(x)
return self.te_part(x)
Let’s run some simple inference benchmarks:
[2]:
from utils import _measure_time
model = Model().eval().cuda()
inps = (torch.randn([S, B, H], device="cuda"),)
def _inference(fp8_enabled):
with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):
model(*inps)
te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))
te_fp8_time = _measure_time(lambda: _inference(fp8_enabled=True))
print(f"Average inference time FP32: {te_fp32_time} ms")
print(f"Average inference time FP8: {te_fp8_time} ms")
Average inference time FP32: 0.065 ms
Average inference time FP8: 0.062 ms
Exporting the TE Model to ONNX Format
PyTorch developed a new ONNX exporter built on TorchDynamo and plans to phase out the existing TorchScript exporter. As this feature is currently in active development, we recommend running this process with the latest PyTorch version.
To export a Transformer Engine model into ONNX format, follow these steps:
Conduct warm-up run within autocast using the recipe intended for export.
Encapsulate your export-related code within
te.onnx_export
, ensuring warm-up runs remain outside this wrapper.Use the PyTorch Dynamo ONNX exporter by invoking:
torch.onnx.export(..., dynamo=True)
.
[3]:
from transformer_engine.pytorch.export import te_translation_table
def export(model, fname, inputs, fp8=True):
with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):
# ! IMPORTANT !
# Transformer Engine models must have warm-up run
# before export. FP8 recipe during warm-up should
# match the recipe used during export.
model(*inputs)
# Only dynamo=True mode is supported;
# dynamo=False is deprecated and unsupported.
#
# te_translation_table contains necessary ONNX translations
# for FP8 quantize/dequantize operators.
print(f"Exporting {fname}")
with te.pytorch.onnx_export(enabled=True):
torch.onnx.export(
model,
inputs,
fname,
output_names=["output"],
dynamo=True,
custom_translation_table=te_translation_table
)
# Example usage:
export(model, "model_fp8.onnx", inps, fp8=True)
export(model, "model_fp32.onnx", inps, fp8=False)
Exporting model_fp8.onnx
[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 12 of general pattern rewrite rules.
Exporting model_fp32.onnx
[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 12 of general pattern rewrite rules.
Inference with TensorRT
TensorRT is a high-performance deep learning inference optimizer and runtime developed by NVIDIA. It enables optimized deployment of neural network models by maximizing inference throughput and reducing latency on NVIDIA GPUs. TensorRT performs various optimization techniques, including layer fusion, precision calibration, kernel tuning, and memory optimization. For detailed information and documentation, refer to the official TensorRT documentation.
When using TensorRT, ONNX model must first be compiled into a TensorRT engine. This compilation step involves converting the ONNX model into an optimized representation tailored specifically to the target GPU platform. The compiled engine file can then be loaded into applications for rapid and efficient inference execution.
[4]:
!trtexec --onnx=model_fp32.onnx --saveEngine=model_fp32.engine > output_fp32.log 2>&1
!trtexec --onnx=model_fp8.onnx --saveEngine=model_fp8.engine > output_fp8.log 2>&1
Let’s run the benchmarks for inference:
[5]:
import tensorrt as trt
# Output tensor is allocated - TRT needs static memory address.
output_tensor = torch.empty_like(model(*inps))
# Loads TRT engine from file.
def load_engine(engine_file_path):
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(engine_file_path, "rb") as f:
engine_data = f.read()
engine = runtime.deserialize_cuda_engine(engine_data)
return engine
def benchmark_inference(model_name):
engine = load_engine(model_name)
context = engine.create_execution_context()
stream = torch.cuda.Stream()
# TRT need static input and output addresses.
# Here they are set.
for i in range(len(inps)):
context.set_tensor_address(engine.get_tensor_name(i), inps[i].data_ptr())
context.set_tensor_address("output", output_tensor.data_ptr())
def _inference():
# The data is loaded from static input addresses
# and output is written to static output address.
context.execute_async_v3(stream_handle=stream.cuda_stream)
stream.synchronize()
return _measure_time(_inference)
trt_fp8_time = benchmark_inference("model_fp8.engine")
trt_fp32_time = benchmark_inference("model_fp32.engine")
print(f"Average inference time without TRT (FP32 for all layers): {te_fp32_time} ms")
print(f"Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): {te_fp8_time} ms, speedup = {te_fp32_time/te_fp8_time:.2f}x")
print(f"Average inference time with TRT (FP32 for all layers): {trt_fp32_time:.4f} ms, speedup = {te_fp32_time/trt_fp32_time:.2f}x")
print(f"Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): {trt_fp8_time:.4f} ms, speedup = {te_fp32_time/trt_fp8_time:.2f}x")
Average inference time without TRT (FP32 for all layers): 0.065 ms
Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): 0.062 ms, speedup = 1.05x
Average inference time with TRT (FP32 for all layers): 0.0500 ms, speedup = 1.30x
Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): 0.0470 ms, speedup = 1.38x
Run |
Inference Time (ms) |
Speedup |
---|---|---|
PyTorch + TE |
0.065 |
1.00x |
PyTorch + TE (FP8 for TE layers) |
0.062 |
1.05x |
TRT |
0.0500 |
1.30x |
TRT (FP8 for TE layers) |
0.047 |
1.38x |
Note that this example highlights how TensorRT can speed up models composed of both TE and non-TE layers. If a larger part of the model’s layers were implemented with TE, the benefits of using FP8 for inference could be greater.
We clearly observe performance improvements when using FP8 and the TensorRT inference engine. These improvements may become even more significant with more complex models, as TensorRT could potentially identify additional optimization opportunities.
Appendix: Low Precision Operators in ONNX and TensorRT
The ONNX standard does not currently support all precision types provided by the Transformer Engine. All available ONNX operators are listed on this website. Consequently, TensorRT and the Transformer Engine utilize certain specialized low-precision operators, detailed below.
TRT_FP8_QUANTIZE
Name: TRT_FP8_QUANTIZE
Domain: trt
Inputs:
x
: float32 tensorscale
: float32 scalar
Outputs:
y
: int8 tensor
Produces an int8 tensor that represents the binary encoding of FP8 values.
TRT_FP8_DEQUANTIZE
Name: TRT_FP8_DEQUANTIZE
Domain: trt
Inputs:
x
: int8 tensorscale
: float32 scalar
Outputs:
y
: float32 tensor
Converts FP8-encoded int8 tensor data back into float32 precision.
Note:
Since standard ONNX operators do not support certain input and output precision types, a workaround is employed: tensors are dequantized to higher precision (float32) before input into these operators or quantized to lower precision after processing. TensorRT recognizes such quantize-dequantize patterns and replaces them with optimized operations. More details are available in this section of the TensorRT documentation.