Browse Source

Use tensor.Id instead of GetHashCode.

tags/yolov3
Oceania2018 4 years ago
parent
commit
c3dd96b1c7
4 changed files with 28 additions and 33 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Binding.cs
  2. +5
    -5
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +15
    -20
      src/TensorFlowNET.Keras/Engine/Functional.cs
  4. +7
    -7
      src/TensorFlowNET.Keras/Engine/Node.cs

+ 1
- 1
src/TensorFlowNET.Core/Binding.cs View File

@@ -4,7 +4,7 @@ namespace Tensorflow
{ {
public static partial class Binding public static partial class Binding
{ {
[DebuggerNonUserCode]
[DebuggerHidden]
public static tensorflow tf { get; } = New<tensorflow>(); public static tensorflow tf { get; } = New<tensorflow>();


/// <summary> /// <summary>


+ 5
- 5
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -4,7 +4,7 @@ using static Tensorflow.Binding;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
public partial class EagerTensor : Tensor
public partial class EagerTensor
{ {
public EagerTensor() : base(IntPtr.Zero) public EagerTensor() : base(IntPtr.Zero)
{ {
@@ -48,8 +48,8 @@ namespace Tensorflow.Eager
if (_handle == IntPtr.Zero) if (_handle == IntPtr.Zero)
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle); _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle);


//print($"new Tensor {Id} {_handle.ToString("x16")}");
//print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
// print($"New TensorHandle {Id} 0x{_handle.ToString("x16")}");
// print($"New EagerTensorHandle {Id} {EagerTensorHandle}");


return this; return this;
} }
@@ -96,14 +96,14 @@ namespace Tensorflow.Eager
{ {
base.DisposeManagedResources(); base.DisposeManagedResources();


//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
// print($"Delete EagerTensorHandle {Id} {EagerTensorHandle}");
EagerTensorHandle.Dispose(); EagerTensorHandle.Dispose();
} }


protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
base.DisposeUnmanagedResources(handle); base.DisposeUnmanagedResources(handle);
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
// print($"Delete TensorHandle {Id} 0x{_handle.ToString("x16")}");
} }
} }
} }

+ 15
- 20
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -23,8 +23,7 @@ namespace Tensorflow.Keras.Engine
List<KerasHistory> _output_coordinates; List<KerasHistory> _output_coordinates;
public string[] NetworkNodes { get; set; } public string[] NetworkNodes { get; set; }


Dictionary<int, int> tensor_usage_count;
public Dictionary<int, int> TensorUsageCount => tensor_usage_count;
Dictionary<long, int> tensor_usage_count;


public Functional(Tensors inputs, Tensors outputs, string name = null) public Functional(Tensors inputs, Tensors outputs, string name = null)
: base(new ModelArgs : base(new ModelArgs
@@ -38,7 +37,7 @@ namespace Tensorflow.Keras.Engine
_output_layers = new List<ILayer>(); _output_layers = new List<ILayer>();
_input_coordinates = new List<KerasHistory>(); _input_coordinates = new List<KerasHistory>();
_output_coordinates = new List<KerasHistory>(); _output_coordinates = new List<KerasHistory>();
tensor_usage_count = new Dictionary<int, int>();
tensor_usage_count = new Dictionary<long, int>();
if (this is Sequential) if (this is Sequential)
return; return;
_init_graph_network(inputs, outputs); _init_graph_network(inputs, outputs);
@@ -116,33 +115,33 @@ namespace Tensorflow.Keras.Engine


void ComputeTensorUsageCount() void ComputeTensorUsageCount()
{ {
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList();
var available_tensors = inputs.Select(x => x.Id).ToList();
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray(); var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
foreach (var depth in depth_keys) foreach (var depth in depth_keys)
{ {
foreach (var node in NodesByDepth[depth]) foreach (var node in NodesByDepth[depth])
{ {
var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray();
var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray();
if (input_tensors.issubset(available_tensors)) if (input_tensors.issubset(available_tensors))
{ {
foreach (var tensor in node.KerasInputs) foreach (var tensor in node.KerasInputs)
{ {
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
tensor_usage_count[tensor.GetHashCode()] = 0;
tensor_usage_count[tensor.GetHashCode()] += 1;
if (!tensor_usage_count.ContainsKey(tensor.Id))
tensor_usage_count[tensor.Id] = 0;
tensor_usage_count[tensor.Id] += 1;
} }


foreach (var output_tensor in node.Outputs) foreach (var output_tensor in node.Outputs)
available_tensors.Add(output_tensor.GetHashCode());
available_tensors.Add(output_tensor.Id);
} }
} }
} }


foreach (var tensor in outputs) foreach (var tensor in outputs)
{ {
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
tensor_usage_count[tensor.GetHashCode()] = 0;
tensor_usage_count[tensor.GetHashCode()] += 1;
if (!tensor_usage_count.ContainsKey(tensor.Id))
tensor_usage_count[tensor.Id] = 0;
tensor_usage_count[tensor.Id] += 1;
} }
} }


@@ -316,12 +315,11 @@ namespace Tensorflow.Keras.Engine
input_t.KerasMask = masks[i]; input_t.KerasMask = masks[i];
} }


