Browse Source

Allow Tensors to extend.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
505565550f
9 changed files with 49 additions and 52 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  2. +10
    -1
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  4. +17
    -10
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  5. +1
    -19
      src/TensorFlowNET.Core/Graphs/Graph.cs
  6. +10
    -12
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  7. +4
    -3
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  8. +2
    -2
      src/TensorFlowNET.Keras/Engine/Functional.cs
  9. +1
    -1
      src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs

+ 1
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -139,7 +139,7 @@ namespace Tensorflow.Functions
"executor_type", "", "executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
}; };
return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs);
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
} }


ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)


+ 10
- 1
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -61,7 +62,7 @@ namespace Tensorflow.Functions
processed_args.add(arg); processed_args.add(arg);
input_index += 1; input_index += 1;
} }
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
return backward.CallFlat(processed_args.ToArray(), outputs); return backward.CallFlat(processed_args.ToArray(), outputs);
}; };


@@ -91,6 +92,14 @@ namespace Tensorflow.Functions
grad_ys: gradients_wrt_outputs.ToArray(), grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph); src_graph: _func_graph);


var captures_from_forward = backwards_graph.external_captures()
.Where(x => !x.IsEagerTensor && x.graph == _func_graph)
.ToArray();
foreach(var capture in captures_from_forward)
{
_func_graph.Outputs.Add(capture);
}

var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>(); var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;


+ 3
- 3
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Graphs


