| @@ -603,7 +603,17 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| return gen_array_ops.shape(input, name: name, out_type: out_type); | |||||
| return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) | |||||
| { | |||||
| GetGradientAttrs = (op) => new | |||||
| { | |||||
| T = op.get_attr<TF_DataType>("T"), | |||||
| out_type = op.get_attr<TF_DataType>("out_type") | |||||
| } | |||||
| }.SetAttributes(new | |||||
| { | |||||
| out_type | |||||
| })).First(); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -703,23 +713,26 @@ namespace Tensorflow | |||||
| int new_axis_mask = 0, | int new_axis_mask = 0, | ||||
| int shrink_axis_mask = 0, | int shrink_axis_mask = 0, | ||||
| string name = null) | string name = null) | ||||
| { | |||||
| var op = gen_array_ops.strided_slice( | |||||
| input: input_, | |||||
| begin: begin, | |||||
| end: end, | |||||
| strides: strides, | |||||
| begin_mask: begin_mask, | |||||
| end_mask: end_mask, | |||||
| ellipsis_mask: ellipsis_mask, | |||||
| new_axis_mask: new_axis_mask, | |||||
| shrink_axis_mask: shrink_axis_mask, | |||||
| name: name); | |||||
| string parent_name = name; | |||||
| return op; | |||||
| } | |||||
| => tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides) | |||||
| { | |||||
| GetGradientAttrs = (op) => new | |||||
| { | |||||
| T = op.get_attr<TF_DataType>("T"), | |||||
| Index = op.get_attr<TF_DataType>("Index"), | |||||
| begin_mask = op.get_attr<long>("begin_mask"), | |||||
| end_mask = op.get_attr<long>("end_mask"), | |||||
| ellipsis_mask = op.get_attr<long>("ellipsis_mask"), | |||||
| new_axis_mask = op.get_attr<long>("new_axis_mask"), | |||||
| shrink_axis_mask = op.get_attr<long>("shrink_axis_mask") | |||||
| } | |||||
| }.SetAttributes(new | |||||
| { | |||||
| begin_mask, | |||||
| end_mask, | |||||
| ellipsis_mask, | |||||
| new_axis_mask, | |||||
| shrink_axis_mask | |||||
| })); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the gradient of `StridedSlice`. | /// Returns the gradient of `StridedSlice`. | ||||
| @@ -5,12 +5,17 @@ | |||||
| /// </summary> | /// </summary> | ||||
| public class KerasTensor | public class KerasTensor | ||||
| { | { | ||||
| private Tensor _tensor; | |||||
| public void SetTensor(Tensors tensor) | |||||
| => _tensor = tensor; | |||||
| private Tensors _inferred_value; | |||||
| public Tensors inferred_value | |||||
| { | |||||
| get => _inferred_value; | |||||
| set => _inferred_value = value; | |||||
| } | |||||
| private TensorSpec _type_spec; | |||||
| private string _name; | private string _name; | ||||
| private TensorSpec _type_spec; | |||||
| public Shape shape => _type_spec.shape; | |||||
| public TF_DataType dtype => _type_spec.dtype; | |||||
| public KerasTensor(TensorSpec type_spec, string name = null) | public KerasTensor(TensorSpec type_spec, string name = null) | ||||
| { | { | ||||
| @@ -22,15 +27,23 @@ public class KerasTensor | |||||
| { | { | ||||
| var type_spec = tensor.ToTensorSpec(); | var type_spec = tensor.ToTensorSpec(); | ||||
| var kt = new KerasTensor(type_spec, name: tensor.name); | var kt = new KerasTensor(type_spec, name: tensor.name); | ||||
| kt.SetTensor(tensor); | |||||
| kt.inferred_value = tensor; | |||||
| return kt; | return kt; | ||||
| } | } | ||||
| public override string ToString() | |||||
| => _inferred_value.Length switch | |||||
| { | |||||
| > 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]", | |||||
| 1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>", | |||||
| _ => _inferred_value.ToString(), | |||||
| }; | |||||
| public static implicit operator Tensors(KerasTensor kt) | public static implicit operator Tensors(KerasTensor kt) | ||||
| => kt._tensor; | |||||
| => kt._inferred_value; | |||||
| public static implicit operator Tensor(KerasTensor kt) | public static implicit operator Tensor(KerasTensor kt) | ||||
| => kt._tensor; | |||||
| => kt._inferred_value; | |||||
| public static implicit operator KerasTensor(Tensor tensor) | public static implicit operator KerasTensor(Tensor tensor) | ||||
| => from_tensor(tensor); | => from_tensor(tensor); | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| array_ops.stack(args.End), | array_ops.stack(args.End), | ||||
| array_ops.stack(args.Strides)); | array_ops.stack(args.Strides)); | ||||
| return gen_array_ops.strided_slice( | |||||
| return array_ops.strided_slice( | |||||
| this, | this, | ||||
| packed_begin, | packed_begin, | ||||
| packed_end, | packed_end, | ||||