| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -61,5 +62,34 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor no_op(string name = null) | public Tensor no_op(string name = null) | ||||
| => gen_control_flow_ops.no_op(name: name); | => gen_control_flow_ops.no_op(name: name); | ||||
| /// <summary> | |||||
| /// map on the list of tensors unpacked from `elems` on dimension 0. | |||||
| /// </summary> | |||||
| /// <param name="fn"></param> | |||||
| /// <param name="elems"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="parallel_iterations"></param> | |||||
| /// <param name="back_prop"></param> | |||||
| /// <param name="swap_memory"></param> | |||||
| /// <param name="infer_shape"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns>A tensor or (possibly nested) sequence of tensors.</returns> | |||||
| public Tensor map_fn(Func<Tensor, Tensor> fn, | |||||
| Tensor elems, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int parallel_iterations = -1, | |||||
| bool back_prop = true, | |||||
| bool swap_memory = false, | |||||
| bool infer_shape = true, | |||||
| string name = null) | |||||
| => Operation.map_fn(fn, | |||||
| elems, | |||||
| dtype, | |||||
| parallel_iterations: parallel_iterations, | |||||
| back_prop: back_prop, | |||||
| swap_memory: swap_memory, | |||||
| infer_shape: infer_shape, | |||||
| name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -145,7 +145,7 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| var ta = new TensorArray(dtype: dtype_, | var ta = new TensorArray(dtype: dtype_, | ||||
| size: time_steps, | size: time_steps, | ||||
| element_shape: element_shape, | |||||
| element_shape: new[] { element_shape }, | |||||
| tensor_array_name: base_name + name); | tensor_array_name: base_name + name); | ||||
| return ta; | return ta; | ||||
| }; | }; | ||||
| @@ -33,9 +33,13 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| _GraphTensorArray _implementation; | _GraphTensorArray _implementation; | ||||
| public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null, | |||||
| public TF_DataType dtype => _implementation._dtype; | |||||
| public Tensor handle => _implementation._handle; | |||||
| public Tensor flow => _implementation._flow; | |||||
| public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||||
| string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | ||||
| bool infer_shape = true, TensorShape element_shape = null, | |||||
| bool infer_shape = true, TensorShape[] element_shape = null, | |||||
| bool colocate_with_first_write_call = true, string name = null) | bool colocate_with_first_write_call = true, string name = null) | ||||
| { | { | ||||
| _implementation = new _GraphTensorArray(dtype, | _implementation = new _GraphTensorArray(dtype, | ||||
| @@ -50,5 +54,8 @@ namespace Tensorflow.Operations | |||||
| colocate_with_first_write_call: colocate_with_first_write_call, | colocate_with_first_write_call: colocate_with_first_write_call, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| public TensorArray unstack(Tensor value, string name = null) | |||||
| => _implementation.unstack(value, name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -16,6 +16,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -23,7 +24,7 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| internal class _GraphTensorArray | internal class _GraphTensorArray | ||||
| { | { | ||||
| TF_DataType _dtype; | |||||
| internal TF_DataType _dtype; | |||||
| /// <summary> | /// <summary> | ||||
| /// Used to keep track of what tensors the TensorArray should be | /// Used to keep track of what tensors the TensorArray should be | ||||
| @@ -33,23 +34,27 @@ namespace Tensorflow.Operations | |||||
| bool _colocate_with_first_write_call; | bool _colocate_with_first_write_call; | ||||
| bool _infer_shape; | bool _infer_shape; | ||||
| bool _dynamic_size; | |||||
| List<TensorShape> _element_shape; | List<TensorShape> _element_shape; | ||||
| object _colocate_with; | |||||
| List<Tensor> _colocate_with; | |||||
| internal Tensor _handle; | |||||
| internal Tensor _flow; | |||||
| public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | ||||
| bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | ||||
| bool infer_shape = true, TensorShape element_shape = null, | |||||
| bool infer_shape = true, TensorShape[] element_shape = null, | |||||
| bool colocate_with_first_write_call = true, string name = null) | bool colocate_with_first_write_call = true, string name = null) | ||||
| { | { | ||||
| clear_after_read = clear_after_read ?? true; | clear_after_read = clear_after_read ?? true; | ||||
| dynamic_size = dynamic_size ?? false; | dynamic_size = dynamic_size ?? false; | ||||
| _dynamic_size = dynamic_size.Value; | |||||
| _dtype = dtype; | _dtype = dtype; | ||||
| _colocate_with_first_write_call = colocate_with_first_write_call; | _colocate_with_first_write_call = colocate_with_first_write_call; | ||||
| if (colocate_with_first_write_call) | if (colocate_with_first_write_call) | ||||
| _colocate_with = new Tensor[0]; | |||||
| _colocate_with = new List<Tensor>(); | |||||
| // Record the current static shape for the array elements. The element | // Record the current static shape for the array elements. The element | ||||
| // shape is defined either by `element_shape` or the shape of the tensor | // shape is defined either by `element_shape` or the shape of the tensor | ||||
| @@ -66,11 +71,12 @@ namespace Tensorflow.Operations | |||||
| _element_shape = new List<TensorShape> { }; | _element_shape = new List<TensorShape> { }; | ||||
| } | } | ||||
| tf_with(ops.name_scope(name, "", new { handle, size, flow }), scope => | |||||
| tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | |||||
| { | { | ||||
| if(handle != null) | if(handle != null) | ||||
| { | { | ||||
| _handle = handle; | |||||
| _flow = flow; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -89,14 +95,65 @@ namespace Tensorflow.Operations | |||||
| if (colocate_with_first_write_call) | if (colocate_with_first_write_call) | ||||
| { | { | ||||
| ops.colocate_with(ignore_existing: true); | ops.colocate_with(ignore_existing: true); | ||||
| create(); | |||||
| (_handle, _flow) = create(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| (_handle, _flow) = create(); | |||||
| } | } | ||||
| } | } | ||||
| }); | }); | ||||
| } | } | ||||
| public TensorArray unstack(Tensor value, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||||
| { | |||||
| var num_elements = array_ops.shape(value)[0]; | |||||
| return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); | |||||
| }); | |||||
| } | |||||
| public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||||
| { | |||||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
| if (_infer_shape) | |||||
| { | |||||
| var shape = new TensorShape(value.TensorShape.dims.Skip(1).ToArray()); | |||||
| _merge_element_shape(shape); | |||||
| } | |||||
| _maybe_colocate_with(value); | |||||
| var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( | |||||
| handle: _handle, | |||||
| indices: indices, | |||||
| value: value, | |||||
| flow_in: _flow, | |||||
| name: name); | |||||
| var ta = new TensorArray(_dtype, | |||||
| infer_shape:_infer_shape, | |||||
| element_shape: _element_shape.ToArray(), | |||||
| dynamic_size: _dynamic_size, | |||||
| handle: _handle, | |||||
| flow: flow_out, | |||||
| colocate_with_first_write_call: _colocate_with_first_write_call); | |||||
| return ta; | |||||
| }); | |||||
| } | |||||
| public void _merge_element_shape(TensorShape shape) | |||||
| { | |||||
| _element_shape.Add(shape); | |||||
| } | |||||
| public void _maybe_colocate_with(Tensor value) | |||||
| { | |||||
| _colocate_with.Add(value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -27,10 +27,13 @@ namespace Tensorflow | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||||
| public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||||
| bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | ||||
| { | { | ||||
| if (tensor_array_name == null) | |||||
| tensor_array_name = string.Empty; | |||||
| var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | ||||
| { | { | ||||
| size, | size, | ||||
| @@ -42,7 +45,21 @@ namespace Tensorflow | |||||
| tensor_array_name | tensor_array_name | ||||
| }); | }); | ||||
| return (null, null); | |||||
| return (_op.outputs[0], _op.outputs[1]); | |||||
| } | |||||
| public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value, | |||||
| Tensor flow_in, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("TensorArrayScatterV3", name, new | |||||
| { | |||||
| handle, | |||||
| indices, | |||||
| value, | |||||
| flow_in | |||||
| }); | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | ||||
| @@ -1494,5 +1494,22 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| #endregion | #endregion | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void map_fn() | |||||
| { | |||||
| var a = tf.constant(new[] { 1, 2, 3, 4 }); | |||||
| var b = tf.constant(new[] { 17, 12, 11, 10 }); | |||||
| var ab = tf.stack(new[] { a, b }, 1); | |||||
| Func<Tensor, Tensor> map_operation = (value_ab) => | |||||
| { | |||||
| var value_a = value_ab[0]; | |||||
| var value_b = value_ab[1]; | |||||
| return value_a + value_b; | |||||
| }; | |||||
| var map_result = tf.map_fn(map_operation, ab); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||