| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -309,5 +310,27 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor stop_gradient(Tensor x, string name = null) | public Tensor stop_gradient(Tensor x, string name = null) | ||||
| => gen_array_ops.stop_gradient(x, name: name); | => gen_array_ops.stop_gradient(x, name: name); | ||||
| public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false, | |||||
| bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, | |||||
| bool infer_shape = true) | |||||
| => tf.executing_eagerly() ? | |||||
| new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, | |||||
| clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||||
| colocate_with_first_write_call: colocate_with_first_write_call) : | |||||
| new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, | |||||
| clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||||
| colocate_with_first_write_call: colocate_with_first_write_call); | |||||
| public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, | |||||
| bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, | |||||
| bool infer_shape = true) | |||||
| => tf.executing_eagerly() ? | |||||
| new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size, | |||||
| clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||||
| colocate_with_first_write_call: colocate_with_first_write_call) : | |||||
| new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size, | |||||
| clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||||
| colocate_with_first_write_call: colocate_with_first_write_call); | |||||
| } | } | ||||
| } | } | ||||
| @@ -294,10 +294,9 @@ namespace Tensorflow.Operations | |||||
| Func<string, Shape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) => | Func<string, Shape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) => | ||||
| { | { | ||||
| var ta = new TensorArray(dtype: dtype_, | |||||
| var ta = tf.TensorArray(dtype: dtype_, | |||||
| size: time_steps, | size: time_steps, | ||||
| element_shape: element_shape, | |||||
| tensor_array_name: base_name + name); | |||||
| element_shape: element_shape); | |||||
| return ta; | return ta; | ||||
| }; | }; | ||||
| @@ -0,0 +1,184 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2022 Haiping Chen. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Framework; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class _EagerTensorArray : TensorArray | |||||
| { | |||||
| TF_DataType _dtype; | |||||
| public override TF_DataType dtype => _dtype; | |||||
| /// <summary> | |||||
| /// Used to keep track of what tensors the TensorArray should be | |||||
| /// colocated with. We choose to colocate the TensorArray with the | |||||
| /// first tensor written to it. | |||||
| /// </summary> | |||||
| bool _colocate_with_first_write_call; | |||||
| public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||||
| bool _infer_shape; | |||||
| public override bool infer_shape => _infer_shape; | |||||
| public bool _dynamic_size; | |||||
| public Shape _element_shape; | |||||
| public List<Tensor> _colocate_with; | |||||
| Tensor _handle; | |||||
| public override Tensor handle => _handle; | |||||
| Tensor _flow; | |||||
| public override Tensor flow => _flow; | |||||
| bool _clear_after_read; | |||||
| List<Tensor> _tensor_array; | |||||
| public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, | |||||
| bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||||
| bool infer_shape = true, Shape? element_shape = null, | |||||
| bool colocate_with_first_write_call = true, string name = null) | |||||
| { | |||||
| _flow = constant_op.constant(0); | |||||
| _infer_shape = infer_shape; | |||||
| _element_shape = element_shape ?? Shape.Null; | |||||
| _colocate_with_first_write_call = colocate_with_first_write_call; | |||||
| _dtype = dtype.as_base_dtype(); | |||||
| _dynamic_size = dynamic_size; | |||||
| _clear_after_read = clear_after_read; | |||||
| _tensor_array = new List<Tensor>(); | |||||
| } | |||||
| public override TensorArray unstack(Tensor value, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||||
| { | |||||
| var num_elements = array_ops.shape(value)[0]; | |||||
| return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); | |||||
| }); | |||||
| } | |||||
| public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||||
| { | |||||
| /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||||
| { | |||||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
| if (_infer_shape) | |||||
| { | |||||
| var shape = new Shape(value.shape.dims.Skip(1).ToArray()); | |||||
| _merge_element_shape(shape); | |||||
| } | |||||
| _maybe_colocate_with(value); | |||||
| var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( | |||||
| handle: _handle, | |||||
| indices: indices, | |||||
| value: value, | |||||
| flow_in: _flow, | |||||
| name: name); | |||||
| var ta = new _EagerTensorArray(_dtype, | |||||
| infer_shape: _infer_shape, | |||||
| element_shape: _element_shape[0], | |||||
| dynamic_size: _dynamic_size, | |||||
| handle: _handle, | |||||
| flow: flow_out, | |||||
| colocate_with_first_write_call: _colocate_with_first_write_call); | |||||
| return ta; | |||||
| });*/ | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public void _merge_element_shape(Shape shape) | |||||
| { | |||||
| _element_shape.concatenate(shape); | |||||
| } | |||||
| public void _maybe_colocate_with(Tensor value) | |||||
| { | |||||
| _colocate_with.Add(value); | |||||
| } | |||||
| public override Tensor read<T>(T index, string name = null) | |||||
| { | |||||
| int index_int = -1; | |||||
| if (index is int int_index) | |||||
| index_int = int_index; | |||||
| else if (index is Tensor tensor_index) | |||||
| index_int = tensor_index.numpy(); | |||||
| else | |||||
| throw new ValueError(""); | |||||
| if (_clear_after_read) | |||||
| { | |||||
| _tensor_array[index_int] = null; | |||||
| } | |||||
| return _tensor_array[index_int]; | |||||
| } | |||||
| public override TensorArray write(Tensor index, Tensor value, string name = null) | |||||
| { | |||||
| if (_infer_shape) | |||||
| _element_shape = _element_shape.merge_with(value.shape); | |||||
| _tensor_array.add(value); | |||||
| return this; | |||||
| } | |||||
| public override TensorArray write<T>(int index, T value, string name = null) | |||||
| { | |||||
| var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
| var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||||
| return write(index_tensor, value_tensor, name: name); | |||||
| } | |||||
| private Tensor size(string name = null) | |||||
| { | |||||
| return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); | |||||
| } | |||||
| public override Tensor stack(string name = null) | |||||
| { | |||||
| ops.colocate_with(_handle); | |||||
| return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||||
| { | |||||
| return gather(math_ops.range(0, size()), name: name); | |||||
| }); | |||||
| } | |||||
| public override Tensor gather(Tensor indices, string name = null) | |||||
| { | |||||
| var element_shape = Shape.Null; | |||||
| var value = gen_data_flow_ops.tensor_array_gather_v3( | |||||
| handle: _handle, | |||||
| indices: indices, | |||||
| flow_in: _flow, | |||||
| dtype: _dtype, | |||||
| name: name, | |||||
| element_shape: element_shape); | |||||
| //if (element_shape != null) | |||||
| //value.set_shape(-1, element_shape.dims); | |||||
| return value; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -21,7 +21,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| public class _GraphTensorArray | |||||
| public class _GraphTensorArray : TensorArray | |||||
| { | { | ||||
| internal TF_DataType _dtype; | internal TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
| @@ -47,7 +47,7 @@ namespace Tensorflow.Operations | |||||
| public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | ||||
| bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | ||||
| bool infer_shape = true, Shape element_shape = null, | |||||
| bool infer_shape = true, Shape? element_shape = null, | |||||
| bool colocate_with_first_write_call = true, string name = null) | bool colocate_with_first_write_call = true, string name = null) | ||||
| { | { | ||||
| clear_after_read = clear_after_read ?? true; | clear_after_read = clear_after_read ?? true; | ||||
| @@ -108,7 +108,7 @@ namespace Tensorflow.Operations | |||||
| }); | }); | ||||
| } | } | ||||
| public TensorArray unstack(Tensor value, string name = null) | |||||
| public override TensorArray unstack(Tensor value, string name = null) | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | ||||
| { | { | ||||
| @@ -119,7 +119,7 @@ namespace Tensorflow.Operations | |||||
| public TensorArray scatter(Tensor indices, Tensor value, string name = null) | public TensorArray scatter(Tensor indices, Tensor value, string name = null) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||||
| /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||||
| { | { | ||||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | ||||
| if (_infer_shape) | if (_infer_shape) | ||||
| @@ -136,7 +136,7 @@ namespace Tensorflow.Operations | |||||
| flow_in: _flow, | flow_in: _flow, | ||||
| name: name); | name: name); | ||||
| var ta = new TensorArray(_dtype, | |||||
| var ta = new _GraphTensorArray(_dtype, | |||||
| infer_shape: _infer_shape, | infer_shape: _infer_shape, | ||||
| element_shape: _element_shape[0], | element_shape: _element_shape[0], | ||||
| dynamic_size: _dynamic_size, | dynamic_size: _dynamic_size, | ||||
| @@ -144,9 +144,9 @@ namespace Tensorflow.Operations | |||||
| flow: flow_out, | flow: flow_out, | ||||
| colocate_with_first_write_call: _colocate_with_first_write_call); | colocate_with_first_write_call: _colocate_with_first_write_call); | ||||
| return ta; | return ta; | ||||
| }); | |||||
| });*/ | |||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| public void _merge_element_shape(Shape shape) | public void _merge_element_shape(Shape shape) | ||||
| @@ -159,11 +159,11 @@ namespace Tensorflow.Operations | |||||
| _colocate_with.Add(value); | _colocate_with.Add(value); | ||||
| } | } | ||||
| public Tensor read(Tensor index, string name = null) | |||||
| public override Tensor read<T>(T index, string name = null) | |||||
| { | { | ||||
| var value = gen_data_flow_ops.tensor_array_read_v3( | var value = gen_data_flow_ops.tensor_array_read_v3( | ||||
| handle: _handle, | handle: _handle, | ||||
| index: index, | |||||
| index: constant_op.constant(index), | |||||
| flow_in: _flow, | flow_in: _flow, | ||||
| dtype: _dtype, | dtype: _dtype, | ||||
| name: name); | name: name); | ||||
| @@ -174,11 +174,10 @@ namespace Tensorflow.Operations | |||||
| return value; | return value; | ||||
| } | } | ||||
| public TensorArray write(Tensor index, Tensor value, string name = null) | |||||
| public override TensorArray write(Tensor index, Tensor value, string name = null) | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate | return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate | ||||
| { | { | ||||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
| _maybe_colocate_with(value); | _maybe_colocate_with(value); | ||||
| var flow_out = gen_data_flow_ops.tensor_array_write_v3( | var flow_out = gen_data_flow_ops.tensor_array_write_v3( | ||||
| handle: _handle, | handle: _handle, | ||||
| @@ -191,12 +190,19 @@ namespace Tensorflow.Operations | |||||
| }); | }); | ||||
| } | } | ||||
| public override TensorArray write<T>(int index, T value, string name = null) | |||||
| { | |||||
| var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
| var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||||
| return write(index_tensor, value_tensor); | |||||
| } | |||||
| private Tensor size(string name = null) | private Tensor size(string name = null) | ||||
| { | { | ||||
| return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); | return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); | ||||
| } | } | ||||
| public Tensor stack(string name = null) | |||||
| public override Tensor stack(string name = null) | |||||
| { | { | ||||
| ops.colocate_with(_handle); | ops.colocate_with(_handle); | ||||
| return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | ||||
| @@ -205,7 +211,7 @@ namespace Tensorflow.Operations | |||||
| }); | }); | ||||
| } | } | ||||
| public Tensor gather(Tensor indices, string name = null) | |||||
| public override Tensor gather(Tensor indices, string name = null) | |||||
| { | { | ||||
| var element_shape = Shape.Null; | var element_shape = Shape.Null; | ||||
| @@ -87,9 +87,9 @@ namespace Tensorflow | |||||
| // n = array_ops.shape(elems_flat[0])[0]; | // n = array_ops.shape(elems_flat[0])[0]; | ||||
| //} | //} | ||||
| var elems_ta = elems_flat.Select(elem => new TensorArray( | |||||
| var elems_ta = elems_flat.Select(elem => tf.TensorArray( | |||||
| elem.dtype, | elem.dtype, | ||||
| size: tf.constant(n), | |||||
| size: n, | |||||
| dynamic_size: false, | dynamic_size: false, | ||||
| element_shape: elem.shape.dims.Skip(1).ToArray(), | element_shape: elem.shape.dims.Skip(1).ToArray(), | ||||
| infer_shape: true)).ToList(); | infer_shape: true)).ToList(); | ||||
| @@ -113,9 +113,9 @@ namespace Tensorflow | |||||
| i = 0; | i = 0; | ||||
| } | } | ||||
| var accs_ta = a_flat.Select(init => new TensorArray( | |||||
| var accs_ta = a_flat.Select(init => tf.TensorArray( | |||||
| dtype: init.dtype, | dtype: init.dtype, | ||||
| size: tf.constant(n), | |||||
| size: n, | |||||
| element_shape: infer_shape ? init.shape : null, | element_shape: infer_shape ? init.shape : null, | ||||
| dynamic_size: false, | dynamic_size: false, | ||||
| infer_shape: infer_shape)).ToArray(); | infer_shape: infer_shape)).ToArray(); | ||||
| @@ -124,7 +124,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| for (int index = 0; index < accs_ta.Length; index++) | for (int index = 0; index < accs_ta.Length; index++) | ||||
| { | { | ||||
| accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]); | |||||
| accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]); | |||||
| } | } | ||||
| } | } | ||||
| @@ -78,8 +78,8 @@ namespace Tensorflow | |||||
| var n = static_shape[0]; | var n = static_shape[0]; | ||||
| // TensorArrays are always flat | // TensorArrays are always flat | ||||
| var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype, | |||||
| size: ops.convert_to_tensor(n), | |||||
| var elems_ta = elems_flat.Select(elem => tf.TensorArray(dtype: elem.dtype, | |||||
| size: Convert.ToInt32(n), | |||||
| dynamic_size: false, | dynamic_size: false, | ||||
| infer_shape: true)).ToArray(); | infer_shape: true)).ToArray(); | ||||
| @@ -92,8 +92,8 @@ namespace Tensorflow | |||||
| var i = constant_op.constant(0); | var i = constant_op.constant(0); | ||||
| var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt, | |||||
| size: ops.convert_to_tensor(n), | |||||
| var accs_ta = dtype_flat.Select(dt => tf.TensorArray(dtype: dt, | |||||
| size: Convert.ToInt32(n), | |||||
| dynamic_size: false, | dynamic_size: false, | ||||
| infer_shape: infer_shape)).ToArray(); | infer_shape: infer_shape)).ToArray(); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -12,37 +13,21 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | ||||
| { | { | ||||
| var impl = old_ta._implementation; | |||||
| var new_ta = tf.TensorArray( | |||||
| dtype: old_ta.dtype, | |||||
| infer_shape: old_ta.infer_shape, | |||||
| colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | |||||
| var new_ta = new TensorArray( | |||||
| dtype: impl.dtype, | |||||
| handle: impl.handle, | |||||
| flow: flow, | |||||
| infer_shape: impl.infer_shape, | |||||
| colocate_with_first_write_call: impl.colocate_with_first_write_call); | |||||
| var new_impl = new_ta._implementation; | |||||
| new_impl._dynamic_size = impl._dynamic_size; | |||||
| new_impl._colocate_with = impl._colocate_with; | |||||
| new_impl._element_shape = impl._element_shape; | |||||
| return new_ta; | return new_ta; | ||||
| } | } | ||||
| public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) | public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) | ||||
| { | { | ||||
| var impl = old_ta; | |||||
| var new_ta = new TensorArray( | |||||
| dtype: impl.dtype, | |||||
| handle: impl.handle, | |||||
| flow: flow, | |||||
| infer_shape: impl.infer_shape, | |||||
| colocate_with_first_write_call: impl.colocate_with_first_write_call); | |||||
| var new_ta = tf.TensorArray( | |||||
| dtype: old_ta.dtype, | |||||
| infer_shape: old_ta.infer_shape, | |||||
| colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | |||||
| var new_impl = new_ta._implementation; | |||||
| new_impl._dynamic_size = impl._dynamic_size; | |||||
| new_impl._colocate_with = impl._colocate_with; | |||||
| new_impl._element_shape = impl._element_shape; | |||||
| return new_ta; | return new_ta; | ||||
| } | } | ||||
| } | } | ||||
| @@ -27,42 +27,22 @@ namespace Tensorflow | |||||
| /// `while_loop` and `map_fn`. It supports gradient back-propagation via special | /// `while_loop` and `map_fn`. It supports gradient back-propagation via special | ||||
| /// "flow" control flow dependencies. | /// "flow" control flow dependencies. | ||||
| /// </summary> | /// </summary> | ||||
| public class TensorArray : ITensorOrTensorArray | |||||
| public abstract class TensorArray : ITensorOrTensorArray | |||||
| { | { | ||||
| internal _GraphTensorArray _implementation; | |||||
| public virtual TF_DataType dtype { get; } | |||||
| public virtual Tensor handle { get; } | |||||
| public virtual Tensor flow { get; } | |||||
| public virtual bool infer_shape { get; } | |||||
| public virtual bool colocate_with_first_write_call { get; } | |||||
| public TF_DataType dtype => _implementation._dtype; | |||||
| public Tensor handle => _implementation._handle; | |||||
| public Tensor flow => _implementation._flow; | |||||
| public abstract TensorArray unstack(Tensor value, string name = null); | |||||
| public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||||
| string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||||
| bool infer_shape = true, Shape element_shape = null, | |||||
| bool colocate_with_first_write_call = true, string name = null) | |||||
| { | |||||
| _implementation = new _GraphTensorArray(dtype, | |||||
| size: size, | |||||
| dynamic_size: dynamic_size, | |||||
| clear_after_read: clear_after_read, | |||||
| tensor_array_name: tensor_array_name, | |||||
| handle: handle, | |||||
| flow: flow, | |||||
| infer_shape: infer_shape, | |||||
| element_shape: element_shape, | |||||
| colocate_with_first_write_call: colocate_with_first_write_call, | |||||
| name: name); | |||||
| } | |||||
| public abstract Tensor read<T>(T index, string name = null); | |||||
| public TensorArray unstack(Tensor value, string name = null) | |||||
| => _implementation.unstack(value, name: name); | |||||
| public abstract TensorArray write<T>(int index, T value, string name = null); | |||||
| public abstract TensorArray write(Tensor index, Tensor value, string name = null); | |||||
| public Tensor read(Tensor index, string name = null) | |||||
| => _implementation.read(index, name: name); | |||||
| public TensorArray write(Tensor index, Tensor value, string name = null) | |||||
| => _implementation.write(index, value, name: name); | |||||
| public Tensor stack(string name = null) | |||||
| => _implementation.stack(name: name); | |||||
| public abstract Tensor stack(string name = null); | |||||
| public abstract Tensor gather(Tensor indices, string name = null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -77,5 +77,20 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| var r3 = tf.gather(p2, i2, axis: 1); | var r3 = tf.gather(p2, i2, axis: 1); | ||||
| Assert.AreEqual(new Shape(4,1,2), r3.shape); | Assert.AreEqual(new Shape(4,1,2), r3.shape); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/TensorArray | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void TensorArray() | |||||
| { | |||||
| var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false); | |||||
| ta.write(0, 10); | |||||
| ta.write(1, 20); | |||||
| ta.write(2, 30); | |||||
| Assert.AreEqual(ta.read(0).numpy(), 10f); | |||||
| Assert.AreEqual(ta.read(1).numpy(), 20f); | |||||
| Assert.AreEqual(ta.read(2).numpy(), 30f); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||