|
|
|
@@ -23,8 +23,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
List<KerasHistory> _output_coordinates; |
|
|
|
public string[] NetworkNodes { get; set; } |
|
|
|
|
|
|
|
Dictionary<int, int> tensor_usage_count; |
|
|
|
public Dictionary<int, int> TensorUsageCount => tensor_usage_count; |
|
|
|
Dictionary<long, int> tensor_usage_count; |
|
|
|
|
|
|
|
public Functional(Tensors inputs, Tensors outputs, string name = null) |
|
|
|
: base(new ModelArgs |
|
|
|
@@ -38,7 +37,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
_output_layers = new List<ILayer>(); |
|
|
|
_input_coordinates = new List<KerasHistory>(); |
|
|
|
_output_coordinates = new List<KerasHistory>(); |
|
|
|
tensor_usage_count = new Dictionary<int, int>(); |
|
|
|
tensor_usage_count = new Dictionary<long, int>(); |
|
|
|
if (this is Sequential) |
|
|
|
return; |
|
|
|
_init_graph_network(inputs, outputs); |
|
|
|
@@ -116,33 +115,33 @@ namespace Tensorflow.Keras.Engine |
|
|
|
|
|
|
|
void ComputeTensorUsageCount() |
|
|
|
{ |
|
|
|
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); |
|
|
|
var available_tensors = inputs.Select(x => x.Id).ToList(); |
|
|
|
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray(); |
|
|
|
foreach (var depth in depth_keys) |
|
|
|
{ |
|
|
|
foreach (var node in NodesByDepth[depth]) |
|
|
|
{ |
|
|
|
var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray(); |
|
|
|
var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray(); |
|
|
|
if (input_tensors.issubset(available_tensors)) |
|
|
|
{ |
|
|
|
foreach (var tensor in node.KerasInputs) |
|
|
|
{ |
|
|
|
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) |
|
|
|
tensor_usage_count[tensor.GetHashCode()] = 0; |
|
|
|
tensor_usage_count[tensor.GetHashCode()] += 1; |
|
|
|
if (!tensor_usage_count.ContainsKey(tensor.Id)) |
|
|
|
tensor_usage_count[tensor.Id] = 0; |
|
|
|
tensor_usage_count[tensor.Id] += 1; |
|
|
|
} |
|
|
|
|
|
|
|
foreach (var output_tensor in node.Outputs) |
|
|
|
available_tensors.Add(output_tensor.GetHashCode()); |
|
|
|
available_tensors.Add(output_tensor.Id); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
foreach (var tensor in outputs) |
|
|
|
{ |
|
|
|
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) |
|
|
|
tensor_usage_count[tensor.GetHashCode()] = 0; |
|
|
|
tensor_usage_count[tensor.GetHashCode()] += 1; |
|
|
|
if (!tensor_usage_count.ContainsKey(tensor.Id)) |
|
|
|
tensor_usage_count[tensor.Id] = 0; |
|
|
|
tensor_usage_count[tensor.Id] += 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -316,12 +315,11 @@ namespace Tensorflow.Keras.Engine |
|
|
|
input_t.KerasMask = masks[i]; |
|
|
|
} |
|
|
|
|
|
|
|
var tensor_dict = new Dictionary<int, Queue<Tensor>>(); |
|
|
|
var tensor_dict = new Dictionary<long, Queue<Tensor>>(); |
|
|
|
foreach (var (x, y) in zip(this.inputs, inputs)) |
|
|
|
{ |
|
|
|
var y1 = conform_to_reference_input(y, x); |
|
|
|
var x_id = x.GetHashCode(); |
|
|
|
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1)); |
|
|
|
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1)); |
|
|
|
} |
|
|
|
|
|
|
|
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); |
|
|
|
@@ -347,13 +345,10 @@ namespace Tensorflow.Keras.Engine |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var output_tensors = new List<Tensor>(); |
|
|
|
var output_tensors = new Tensors(); |
|
|
|
|
|
|
|
foreach (var x in outputs) |
|
|
|
{ |
|
|
|
var x_id = x.GetHashCode(); |
|
|
|
output_tensors.append(tensor_dict[x_id].Dequeue()); |
|
|
|
} |
|
|
|
output_tensors.Add(tensor_dict[x.Id].Dequeue()); |
|
|
|
|
|
|
|
return output_tensors; |
|
|
|
} |
|
|
|
|