Browse Source

GraphTensorArray write, size, stack.

tags/v0.12
Oceania2018 6 years ago
parent
commit
1ef2ec1ca1
3 changed files with 76 additions and 3 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  2. +53
    -1
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  3. +19
    -0
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs

+ 4
- 2
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -66,12 +66,14 @@ namespace Tensorflow
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);
var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable);
return (inputs, inputs);
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
var output = _activation(gate_inputs, null);
return new[] { output, output };
}
}
}

+ 53
- 1
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -22,7 +22,7 @@ using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
internal class _GraphTensorArray
public class _GraphTensorArray
{
internal TF_DataType _dtype;
public TF_DataType dtype => _dtype;
@@ -174,5 +174,57 @@ namespace Tensorflow.Operations

return value;
}

public TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
_maybe_colocate_with(value);
var flow_out = gen_data_flow_ops.tensor_array_write_v3(
handle: _handle,
index: index,
value: value,
flow_in: _flow,
name: name);

return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

private Tensor size(string name = null)
{
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
}

public Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
{
return gather(math_ops.range(0, size()), name: name);
});
}

public Tensor gather(Tensor indices, string name = null)
{
var element_shape = new TensorShape();

if (_element_shape.Count > 0)
element_shape = _element_shape[0];

var value = gen_data_flow_ops.tensor_array_gather_v3(
handle: _handle,
indices: indices,
flow_in: _flow,
dtype: _dtype,
name: name,
element_shape: element_shape);

//if (element_shape != null)
//value.set_shape(-1, element_shape.dims);

return value;
}
}
}

+ 19
- 0
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -29,5 +30,23 @@ namespace Tensorflow
new_impl._element_shape = impl._element_shape;
return new_ta;
}

public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow)
{
var impl = old_ta;

var new_ta = new TensorArray(
dtype: impl.dtype,
handle: impl.handle,
flow: flow,
infer_shape: impl.infer_shape,
colocate_with_first_write_call: impl.colocate_with_first_write_call);

var new_impl = new_ta._implementation;
new_impl._dynamic_size = impl._dynamic_size;
new_impl._colocate_with = impl._colocate_with;
new_impl._element_shape = impl._element_shape;
return new_ta;
}
}
}

Loading…
Cancel
Save