Browse Source

Use tensor.Id instead of GetHashCode.

tags/yolov3
Oceania2018 4 years ago
parent
commit
c3dd96b1c7
4 changed files with 28 additions and 33 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Binding.cs
  2. +5
    -5
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +15
    -20
      src/TensorFlowNET.Keras/Engine/Functional.cs
  4. +7
    -7
      src/TensorFlowNET.Keras/Engine/Node.cs

+ 1
- 1
src/TensorFlowNET.Core/Binding.cs View File

@@ -4,7 +4,7 @@ namespace Tensorflow
{
public static partial class Binding
{
[DebuggerNonUserCode]
[DebuggerHidden]
public static tensorflow tf { get; } = New<tensorflow>();

/// <summary>


+ 5
- 5
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -4,7 +4,7 @@ using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
public partial class EagerTensor : Tensor
public partial class EagerTensor
{
public EagerTensor() : base(IntPtr.Zero)
{
@@ -48,8 +48,8 @@ namespace Tensorflow.Eager
if (_handle == IntPtr.Zero)
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle);

//print($"new Tensor {Id} {_handle.ToString("x16")}");
//print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
// print($"New TensorHandle {Id} 0x{_handle.ToString("x16")}");
// print($"New EagerTensorHandle {Id} {EagerTensorHandle}");

return this;
}
@@ -96,14 +96,14 @@ namespace Tensorflow.Eager
{
base.DisposeManagedResources();

//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
// print($"Delete EagerTensorHandle {Id} {EagerTensorHandle}");
EagerTensorHandle.Dispose();
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
base.DisposeUnmanagedResources(handle);
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
// print($"Delete TensorHandle {Id} 0x{_handle.ToString("x16")}");
}
}
}

+ 15
- 20
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -23,8 +23,7 @@ namespace Tensorflow.Keras.Engine
List<KerasHistory> _output_coordinates;
public string[] NetworkNodes { get; set; }

Dictionary<int, int> tensor_usage_count;
public Dictionary<int, int> TensorUsageCount => tensor_usage_count;
Dictionary<long, int> tensor_usage_count;

public Functional(Tensors inputs, Tensors outputs, string name = null)
: base(new ModelArgs
@@ -38,7 +37,7 @@ namespace Tensorflow.Keras.Engine
_output_layers = new List<ILayer>();
_input_coordinates = new List<KerasHistory>();
_output_coordinates = new List<KerasHistory>();
tensor_usage_count = new Dictionary<int, int>();
tensor_usage_count = new Dictionary<long, int>();
if (this is Sequential)
return;
_init_graph_network(inputs, outputs);
@@ -116,33 +115,33 @@ namespace Tensorflow.Keras.Engine

void ComputeTensorUsageCount()
{
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList();
var available_tensors = inputs.Select(x => x.Id).ToList();
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
foreach (var depth in depth_keys)
{
foreach (var node in NodesByDepth[depth])
{
var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray();
var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray();
if (input_tensors.issubset(available_tensors))
{
foreach (var tensor in node.KerasInputs)
{
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
tensor_usage_count[tensor.GetHashCode()] = 0;
tensor_usage_count[tensor.GetHashCode()] += 1;
if (!tensor_usage_count.ContainsKey(tensor.Id))
tensor_usage_count[tensor.Id] = 0;
tensor_usage_count[tensor.Id] += 1;
}

foreach (var output_tensor in node.Outputs)
available_tensors.Add(output_tensor.GetHashCode());
available_tensors.Add(output_tensor.Id);
}
}
}

foreach (var tensor in outputs)
{
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
tensor_usage_count[tensor.GetHashCode()] = 0;
tensor_usage_count[tensor.GetHashCode()] += 1;
if (!tensor_usage_count.ContainsKey(tensor.Id))
tensor_usage_count[tensor.Id] = 0;
tensor_usage_count[tensor.Id] += 1;
}
}

@@ -316,12 +315,11 @@ namespace Tensorflow.Keras.Engine
input_t.KerasMask = masks[i];
}

var tensor_dict = new Dictionary<int, Queue<Tensor>>();
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
foreach (var (x, y) in zip(this.inputs, inputs))
{
var y1 = conform_to_reference_input(y, x);
var x_id = x.GetHashCode();
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1));
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1));
}

var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
@@ -347,13 +345,10 @@ namespace Tensorflow.Keras.Engine
}
}

var output_tensors = new List<Tensor>();
var output_tensors = new Tensors();

foreach (var x in outputs)
{
var x_id = x.GetHashCode();
output_tensors.append(tensor_dict[x_id].Dequeue());
}
output_tensors.Add(tensor_dict[x.Id].Dequeue());

return output_tensors;
}


+ 7
- 7
src/TensorFlowNET.Keras/Engine/Node.cs View File

@@ -42,10 +42,10 @@ namespace Tensorflow.Keras.Engine
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>();
public ILayer Layer { get; set; }
public bool is_input => args.InputTensors == null;
public int[] FlatInputIds { get; set; }
public int[] FlatOutputIds { get; set; }
public long[] FlatInputIds { get; set; }
public long[] FlatOutputIds { get; set; }
bool _single_positional_tensor_passed => KerasInputs.Count() == 1;
Dictionary<int, int> _keras_inputs_ids_and_indices = new Dictionary<int, int>();
Dictionary<int, long> _keras_inputs_ids_and_indices = new Dictionary<int, long>();
public INode[] ParentNodes
{
get
@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Engine
KerasInputs.AddRange(args.InputTensors);

foreach (var (i, ele) in enumerate(KerasInputs))
_keras_inputs_ids_and_indices[i] = ele.GetHashCode();
_keras_inputs_ids_and_indices[i] = ele.Id;

// Wire up Node to Layers.
layer.InboundNodes.Add(this);
@@ -89,8 +89,8 @@ namespace Tensorflow.Keras.Engine
tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor);

// Cached for performance.
FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray();
FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray();
FlatInputIds = KerasInputs.Select(x => x.Id).ToArray();
FlatOutputIds = Outputs.Select(x => x.Id).ToArray();
}

/// <summary>
@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="tensor_dict"></param>
/// <returns></returns>
public Tensors MapArguments(Dictionary<int, Queue<Tensor>> tensor_dict)
public Tensors MapArguments(Dictionary<long, Queue<Tensor>> tensor_dict)
{
if (_single_positional_tensor_passed)
{


Loading…
Cancel
Save