From 66f7e6d87f3425cc0f4459523b506611687e0cb8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 1 Oct 2019 13:35:10 -0500 Subject: [PATCH] Can't get Output Context fixed #411 --- .../Operations/Operation.Input.cs | 2 +- .../Operations/Operation.Instance.cs | 41 +++++++++++++++++++ .../Operations/Operation.cs | 14 +++---- src/TensorFlowNET.Core/ops.cs | 37 +++++++++-------- 4 files changed, 67 insertions(+), 27 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/Operation.Instance.cs diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 6d6403c9..c80e99f6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -53,7 +53,7 @@ namespace Tensorflow for (int i = 0; i < NumInputs; i++) { var tf_output = Input(i); - var op = new Operation(tf_output.oper); + var op = GetOperation(tf_output.oper); retval[i] = op.outputs[tf_output.index]; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs new file mode 100644 index 00000000..6f6c8226 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs @@ -0,0 +1,41 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; + +namespace Tensorflow +{ + public partial class Operation + { + // cache the mapping between managed and unmanaged op + // some data is stored in managed instance, so when + // create Operation by IntPtr, it will lost some data. + private static Dictionary OpInstances = new Dictionary(); + + /// + /// Get operation by handle + /// + /// + /// + public Operation GetOperation(IntPtr handle) + { + return OpInstances.ContainsKey(handle) ? + OpInstances[handle] : + new Operation(handle); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index b6811917..caf5ac18 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -84,9 +84,10 @@ namespace Tensorflow _control_flow_context = _graph._get_control_flow_context(); // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. + OpInstances[_handle] = this; } - public Operation(Graph g, string opType, string oper_name) + /*public Operation(Graph g, string opType, string oper_name) { _graph = g; @@ -102,7 +103,7 @@ namespace Tensorflow // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); - } + }*/ /// /// Creates an `Operation`. @@ -151,11 +152,6 @@ namespace Tensorflow } } - if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") - { - - } - // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); @@ -164,7 +160,7 @@ namespace Tensorflow if (op_def == null) op_def = g.GetOpDef(node_def.Op); - var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); // Initialize self._outputs. @@ -180,6 +176,8 @@ namespace Tensorflow if (_handle != IntPtr.Zero) _control_flow_post_processing(); + + OpInstances[_handle] = this; } public void run(FeedItem[] feed_dict = null, Session session = null) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 6ab9feae..846de1ea 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -227,29 +227,30 @@ namespace Tensorflow throw new NotImplementedException("_create_c_op"); } - var status = new Status(); - - // Add control inputs - foreach (var control_input in control_inputs) - c_api.TF_AddControlInput(op_desc, control_input); - - // Add attrs - foreach (var attr in node_def.Attr) + using (var status = new Status()) { - var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. - var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak - Marshal.Copy(bytes, 0, proto, bytes.Length); - uint len = (uint) bytes.Length; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); + // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); - status.Check(true); - } + // Add attrs + foreach (var attr in node_def.Attr) + { + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var protoHandle = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, protoHandle, bytes.Length); + uint len = (uint)bytes.Length; + c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); + status.Check(true); + Marshal.FreeHGlobal(protoHandle); + } - var c_op = c_api.TF_FinishOperation(op_desc, status); + var c_op = c_api.TF_FinishOperation(op_desc, status); - status.Check(true); + status.Check(true); - return c_op; + return c_op; + } } }