论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
论文下载链接:https://arxiv.org/abs/2010.11929
原论文对应源码:https://github.com/google-research/vision_transformer
我只导出了ViT-B-16
模型,其他模型不确定是否能导出。
-
首先下载github仓库的源码,然后执行以下命令安装环境
pip install -r vit_jax/requirements.txt
我只有GPU的环境,没有TPU,并且导出onnx在cpu上也可以进行,因此使用这个命令安装没啥问题。
-
下载权重文件
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
-
使用以下代码将jax模型转换为tf模型
注意替换pretrained_path
变量为下载的ViT-B_16.npz
文件的位置。import jax import numpy as np from vit_jax import checkpoint from vit_jax import models from vit_jax.configs import models as models_config from jax.experimental import jax2tf import tensorflow as tf import tf2onnx# 1. 模型加载与参数初始化 ------------------------------------------------- model_name = 'ViT-B_16' model_config = models_config.MODEL_CONFIGS[model_name] num_classes = 1000# 初始化模型 model = models.VisionTransformer(num_classes=num_classes, **model_config)# 生成随机输入样本(使用float32类型) random_image = np.random.randn(1, 224, 224, 3).astype(np.float32)# 初始化模型参数 variables = jax.jit(lambda: model.init(jax.random.PRNGKey(0),random_image,train=False, ), backend='cpu')()# 加载预训练参数 params = checkpoint.load_pretrained(pretrained_path='imagenet21k_ViT-B_16.npz',init_params=variables['params'],model_config=model_config, )# 2. 创建JAX推理函数 ------------------------------------------------------ def jax_predict(inputs):return model.apply({'params': params},inputs,train=False,mutable=False)# 3. 转换为TensorFlow模块 ------------------------------------------------- class VitModule(tf.Module):def __init__(self, predict_fn):super().__init__()self.predict = predict_fntf_predict = tf.function(jax2tf.convert(jax_predict,enable_xla=False), # 设置enable_xla=False很重要,否则导出onnx会遇到https://github.com/onnx/tensorflow-onnx/issues/2259 这个问题input_signature=[tf.TensorSpec(shape=(1, 224, 224, 3), dtype=tf.float32,name="input_image") # 输入形状根据实际调整] )# 4. 保存为SavedModel(关键修正)------------------------------------------- tf_model = VitModule(tf_predict)# 显式获取具象函数 concrete_func = tf_model.predict.get_concrete_function(tf.TensorSpec(shape=[1, 224, 224, 3], dtype=tf.float32) )tf.saved_model.save(tf_model,"vit_dynamic",signatures={'serving_default': concrete_func # 直接使用具象函数} )
-
将
tf
模型导出为onnx
终端执行以下命令:python -m tf2onnx.convert --saved-model ./vit_dynamic/ --output model.onnx --opset 15 --verbose
其中./vit_dynamic/
对应第3个步骤中tf.saved_model.save(
这里指定的tf
模型导出路径。
运行后就可以在当前路径生成onnx模型: