diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index ea52ab57..e9805010 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -115,8 +115,8 @@ namespace Tensorflow public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name); public Tensor[] fused_batch_norm(Tensor x, - RefVariable scale, - RefVariable offset, + VariableV1 scale, + VariableV1 offset, Tensor mean = null, Tensor variance = null, float epsilon = 0.001f, diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 530ca76c..0428b2ad 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -37,8 +37,8 @@ namespace Tensorflow.Keras.Layers private IInitializer gamma_initializer; private IInitializer moving_mean_initializer; private IInitializer moving_variance_initializer; - private RefVariable gamma; - private RefVariable beta; + private VariableV1 gamma; + private VariableV1 beta; private RefVariable moving_mean; private RefVariable moving_variance; @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers var param_shape = new int[] { input_shape.dims[axis[0]] }; if (scale) - gamma = (RefVariable)add_weight("gamma", + gamma = add_weight("gamma", param_shape, dtype: param_dtype, initializer: gamma_initializer, @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Layers throw new NotImplementedException("add_weight gamma"); if (center) - beta = (RefVariable)add_weight("beta", + beta = add_weight("beta", param_shape, dtype: param_dtype, initializer: beta_initializer, diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index d4feebc2..c7839b04 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -32,36 +32,21 @@ namespace Tensorflow.Keras.Utils /// /// /// - public static RefVariable make_variable(string name, + public static VariableV1 make_variable(string name, int[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, - bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true); - - /// - /// Adds a new variable to the layer. - /// - /// - /// - /// - /// - /// - /// - public static RefVariable make_variable(string name, - int[] shape, - TF_DataType dtype = TF_DataType.TF_FLOAT, - IInitializer initializer = null, - bool trainable = true, - bool use_resource = true) + bool trainable = true) { var initializing_from_value = false; + bool use_resource = true; ops.init_scope(); Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); var variable_dtype = dtype.as_base_dtype(); - var v = tf.Variable(init_val); + var v = tf.VariableV1(init_val); return v; } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index bced0047..42103b00 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -97,9 +97,9 @@ namespace Tensorflow /// /// /// - public static Tensor[] fused_batch_norm(Tensor x, - RefVariable scale, - RefVariable offset, + public static Tensor[] fused_batch_norm(Tensor x, + VariableV1 scale, + VariableV1 offset, Tensor mean, Tensor variance, float epsilon = 0.001f, diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index cf973864..4b260632 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -58,6 +58,21 @@ namespace Tensorflow dtype: dtype); } + public VariableV1 VariableV1(T data, + bool trainable = true, + bool validate_shape = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool use_resource = false) + { + return Tensorflow.variable_scope.default_variable_creator(data, + trainable: trainable, + validate_shape: validate_shape, + name: name, + dtype: dtype, + use_resource: use_resource); + } + public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) { return gen_array_ops.placeholder(dtype, shape, name);