Browse Source

Can't get Output Context fixed #411

tags/v0.12
Oceania2018 6 years ago
parent
commit
66f7e6d87f
4 changed files with 67 additions and 27 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  2. +41
    -0
      src/TensorFlowNET.Core/Operations/Operation.Instance.cs
  3. +6
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +19
    -18
      src/TensorFlowNET.Core/ops.cs

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow
for (int i = 0; i < NumInputs; i++)
{
var tf_output = Input(i);
var op = new Operation(tf_output.oper);
var op = GetOperation(tf_output.oper);
retval[i] = op.outputs[tf_output.index];
}


+ 41
- 0
src/TensorFlowNET.Core/Operations/Operation.Instance.cs View File

@@ -0,0 +1,41 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;

namespace Tensorflow
{
public partial class Operation
{
// cache the mapping between managed and unmanaged op
// some data is stored in managed instance, so when
// create Operation by IntPtr, it will lost some data.
private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>();

/// <summary>
/// Get operation by handle
/// </summary>
/// <param name="handle"></param>
/// <returns></returns>
public Operation GetOperation(IntPtr handle)
{
return OpInstances.ContainsKey(handle) ?
OpInstances[handle] :
new Operation(handle);
}
}
}

+ 6
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -84,9 +84,10 @@ namespace Tensorflow
_control_flow_context = _graph._get_control_flow_context();

// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
OpInstances[_handle] = this;
}

public Operation(Graph g, string opType, string oper_name)
/*public Operation(Graph g, string opType, string oper_name)
{
_graph = g;

@@ -102,7 +103,7 @@ namespace Tensorflow
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
}
}*/

/// <summary>
/// Creates an `Operation`.
@@ -151,11 +152,6 @@ namespace Tensorflow
}
}

if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1")
{
}

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
@@ -164,7 +160,7 @@ namespace Tensorflow
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);

var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

// Initialize self._outputs.
@@ -180,6 +176,8 @@ namespace Tensorflow

if (_handle != IntPtr.Zero)
_control_flow_post_processing();

OpInstances[_handle] = this;
}

public void run(FeedItem[] feed_dict = null, Session session = null)


+ 19
- 18
src/TensorFlowNET.Core/ops.cs View File

@@ -227,29 +227,30 @@ namespace Tensorflow
throw new NotImplementedException("_create_c_op");
}

var status = new Status();

// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);

// Add attrs
foreach (var attr in node_def.Attr)
using (var status = new Status())
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak
Marshal.Copy(bytes, 0, proto, bytes.Length);
uint len = (uint) bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);
// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);

status.Check(true);
}
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var protoHandle = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, protoHandle, bytes.Length);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status);
status.Check(true);
Marshal.FreeHGlobal(protoHandle);
}

var c_op = c_api.TF_FinishOperation(op_desc, status);
var c_op = c_api.TF_FinishOperation(op_desc, status);

status.Check(true);
status.Check(true);

return c_op;
return c_op;
}
}
}



Loading…
Cancel
Save