Browse Source

Gradient function for Conv2D

tags/v0.9
Oceania2018 6 years ago
parent
commit
13189451f7
5 changed files with 122 additions and 2 deletions
  1. +44
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  2. +16
    -0
      src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs
  3. +46
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  4. +13
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +3
    -2
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

+ 44
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -97,6 +97,50 @@ namespace Tensorflow.Gradients
}; };
} }


/// <summary>
/// Gradient function for Conv2D.
/// </summary>
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("Conv2D")]
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
{
var dilations = op.get_attr("dilations");
var strides = op.get_attr("strides");
var padding = op.get_attr("padding");
var explicit_paddings = op.get_attr("explicit_paddings");
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu");
var data_format = op.get_attr("data_format");
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
return new Tensor[]
{
gen_nn_ops.conv2d_backprop_input(new Conv2dParams
{
InputSizes = shape[0],
Filter = op.inputs[1],
Dilations = dilations == null ? null : dilations as int[],
Strides = strides == null ? null : strides as int[],
Padding = padding.ToString(),
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
DataFormat = data_format.ToString()
}),
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams
{
Input = op.inputs[0],
FilterSizes = shape[1],
Dilations = dilations == null ? null : dilations as int[],
Strides = strides == null ? null : strides as int[],
Padding = padding.ToString(),
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
DataFormat = data_format.ToString()
})
};
}

private static bool IsZero(Tensor g) private static bool IsZero(Tensor g)
{ {
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))


+ 16
- 0
src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs View File

@@ -22,11 +22,27 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
public Tensor Input { get; set; } public Tensor Input { get; set; }


/// <summary>
/// An integer vector representing the shape of `input`
/// </summary>
public Tensor InputSizes { get; set; }

/// <summary> /// <summary>
/// A 4-D tensor of shape /// A 4-D tensor of shape
/// </summary> /// </summary>
public Tensor Filter { get; set; } public Tensor Filter { get; set; }


/// <summary>
/// An integer vector representing the tensor shape of `filter`
/// </summary>
public Tensor FilterSizes { get; set; }

/// <summary>
/// A `Tensor`. Must have the same type as `filter`.
/// 4-D with shape `[batch, out_height, out_width, out_channels]`.
/// </summary>
public Tensor OutBackProp { get; set; }

/// <summary> /// <summary>
/// The stride of the sliding window for each /// The stride of the sliding window for each
/// dimension of `input`. The dimension order is determined by the value of /// dimension of `input`. The dimension order is determined by the value of


+ 46
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -43,6 +43,52 @@ namespace Tensorflow.Operations
}); });


return _op.outputs[0]; return _op.outputs[0];
}
/// <summary>
/// Computes the gradients of convolution with respect to the filter.
/// </summary>
/// <param name="parameters"></param>
/// <returns></returns>
public static Tensor conv2d_backprop_filter(Conv2dParams parameters)
{
var _op = _op_def_lib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new
{
input = parameters.Input,
filter_sizes = parameters.FilterSizes,
out_backprop = parameters.OutBackProp,
strides = parameters.Strides,
padding = parameters.Padding,
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
explicit_paddings = parameters.ExplicitPaddings,
data_format = parameters.DataFormat,
dilations = parameters.Dilations
});

return _op.outputs[0];
}

/// <summary>
/// Computes the gradients of convolution with respect to the input.
/// </summary>
/// <param name="parameters"></param>
/// <returns></returns>
public static Tensor conv2d_backprop_input(Conv2dParams parameters)
{
var _op = _op_def_lib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new
{
input_sizes = parameters.InputSizes,
filter = parameters.Filter,
out_backprop = parameters.OutBackProp,
strides = parameters.Strides,
padding = parameters.Padding,
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
explicit_paddings = parameters.ExplicitPaddings,
data_format = parameters.DataFormat,
dilations = parameters.Dilations
});

return _op.outputs[0];
} }


public static Tensor bias_add(Tensor value, public static Tensor bias_add(Tensor value,


+ 13
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -252,6 +252,19 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }


/// <summary>
/// Returns shape of tensors.
/// </summary>
/// <param name="input"></param>
/// <param name="out_type"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ShapeN", name, new { input, out_type });
return _op.outputs;
}

public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type }); var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type });


+ 3
- 2
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -20,8 +20,9 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.8.1.0</AssemblyVersion> <AssemblyVersion>0.8.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.8: <PackageReleaseNotes>Changes since v0.8:


1. Removed global static graph instance.
2. Provide custom gradient function.</PackageReleaseNotes>
1. Remove global static graph instance.
2. Provide custom gradient function.
3. Add gradient function for Conv2D.</PackageReleaseNotes>
<LangVersion>7.2</LangVersion> <LangVersion>7.2</LangVersion>
<FileVersion>0.8.1.0</FileVersion> <FileVersion>0.8.1.0</FileVersion>
</PropertyGroup> </PropertyGroup>


Loading…
Cancel
Save