From f226ad704f417a781168b421e95857c6e20f7bf6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 20 Sep 2020 08:56:12 -0500 Subject: [PATCH] add overload for Layer call function, be able to input array and return array. --- .../MemoryTestingCases.cs | 17 +++ src/TensorFlowNET.Console/Program.cs | 3 + src/TensorFlowNET.Core/APIs/c_api.cs | 2 +- .../Eager/EagerTensor.Creation.cs | 2 +- .../Gradients/gradients_util.cs | 1 + src/TensorFlowNET.Core/Graphs/Graph.cs | 4 + .../Keras/Engine/Flatten.cs | 2 +- src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 40 +++++- .../Keras/Layers/BatchNormalization.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 2 +- .../Keras/Layers/Dropout.cs | 2 +- .../Keras/Layers/Embedding.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/LSTM.cs | 4 +- .../Keras/Layers/Pooling2D.cs | 2 +- .../Keras/Layers/Rescaling.cs | 2 +- src/TensorFlowNET.Core/Layers/Layer.cs | 49 ++++++- .../Operations/ControlFlows/WhileContext.cs | 2 +- .../Operations/NnOps/BasicLSTMCell.cs | 4 +- .../Operations/NnOps/BasicRNNCell.cs | 6 +- .../Operations/NnOps/rnn.cs | 2 +- .../Operations/Operation.cs | 2 +- .../Tensorflow.Binding.csproj | 6 +- .../Tensors/Tensor.Creation.cs | 2 + src/TensorFlowNET.Core/Tensors/Tensor.cs | 1 - .../Variables/ResourceVariable.cs | 131 +++++++++--------- .../Tensorflow.Benchmark.csproj | 1 + .../Tensorflow.UnitTest.csproj | 2 +- 28 files changed, 199 insertions(+), 98 deletions(-) diff --git a/src/TensorFlowNET.Console/MemoryTestingCases.cs b/src/TensorFlowNET.Console/MemoryTestingCases.cs index f9356955..09121513 100644 --- a/src/TensorFlowNET.Console/MemoryTestingCases.cs +++ b/src/TensorFlowNET.Console/MemoryTestingCases.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using NumSharp; using static Tensorflow.Binding; namespace Tensorflow @@ -18,6 +19,22 @@ namespace Tensorflow var tensor = tf.constant(3112.0f); } }; + + public Action Constant2x3 + => (iterate) => + { + var nd = np.array(new byte[,] + { + {1, 2, 3}, + {4, 5, 6} + }); + for (int i = 0; i < iterate; i++) + { + var tensor = tf.constant(nd); + var data = tensor.numpy(); + } + }; + public Action Variable => (iterate) => { diff --git a/src/TensorFlowNET.Console/Program.cs b/src/TensorFlowNET.Console/Program.cs index b8709849..e2360dff 100644 --- a/src/TensorFlowNET.Console/Program.cs +++ b/src/TensorFlowNET.Console/Program.cs @@ -15,6 +15,9 @@ namespace Tensorflow int batchSize = 1000; + // explaination of constant + mm.Execute(10, 100 * batchSize, cases.Constant2x3); + // 1 million float tensor 68M. mm.Execute(10, 100 * batchSize, cases.Constant); diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index db8f1c8f..4fb1d32e 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -43,7 +43,7 @@ namespace Tensorflow /// public partial class c_api { - public const string TensorFlowLibName = "tensorflow"; + public const string TensorFlowLibName = @"C:\Users\haipi\Documents\Projects\tensorflow\bazel-bin\tensorflow\tensorflow"; public static string StringPiece(IntPtr handle) { diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 68ef56b8..809c4cea 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -70,8 +70,8 @@ namespace Tensorflow.Eager protected override void DisposeUnmanagedResources(IntPtr handle) { + base.DisposeUnmanagedResources(handle); //print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); - c_api.TF_DeleteTensor(_handle); } } } diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index d802d28d..6eec094e 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -311,6 +311,7 @@ namespace Tensorflow while (queue.Count > 0) { var op = queue.Dequeue(); + if (reached_ops.Contains(op)) { between_ops.Add(op); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 35275d4e..9f6e87ac 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -278,7 +278,11 @@ namespace Tensorflow // after removing the trailing '/'. name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, attrs: attrs); + if (name == "rnn/while/basic_rnn_cell/MatMul" + || name == "rnn/while/basic_rnn_cell/MatMul/Enter") + { + } var input_ops = inputs.Select(x => x.op).ToArray(); var control_inputs = _control_dependencies_for_inputs(input_ops); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs index 45cfd8f2..6bd10151 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine _channels_first = args.DataFormat == "channels_first"; } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { if (_channels_first) { diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 530c4b27..2887a97b 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -121,7 +121,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null) + public Tensor Apply(Tensor inputs, bool is_training = false) { Tensor outputs = null; @@ -148,7 +148,7 @@ namespace Tensorflow.Keras.Engine if (!built) MaybeBuild(inputs); - outputs = call(inputs, is_training: is_training, state: state); + outputs = call(inputs, is_training: is_training); outputs = _set_connectivity_metadata_(inputs, outputs); _handle_activity_regularization(inputs, outputs); @@ -161,6 +161,35 @@ namespace Tensorflow.Keras.Engine return outputs; } + public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false) + { + Tensor[] outputs = null; + + callContext = callContext ?? new ThreadLocal() + { + Value = new CallContext() + }; + + var eager = tf.executing_eagerly(); + using var ctxManager = CallContext.enter(); + + string nameScope = ""; + if (eager) + nameScope = name; + else + nameScope = _name_scope(); + + tf_with(ops.name_scope(nameScope), scope => + { + if (!built) + MaybeBuild(inputs[0]); + + outputs = call(inputs, is_training: is_training, state: state); + }); + + return outputs; + } + private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) { /*var returnOutputs = new List(); @@ -200,7 +229,12 @@ namespace Tensorflow.Keras.Engine return null; } - protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected virtual Tensor call(Tensor inputs, bool is_training = false) + { + throw new NotImplementedException(""); + } + + protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index c8298234..a1c0ab7b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { Tensor outputs = null; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index c85c4379..282cef9d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool training = false) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index b6258aea..9c117fd4 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool training = false) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs index 6449be48..b581ac62 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { var output = tf_utils.smart_cond(is_training, () => tf.nn.dropout(inputs, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index bbc9e66d..f07c9c73 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs index 41ce3033..e5ddb1ec 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers .ToArray(); } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { - return base.call(inputs, is_training, state); + return base.call(inputs, is_training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index e2b5fa4d..83bfdaab 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 4); } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs index ec32d75b..99d3a9f5 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { scale = math_ops.cast(args.Scale, args.DType); offset = math_ops.cast(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 43fd90bc..4aaae7d0 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -61,9 +61,8 @@ namespace Tensorflow.Layers return (results[0], results[1]); } - public Tensor[] __call__(Tensor inputs, + public Tensor __call__(Tensor inputs, Tensor training = null, - Tensor state = null, VariableScope scope = null) { _set_scope(scope); @@ -88,16 +87,54 @@ namespace Tensorflow.Layers { _current_scope = scope2; // Actually call layer - outputs = base.Apply(inputs, - is_training: training == null ? false : false, - state: state); + outputs = base.Apply(inputs[0], + is_training: training == null ? false : false); + }); + + + // Update global default collections. + _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); + + return outputs; + } + + public Tensor[] __call__(Tensor[] inputs, + Tensor state = null, + Tensor training = null, + VariableScope scope = null) + { + _set_scope(scope); + _graph = ops._get_graph_from_inputs(inputs, graph: _graph); + + variable_scope scope_context_manager = null; + if (built) + { + scope_context_manager = tf.variable_scope(_scope, + reuse: true, + auxiliary_name_scope: false); + } + else + { + scope_context_manager = tf.variable_scope(_scope, + reuse: _reuse, + auxiliary_name_scope: false); + } + + Tensor[] outputs = null; + tf_with(scope_context_manager, scope2 => + { + _current_scope = scope2; + // Actually call layer + outputs = base.Apply(inputs, + state, + is_training: training == null ? false : false); }); // Update global default collections. _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); - return new Tensor[] { outputs }; + return outputs; } protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index a7a98743..2e634a1c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -326,7 +326,7 @@ namespace Tensorflow.Operations protected override void _AddOpInternal(Operation op) { - if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") + if (op.name == "rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape") { } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index 37f21377..bb53a468 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -61,7 +61,7 @@ namespace Tensorflow built = true; } - public Tensor[] __call__(Tensor inputs, LSTMStateTuple state) + public Tensor __call__(Tensor inputs, LSTMStateTuple state) { _state = state; return base.__call__(inputs); @@ -74,7 +74,7 @@ namespace Tensorflow /// /// /// - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false) { var one = constant_op.constant(1, dtype: dtypes.int32); // Parameters of gates are concatenated into one multiply for efficiency. diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index 592be625..3754072d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -67,14 +67,14 @@ namespace Tensorflow built = true; } - protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) { // Most basic RNN: output = new_state = act(W * input + U * state + B). - var concat = array_ops.concat(new[] { inputs, state }, 1); + var concat = array_ops.concat(new[] { inputs[0], state }, 1); var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); var output = _activation(gate_inputs, null); - return output; + return new[] { output, output }; } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 5509ba2c..66327cb5 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -364,7 +364,7 @@ namespace Tensorflow.Operations if (sequence_length != null) throw new NotImplementedException("sequence_length != null"); else - outputs = cell.__call__(input_t_t, state: state1); + outputs = cell.__call__(new[] { input_t_t }, state: state1); var (output, new_state) = (outputs[0], outputs[1]); // Keras cells always wrap state as list, even if it's a single tensor. diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 5a99deff..db528e70 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -326,7 +326,7 @@ namespace Tensorflow // the updated inputs are reloaded from the c_api lock (Locks.ProcessWide) { - // c_api.UpdateEdge(_graph, output, input, tf.Status.Handle); + c_api.UpdateEdge(_graph, output, input, tf.Status.Handle); //var updated_inputs = inputs; tf.Status.Check(); } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index b483987e..5ec2c317 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 2.2.0 - 0.20.0 + 0.20.1 8.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK @@ -19,13 +19,13 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.20.0.0 + 0.20.1.0 tf.net 0.20.x and above are based on tensorflow native 2.x. * Eager Mode is added finally. * tf.keras is partially working. * tf.data is added. - 0.20.0.0 + 0.20.1.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index be961738..0306eb8e 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -50,6 +50,8 @@ namespace Tensorflow /// public AllocationType AllocationType { get; protected set; } + public IntPtr TensorDataPointer => TF_TensorData(_handle); + /// /// Create a Tensor object from an existing TF handle /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 21d5ba00..b1b6700d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -261,7 +261,6 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { c_api.TF_DeleteTensor(handle); - if (AllocationHandle == null) return; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index d42eb3dd..3655a6db 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -88,80 +88,83 @@ namespace Tensorflow if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); - - ops.init_scope(); + _in_graph_mode = !tf.Context.executing_eagerly(); - tf_with(ops.name_scope(name, "Variable"), scope => + tf_with(ops.init_scope2(), delegate { - name = scope; - var handle_name = ops.name_from_scope_name(name); - string unique_id = ""; - string shared_name = ""; - - if (_in_graph_mode) - { - shared_name = handle_name; - unique_id = shared_name; - } - else + var values = init_from_fn ? new object[0] : new object[] { initial_value }; + tf_with(ops.name_scope(name, "Variable", values), scope => { - unique_id = $"{handle_name}_{ops.uid()}"; - shared_name = tf.Context.shared_name(); - } - - var attr = new AttrValue(); - attr.List = new AttrValue.Types.ListValue(); - attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); - tf_with(ops.name_scope("Initializer"), delegate - { - if (initial_value.GetType().GetInterface("IInitializer") != null) - initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); + name = scope; + var handle_name = ops.name_from_scope_name(name); + string unique_id = ""; + string shared_name = ""; + + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } else { - var value = init_from_fn ? (initial_value as Func)() : initial_value; - initial_value = ops.convert_to_tensor(value, - name: "initial_value", - dtype: dtype); + unique_id = $"{handle_name}_{ops.uid()}"; + shared_name = tf.Context.shared_name(); } - }); - _shape = shape ?? (initial_value as Tensor).TensorShape; - _initial_value = initial_value as Tensor; - + var attr = new AttrValue(); + attr.List = new AttrValue.Types.ListValue(); + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); + tf_with(ops.name_scope("Initializer"), delegate + { + if (initial_value.GetType().GetInterface("IInitializer") != null) + initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); + else + { + var value = init_from_fn ? (initial_value as Func)() : initial_value; + initial_value = ops.convert_to_tensor(value, + name: "initial_value", + dtype: dtype); + } + }); + _shape = shape ?? (initial_value as Tensor).TensorShape; + _initial_value = initial_value as Tensor; + + + + if (_in_graph_mode) + { + handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; - if (_in_graph_mode) - { - handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); - initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; + ops.colocate_with(initializer_op); - ops.colocate_with(initializer_op); + _graph_element = gen_array_ops.identity(handle, name = "read"); + ops.add_to_collections(collections, this); + _dtype = handle.dtype; + } + else + { + handle = resource_variable_ops.eager_safe_variable_handle( + initial_value: _initial_value, + shape: _shape, + shared_name: shared_name, + name: name, + graph_mode: _in_graph_mode); + + gen_resource_variable_ops.assign_variable_op(handle, _initial_value); + is_initialized_op = null; + initializer_op = null; + _graph_element = null; + _dtype = _initial_value.dtype.as_base_dtype(); + initial_value = _in_graph_mode ? initial_value : null; + } - _graph_element = gen_array_ops.identity(handle, name = "read"); - ops.add_to_collections(collections, this); - _dtype = handle.dtype; - } - else - { - handle = resource_variable_ops.eager_safe_variable_handle( - initial_value: _initial_value, - shape: _shape, - shared_name: shared_name, - name: name, - graph_mode: _in_graph_mode); - - gen_resource_variable_ops.assign_variable_op(handle, _initial_value); - is_initialized_op = null; - initializer_op = null; - _graph_element = null; - _dtype = _initial_value.dtype.as_base_dtype(); - initial_value = _in_graph_mode ? initial_value : null; - } - - base.__init__(trainable: trainable, - handle: handle, - name: name, - unique_id: unique_id, - handle_name: handle_name); + base.__init__(trainable: trainable, + handle: handle, + name: name, + unique_id: unique_id, + handle_name: handle_name); + }); }); } diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index cb63772b..c539919c 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -30,6 +30,7 @@ + diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index d5c854c0..1b9cae28 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -43,7 +43,7 @@ - +