| @@ -53,7 +53,7 @@ namespace Tensorflow | |||
| for (int i = 0; i < NumInputs; i++) | |||
| { | |||
| var tf_output = Input(i); | |||
| var op = new Operation(tf_output.oper); | |||
| var op = GetOperation(tf_output.oper); | |||
| retval[i] = op.outputs[tf_output.index]; | |||
| } | |||
| @@ -0,0 +1,41 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Operation | |||
| { | |||
| // cache the mapping between managed and unmanaged op | |||
| // some data is stored in managed instance, so when | |||
| // create Operation by IntPtr, it will lost some data. | |||
| private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>(); | |||
| /// <summary> | |||
| /// Get operation by handle | |||
| /// </summary> | |||
| /// <param name="handle"></param> | |||
| /// <returns></returns> | |||
| public Operation GetOperation(IntPtr handle) | |||
| { | |||
| return OpInstances.ContainsKey(handle) ? | |||
| OpInstances[handle] : | |||
| new Operation(handle); | |||
| } | |||
| } | |||
| } | |||
| @@ -84,9 +84,10 @@ namespace Tensorflow | |||
| _control_flow_context = _graph._get_control_flow_context(); | |||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||
| OpInstances[_handle] = this; | |||
| } | |||
| public Operation(Graph g, string opType, string oper_name) | |||
| /*public Operation(Graph g, string opType, string oper_name) | |||
| { | |||
| _graph = g; | |||
| @@ -102,7 +103,7 @@ namespace Tensorflow | |||
| // Dict mapping op name to file and line information for op colocation | |||
| // context managers. | |||
| _control_flow_context = graph._get_control_flow_context(); | |||
| } | |||
| }*/ | |||
| /// <summary> | |||
| /// Creates an `Operation`. | |||
| @@ -151,11 +152,6 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||
| { | |||
| } | |||
| // Dict mapping op name to file and line information for op colocation | |||
| // context managers. | |||
| _control_flow_context = graph._get_control_flow_context(); | |||
| @@ -164,7 +160,7 @@ namespace Tensorflow | |||
| if (op_def == null) | |||
| op_def = g.GetOpDef(node_def.Op); | |||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||
| // Initialize self._outputs. | |||
| @@ -180,6 +176,8 @@ namespace Tensorflow | |||
| if (_handle != IntPtr.Zero) | |||
| _control_flow_post_processing(); | |||
| OpInstances[_handle] = this; | |||
| } | |||
| public void run(FeedItem[] feed_dict = null, Session session = null) | |||
| @@ -227,29 +227,30 @@ namespace Tensorflow | |||
| throw new NotImplementedException("_create_c_op"); | |||
| } | |||
| var status = new Status(); | |||
| // Add control inputs | |||
| foreach (var control_input in control_inputs) | |||
| c_api.TF_AddControlInput(op_desc, control_input); | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| using (var status = new Status()) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| uint len = (uint) bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
| // Add control inputs | |||
| foreach (var control_input in control_inputs) | |||
| c_api.TF_AddControlInput(op_desc, control_input); | |||
| status.Check(true); | |||
| } | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
| var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||
| uint len = (uint)bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||
| status.Check(true); | |||
| Marshal.FreeHGlobal(protoHandle); | |||
| } | |||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
| status.Check(true); | |||
| status.Check(true); | |||
| return c_op; | |||
| return c_op; | |||
| } | |||
| } | |||
| } | |||