public override void OnEntry(MethodExecutionArgs args) public override void OnEntry(MethodExecutionArgs args)
{ {
func_name = $"autograph_{args.Instance.GetType().FullName}.{args.Method.Name}";
func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}";


if (functions.ContainsKey(func_name)) if (functions.ContainsKey(func_name))
{ {
@@ -44,13 +44,13 @@ namespace Tensorflow.Graphs
} }
else else
{ {
originalInputs = new Tensors(args.Arguments.Length);
originalInputs = new Tensors();
// convert args to placeholder // convert args to placeholder
for (var i = 0; i < args.Arguments.Length; i++) for (var i = 0; i < args.Arguments.Length; i++)
{ {
if (args.Arguments[i] is EagerTensor tensor) if (args.Arguments[i] is EagerTensor tensor)
{ {
originalInputs[i] = tensor;
originalInputs.Add(tensor);
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs");
} }
} }


+ 17
- 10
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -16,16 +16,23 @@ namespace Tensorflow.Graphs
Graph outer_graph; Graph outer_graph;
public Graph OuterGraph => outer_graph; public Graph OuterGraph => outer_graph;


string func_name;

// _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
IntPtr func_handle; IntPtr func_handle;
public string FuncName => func_name;
public string FuncName => _graph_key;


public Tensors Inputs { get; set; } public Tensors Inputs { get; set; }
public Tensors Outputs { get; set; } public Tensors Outputs { get; set; }
public Dictionary<string, string> Attrs { get; set; } public Dictionary<string, string> Attrs { get; set; }


public Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures()
=> _captures.Select(x => x.Value.Item1).ToArray();

public Tensor[] internal_captures()
=> _captures.Select(x => x.Value.Item2).ToArray();

// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray();


@@ -35,7 +42,7 @@ namespace Tensorflow.Graphs
public FuncGraph(string name) : base() public FuncGraph(string name) : base()
{ {
outer_graph = ops.get_default_graph(); outer_graph = ops.get_default_graph();
func_name = name;
_graph_key = name;


tf.Context.graph_mode(); tf.Context.graph_mode();
as_default(); as_default();
@@ -44,7 +51,7 @@ namespace Tensorflow.Graphs
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
{ {
outer_graph = ops.get_default_graph(); outer_graph = ops.get_default_graph();
func_name = name;
_graph_key = name;
Attrs = attrs; Attrs = attrs;
// Will to test if FuncGraph has memory leak // Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle); // c_api.TF_DeleteGraph(_handle);
@@ -60,7 +67,7 @@ namespace Tensorflow.Graphs
{ {
using var status = new Status(); using var status = new Status();
func_handle = c_api.TF_GraphToFunction(_handle, func_handle = c_api.TF_GraphToFunction(_handle,
func_name,
_graph_key,
false, false,
opers.Length, opers.Length,
opers.Select(x => (IntPtr)x).ToArray(), opers.Select(x => (IntPtr)x).ToArray(),
@@ -82,7 +89,7 @@ namespace Tensorflow.Graphs
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle); c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle);
status.Check(true); status.Check(true);


func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle));
_graph_key = c_api.StringPiece(c_api.TF_FunctionName(func_handle));


Inputs = inputs; Inputs = inputs;
// mark_as_return // mark_as_return
@@ -131,7 +138,7 @@ namespace Tensorflow.Graphs
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
{ {
Tensor placeholder = null; Tensor placeholder = null;
if (!_captures.Contains(tensor.Id))
if (!_captures.ContainsKey(tensor.Id))
{ {
placeholder = _create_substitute_placeholder(tensor, placeholder = _create_substitute_placeholder(tensor,
name: name, name: name,
@@ -141,7 +148,7 @@ namespace Tensorflow.Graphs
} }
else else
{ {
placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2;
placeholder = _captures[tensor.Id].Item2;
} }


BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
@@ -160,7 +167,7 @@ namespace Tensorflow.Graphs


void add_capture(Tensor tensor, Tensor placeholder) void add_capture(Tensor tensor, Tensor placeholder)
{ {
_captures[tensor.Id] = (tensor, placeholder);
_captures.Add(tensor.Id, (tensor, placeholder));
if (Inputs == null) if (Inputs == null)
Inputs = new Tensors(placeholder); Inputs = new Tensors(placeholder);
else else


+ 1
- 19
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -87,7 +87,7 @@ namespace Tensorflow
private List<Tensor> _unfeedable_tensors = new List<Tensor>(); private List<Tensor> _unfeedable_tensors = new List<Tensor>();


public string _name_stack = ""; public string _name_stack = "";
private string _graph_key;
protected string _graph_key;
public string graph_key => _graph_key; public string graph_key => _graph_key;
public string _last_loss_reduction; public string _last_loss_reduction;
public bool _is_loss_scaled_by_optimizer { get; set; } public bool _is_loss_scaled_by_optimizer { get; set; }
@@ -552,23 +552,5 @@ namespace Tensorflow
{ {
return graph._handle; return graph._handle;
} }

public OrderedDictionary _captures => new OrderedDictionary();

public Tensor[] external_captures()
{
Tensor[] captures = new Tensor[_captures.Count];
ICollection inner = _captures.Keys; // c[0]
inner.CopyTo(captures, 0);
return captures;
}

public Tensor[] internal_captures()
{
Tensor[] captures = new Tensor[_captures.Count];
ICollection inner = _captures.Values; // c[1]
inner.CopyTo(captures, 0);
return captures;
}
} }
} }

+ 10
- 12
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public class Tensors : IEnumerable<Tensor> public class Tensors : IEnumerable<Tensor>
{ {
Tensor[] items;
List<Tensor> items = new List<Tensor>();


public TF_DataType dtype => items.First().dtype; public TF_DataType dtype => items.First().dtype;
public TensorShape shape => items.First().TensorShape; public TensorShape shape => items.First().TensorShape;
@@ -23,7 +23,7 @@ namespace Tensorflow
public Graph graph => items.First().graph; public Graph graph => items.First().graph;
public bool IsEagerTensor => items.First().IsEagerTensor; public bool IsEagerTensor => items.First().IsEagerTensor;
public bool IsList { get; set; } public bool IsList { get; set; }
public int Length => items.Length;
public int Length => items.Count();


public Tensor this[int index] public Tensor this[int index]
{ {
@@ -40,17 +40,12 @@ namespace Tensorflow


public Tensors(params Tensor[] tensors) public Tensors(params Tensor[] tensors)
{ {
items = tensors;
items.AddRange(tensors);
} }


public Tensors(NDArray nd) public Tensors(NDArray nd)
{ {
items = new[] { ops.convert_to_tensor(nd) };
}

public Tensors(int count)
{
items = new Tensor[count];
items.Add(ops.convert_to_tensor(nd));
} }


public IEnumerator<Tensor> GetEnumerator() public IEnumerator<Tensor> GetEnumerator()
@@ -59,6 +54,9 @@ namespace Tensorflow
yield return tensor; yield return tensor;
} }


public void Add(Tensor tensor)
=> items.Add(tensor);

IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
@@ -80,11 +78,11 @@ namespace Tensorflow
=> tensors.FirstOrDefault(); => tensors.FirstOrDefault();


public static implicit operator Tensor[](Tensors tensors) public static implicit operator Tensor[](Tensors tensors)
=> tensors.items;
=> tensors.items.ToArray();


public override string ToString() public override string ToString()
=> items.Length == 1
=> items.Count() == 1
? items.First().ToString() ? items.First().ToString()
: items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
} }
} }

+ 4
- 3
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -20,6 +20,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Graphs;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -392,14 +393,14 @@ would not be rank 1.", tensor.op.get_attr("axis")));
} }
else if (tensor.op.type == "Placeholder" && else if (tensor.op.type == "Placeholder" &&
tensor.op.graph.building_function && tensor.op.graph.building_function &&
hasattr(tensor.op.graph, "internal_captures"))
tensor.op.graph is FuncGraph func_graph)
{ {
int i = 0; int i = 0;
foreach (Tensor capture in tensor.op.graph.internal_captures())
foreach (Tensor capture in func_graph.internal_captures())
{ {
if (capture.GetType() == typeof(Tensor)) if (capture.GetType() == typeof(Tensor))
{ {
var external_capture = tensor.op.graph.external_captures()[i];
var external_capture = func_graph.external_captures()[i];
return constant_value_as_shape(external_capture); return constant_value_as_shape(external_capture);
} }




+ 2
- 2
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -337,10 +337,10 @@ namespace Tensorflow.Keras.Engine


var layer_inputs = node.MapArguments(tensor_dict); var layer_inputs = node.MapArguments(tensor_dict);


tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}");
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}");
var outputs = node.Layer.Apply(layer_inputs, is_training: training); var outputs = node.Layer.Apply(layer_inputs, is_training: training);
foreach (var output in outputs.Where(x => x != null)) foreach (var output in outputs.Where(x => x != null))
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
// Update tensor_dict for next input // Update tensor_dict for next input
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));


+ 1
- 1
src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine


tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs);


return op.output;
return op.outputs;
} }
} }
} }

Loading…
Cancel
Save