diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index ea52ab57..e9805010 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -115,8 +115,8 @@ namespace Tensorflow
public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);
public Tensor[] fused_batch_norm(Tensor x,
- RefVariable scale,
- RefVariable offset,
+ VariableV1 scale,
+ VariableV1 offset,
Tensor mean = null,
Tensor variance = null,
float epsilon = 0.001f,
diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
index 530ca76c..0428b2ad 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
@@ -37,8 +37,8 @@ namespace Tensorflow.Keras.Layers
private IInitializer gamma_initializer;
private IInitializer moving_mean_initializer;
private IInitializer moving_variance_initializer;
- private RefVariable gamma;
- private RefVariable beta;
+ private VariableV1 gamma;
+ private VariableV1 beta;
private RefVariable moving_mean;
private RefVariable moving_variance;
@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers
var param_shape = new int[] { input_shape.dims[axis[0]] };
if (scale)
- gamma = (RefVariable)add_weight("gamma",
+ gamma = add_weight("gamma",
param_shape,
dtype: param_dtype,
initializer: gamma_initializer,
@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("add_weight gamma");
if (center)
- beta = (RefVariable)add_weight("beta",
+ beta = add_weight("beta",
param_shape,
dtype: param_dtype,
initializer: beta_initializer,
diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
index d4feebc2..c7839b04 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
@@ -32,36 +32,21 @@ namespace Tensorflow.Keras.Utils
///
///
///
- public static RefVariable make_variable(string name,
+ public static VariableV1 make_variable(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
- bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true);
-
- ///
- /// Adds a new variable to the layer.
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- public static RefVariable make_variable(string name,
- int[] shape,
- TF_DataType dtype = TF_DataType.TF_FLOAT,
- IInitializer initializer = null,
- bool trainable = true,
- bool use_resource = true)
+ bool trainable = true)
{
var initializing_from_value = false;
+ bool use_resource = true;
ops.init_scope();
Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);
var variable_dtype = dtype.as_base_dtype();
- var v = tf.Variable(init_val);
+ var v = tf.VariableV1(init_val);
return v;
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index bced0047..42103b00 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -97,9 +97,9 @@ namespace Tensorflow
///
///
///
- public static Tensor[] fused_batch_norm(Tensor x,
- RefVariable scale,
- RefVariable offset,
+ public static Tensor[] fused_batch_norm(Tensor x,
+ VariableV1 scale,
+ VariableV1 offset,
Tensor mean,
Tensor variance,
float epsilon = 0.001f,
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index cf973864..4b260632 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -58,6 +58,21 @@ namespace Tensorflow
dtype: dtype);
}
+ public VariableV1 VariableV1(T data,
+ bool trainable = true,
+ bool validate_shape = true,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ bool use_resource = false)
+ {
+ return Tensorflow.variable_scope.default_variable_creator(data,
+ trainable: trainable,
+ validate_shape: validate_shape,
+ name: name,
+ dtype: dtype,
+ use_resource: use_resource);
+ }
+
public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
{
return gen_array_ops.placeholder(dtype, shape, name);