TF中CheckPoint源码追踪
契机:由于该函数会保存检查点,由于设置默认为五个checkpoint,所以此处需要追寻第六个检查点产生时系统是如何工作的;初次之外摸清底层检查点的讨论,包括I/O的事宜等等。
- 第一步,最上层用户创建,每1000轮保存一次检查点
if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
- 追溯saver.save()
This method runs the ops added by the constructor for saving variables.
It requires a session in which the graph was launched. The variables to
save must also have been initialized.
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False):
# pylint: disable=line-too-long
"""Saves variables.
This method runs the ops added by the constructor for saving variables.
It requires a session in which the graph was launched. The variables to
save must also have been initialized.
The method returns the path prefix of the newly created checkpoint files.
This string can be passed directly to a call to `restore()`.
Args:
sess: A Session to use to save the variables.
save_path: String. Prefix of filenames created for the checkpoint.
global_step: If provided the global step number is appended to `save_path`
to create the checkpoint filenames. The optional argument can be a
`Tensor`, a `Tensor` name or an integer.
latest_filename: Optional name for the protocol buffer file that will
contains the list of most recent checkpoints. That file, kept in the
same directory as the checkpoint files, is automatically managed by the
saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
write_meta_graph: `Boolean` indicating whether or not to write the meta
graph file.
write_state: `Boolean` indicating whether or not to write the
`CheckpointStateProto`.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of save_path and with `_debug` added before
the file extension. This is only enabled when `write_meta_graph` is
`True`
Returns:
A string: path prefix used for the checkpoint files. If the saver is
sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
is the number of shards created.
If the saver is empty, returns None.
Raises:
TypeError: If `sess` is not a `Session`.
ValueError: If `latest_filename` contains path components, or if it
collides with `save_path`.
RuntimeError: If save and restore ops weren't built.
"""
# pylint: enable=line-too-long
if not self._is_built and not context.executing_eagerly():
raise RuntimeError(
"`build()` should be called before save if defer_build==True")
if latest_filename is None:
latest_filename = "checkpoint"
if self._write_version != saver_pb2.SaverDef.V2:
logging.warning("*******************************************************")
logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
logging.warning("Consider switching to the more efficient V2 format:")
logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
logging.warning("now on by default.")
logging.warning("*******************************************************")
if os.path.split(latest_filename)[0]:
raise ValueError("'latest_filename' must not contain path components")
if global_step is not None:
if not isinstance(global_step, compat.integral_types):
global_step = training_util.global_step(sess, global_step)
checkpoint_file = "%s-%d" % (save_path, global_step)
if self._pad_step_number:
# Zero-pads the step numbers, so that they are sorted when listed.
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
else:
checkpoint_file = save_path
if os.path.basename(save_path) == latest_filename and not self._sharded:
# Guard against collision between data file and checkpoint state file.
raise ValueError(
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
(latest_filename, save_path))
if (not context.executing_eagerly() and
not isinstance(sess, session.SessionInterface)):
raise TypeError("'sess' must be a Session; %s" % sess)
save_path_parent = os.path.dirname(save_path)
if not self._is_empty:
try:
if context.executing_eagerly():
self._build_eager(
checkpoint_file, build_save=True, build_restore=False)
model_checkpoint_path = self.saver_def.save_tensor_name
else:
model_checkpoint_path = sess.run(
self.saver_def.save_tensor_name,
{self.saver_def.filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
if write_state:
## 上面始终在对checkpoint的名字做修改,下句将最近的checkpoint保存下来
self._RecordLastCheckpoint(model_checkpoint_path)
## 更新"checkpoint"文件里面的值,该函数设计写入操作
checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
latest_filename=latest_filename,
save_relative_paths=self._save_relative_paths)
##
self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
if not gfile.IsDirectory(save_path_parent):
exc = ValueError(
"Parent directory of {} doesn't exist, can't save.".format(
save_path))
raise exc
if write_meta_graph:
meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
self.export_meta_graph(
meta_graph_filename,
strip_default_attrs=strip_default_attrs,
save_debug_info=save_debug_info)
if self._is_empty:
return None
else:
return model_checkpoint_path
这里代码中反复提到了context.executing_eagerly():
这里针对该内容进行总结:
TensorFlow 引入了「Eager Execution」,它是一个命令式、由运行定义的接口,一旦从 Python 被调用,其操作立即被执行。这使得入门 TensorFlow 变的更简单,也使研发更直观。
简单来说,这是用户可定义的一个机制,可选择是否开启。这里关心I/O层面的问题,所以这里不详述。
进入该函数,进行判断:
## global_step非空则进入
if global_step is not None:
## 查看源码发现都是flase,所以可以直接进入
if not isinstance(global_step, compat.integral_types):
## 获取当前步数
global_step = training_util.global_step(sess, global_step)
## 打印出来:“checkpoint路径-步数”
checkpoint_file = "%s-%d" % (save_path, global_step)
之后进入创建checkpoint部分:
if not self._is_empty:
try:
if context.executing_eagerly():
self._build_eager(
checkpoint_file, build_save=True, build_restore=False)
## Store the tensor values to the tensor_names.
model_checkpoint_path = self.saver_def.save_tensor_name
else:
model_checkpoint_path = sess.run(
self.saver_def.save_tensor_name,
{self.saver_def.filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
其中save_tensor_name
=save_tensor.numpy() if build_save else ""
# 用tf.train.Saver()创建一个Saver来管理模型中的所有变量
saver = tf.train.Saver(tf.all_variables())
在看下面函数:
checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
latest_filename=latest_filename,
save_relative_paths=self._save_relative_paths)
def update_checkpoint_state_internal(save_dir, ##model save path
model_checkpoint_path, ##checkpoint
all_model_checkpoint_paths=None, ## from old to new
latest_filename=None,
save_relative_paths=False,
all_model_checkpoint_timestamps=None,
last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
all_model_checkpoint_timestamps: Optional list of timestamps (floats,
seconds since the Epoch) indicating when the checkpoints in
`all_model_checkpoint_paths` were created.
last_preserved_timestamp: A float, indicating the number of seconds since
the Epoch when the last preserved checkpoint was written, e.g. due to a
`keep_checkpoint_every_n_hours` parameter (see
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
if save_relative_paths:
if os.path.isabs(model_checkpoint_path):
rel_model_checkpoint_path = os.path.relpath(
model_checkpoint_path, save_dir)
else:
rel_model_checkpoint_path = model_checkpoint_path
rel_all_model_checkpoint_paths = []
for p in all_model_checkpoint_paths:
if os.path.isabs(p):
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
else:
rel_all_model_checkpoint_paths.append(p)
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
else:
## generate related checkpoint
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
# Preventing potential read/write race condition by *atomically* writing to a
# file.
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
该函数比较关键(对checkpint这个文件进行写入),需要详细分析:
该函数目的为更新checkpoint文件中的内容,更新包含CheckpointState原型的检查点文件。
传入参数介绍:
-
save_dir: Directory where the model was saved.(模型存储的目录)
-
model_checkpoint_path: 检查点文件名字(/tmp/cifar10_train/model.ckpt-400)
-
all_model_checkpoint_paths: 字符串列表,即当前没有删除的所有checkpoint内容。“checkpoint”这个文件中的内容:
CP_DIY_last_checkpoints: [('/tmp/cifar10_train/model.ckpt-0', 1567741570.985191), ('/tmp/cifar10_train/model.ckpt-100', 1567741597.823174), ('/tmp/cifar10_train/model.ckpt-200', 1567741622.994202), ('/tmp/cifar10_train/model.ckpt-300', 1567741648.138035), ('/tmp/cifar10_train/model.ckpt-400', 1567741674.720897)]
-
latest_filename: checkpont这个文件的名字,默认叫做 'checkpoint'
-
save_relative_paths: 相对、绝对路径
在该函数中打印ckpt:
得到:
CP_DIY_ckpt: model_checkpoint_path: "/tmp/cifar10_train/model.ckpt-0"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-0"
CP_DIY_ckpt: model_checkpoint_path: "/tmp/cifar10_train/model.ckpt-100"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-0"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-100"
说明存储的内容包括当前的model_checkpoint_path
以及all_model_checkpoint_paths
。
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
最后一句话表明:将ckpt中的内容写入coord_checkpoint_filename
,
即将
model_checkpoint_path: "/tmp/cifar10_train/model.ckpt-999999"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-998000"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-998500"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-999000"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-999500"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-999999"
写入checkpoint这个文件中。
下面更新meta计算图checkpoint文件的名字
meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
得到xxxx.meta文件。
为了在该文件中写入,我们查看export_meta_graph
函数:
@tf_export(v1=["train.export_meta_graph"])
def export_meta_graph(filename=None,
meta_info_def=None,
graph_def=None,
saver_def=None,
collection_list=None,
as_text=False,
graph=None,
export_scope=None,
clear_devices=False,
clear_extraneous_savers=False,
strip_default_attrs=False,
save_debug_info=False,
**kwargs):
# pylint: disable=line-too-long
"""Returns `MetaGraphDef` proto.
Optionally writes it to filename.
This function exports the graph, saver, and collection objects into
`MetaGraphDef` protocol buffer with the intention of it being imported
at a later time or location to restart training, run inference, or be
a subgraph.
Args:
filename: Optional filename including the path for writing the generated
`MetaGraphDef` protocol buffer.
meta_info_def: `MetaInfoDef` protocol buffer.
graph_def: `GraphDef` protocol buffer.
saver_def: `SaverDef` protocol buffer.
collection_list: List of string keys to collect.
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
graph: The `Graph` to export. If `None`, use the default graph.
export_scope: Optional `string`. Name scope under which to extract the
subgraph. The scope name will be striped from the node definitions for
easy import later into new name scopes. If `None`, the whole graph is
exported. graph_def and export_scope cannot both be specified.
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
clear_extraneous_savers: Remove any Saver-related information from the graph
(both Save/Restore ops and SaverDefs) that are not associated with the
provided SaverDef.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before the
file extend.
**kwargs: Optional keyed arguments.
Returns:
A `MetaGraphDef` proto.
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
RuntimeError: If called with eager execution enabled.
@compatibility(eager)
Exporting/importing meta graphs is not supported unless both `graph_def` and
`graph` are provided. No graph exists when eager execution is enabled.
@end_compatibility
"""
# pylint: enable=line-too-long
if context.executing_eagerly() and not (graph_def is not None and
graph is not None):
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
"execution is enabled.")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
filename=filename,
meta_info_def=meta_info_def,
graph_def=graph_def,
saver_def=saver_def,
collection_list=collection_list,
as_text=as_text,
graph=graph,
export_scope=export_scope,
clear_devices=clear_devices,
clear_extraneous_savers=clear_extraneous_savers,
strip_default_attrs=strip_default_attrs,
save_debug_info=save_debug_info,
**kwargs)
return meta_graph_def
那么系统如何将计算图导出呢?
def export_scoped_meta_graph(filename=None,
graph_def=None,
graph=None,
export_scope=None,
as_text=False,
unbound_inputs_col_name="unbound_inputs",
clear_devices=False,
saver_def=None,
clear_extraneous_savers=False,
strip_default_attrs=False,
save_debug_info=False,
**kwargs):
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
This function exports the graph, saver, and collection objects into
`MetaGraphDef` protocol buffer with the intention of it being imported
at a later time or location to restart training, run inference, or be
a subgraph.
Args:
filename: Optional filename including the path for writing the
generated `MetaGraphDef` protocol buffer.
graph_def: `GraphDef` protocol buffer.
graph: The `Graph` to export. If `None`, use the default graph.
export_scope: Optional `string`. Name scope under which to extract
the subgraph. The scope name will be stripped from the node definitions
for easy import later into new name scopes. If `None`, the whole graph
is exported.
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
unbound_inputs_col_name: Optional `string`. If provided, a string collection
with the given name will be added to the returned `MetaGraphDef`,
containing the names of tensors that must be remapped when importing the
`MetaGraphDef`.
clear_devices: Boolean which controls whether to clear device information
before exporting the graph.
saver_def: `SaverDef` protocol buffer.
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with the provided SaverDef.
strip_default_attrs: Set to true if default valued attributes must be
removed while exporting the GraphDef.
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before the
file extension.
**kwargs: Optional keyed arguments, including meta_info_def and
collection_list.
Returns:
A `MetaGraphDef` proto and dictionary of `Variables` in the exported
name scope.
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
ValueError: When executing in Eager mode and either `graph_def` or `graph`
is undefined.
"""
if context.executing_eagerly() and not (graph_def is not None and
graph is not None):
raise ValueError("Exporting/importing meta graphs is not supported when "
"Eager Execution is enabled.")
graph = graph or ops.get_default_graph()
exclude_nodes = None
unbound_inputs = []
if export_scope or clear_extraneous_savers or clear_devices:
if graph_def:
new_graph_def = graph_pb2.GraphDef()
new_graph_def.versions.CopyFrom(graph_def.versions)
new_graph_def.library.CopyFrom(graph_def.library)
if clear_extraneous_savers:
exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
for node_def in graph_def.node:
if _should_include_node(node_def.name, export_scope, exclude_nodes):
new_node_def = _node_def(node_def, export_scope, unbound_inputs,
clear_devices=clear_devices)
new_graph_def.node.extend([new_node_def])
graph_def = new_graph_def
else:
# Only do this complicated work if we want to remove a name scope.
graph_def = graph_pb2.GraphDef()
# pylint: disable=protected-access
graph_def.versions.CopyFrom(graph.graph_def_versions)
bytesize = 0
if clear_extraneous_savers:
exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
saver_def)
for key in sorted(graph._nodes_by_id):
if _should_include_node(graph._nodes_by_id[key].name,
export_scope,
exclude_nodes):
value = graph._nodes_by_id[key]
# pylint: enable=protected-access
node_def = _node_def(value.node_def, export_scope, unbound_inputs,
clear_devices=clear_devices)
graph_def.node.extend([node_def])
if value.outputs:
assert "_output_shapes" not in graph_def.node[-1].attr
graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
output.get_shape().as_proto() for output in value.outputs])
bytesize += value.node_def.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access
# It's possible that not all the inputs are in the export_scope.
# If we would like such information included in the exported meta_graph,
# add them to a special unbound_inputs collection.
if unbound_inputs_col_name:
# Clears the unbound_inputs collections.
graph.clear_collection(unbound_inputs_col_name)
for k in unbound_inputs:
graph.add_to_collection(unbound_inputs_col_name, k)
var_list = {}
variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=export_scope)
for v in variables:
if _should_include_node(v, export_scope, exclude_nodes):
var_list[ops.strip_name_scope(v.name, export_scope)] = v
scoped_meta_graph_def = create_meta_graph_def(
graph_def=graph_def,
graph=graph,
export_scope=export_scope,
exclude_nodes=exclude_nodes,
clear_extraneous_savers=clear_extraneous_savers,
saver_def=saver_def,
strip_default_attrs=strip_default_attrs,
**kwargs)
if filename:
graph_io.write_graph(
scoped_meta_graph_def,
os.path.dirname(filename),
os.path.basename(filename),
as_text=as_text)
if save_debug_info:
name, _ = os.path.splitext(filename)
debug_filename = "{name}{ext}".format(name=name, ext=".debug")
# Gets the operation from the graph by the name. Exludes variable nodes,
# so only the nodes in the frozen models are included.
ops_to_export = []
for node in scoped_meta_graph_def.graph_def.node:
scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
ops_to_export.append(graph.get_operation_by_name(scoped_op_name))
graph_debug_info = create_graph_debug_info_def(ops_to_export)
graph_io.write_graph(
graph_debug_info,
os.path.dirname(debug_filename),
os.path.basename(debug_filename),
as_text=as_text)
return scoped_meta_graph_def, var_list
该函数有些复杂,后续继续分析。那么如何将内容写入该文件呢?
def write_graph(graph_or_graph_def, logdir, name, as_text=True):
"""Writes a graph proto to a file.
The graph is written as a text proto unless `as_text` is `False`.
python
v = tf.Variable(0, name='my_variable')
sess = tf.compat.v1.Session()
tf.io.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
or
python
v = tf.Variable(0, name='my_variable')
sess = tf.compat.v1.Session()
tf.io.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt')
Args:
graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.
logdir: Directory where to write the graph. This can refer to remote
filesystems, such as Google Cloud Storage (GCS).
name: Filename for the graph.
as_text: If `True`, writes the graph as an ASCII proto.
Returns:
The path of the output proto file.
"""
if isinstance(graph_or_graph_def, ops.Graph):
graph_def = graph_or_graph_def.as_graph_def()
else:
graph_def = graph_or_graph_def
# gcs does not have the concept of directory at the moment.
if not file_io.file_exists(logdir) and not logdir.startswith('gs:'):
file_io.recursive_create_dir(logdir)
path = os.path.join(logdir, name)
if as_text:
file_io.atomic_write_string_to_file(path,
text_format.MessageToString(
graph_def, float_format=''))
else:
file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())
return path
首先判断是否存在文件夹,如果不存在则:file_io.recursive_create_dir(logdir)
。
继续向下走,我们提取出
def write_string_to_file(filename, file_content):
"""Writes a string to a given file.
Args:
filename: string, path to a file
file_content: string, contents that need to be written to the file
Raises:
errors.OpError: If there are errors during the operation.
"""
with FileIO(filename, mode="w") as f:
f.write(file_content)
对于FileIO:
The constructor takes the following arguments:
name: name of the file
mode: one of 'r', 'w', 'a', 'r+', 'w+', 'a+'. Append 'b' for bytes mode.
Can be used as an iterator to iterate over lines in the file.
The default buffer size used for the BufferedInputStream used for reading
the file line by line is 1024 * 512 bytes.
"""
def write(self, file_content):
"""Writes file_content to the file. Appends to the end of the file."""
self._prewrite_check()
pywrap_tensorflow.AppendToFile(
compat.as_bytes(file_content), self._writable_file)
这里调用write
函数分为两步:
- 第一步
self._prewrite_check()
我们查看底层相关代码:
def _prewrite_check(self):
if not self._writable_file:
if not self._write_check_passed:
raise errors.PermissionDeniedError(None, None,
"File isn't open for writing")
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
前面是判断,是否输入的mode是合适的,然后直接进入:
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
def CreateWritableFile(filename, mode):
return _pywrap_tensorflow_internal.CreateWritableFile(filename, mode)
CreateWritableFile = _pywrap_tensorflow_internal.CreateWritableFile
这里关键的函数是:CreateWritableFile
定位到底层C:在file_io.i文件中
tensorflow::WritableFile* CreateWritableFile(
const string& filename, const string& mode, TF_Status* status) {
std::unique_ptr<tensorflow::WritableFile> file;
tensorflow::Status s;
if (mode.find("a") != std::string::npos) {
s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
} else {
s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
}
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return nullptr;
}
return file.release();
}
使用a
进入s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
Status PosixFileSystem::NewAppendableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
string translated_fname = TranslateName(fname);
Status s;
FILE* f = fopen(translated_fname.c_str(), "a");
if (f == nullptr) {
s = IOError(fname, errno);
} else {
result->reset(new PosixWritableFile(translated_fname, f));
}
return s;
}
其他的进入s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
Status PosixFileSystem::NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) {
string translated_fname = TranslateName(fname);
Status s;
FILE* f = fopen(translated_fname.c_str(), "w");
if (f == nullptr) {
s = IOError(fname, errno);
} else {
result->reset(new PosixWritableFile(translated_fname, f));
}
return s;
}
w 打开只写文件,若文件存在则文件长度清为0,即该文件内容会消失。若文件不存在则建立该文件。
w+ 打开可读写文件,若文件存在则文件长度清为零,即该文件内容会消失。若文件不存在则建立该文件。
用w不能读只能写,w+能力强一点。
均在PosixFileSystem
中。
- 第二步
pywrap_tensorflow.AppendToFile(compat.as_bytes(file_content), self._writable_file)
第一步中将self._writable_file = None
参数更新为当前所操作的文件,之后将file_content与_writable_file作为参数传入:
def AppendToFile(file_content, file):
return _pywrap_tensorflow_internal.AppendToFile(file_content, file)
AppendToFile = _pywrap_tensorflow_internal.AppendToFile
在file_io.i
中存在AppendToFile
底层源码:
void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
TF_Status* status) {
tensorflow::Status s = file->Append(file_content);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
tensorflow::Status s = file->Append(file_content);
Status Append(StringPiece data) override {
size_t r = fwrite(data.data(), 1, data.size(), file_);
if (r != data.size()) {
return IOError(filename_, errno);
}
return Status::OK();
}