| @@ -603,7 +603,17 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| return gen_array_ops.shape(input, name: name, out_type: out_type); | |||
| return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) | |||
| { | |||
| GetGradientAttrs = (op) => new | |||
| { | |||
| T = op.get_attr<TF_DataType>("T"), | |||
| out_type = op.get_attr<TF_DataType>("out_type") | |||
| } | |||
| }.SetAttributes(new | |||
| { | |||
| out_type | |||
| })).First(); | |||
| }); | |||
| } | |||
| @@ -703,23 +713,26 @@ namespace Tensorflow | |||
| int new_axis_mask = 0, | |||
| int shrink_axis_mask = 0, | |||
| string name = null) | |||
| { | |||
| var op = gen_array_ops.strided_slice( | |||
| input: input_, | |||
| begin: begin, | |||
| end: end, | |||
| strides: strides, | |||
| begin_mask: begin_mask, | |||
| end_mask: end_mask, | |||
| ellipsis_mask: ellipsis_mask, | |||
| new_axis_mask: new_axis_mask, | |||
| shrink_axis_mask: shrink_axis_mask, | |||
| name: name); | |||
| string parent_name = name; | |||
| return op; | |||
| } | |||
| => tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides) | |||
| { | |||
| GetGradientAttrs = (op) => new | |||
| { | |||
| T = op.get_attr<TF_DataType>("T"), | |||
| Index = op.get_attr<TF_DataType>("Index"), | |||
| begin_mask = op.get_attr<long>("begin_mask"), | |||
| end_mask = op.get_attr<long>("end_mask"), | |||
| ellipsis_mask = op.get_attr<long>("ellipsis_mask"), | |||
| new_axis_mask = op.get_attr<long>("new_axis_mask"), | |||
| shrink_axis_mask = op.get_attr<long>("shrink_axis_mask") | |||
| } | |||
| }.SetAttributes(new | |||
| { | |||
| begin_mask, | |||
| end_mask, | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| })); | |||
| /// <summary> | |||
| /// Returns the gradient of `StridedSlice`. | |||
| @@ -5,12 +5,17 @@ | |||
| /// </summary> | |||
| public class KerasTensor | |||
| { | |||
| private Tensor _tensor; | |||
| public void SetTensor(Tensors tensor) | |||
| => _tensor = tensor; | |||
| private Tensors _inferred_value; | |||
| public Tensors inferred_value | |||
| { | |||
| get => _inferred_value; | |||
| set => _inferred_value = value; | |||
| } | |||
| private TensorSpec _type_spec; | |||
| private string _name; | |||
| private TensorSpec _type_spec; | |||
| public Shape shape => _type_spec.shape; | |||
| public TF_DataType dtype => _type_spec.dtype; | |||
| public KerasTensor(TensorSpec type_spec, string name = null) | |||
| { | |||
| @@ -22,15 +27,23 @@ public class KerasTensor | |||
| { | |||
| var type_spec = tensor.ToTensorSpec(); | |||
| var kt = new KerasTensor(type_spec, name: tensor.name); | |||
| kt.SetTensor(tensor); | |||
| kt.inferred_value = tensor; | |||
| return kt; | |||
| } | |||
| public override string ToString() | |||
| => _inferred_value.Length switch | |||
| { | |||
| > 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]", | |||
| 1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>", | |||
| _ => _inferred_value.ToString(), | |||
| }; | |||
| public static implicit operator Tensors(KerasTensor kt) | |||
| => kt._tensor; | |||
| => kt._inferred_value; | |||
| public static implicit operator Tensor(KerasTensor kt) | |||
| => kt._tensor; | |||
| => kt._inferred_value; | |||
| public static implicit operator KerasTensor(Tensor tensor) | |||
| => from_tensor(tensor); | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||
| array_ops.stack(args.End), | |||
| array_ops.stack(args.Strides)); | |||
| return gen_array_ops.strided_slice( | |||
| return array_ops.strided_slice( | |||
| this, | |||
| packed_begin, | |||
| packed_end, | |||