From 23edfd09f8d8e7435b23e5e9b15a48a34a5aa390 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 7 Sep 2019 09:55:44 -0500 Subject: [PATCH] _safe_initial_value_from_op #360 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Train/ExponentialMovingAverage.cs | 5 +++++ src/TensorFlowNET.Core/Train/moving_averages.cs | 6 +++--- src/TensorFlowNET.Core/Variables/RefVariable.cs | 12 ++++++++---- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index cad7a5a6..fa806156 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -242,7 +242,7 @@ namespace Tensorflow throw new RuntimeError("Graph is finalized and cannot be modified."); } - public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, + public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null) { diff --git a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs index 64068c18..2d4effca 100644 --- a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs +++ b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs @@ -49,6 +49,11 @@ namespace Tensorflow.Train ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); _averages[var] = avg; } + else + { + // avg = slot_creator.create_zeros_slot( + throw new NotImplementedException(""); + } } return tf_with(ops.name_scope(name), scope => diff --git a/src/TensorFlowNET.Core/Train/moving_averages.cs b/src/TensorFlowNET.Core/Train/moving_averages.cs index 5aee7901..d77367f3 100644 --- a/src/TensorFlowNET.Core/Train/moving_averages.cs +++ b/src/TensorFlowNET.Core/Train/moving_averages.cs @@ -19,14 +19,14 @@ namespace Tensorflow.Train public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, bool zero_debias = true, string name = null) { - tf_with(ops.name_scope(name, "", new { variable, value, decay }), scope => + return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope => { decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); if (decay.dtype != variable.dtype.as_base_dtype()) decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); - }); - throw new NotImplementedException("assign_moving_average"); + return decay; + }); } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 1f7ca41b..97e1d0f4 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -273,10 +273,14 @@ namespace Tensorflow new_op_type = "Switch"; var new_op_name = op.node_def.Name + "_" + name; new_op_name = new_op_name.Replace(":", "_"); - var attrs = new Dictionary(); - attrs[op.node_def.Name] = op.node_def.Attr.ElementAt(0).Value; - /*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, - name: new_op_name, attrs: attrs);*/ + + // Convert attr values to AttrValue protos. + var attr_protos = new Dictionary(); + foreach (var attr_def in op.node_def.Attr) + attr_protos[attr_def.Key] = attr_def.Value; + + return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, + name: new_op_name, attrs: attr_protos); } return op; }