当前位置: 首页> 汽车> 维修 > 网页设计实验总结与体会_广告设计公司利润_如何进行网站推广?网站推广的基本手段有哪些_网站优化的方式有哪些

网页设计实验总结与体会_广告设计公司利润_如何进行网站推广?网站推广的基本手段有哪些_网站优化的方式有哪些

时间:2025/7/8 15:00:47来源:https://blog.csdn.net/qq_36649698/article/details/145946326 浏览次数: 1次
网页设计实验总结与体会_广告设计公司利润_如何进行网站推广?网站推广的基本手段有哪些_网站优化的方式有哪些

论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
论文下载链接:https://arxiv.org/abs/2010.11929
原论文对应源码:https://github.com/google-research/vision_transformer

我只导出了ViT-B-16模型,其他模型不确定是否能导出。

  1. 首先下载github仓库的源码,然后执行以下命令安装环境

    pip install -r vit_jax/requirements.txt
    

    我只有GPU的环境,没有TPU,并且导出onnx在cpu上也可以进行,因此使用这个命令安装没啥问题。

  2. 下载权重文件

    wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
    
  3. 使用以下代码将jax模型转换为tf模型
    注意替换pretrained_path 变量为下载的ViT-B_16.npz文件的位置。

    import jax
    import numpy as np
    from vit_jax import checkpoint
    from vit_jax import models
    from vit_jax.configs import models as models_config
    from jax.experimental import jax2tf
    import tensorflow as tf
    import tf2onnx# 1. 模型加载与参数初始化 -------------------------------------------------
    model_name = 'ViT-B_16'
    model_config = models_config.MODEL_CONFIGS[model_name]
    num_classes = 1000# 初始化模型
    model = models.VisionTransformer(num_classes=num_classes, **model_config)# 生成随机输入样本(使用float32类型)
    random_image = np.random.randn(1, 224, 224, 3).astype(np.float32)# 初始化模型参数
    variables = jax.jit(lambda: model.init(jax.random.PRNGKey(0),random_image,train=False,
    ), backend='cpu')()# 加载预训练参数
    params = checkpoint.load_pretrained(pretrained_path='imagenet21k_ViT-B_16.npz',init_params=variables['params'],model_config=model_config,
    )# 2. 创建JAX推理函数 ------------------------------------------------------
    def jax_predict(inputs):return model.apply({'params': params},inputs,train=False,mutable=False)# 3. 转换为TensorFlow模块 -------------------------------------------------
    class VitModule(tf.Module):def __init__(self, predict_fn):super().__init__()self.predict = predict_fntf_predict = tf.function(jax2tf.convert(jax_predict,enable_xla=False),  # 设置enable_xla=False很重要,否则导出onnx会遇到https://github.com/onnx/tensorflow-onnx/issues/2259  这个问题input_signature=[tf.TensorSpec(shape=(1, 224, 224, 3), dtype=tf.float32,name="input_image")  # 输入形状根据实际调整]
    )# 4. 保存为SavedModel(关键修正)-------------------------------------------
    tf_model = VitModule(tf_predict)# 显式获取具象函数
    concrete_func = tf_model.predict.get_concrete_function(tf.TensorSpec(shape=[1, 224, 224, 3], dtype=tf.float32)
    )tf.saved_model.save(tf_model,"vit_dynamic",signatures={'serving_default': concrete_func  # 直接使用具象函数}
    )
    
  4. tf模型导出为onnx
    终端执行以下命令:

    python -m tf2onnx.convert --saved-model ./vit_dynamic/ --output model.onnx --opset 15 --verbose
    

其中./vit_dynamic/对应第3个步骤中tf.saved_model.save( 这里指定的tf模型导出路径。

运行后就可以在当前路径生成onnx模型:
在这里插入图片描述

关键字:网页设计实验总结与体会_广告设计公司利润_如何进行网站推广?网站推广的基本手段有哪些_网站优化的方式有哪些

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: