diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 06d48555..793318cb 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -99,6 +99,18 @@ namespace Tensorflow int axis = -1, string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name); + /// + /// Pads a tensor + /// + /// + /// + /// + /// + /// + /// + public Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + => array_ops.pad(tensor, paddings, mode: mode, name: name, constant_values: constant_values); + /// /// A placeholder op that passes through `input` when its output is not fed. /// diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 92f65906..c85ac245 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -552,6 +552,40 @@ namespace Tensorflow throw new NotImplementedException("array_ops.stack"); } + public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + { + Tensor result = null; + mode = mode.ToUpper(); + if(mode == "CONSTANT") + { + if (constant_values != 0) + throw new NotImplementedException("gen_array_ops.pad_v2"); + else + result = gen_array_ops.pad(tensor, paddings, name: name); + } + + // Restore shape information where possible. + var paddings_constant = tensor_util.constant_value( + result.op.inputs[1], partial: true); + var input_shape = result.op.inputs[0].TensorShape; + if (input_shape.ndim > -1 && + !result.TensorShape.is_fully_defined() && + !(paddings_constant is null)) + { + var new_shape = new List(); + foreach((NDArray padding, int dim) in zip(paddings_constant.GetNDArrays(), np.array(input_shape.dims).GetNDArrays())) + { + if (padding is null || dim == -1 || padding.GetData().Contains(-1)) + new_shape.Add(-1); + else + new_shape.Add(np.sum(padding) + dim); + } + result.set_shape(new_shape.ToArray()); + } + + return result; + } + public static Tensor placeholder(TF_DataType dtype) { throw new NotImplementedException("array_ops.placeholder"); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 092d152c..61fa956b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -99,6 +99,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor pad(Tensor input, Tensor paddings, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Pad", name: name, args: new { input, paddings }); + + return _op.output; + } + public static Tensor pack(Tensor[] values, int axis = 0, string name = null) { var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs index 57105aa1..52aafa97 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs @@ -19,6 +19,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO if (downsample) { + (int pad_h, int pad_w) = ((int)Math.Floor((filters_shape[0] - 2) / 2.0f) + 1, (int)Math.Floor((filters_shape[1] - 2) / 2.0f) + 1); + var paddings = tf.constant(new int[,] { { 0, 0 }, { pad_h, pad_h }, { pad_w, pad_w }, { 0, 0 } }); + input_data = tf.pad(input_data, paddings, "CONSTANT"); throw new NotImplementedException(""); } else