From 3d7ff13d2c6213bf4978f4e4a115904e10f491c8 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Wed, 13 Feb 2019 18:11:40 -0600 Subject: [PATCH] change constant creation method. --- src/TensorFlowNET.Core/Framework/tf.ops.cs | 12 ++++ .../Operations/Losses/losses_impl.py.cs | 10 +++ .../Operations/array_ops.py.cs | 5 +- .../Sessions/BaseSession.cs | 40 ++++++----- src/TensorFlowNET.Core/Tensors/constant_op.cs | 2 +- src/TensorFlowNET.Core/Tensors/tf.constant.cs | 20 +++++- src/TensorFlowNET.Core/Train/Saving/Saver.cs | 35 +++++++++- .../Train/Saving/checkpoint_management.py.cs | 19 ++++++ .../Train/Saving/saver.py.cs | 25 ++++++- test/TensorFlowNET.Examples/MetaGraph.cs | 29 ++++++++ .../TensorFlowNET.Examples.csproj | 4 -- .../python/meta_graph.py | 67 +++++++++++++++++++ 12 files changed, 240 insertions(+), 28 deletions(-) create mode 100644 src/TensorFlowNET.Core/Framework/tf.ops.cs create mode 100644 src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs create mode 100644 test/TensorFlowNET.Examples/MetaGraph.cs create mode 100644 test/TensorFlowNET.Examples/python/meta_graph.py diff --git a/src/TensorFlowNET.Core/Framework/tf.ops.cs b/src/TensorFlowNET.Core/Framework/tf.ops.cs new file mode 100644 index 00000000..a38caee4 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/tf.ops.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static object get_collection(string key, string scope = "") => get_default_graph() + .get_collection(key, scope: scope); + } +} diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs new file mode 100644 index 00000000..e057ba01 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Losses +{ + class losses_impl + { + } +} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index f68d7747..38ca2799 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -82,9 +82,9 @@ namespace Tensorflow return shape_internal(input, name, optimize: true, out_type: out_type); } - public static Tensor size(Tensor input, string name = "", TF_DataType out_type = TF_DataType.TF_INT32) + public static Tensor size(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) { - return size_internal(input, name, optimize: true, out_type: out_type); + return size_internal(input, name, optimize: optimize, out_type: out_type); } private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) @@ -132,6 +132,7 @@ namespace Tensorflow else { // result = gen_array_ops.shape(); + throw new NotImplementedException("array_ops.size_internal"); } return null; diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 656ab344..9748e33f 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -46,28 +46,36 @@ namespace Tensorflow var feed_dict_tensor = new Dictionary(); var feed_map = new Dictionary(); + Func> feed_fn = (item) => + { + return new (object, object)[] { (item.Key, item.Value) }; + }; + // Validate and process feed_dict. if (feed_dict != null) { - foreach(var subfeed in feed_dict) + foreach (var feed in feed_dict) { - var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); - var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); - switch(subfeed.Value) + foreach (var (subfeed, subfeed_val) in feed_fn(feed)) { - case float floatVal: - feed_dict_tensor[subfeed_t] = (NDArray)floatVal; - break; - case int intVal: - feed_dict_tensor[subfeed_t] = (NDArray)intVal; - break; - case string str: - feed_dict_tensor[subfeed_t] = (NDArray)str; - break; - default: - throw new NotImplementedException("_run subfeed"); + var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); + var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); + switch (subfeed_val) + { + case float floatVal: + feed_dict_tensor[subfeed_t] = (NDArray)floatVal; + break; + case int intVal: + feed_dict_tensor[subfeed_t] = (NDArray)intVal; + break; + case string str: + feed_dict_tensor[subfeed_t] = (NDArray)str; + break; + default: + throw new NotImplementedException("_run subfeed"); + } + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); } - feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index f65d71c4..c4af698e 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -24,7 +24,7 @@ namespace Tensorflow return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); } - private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast) + public static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast) { if (tf.context.executing_eagerly()) { diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 93bbc459..b3896f20 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -7,8 +7,26 @@ namespace Tensorflow { public static partial class tf { - public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name); + // public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name); + + public static Tensor constant(object value, + TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, + string name = "Const", + bool verify_shape = false) => constant_op._constant_impl(value, + dtype, + shape, + name, + verify_shape: verify_shape, + allow_broadcast: false); public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name); + + public static Tensor size(Tensor input, + string name = "", + TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, + name, + optimize: true, + out_type: out_type); } } diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 5d436e00..4cdc85a5 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -55,6 +55,7 @@ namespace Tensorflow _keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours; _name = name; _restore_sequentially = restore_sequentially; + _saver_def = saver_def; _builder = builder; _is_built = false; _allow_empty = allow_empty; @@ -122,7 +123,7 @@ namespace Tensorflow } else if (_saver_def != null && !string.IsNullOrEmpty(_name)) { - throw new NotImplementedException(""); + throw new NotImplementedException("Saver._build"); } _check_saver_def(); @@ -200,6 +201,38 @@ namespace Tensorflow return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope); } + /// + /// Restores previously saved variables. + /// + /// This method runs the ops added by the constructor for restoring variables. + /// It requires a session in which the graph was launched. The variables to + /// restore do not have to have been initialized, as restoring is itself a way + /// to initialize variables. + /// + /// A `Session` to use to restore the parameters. None in eager mode. + /// Path where parameters were previously saved. + public void restore(Session sess, string save_path) + { + if (_is_empty) + return; + + if (string.IsNullOrEmpty(save_path)) + throw new ValueError("Can't load save_path when it is None."); + + if (!checkpoint_management.checkpoint_exists(save_path)) + throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}"); + + Console.WriteLine($"Restoring parameters from {save_path}"); + + if (tf.context.executing_eagerly()) + ; + else + sess.run(_saver_def.RestoreOpName, new FeedItem[] + { + new FeedItem(_saver_def.FilenameTensorName, save_path) + }); + } + /// /// Writes `MetaGraphDef` to save_path/filename. /// diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs index f301a027..4bd224da 100644 --- a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using static Tensorflow.SaverDef.Types; namespace Tensorflow { @@ -105,5 +106,23 @@ namespace Tensorflow string suffixed_filename = basename + "." + meta_graph_suffix; return suffixed_filename; } + + public static bool checkpoint_exists(string checkpoint_prefix) + { + string pathname = _prefix_to_checkpoint_path(checkpoint_prefix, CheckpointFormatVersion.V2); + if (File.Exists(pathname)) + return true; + else if (File.Exists(checkpoint_prefix)) + return true; + else + return false; + } + + private static string _prefix_to_checkpoint_path(string prefix, CheckpointFormatVersion format_version) + { + if (format_version == CheckpointFormatVersion.V2) + return prefix + ".index"; + return prefix; + } } } diff --git a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs index bb22702f..d5d1ff47 100644 --- a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -13,25 +14,43 @@ namespace Tensorflow { var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); - var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements( + var meta = meta_graph.import_scoped_meta_graph_with_return_elements( meta_graph_def, clear_devices: clear_devices, import_scope: import_scope, return_elements: return_elements); + var (imported_vars, imported_return_elements) = meta; + var saver = _create_saver_from_imported_meta_graph( meta_graph_def, import_scope, imported_vars); return (saver, null); } + /// + /// Return a saver for restoring variable values to an imported MetaGraph. + /// + /// + /// + /// + /// public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def, string import_scope, - (Dictionary, ITensorOrOperation[]) imported_vars) + Dictionary imported_vars) { if(meta_graph_def.SaverDef != null) { - throw new NotImplementedException("_create_saver_from_imported_meta_graph"); + // Infer the scope that is prepended by `import_scoped_meta_graph`. + string scope = import_scope; + var var_names = imported_vars.Keys.ToArray(); + if(var_names.Length > 0) + { + var sample_key = var_names[0]; + var sample_var = imported_vars[sample_key]; + scope = string.Join("", sample_var.name.Skip(sample_key.Length)); + } + return new Saver(saver_def: meta_graph_def.SaverDef, name: scope); } else { diff --git a/test/TensorFlowNET.Examples/MetaGraph.cs b/test/TensorFlowNET.Examples/MetaGraph.cs new file mode 100644 index 00000000..d257e712 --- /dev/null +++ b/test/TensorFlowNET.Examples/MetaGraph.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples +{ + public class MetaGraph : Python, IExample + { + public void Run() + { + ImportMetaGraph("my-save-dir/"); + } + + private void ImportMetaGraph(string dir) + { + with(tf.Session(), sess => + { + var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); + new_saver.restore(sess, dir + "my-model-10000"); + var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); + var batch_size = tf.size(labels); + var logits = (tf.get_collection("logits") as List)[0]; + var loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, + logits = logits); + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index ce9ebf07..5ed00751 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -14,8 +14,4 @@ - - - - diff --git a/test/TensorFlowNET.Examples/python/meta_graph.py b/test/TensorFlowNET.Examples/python/meta_graph.py new file mode 100644 index 00000000..cb426091 --- /dev/null +++ b/test/TensorFlowNET.Examples/python/meta_graph.py @@ -0,0 +1,67 @@ + +import tensorflow as tf +import math + +# Creates an inference graph. +# Hidden 1 +images = tf.constant(1.2, tf.float32, shape=[100, 28]) +with tf.name_scope("hidden1"): + weights = tf.Variable( + tf.truncated_normal([28, 128], + stddev=1.0 / math.sqrt(float(28))), + name="weights") + biases = tf.Variable(tf.zeros([128]), + name="biases") + hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) +# Hidden 2 +with tf.name_scope("hidden2"): + weights = tf.Variable( + tf.truncated_normal([128, 32], + stddev=1.0 / math.sqrt(float(128))), + name="weights") + biases = tf.Variable(tf.zeros([32]), + name="biases") + hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) +# Linear +with tf.name_scope("softmax_linear"): + weights = tf.Variable( + tf.truncated_normal([32, 10], + stddev=1.0 / math.sqrt(float(32))), + name="weights") + biases = tf.Variable(tf.zeros([10]), + name="biases") + logits = tf.matmul(hidden2, weights) + biases + tf.add_to_collection("logits", logits) + +init_all_op = tf.global_variables_initializer() + +with tf.Session() as sess: + # Initializes all the variables. + sess.run(init_all_op) + # Runs to logit. + sess.run(logits) + # Creates a saver. + saver0 = tf.train.Saver() + saver0.save(sess, 'my-save-dir/my-model-10000') + # Generates MetaGraphDef. + saver0.export_meta_graph('my-save-dir/my-model-10000.meta') + + +# Then later import it and extend it to a training graph. +with tf.Session() as sess: + new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') + new_saver.restore(sess, 'my-save-dir/my-model-10000') + # Addes loss and train. + labels = tf.constant(0, tf.int32, shape=[100], name="labels") + batch_size = tf.size(labels) + logits = tf.get_collection("logits")[0] + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, + logits=logits) + + tf.summary.scalar('loss', loss) + # Creates the gradient descent optimizer with the given learning rate. + optimizer = tf.train.GradientDescentOptimizer(0.01) + + # Runs train_op. + train_op = optimizer.minimize(loss) + sess.run(train_op)