Tf下graph太大保存的解决办法

Posted by Shaoguang Blog on March 6, 2019
  1. tensorflow在使用tf.train.CheckpointSaverHook时,会在训练正式开始之前出发write_graph的操作,如果定义的graph大小超过2G,就会保存失败,具体的错误示例为:Invalid arguments graph.pbtxt.tmp1246f70ae1304a66abaa4a7cab067148。
  2. 一般模型太大的原因是graph中保存了很大的常量,如外部加载的资源等。具体表现形式为在代码中有如下语句:
    (1)var = tf.Variable(tf.constant(x), trainable=False)
    (2)var = tf.Variable(x, trainable=False)
    其中x为一个确定数值的常量。

此式问题的解决方法可以为:使用tf.Variable的load函数,先初始化variable,然后调用load; 例如加载一个很大的外部变量 ``` import tensorflow as tf from gensim import models

model = models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin', binary=True)
X = model.syn0

embeddings = tf.Variable(tf.random_uniform(X.shape, minval=-0.1, maxval=0.1), trainable=False)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
embeddings.load(model.syn0, sess)
```
  1. graph保存时保存的内容为graph_def(具体逻辑可以在write_graph函数中查看),具体内容的查看方式可参考如下:
     from tensorflow.python.framework import ops  
     graph_def = ops.get_default_graph().as_graph_def(add_shapes=True) 
    
  2. 在estimator里,不能直接获取到session,可以使用tf.train.Scaffold来初始化embedding,参见https://stackoverflow.com/questions/48217599/how-to-initialize-embeddings-layer-within-estimator-api。 一个示例程序如下:

     def model_fn(features, labels, mode):
       size = 10
       initial_value = np.random.randn(size, size).astype(np.float32)
       x = tf.get_variable("x", [size, size])
    
       def init_fn(scaffold, sess):
           sess.run(x.initializer, {x.initial_value: initial_value})
       scaffold = tf.train.Scaffold(init_fn=init_fn)
    
       loss = ...
       train_op = ...
    
       return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, scaffold=scaffold)