var tensor_dict = new Dictionary<int, Queue<Tensor>>();
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
foreach (var (x, y) in zip(this.inputs, inputs)) foreach (var (x, y) in zip(this.inputs, inputs))
{ {
var y1 = conform_to_reference_input(y, x); var y1 = conform_to_reference_input(y, x);
var x_id = x.GetHashCode();
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1));
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1));
} }


var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
@@ -347,13 +345,10 @@ namespace Tensorflow.Keras.Engine
} }
} }


var output_tensors = new List<Tensor>();
var output_tensors = new Tensors();


foreach (var x in outputs) foreach (var x in outputs)
{
var x_id = x.GetHashCode();
output_tensors.append(tensor_dict[x_id].Dequeue());
}
output_tensors.Add(tensor_dict[x.Id].Dequeue());


return output_tensors; return output_tensors;
} }


+ 7
- 7
src/TensorFlowNET.Keras/Engine/Node.cs View File

@@ -42,10 +42,10 @@ namespace Tensorflow.Keras.Engine
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); public List<Tensor> KerasInputs { get; set; } = new List<Tensor>();
public ILayer Layer { get; set; } public ILayer Layer { get; set; }
public bool is_input => args.InputTensors == null; public bool is_input => args.InputTensors == null;
public int[] FlatInputIds { get; set; }
public int[] FlatOutputIds { get; set; }
public long[] FlatInputIds { get; set; }
public long[] FlatOutputIds { get; set; }
bool _single_positional_tensor_passed => KerasInputs.Count() == 1; bool _single_positional_tensor_passed => KerasInputs.Count() == 1;
Dictionary<int, int> _keras_inputs_ids_and_indices = new Dictionary<int, int>();
Dictionary<int, long> _keras_inputs_ids_and_indices = new Dictionary<int, long>();
public INode[] ParentNodes public INode[] ParentNodes
{ {
get get
@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Engine
KerasInputs.AddRange(args.InputTensors); KerasInputs.AddRange(args.InputTensors);


foreach (var (i, ele) in enumerate(KerasInputs)) foreach (var (i, ele) in enumerate(KerasInputs))
_keras_inputs_ids_and_indices[i] = ele.GetHashCode();
_keras_inputs_ids_and_indices[i] = ele.Id;


// Wire up Node to Layers. // Wire up Node to Layers.
layer.InboundNodes.Add(this); layer.InboundNodes.Add(this);
@@ -89,8 +89,8 @@ namespace Tensorflow.Keras.Engine
tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor);


// Cached for performance. // Cached for performance.
FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray();
FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray();
FlatInputIds = KerasInputs.Select(x => x.Id).ToArray();
FlatOutputIds = Outputs.Select(x => x.Id).ToArray();
} }


/// <summary> /// <summary>
@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Engine
/// </summary> /// </summary>
/// <param name="tensor_dict"></param> /// <param name="tensor_dict"></param>
/// <returns></returns> /// <returns></returns>
public Tensors MapArguments(Dictionary<int, Queue<Tensor>> tensor_dict)
public Tensors MapArguments(Dictionary<long, Queue<Tensor>> tensor_dict)
{ {
if (_single_positional_tensor_passed) if (_single_positional_tensor_passed)
{ {


Loading…
Cancel
Save