|
|
|
@@ -22,7 +22,7 @@ using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow.Operations |
|
|
|
{ |
|
|
|
internal class _GraphTensorArray |
|
|
|
public class _GraphTensorArray |
|
|
|
{ |
|
|
|
internal TF_DataType _dtype; |
|
|
|
public TF_DataType dtype => _dtype; |
|
|
|
@@ -174,5 +174,57 @@ namespace Tensorflow.Operations |
|
|
|
|
|
|
|
return value; |
|
|
|
} |
|
|
|
|
|
|
|
public TensorArray write(Tensor index, Tensor value, string name = null) |
|
|
|
{ |
|
|
|
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate |
|
|
|
{ |
|
|
|
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); |
|
|
|
_maybe_colocate_with(value); |
|
|
|
var flow_out = gen_data_flow_ops.tensor_array_write_v3( |
|
|
|
handle: _handle, |
|
|
|
index: index, |
|
|
|
value: value, |
|
|
|
flow_in: _flow, |
|
|
|
name: name); |
|
|
|
|
|
|
|
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
private Tensor size(string name = null) |
|
|
|
{ |
|
|
|
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor stack(string name = null) |
|
|
|
{ |
|
|
|
ops.colocate_with(_handle); |
|
|
|
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate |
|
|
|
{ |
|
|
|
return gather(math_ops.range(0, size()), name: name); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor gather(Tensor indices, string name = null) |
|
|
|
{ |
|
|
|
var element_shape = new TensorShape(); |
|
|
|
|
|
|
|
if (_element_shape.Count > 0) |
|
|
|
element_shape = _element_shape[0]; |
|
|
|
|
|
|
|
var value = gen_data_flow_ops.tensor_array_gather_v3( |
|
|
|
handle: _handle, |
|
|
|
indices: indices, |
|
|
|
flow_in: _flow, |
|
|
|
dtype: _dtype, |
|
|
|
name: name, |
|
|
|
element_shape: element_shape); |
|
|
|
|
|
|
|
//if (element_shape != null) |
|
|
|
//value.set_shape(-1, element_shape.dims); |
|
|
|
|
|
|
|
return value; |
|
|
|
} |
|
|
|
} |
|
|
|
} |