Browse Source

fix return value issue for weight_variable and bias_variable.

tags/v0.10
Oceania2018 6 years ago
parent
commit
051855e839
2 changed files with 10 additions and 9 deletions
  1. +8
    -6
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +2
    -3
      test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs

+ 8
- 6
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -30,19 +30,21 @@ namespace Tensorflow
public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
string data_format= "NHWC", int[] dilations= null, string name = null) string data_format= "NHWC", int[] dilations= null, string name = null)
{ {
if (dilations == null)
dilations = new[] { 1, 1, 1, 1 };

return gen_nn_ops.conv2d(new Conv2dParams
var parameters = new Conv2dParams
{ {
Input = input, Input = input,
Filter = filter, Filter = filter,
Strides = strides, Strides = strides,
Padding = padding,
UseCudnnOnGpu = use_cudnn_on_gpu, UseCudnnOnGpu = use_cudnn_on_gpu,
DataFormat = data_format, DataFormat = data_format,
Dilations = dilations,
Name = name Name = name
});
};

if (dilations != null)
parameters.Dilations = dilations;

return gen_nn_ops.conv2d(parameters);
} }


/// <summary> /// <summary>


+ 2
- 3
test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs View File

@@ -158,7 +158,6 @@ namespace TensorFlowNET.Examples.ImageProcess
layer += b; layer += b;
return tf.nn.relu(layer); return tf.nn.relu(layer);
}); });

} }


/// <summary> /// <summary>
@@ -195,7 +194,7 @@ namespace TensorFlowNET.Examples.ImageProcess
}); });
} }


private Tensor weight_variable(string name, int[] shape)
private RefVariable weight_variable(string name, int[] shape)
{ {
var initer = tf.truncated_normal_initializer(stddev: 0.01f); var initer = tf.truncated_normal_initializer(stddev: 0.01f);
return tf.get_variable(name, return tf.get_variable(name,
@@ -210,7 +209,7 @@ namespace TensorFlowNET.Examples.ImageProcess
/// <param name="name"></param> /// <param name="name"></param>
/// <param name="shape"></param> /// <param name="shape"></param>
/// <returns></returns> /// <returns></returns>
private Tensor bias_variable(string name, int[] shape)
private RefVariable bias_variable(string name, int[] shape)
{ {
var initial = tf.constant(0f, shape: shape, dtype: tf.float32); var initial = tf.constant(0f, shape: shape, dtype: tf.float32);
return tf.get_variable(name, return tf.get_variable(name,


Loading…
Cancel
Save