Browse Source

add tf.random_normal_initializer

tags/v0.12
Oceania2018 6 years ago
parent
commit
e8ff9f03d3
3 changed files with 72 additions and 0 deletions
  1. +8
    -0
      src/TensorFlowNET.Core/APIs/tf.init.cs
  2. +51
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  3. +13
    -0
      src/TensorFlowNET.Core/Operations/nn_ops.cs

+ 8
- 0
src/TensorFlowNET.Core/APIs/tf.init.cs View File

@@ -52,5 +52,13 @@ namespace Tensorflow
stddev: stddev,
seed: seed,
dtype: dtype);

public IInitializer random_normal_initializer(float mean = 0.0f,
float stddev = 1.0f,
int? seed = null,
TF_DataType dtype = TF_DataType.DtInvalid) => new RandomNormal(mean: mean,
stddev: stddev,
seed: seed,
dtype: dtype);
}
}

+ 51
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -0,0 +1,51 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Initializers
{
public class RandomNormal : IInitializer
{
private float mean;
private float stddev;
private int? seed;
private TF_DataType dtype;

public RandomNormal(float mean = 0.0f,
float stddev = 1.0f,
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.mean = mean;
this.stddev = stddev;
this.seed = seed;
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
{
throw new NotImplementedException();
}

public object get_config()
{
throw new NotImplementedException();
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -116,6 +116,19 @@ namespace Tensorflow
return _softmax(logits, gen_nn_ops.log_softmax, axis, name);
}

public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
{
return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope =>
{
name = scope;
features = ops.convert_to_tensor(features, name: "features");
if (features.dtype.is_integer())
features = math_ops.cast(features, dtypes.float32);
return gen_nn_ops.leaky_relu(features, alpha: alpha, name: name);
//return math_ops.maximum(alpha * features, features, name: name);
});
}

/// <summary>
/// Performs the max pooling on the input.
/// </summary>


Loading…
Cancel
Save