diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 5237ec44..02bf0e86 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -603,7 +603,17 @@ namespace Tensorflow } } - return gen_array_ops.shape(input, name: name, out_type: out_type); + return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + out_type = op.get_attr("out_type") + } + }.SetAttributes(new + { + out_type + })).First(); }); } @@ -703,23 +713,26 @@ namespace Tensorflow int new_axis_mask = 0, int shrink_axis_mask = 0, string name = null) - { - var op = gen_array_ops.strided_slice( - input: input_, - begin: begin, - end: end, - strides: strides, - begin_mask: begin_mask, - end_mask: end_mask, - ellipsis_mask: ellipsis_mask, - new_axis_mask: new_axis_mask, - shrink_axis_mask: shrink_axis_mask, - name: name); - - string parent_name = name; - - return op; - } + => tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Index = op.get_attr("Index"), + begin_mask = op.get_attr("begin_mask"), + end_mask = op.get_attr("end_mask"), + ellipsis_mask = op.get_attr("ellipsis_mask"), + new_axis_mask = op.get_attr("new_axis_mask"), + shrink_axis_mask = op.get_attr("shrink_axis_mask") + } + }.SetAttributes(new + { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + })); /// /// Returns the gradient of `StridedSlice`. diff --git a/src/TensorFlowNET.Core/Tensors/KerasTensor.cs b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs index 1034dcc8..3204b4ac 100644 --- a/src/TensorFlowNET.Core/Tensors/KerasTensor.cs +++ b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs @@ -5,12 +5,17 @@ /// public class KerasTensor { - private Tensor _tensor; - public void SetTensor(Tensors tensor) - => _tensor = tensor; + private Tensors _inferred_value; + public Tensors inferred_value + { + get => _inferred_value; + set => _inferred_value = value; + } - private TensorSpec _type_spec; private string _name; + private TensorSpec _type_spec; + public Shape shape => _type_spec.shape; + public TF_DataType dtype => _type_spec.dtype; public KerasTensor(TensorSpec type_spec, string name = null) { @@ -22,15 +27,23 @@ public class KerasTensor { var type_spec = tensor.ToTensorSpec(); var kt = new KerasTensor(type_spec, name: tensor.name); - kt.SetTensor(tensor); + kt.inferred_value = tensor; return kt; } + public override string ToString() + => _inferred_value.Length switch + { + > 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"")) + "]", + 1 => $"", + _ => _inferred_value.ToString(), + }; + public static implicit operator Tensors(KerasTensor kt) - => kt._tensor; + => kt._inferred_value; public static implicit operator Tensor(KerasTensor kt) - => kt._tensor; + => kt._inferred_value; public static implicit operator KerasTensor(Tensor tensor) => from_tensor(tensor); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs index 217712fe..51062cf3 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -42,7 +42,7 @@ namespace Tensorflow array_ops.stack(args.End), array_ops.stack(args.Strides)); - return gen_array_ops.strided_slice( + return array_ops.strided_slice( this, packed_begin, packed_end,