字节跳动 一帧数据 https button search syntax ssis postman vue优势 android实战项目 jquery遍历对象 office2016修复 java查看版本 javase教程 搭建java开发环境 java方法的重载 java方法 java运算 java的for循环 java定义接口 java获取当前线程 java获取 java文件输入输出 js四舍五入 js删除数组指定元素 phpqrcode labview宝典 选项卡 手机知识 mathcad15 橄榄山快模 幽灵行动多少钱 ps反选 ps平面广告设计教程 保留两位小数的函数 视频抠图 网易云听歌识曲电脑版 js绑定事件的方法 磁盘阵列教程 linux解压rar
当前位置: 首页 > 学习教程  > 编程语言

Tensorflow学习笔记十三——模型持久化

2021/1/28 23:35:08 文章标签:

13.1 典型的模型保存方法 train.Saver类是Tensorflow1.x自己OMG提供的用于保存和还原一个神经网络模型的低阶API。 import tensorflow as tf import numpy as npatf.Variable(tf.constant([1.0,2.0],shape[2]),name"a") btf.Variable(tf.constant([3.0,4.0],shape[2]…

13.1 典型的模型保存方法

  • train.Saver类是Tensorflow1.x自己OMG提供的用于保存和还原一个神经网络模型的低阶API。
 
 import tensorflow as tf
import numpy as np

a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b

init_op=tf.initialize_all_variables()
saver=tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    

该方法会产生4个文件,其中

  • checkpoint文件是一个文本文件,保存了一个目录下所有模型文件列表。checkpoint文件会被自动更新。

  • model.ckpt.data-00000-of-00001文件保存了Tensorflow程序中每一个变量的取值

  • model.ckpt.index文件保存了每一个变量的名称,是一个string-string的table,其中tabe的key值为tensor名,value值为BundleEntryProto

  • model.ckpt.meta文件保存了计算图的结构,或者说是神经网络的结构

  • restore()函数需要在模型参数恢复前定义计算图上的所有运算,并且变量名需要与模型中存储的变量名一致,这样就可以将变量的值通过已保存的模型加载进来。

import tensorflow as tf
import numpy as np

a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b

saver=tf.train.Saver()

with tf.Session() as sess:
   saver.restore(sess,"/home/xxy/model/model/ckpt")
   print(sess.run(result))
  • import_meta_graph()直接加载已经持久化的计算图。其输入参数为一个.meta文件的路径。它返回一个Saver实例,在调用restore()函数就可以回复其参数了。
 import tensorflow as tf
meta_graph=tf.train.import_meta_graph("/home/xxy/model/model.ckpt.meta")
with tf.Session() as sess:
    meta_graph.restore(sess,"/home/xxy/model/model.ckpt")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
  • 保存和加载部分变量
 import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b

saver=tf.train.Saver([a])

with tf.Session() as sess:
    saver.restore(sess,"/home/xxy/model/model.ckpt")
    print(sess.run(a))

保存部分变量也可以通过在声明train.Saver类的同时提供一个列表的方式来指定。

  • 保存或加载时给变量重新命名。
 import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a2")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b2")
result=a+b

saver=tf.train.Saver({"a":a,"b":b})

with tf.Session() as sess:
    saver.restore(sess,"/home/xxy/model/model/ckpt")
    print(sess.run(result))
  • 滑动平均变量保存方式

writer

import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")

##滑动平均变量定义
averages_class=tf.train.ExponentialMovingAverage(0.99)
averages_op=averages_class.apply(tf.all_variables())

for variables in tf.global_variables():
    print(variables.name)    
    
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #assign()更新变量值
    sess.run(tf.assign(a,10))
    sess.run(tf.assign(b,5))
    
    sess.run(averages_op)
    saver.save(sess,"/home/xxy/model/model2.ckpt")
    print(sess.run([a,averages_class.average(a)]))
    print(sess.run([b,averages_class.average(b)]))

reader

import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")

##滑动平均变量定义
averages_class=tf.train.ExponentialMovingAverage(0.99)

saver=tf.train.Saver({"a/ExponentialMovingAverage":a,
                      "b/ExponentialMovingAverage":b})
#train.ExponentialMovingAverage 提供的variables_to_restore()函数直接生成上面代码中提供的字典,与上面的代码功能相同
'''
saver=tf.train.Saver(averages_class.variables_to_restore())
'''

with tf.Session() as sess:
    sess.restore(sess,"/home/xxy/model/model2.ckpt")
    print(sess.run([a,b]))
    print(averages_class.variables_to_restore())

13.2 模型持久化的原理
1.model.ckpt.mate文件
model.ckpt.mate文件存储的是Tensorflow程序的元图数据。就是计算图的节点信息。元图数据的存储格式为MetaGraphDef。

message MetaGraphDef{
    MetaInfoDef meta_info_def=1;
    GraphDef graph_def=2;
    SaveDef saver_def=3;
    map<string,CollectionDef> collection_def=4;
    map<string,SIgnatureDef> signature_def=5;
};
import tensorflow as tf
a=tf.Variable(constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(constant([3.0,4.0],shape=[2]),name="b")
result=a+b
saver=tf.train.Saver()
asver.export_meta_graph("/home/xxy/model_ckpt_meta_json",as_text=True)
  • meta_info_def属性
 message MetaInfoDef{
    string meta_graph_version=1;
    OpList stripped_op_list = 2;
    google.protobuf.Any_any_info=3;
    repeated string tags=4;
};

其中stripped_op_list记录了计算图中用到的所有运算方法信息。记录了OpDef型的op属性

attr{
    name:"T"
    type:"type"
    allow_values{
        list{
        type:DT_HALF
        type:DT_FLOAT
        type:DT_DOUBLE
        type:DT_UINT8
        type:DT_INT8
        type:DT_INT16
        type:DT_INT32
        type:DT_INT64
        type:DT_COMPLEX64
        tpe:DT_COMPLEX128
        type:DT_STRING
    }
    }
}
  • graph_def属性
 message GrapDef{
    repeated NodeDef node=1;
    VersionDef versions=4;
};
message NodeDef{
    string name=1;
    string op=2;
    repeated string input=3;
    string device =4;
    map<string,AttrValue> attr =5;
};
  • saver_def属性
message SaveDef{
    string filename_tensor_name =1;
    string save_tensor_name = 2;
    string restore_op_name =3;
    int32 max_to_keep=4;
    bool shared =5;
    float keep_checkpoint_every_n_hours=6;
    
    enum CheckpointFormatVersion{
    LEGACY=0;
    V1=1;
    V2=2;
}
    CheckpointFormatVersion=7;
};
save_def{
    filename_tensor_name:"save/Const:0"
    save_tensor_name:"save/control_dependency:0"
    restore_op_name:"save/restore_all"
    max_to_keep:5
    keep_checkpoint_every_n_hours:10000.0
    version:V2
};
  • collection_def属性
 message CollectionDef{
    message NodeList{
        repeated string value=1;
    }
    message BytesList{
        repeated bytes value=1;
    }
    message Int64List{
        repeated int64 value=1[packed=true];
    }
    message {
        repeated floFloatListat value =1;
    }
    message AnyList{
        repeated google.protocolbuf.Any value =1;
    }
    oneof kind{
        NodeList node_list=1;
        BytesList bytes_list=2;
        Int64List int64_list=3;
        FLoatList float_list=4;
        AnyList any_list=5;
    }   
};

2.从.index和.data文件中读取变量的值

import tensorflow as tf
reader=tf.train.NewCheckpointReader("/home/xxy/model/model.ckpt")
all_variables=reader.get_variable_to_shape_map()
print(all_variables)

for variale_name in all_variables:
    print(variale_name,"shape is:",all_variables[variale_name])
print("Value for variable a is : ",reader.get_tensor("a"))
print("Value for variable b is : ",reader.get_tensor("b"))

13.3 在Tensorflow 2.0中实现模型保存
save

import tensorflow as tf
from tensorflow.keras import layers

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

train_images=train_images.reshape(60000,748).astype('float32')/255
test_images=test_images.reshape(10000,748).astype('float32')/255

inputs=tf.keras.Input(shape=(784,),name='digits')

x=layers.Dense(500,activation='relu',name='dense_1')(inputs)

outputs=layers.Dense(10,activation='softmax',name='predictions')(x)

mlpmodel=tf.keras.Model(inputs=inputs,outputs=outputs,name='MLPModel')

model.summary()

mlpmodel.comploe(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.SGD(),metrics=['accuracy'])
mlpmodel.fix(x=train_images,y=train_labels,epochs=10,batch_size=100,validation_data=(test_images,test_labels))
mlpmodel.save("/home/xxy/model/model.h5")

read

import tensorflow as tf
import numpy as np

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

load_mlpmodel=tf.keras.models.load_model("/home/xxy/model/model.h5")
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
predictions=load_mlpmodel.predict(test_images)
for i in range(100):
    print("predict class result is : ",class_name[np.argmax(predictions[i])])
    print("crrect class result is : ", class_name[test_labels[i]])

save

import tensorflow as tf
from tensorflow.keras import layers

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

train_images=train_images.reshape(60000,748).astype('float32')/255
test_images=test_images.reshape(10000,748).astype('float32')/255

inputs=tf.keras.Input(shape=(784,),name='digits')

x=layers.Dense(500,activation='relu',name='dense_1')(inputs)

outputs=layers.Dense(10,activation='softmax',name='predictions')(x)

mlpmodel=tf.keras.Model(inputs=inputs,outputs=outputs,name='MLPModel')

model.summary()

mlpmodel.comploe(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.SGD(),metrics=['accuracy'])
mlpmodel.fix(x=train_images,y=train_labels,epochs=10,batch_size=100,validation_data=(test_images,test_labels))
tf.keras.experimental.export_saved_model(mlpmodel,"/home/xxy/model/")

read

import tensorflow as tf
import numpy as np

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

load_mlpmodel=tf.keras.experimental.load_from_saved_model("/home/xxy/model/")
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
predictions=load_mlpmodel.predict(test_images)
for i in range(100):
    print("predict class result is : ",class_name[np.argmax(predictions[i])])
    print("crrect class result is : ", class_name[test_labels[i]])
config=mlpmode.get_config()
load_mlpmodel=tf.keras.from_config(config)

save_weight

import tensorflow as tf
from tensorflow.keras import layers

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

class MLPModel(tf.keras.Model):
    def __init__(self,name=None):
        super(MLPModel,self).__init__(name=name)
        self.dense=layers.Dense(500,activation='relu',name='dense')
        self.dense_1=layers.Dense(10,activation='softmax',name='dense_1')
    def call(self,inputs):
        x=self.dense(inputs)
        return self.dense_1(x)
    
mlpmodel=MLPModel()
mlpmodel.compile(loss='sparse_categorical_crossentropy',optimzer=tf.keras.optimizers.SGD())
history=mlpmodel.fit(train_images,train_labels,batch_size=100,epochs=10)

mlpmodel.save_weights("/home/xxy/model/",save_format='tf')

read_weight

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

class MLPModel(tf.keras.Model):
    def __init__(self,name=None):
        super(MLPModel,self).__init__(name=name)
        self.dense=layers.Dense(500,activation='relu',name='dense')
        self.dense_1=layers.Dense(10,activation='softmax',name='dense_1')
    def call(self,inputs):
        x=self.dense(inputs)
        return self.dense_1(x)

new_model=MLPModel()
new_model.compile(loss='sparse_categorical_crossentropy',optimzer=tf.keras.optimizers.SGD())
new_model.load_weights('/home/xxy/model/')

class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
new_predictions=new_model.predict(test_images)
for i in range(100):
    print("predict class result is : ",class_name[np.argmax(new_predictions[i])])
    print("crrect class result is : ", class_name[test_labels[i]])

13.4 PB文件
writer

import tensorflow as tf
tensorflow/python/framework/graph_util.py
fromg tensorflow.python.framework import graph_util

a=tf.Variable(tf.constant([1.0,shape=[1]),name="a")
b=tf.Variable(tf.constant([3.0,shape=[1]),name="b")
result=a+b
init_op=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    
    graph_def=tf.get_defalut_graph().as_graph_def()
    
    output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    with tf.gfile.GFile("/home/xxy/model/model.pb","wb") as f:
        f.write(output_graph_def.SerializeToString())

reader

import tensorflow as tf
fromg tensorflow.python.platform import gfile

with tf.Session() as sess:
    with tf.gfile.FatGFile("/home/xxy/model/model.pb","rb") as f:
        graph_def=tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    result=tf.import_graph_def(graph_def,return_elements=["add:0"])
    
    print(sess.run(result))

本文链接: http://www.dtmao.cc/news_show_650250.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?