| @@ -1,6 +1,6 @@ | |||||
|  |  | ||||
| **TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
| **TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
| [](https://gitter.im/sci-sharp/community) | [](https://gitter.im/sci-sharp/community) | ||||
| [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | ||||
| @@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET | |||||
| ### Install tensorflow binary | ### Install tensorflow binary | ||||
| ### For CPU version | ### For CPU version | ||||
| PM> Install-Package SciSharp.TensorFlow.Redist | PM> Install-Package SciSharp.TensorFlow.Redist | ||||
| ### For GPU version (CUDA and cuDNN are required) | ### For GPU version (CUDA and cuDNN are required) | ||||
| PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | ||||
| ``` | ``` | ||||
| Import TF.NET. | |||||
| ```cs | |||||
| using Tensorflow; | |||||
| ``` | |||||
| Add two constants: | |||||
| ```cs | |||||
| // Create a Constant op | |||||
| var a = tf.constant(4.0f); | |||||
| var b = tf.constant(5.0f); | |||||
| var c = tf.add(a, b); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(c); | |||||
| } | |||||
| ``` | |||||
| Import TF.NET in your project. | |||||
| Feed placeholder: | |||||
| ```cs | ```cs | ||||
| // Create a placeholder op | |||||
| var a = tf.placeholder(tf.float32); | |||||
| var b = tf.placeholder(tf.float32); | |||||
| var c = tf.add(a, b); | |||||
| using(var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); | |||||
| } | |||||
| using static Tensorflow.Binding; | |||||
| ``` | ``` | ||||
| Linear Regression: | Linear Regression: | ||||
| @@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| // Start training | // Start training | ||||
| with(tf.Session(), sess => | |||||
| using(tf.Session()) | |||||
| { | { | ||||
| // Run the initializer | // Run the initializer | ||||
| sess.run(init); | sess.run(init); | ||||
| // Fit all training data | // Fit all training data | ||||
| for (int epoch = 0; epoch < training_epochs; epoch++) | for (int epoch = 0; epoch < training_epochs; epoch++) | ||||
| { | { | ||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | foreach (var (x, y) in zip<float>(train_X, train_Y)) | ||||
| sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); | |||||
| sess.run(optimizer, (X, x), (Y, y)); | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| { | { | ||||
| var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||||
| var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | ||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | |||||
| var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||||
| // Testing example | |||||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y)); | |||||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | |||||
| var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||||
| // Testing example | |||||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||||
| (X, test_X), (Y, test_Y)); | |||||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||||
| return diff < 0.01; | |||||
| }); | }); | ||||
| ``` | ``` | ||||
| @@ -14,11 +14,16 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.ops; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public graph_util_impl graph_util => new graph_util_impl(); | public graph_util_impl graph_util => new graph_util_impl(); | ||||
| public GraphKeys GraphKeys { get; } = new GraphKeys(); | |||||
| public Graph get_default_graph() | public Graph get_default_graph() | ||||
| { | { | ||||
| return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -22,7 +23,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public VariableV1[] global_variables(string scope = null) | public VariableV1[] global_variables(string scope = null) | ||||
| { | { | ||||
| return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||||
| return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||||
| .ToArray(); | .ToArray(); | ||||
| } | } | ||||
| @@ -95,7 +95,7 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| case KindOneofCase.BytesList: | case KindOneofCase.BytesList: | ||||
| //var proto_type = ops.get_collection_proto_type(key) | //var proto_type = ops.get_collection_proto_type(key) | ||||
| if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||||
| if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||||
| { | { | ||||
| foreach (var value in col.Value.BytesList.Value) | foreach (var value in col.Value.BytesList.Value) | ||||
| { | { | ||||
| @@ -146,7 +146,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| var variables = graph.get_collection<VariableV1>(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names); | scope: scope_to_prepend_to_names); | ||||
| var var_list = new Dictionary<string, VariableV1>(); | var var_list = new Dictionary<string, VariableV1>(); | ||||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | ||||
| @@ -180,7 +180,7 @@ namespace Tensorflow | |||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| var var_list = new Dictionary<string, RefVariable>(); | var var_list = new Dictionary<string, RefVariable>(); | ||||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||||
| var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||||
| if (variables != null) | if (variables != null) | ||||
| { | { | ||||
| @@ -81,7 +81,7 @@ namespace Tensorflow.Layers | |||||
| // Update global default collections. | // Update global default collections. | ||||
| _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); | |||||
| _add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| @@ -152,9 +152,9 @@ namespace Tensorflow.Operations | |||||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | ||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | // Add the subgraph defined by fn() to the graph. | ||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | var original_result = fn(); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| //TODO: port this chunck of missing code: | //TODO: port this chunck of missing code: | ||||
| /* | /* | ||||
| @@ -191,9 +191,9 @@ namespace Tensorflow.Operations | |||||
| public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | ||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | // Add the subgraph defined by fn() to the graph. | ||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | var original_result = fn(); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| switch (original_result) | switch (original_result) | ||||
| { | { | ||||
| @@ -195,7 +195,7 @@ namespace Tensorflow.Operations | |||||
| // their associated TensorArrays for calling the body. | // their associated TensorArrays for calling the body. | ||||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | ||||
| var body_result = body(packed_vars_for_body[0]); | var body_result = body(packed_vars_for_body[0]); | ||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| // Store body_result to keep track of TensorArrays returned by body | // Store body_result to keep track of TensorArrays returned by body | ||||
| var original_body_result = new[] { body_result }; | var original_body_result = new[] { body_result }; | ||||
| @@ -2,7 +2,7 @@ | |||||
| { | { | ||||
| public class Util | public class Util | ||||
| { | { | ||||
| public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES) | |||||
| public static void add_loss(Tensor loss, string loss_collection = "losses") | |||||
| { | { | ||||
| if (!string.IsNullOrEmpty(loss_collection)) | if (!string.IsNullOrEmpty(loss_collection)) | ||||
| ops.add_to_collection(loss_collection, loss); | ops.add_to_collection(loss_collection, loss); | ||||
| @@ -22,7 +22,7 @@ namespace Tensorflow | |||||
| public class LossesImpl | public class LossesImpl | ||||
| { | { | ||||
| public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | ||||
| string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||||
| string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||||
| { | { | ||||
| return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | ||||
| { | { | ||||
| @@ -101,7 +101,7 @@ namespace Tensorflow | |||||
| Tensor logits, | Tensor logits, | ||||
| float weights = 1.0f, | float weights = 1.0f, | ||||
| string scope = null, | string scope = null, | ||||
| string loss_collection= ops.GraphKeys.LOSSES, | |||||
| string loss_collection= "losses", | |||||
| string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(scope, | return tf_with(ops.name_scope(scope, | ||||
| @@ -431,8 +431,8 @@ namespace Tensorflow | |||||
| merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | ||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| return merges[0]; | return merges[0]; | ||||
| }); | }); | ||||
| @@ -479,8 +479,8 @@ namespace Tensorflow | |||||
| merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | ||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| return merges; | return merges; | ||||
| }); | }); | ||||
| @@ -596,7 +596,7 @@ namespace Tensorflow | |||||
| swap_memory: swap_memory); | swap_memory: swap_memory); | ||||
| if (loop_context.outer_context == null) | if (loop_context.outer_context == null) | ||||
| ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||||
| ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||||
| var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | ||||
| return_same_structure); | return_same_structure); | ||||
| @@ -33,11 +33,11 @@ namespace Tensorflow.Summaries | |||||
| { | { | ||||
| var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); | var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); | ||||
| var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); | var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); | ||||
| collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||||
| collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||||
| return val; | return val; | ||||
| } | } | ||||
| public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null) | |||||
| public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | |||||
| { | { | ||||
| var summary_ops = ops.get_collection(key, scope: scope); | var summary_ops = ops.get_collection(key, scope: scope); | ||||
| if (summary_ops == null) | if (summary_ops == null) | ||||
| @@ -67,7 +67,7 @@ namespace Tensorflow.Summaries | |||||
| { | { | ||||
| var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); | var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); | ||||
| var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); | var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); | ||||
| collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||||
| collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||||
| return val; | return val; | ||||
| } | } | ||||
| @@ -198,7 +198,7 @@ namespace Tensorflow | |||||
| if (!tf.context.executing_eagerly()) | if (!tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||||
| var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||||
| if (train_op != null && train_op.Contains(apply_updates)) | if (train_op != null && train_op.Contains(apply_updates)) | ||||
| train_op.Add(apply_updates); | train_op.Add(apply_updates); | ||||
| } | } | ||||
| @@ -359,7 +359,7 @@ namespace Tensorflow | |||||
| var tmp = variables.trainable_variables(); | var tmp = variables.trainable_variables(); | ||||
| var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
| var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
| switch (tmp) | switch (tmp) | ||||
| { | { | ||||
| case List<RefVariable> values: | case List<RefVariable> values: | ||||
| @@ -370,7 +370,7 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| } | } | ||||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||||
| var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | ||||
| var var_refs = processors.Select(x => x.target()).ToArray(); | var var_refs = processors.Select(x => x.target()).ToArray(); | ||||
| @@ -121,7 +121,7 @@ namespace Tensorflow | |||||
| if(collections == null) | if(collections == null) | ||||
| { | { | ||||
| collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||||
| collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES }; | |||||
| } | } | ||||
| // Store the graph key so optimizers know how to only retrieve variables from | // Store the graph key so optimizers know how to only retrieve variables from | ||||
| @@ -129,8 +129,8 @@ namespace Tensorflow | |||||
| _graph_key = ops.get_default_graph().graph_key; | _graph_key = ops.get_default_graph().graph_key; | ||||
| _trainable = trainable; | _trainable = trainable; | ||||
| if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||||
| collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
| if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||||
| collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||||
| ops.init_scope(); | ops.init_scope(); | ||||
| var values = init_from_fn ? new object[0] : new object[] { initial_value }; | var values = init_from_fn ? new object[0] : new object[] { initial_value }; | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -28,7 +29,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static object trainable_variables() | public static object trainable_variables() | ||||
| { | { | ||||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
| return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -40,11 +41,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| var all = new List<VariableV1>(); | var all = new List<VariableV1>(); | ||||
| var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| if(collection != null) | if(collection != null) | ||||
| all.AddRange(collection as List<VariableV1>); | all.AddRange(collection as List<VariableV1>); | ||||
| collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
| collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
| if (collection != null) | if (collection != null) | ||||
| all.AddRange(collection as List<VariableV1>); | all.AddRange(collection as List<VariableV1>); | ||||
| @@ -64,7 +65,7 @@ namespace Tensorflow | |||||
| /// <returns>A list of `Variable` objects.</returns> | /// <returns>A list of `Variable` objects.</returns> | ||||
| public static List<VariableV1> global_variables(string scope = null) | public static List<VariableV1> global_variables(string scope = null) | ||||
| { | { | ||||
| var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| return result == null ? new List<VariableV1>() : result as List<VariableV1>; | return result == null ? new List<VariableV1>() : result as List<VariableV1>; | ||||
| } | } | ||||
| @@ -27,57 +27,57 @@ namespace Tensorflow | |||||
| /// specified, but it is also possible to pass an explicit list of | /// specified, but it is also possible to pass an explicit list of | ||||
| /// variables. | /// variables. | ||||
| /// </summary> | /// </summary> | ||||
| public static class GraphKeys | |||||
| public class GraphKeys | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
| public string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Trainable resource-style variables. | /// Trainable resource-style variables. | ||||
| /// </summary> | /// </summary> | ||||
| public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||||
| public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key for streaming model ports. | /// Key for streaming model ports. | ||||
| /// </summary> | /// </summary> | ||||
| public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||||
| public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect losses | /// Key to collect losses | ||||
| /// </summary> | /// </summary> | ||||
| public const string LOSSES = "losses"; | |||||
| public string LOSSES = "losses"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| /// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
| /// </summary> | /// </summary> | ||||
| public static string GLOBAL_VARIABLES = "variables"; | |||||
| public string GLOBAL_VARIABLES = "variables"; | |||||
| public static string TRAIN_OP = "train_op"; | |||||
| public string TRAIN_OP = "train_op"; | |||||
| public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; | |||||
| public string GLOBAL_STEP = "global_step"; | |||||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| /// </summary> | /// </summary> | ||||
| public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||||
| public string SAVEABLE_OBJECTS = "saveable_objects"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect update_ops | /// Key to collect update_ops | ||||
| /// </summary> | /// </summary> | ||||
| public static string UPDATE_OPS = "update_ops"; | |||||
| public string UPDATE_OPS = "update_ops"; | |||||
| // Key to collect summaries. | // Key to collect summaries. | ||||
| public const string SUMMARIES = "summaries"; | |||||
| public string SUMMARIES = "summaries"; | |||||
| // Used to store v2 summary names. | // Used to store v2 summary names. | ||||
| public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
| public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
| // Key for control flow context. | // Key for control flow context. | ||||
| public static string COND_CONTEXT = "cond_context"; | |||||
| public static string WHILE_CONTEXT = "while_context"; | |||||
| public string COND_CONTEXT = "cond_context"; | |||||
| public string WHILE_CONTEXT = "while_context"; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples | |||||
| for (int epoch = 0; epoch < training_epochs; epoch++) | for (int epoch = 0; epoch < training_epochs; epoch++) | ||||
| { | { | ||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | foreach (var (x, y) in zip<float>(train_X, train_Y)) | ||||
| { | |||||
| sess.run(optimizer, | |||||
| new FeedItem(X, x), | |||||
| new FeedItem(Y, y)); | |||||
| } | |||||
| sess.run(optimizer, (X, x), (Y, y)); | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| { | { | ||||
| var c = sess.run(cost, | |||||
| new FeedItem(X, train_X), | |||||
| new FeedItem(Y, train_Y)); | |||||
| var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | ||||
| } | } | ||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | Console.WriteLine("Optimization Finished!"); | ||||
| var training_cost = sess.run(cost, | |||||
| new FeedItem(X, train_X), | |||||
| new FeedItem(Y, train_Y)); | |||||
| var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | ||||
| // Testing example | // Testing example | ||||
| @@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples | |||||
| var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | ||||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | Console.WriteLine("Testing... (Mean square loss Comparison)"); | ||||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | ||||
| new FeedItem(X, test_X), | |||||
| new FeedItem(Y, test_Y)); | |||||
| (X, test_X), (Y, test_Y)); | |||||
| Console.WriteLine($"Testing cost={testing_cost}"); | Console.WriteLine($"Testing cost={testing_cost}"); | ||||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | var diff = Math.Abs((float)training_cost - (float)testing_cost); | ||||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | Console.WriteLine($"Absolute mean square loss difference: {diff}"); | ||||
| @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text | |||||
| var y_one_hot = tf.one_hot(y, num_class); | var y_one_hot = tf.one_hot(y, num_class); | ||||
| loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | ||||
| var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<object>; | |||||
| var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List<object>; | |||||
| tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate | tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate | ||||
| { | { | ||||
| var adam = tf.train.AdamOptimizer(learning_rate); | var adam = tf.train.AdamOptimizer(learning_rate); | ||||