From 8a3a16b72fa26c5ac2ac246481027143418153c1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 30 Sep 2019 11:03:01 -0500 Subject: [PATCH] resove conflict. --- src/TensorFlowNET.Core/APIs/tf.ops.cs | 30 ++++++++ .../Operations/NnOps/rnn.cs | 2 +- .../Operations/TensorArray.cs | 11 ++- .../Operations/_GraphTensorArray.cs | 75 ++++++++++++++++--- .../Operations/gen_data_flow_ops.cs | 23 +++++- test/TensorFlowNET.UnitTest/OperationsTest.cs | 17 +++++ 6 files changed, 143 insertions(+), 15 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index cd7a3642..c91f283d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; namespace Tensorflow @@ -61,5 +62,34 @@ namespace Tensorflow /// public Tensor no_op(string name = null) => gen_control_flow_ops.no_op(name: name); + + /// + /// map on the list of tensors unpacked from `elems` on dimension 0. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A tensor or (possibly nested) sequence of tensors. + public Tensor map_fn(Func fn, + Tensor elems, + TF_DataType dtype = TF_DataType.DtInvalid, + int parallel_iterations = -1, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + string name = null) + => Operation.map_fn(fn, + elems, + dtype, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + infer_shape: infer_shape, + name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 1b68d1cd..8e7425e5 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -145,7 +145,7 @@ namespace Tensorflow.Operations { var ta = new TensorArray(dtype: dtype_, size: time_steps, - element_shape: element_shape, + element_shape: new[] { element_shape }, tensor_array_name: base_name + name); return ta; }; diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs index 858dac47..7251bf85 100644 --- a/src/TensorFlowNET.Core/Operations/TensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs @@ -33,9 +33,13 @@ namespace Tensorflow.Operations { _GraphTensorArray _implementation; - public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null, + public TF_DataType dtype => _implementation._dtype; + public Tensor handle => _implementation._handle; + public Tensor flow => _implementation._flow; + + public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape element_shape = null, + bool infer_shape = true, TensorShape[] element_shape = null, bool colocate_with_first_write_call = true, string name = null) { _implementation = new _GraphTensorArray(dtype, @@ -50,5 +54,8 @@ namespace Tensorflow.Operations colocate_with_first_write_call: colocate_with_first_write_call, name: name); } + + public TensorArray unstack(Tensor value, string name = null) + => _implementation.unstack(value, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 4c700a5f..bd919ad8 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using static Tensorflow.Binding; @@ -23,7 +24,7 @@ namespace Tensorflow.Operations { internal class _GraphTensorArray { - TF_DataType _dtype; + internal TF_DataType _dtype; /// /// Used to keep track of what tensors the TensorArray should be @@ -33,23 +34,27 @@ namespace Tensorflow.Operations bool _colocate_with_first_write_call; bool _infer_shape; + bool _dynamic_size; List _element_shape; - object _colocate_with; + List _colocate_with; + + internal Tensor _handle; + internal Tensor _flow; public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape element_shape = null, + bool infer_shape = true, TensorShape[] element_shape = null, bool colocate_with_first_write_call = true, string name = null) { clear_after_read = clear_after_read ?? true; dynamic_size = dynamic_size ?? false; - + _dynamic_size = dynamic_size.Value; _dtype = dtype; _colocate_with_first_write_call = colocate_with_first_write_call; if (colocate_with_first_write_call) - _colocate_with = new Tensor[0]; + _colocate_with = new List(); // Record the current static shape for the array elements. The element // shape is defined either by `element_shape` or the shape of the tensor @@ -66,11 +71,12 @@ namespace Tensorflow.Operations _element_shape = new List { }; } - tf_with(ops.name_scope(name, "", new { handle, size, flow }), scope => + tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => { if(handle != null) { - + _handle = handle; + _flow = flow; } else { @@ -89,14 +95,65 @@ namespace Tensorflow.Operations if (colocate_with_first_write_call) { ops.colocate_with(ignore_existing: true); - create(); + (_handle, _flow) = create(); } else { - + (_handle, _flow) = create(); } } }); } + + public TensorArray unstack(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate + { + var num_elements = array_ops.shape(value)[0]; + return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); + }); + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + if (_infer_shape) + { + var shape = new TensorShape(value.TensorShape.dims.Skip(1).ToArray()); + _merge_element_shape(shape); + } + + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( + handle: _handle, + indices: indices, + value: value, + flow_in: _flow, + name: name); + + var ta = new TensorArray(_dtype, + infer_shape:_infer_shape, + element_shape: _element_shape.ToArray(), + dynamic_size: _dynamic_size, + handle: _handle, + flow: flow_out, + colocate_with_first_write_call: _colocate_with_first_write_call); + + + return ta; + }); + } + + public void _merge_element_shape(TensorShape shape) + { + _element_shape.Add(shape); + } + + public void _maybe_colocate_with(Tensor value) + { + _colocate_with.Add(value); + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 4e5bd1f6..fa194934 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -27,10 +27,13 @@ namespace Tensorflow return _op.output; } - public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, - int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, + public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid, + TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) { + if (tensor_array_name == null) + tensor_array_name = string.Empty; + var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new { size, @@ -42,7 +45,21 @@ namespace Tensorflow tensor_array_name }); - return (null, null); + return (_op.outputs[0], _op.outputs[1]); + } + + public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value, + Tensor flow_in, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArrayScatterV3", name, new + { + handle, + indices, + value, + flow_in + }); + + return _op.output; } public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 226a4839..2086bb36 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1494,5 +1494,22 @@ namespace TensorFlowNET.UnitTest } #endregion } + + [TestMethod] + public void map_fn() + { + var a = tf.constant(new[] { 1, 2, 3, 4 }); + var b = tf.constant(new[] { 17, 12, 11, 10 }); + var ab = tf.stack(new[] { a, b }, 1); + + Func map_operation = (value_ab) => + { + var value_a = value_ab[0]; + var value_b = value_ab[1]; + return value_a + value_b; + }; + + var map_result = tf.map_fn(map_operation, ab); + } } }