diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 0525c6ec..1595e52f 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -26,7 +26,7 @@ namespace Tensorflow
public class nn_internal
{
- public Tensor conv2d(Tensor input, IVariableV1 filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
+ public Tensor conv2d(Tensor input, Tensor filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
string data_format = "NHWC", int[] dilations = null, string name = null)
{
var parameters = new Conv2dParams
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index 3d98854c..d461595b 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -133,6 +133,46 @@ namespace Tensorflow.Gradients
-x_grad
};
}
+
+ ///
+ /// The derivatives for deconvolution.
+ ///
+ /// The Deconvolution op.
+ /// The tensor representing the gradient w.r.t. the output
+ /// The gradients w.r.t. the input and the filter
+ [RegisterGradient("Conv2DBackpropInput")]
+ public static Tensor[] _Conv2DBackpropInputGrad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var dilations = op.get_attr_list("dilations");
+ var strides = op.get_attr_list("strides");
+ var padding = op.get_attr("padding");
+ var explicit_paddings = op.get_attr_list("explicit_paddings");
+ var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu");
+ var data_format = op.get_attr("data_format");
+
+ return new Tensor[]
+ {
+ gen_nn_ops.conv2d_backprop_filter(grad, array_ops.shape(op.inputs[1]), op.inputs[2],
+ strides, padding,
+ use_cudnn_on_gpu: use_cudnn_on_gpu,
+ explicit_paddings: explicit_paddings,
+ dilations: dilations,
+ data_format: data_format),
+ gen_nn_ops.conv2d(new Conv2dParams
+ {
+ Input = grad,
+ Filter = op.inputs[1],
+ Strides = strides,
+ Padding = padding,
+ DataFormat = data_format,
+ Dilations = dilations,
+ ExplicitPaddings = explicit_paddings,
+ UseCudnnOnGpu = use_cudnn_on_gpu
+ })
+ };
+ }
+
///
/// Gradient function for Conv2D.
///
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 6c652040..2a982274 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -283,7 +283,7 @@ namespace Tensorflow
// This was causing duplicate graph node name errors, when testing a conv2d autoencoder
// https://keras.io/guides/functional_api/#:~:text=keras.,graph%20(DAG)%20of%20layers.
// name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
- name = name.EndsWith("/") ? unique_name(ops.name_from_scope_name(name)) : unique_name(name);
+ name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, attrs: attrs);
var input_ops = inputs.Select(x => x.op).ToArray();
@@ -386,10 +386,6 @@ namespace Tensorflow
/// to name the operation being created.
public string unique_name(string name, bool mark_as_used = true)
{
- if (name.EndsWith("basic_r_n_n_cell"))
- {
-
- }
if (!String.IsNullOrEmpty(_name_stack))
name = _name_stack + "/" + name;
// For the sake of checking for names in use, we treat names as case
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs
index ff594077..fa0d5bef 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Operations
///
/// A 4-D tensor of shape
///
- public IVariableV1 Filter { get; set; }
+ public Tensor Filter { get; set; }
///
/// An integer vector representing the tensor shape of `filter`
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs
index 0e041836..dbf53988 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs
@@ -36,7 +36,7 @@ namespace Tensorflow.Operations
name = args.Name;
}
- public Tensor Apply(Tensors input, IVariableV1 filters)
+ public Tensor Apply(Tensors input, Tensor filters)
{
var filters_rank = filters.shape.ndim;
var inputs_rank = input.shape.ndim;
diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
index 7fc85dff..3ccf0c19 100644
--- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
+++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
@@ -60,201 +60,204 @@ namespace Tensorflow
object values = null;
g.as_default();
- var ret_op = tf_with(ops.name_scope(name), scope =>
- {
- var inferred_from = new Dictionary();
- var base_types = new List();
- var types = new List();
- string _scope_name = scope;
- // Perform input type inference
- foreach (var (i, input_arg) in enumerate(op_def.InputArg))
- {
- var input_name = input_arg.Name;
+ var scope = ops.name_scope(name);
+ scope.__enter__();
+
+ var inferred_from = new Dictionary();
+ var base_types = new List();
+ var types = new List();
+ string _scope_name = scope;
+
+ // Perform input type inference
+ foreach (var (i, input_arg) in enumerate(op_def.InputArg))
+ {
+ var input_name = input_arg.Name;
- if (keywords.ContainsKey(input_name))
- values = keywords[input_name];
- else if (keywords.ContainsKey(input_name + "_"))
- {
- input_name += "_";
- values = keywords[input_name];
- }
- else if (keywords.ContainsKey($"input_{i}"))
- {
- values = keywords[$"input_{i}"];
- }
- else
- throw new TypeError("No argument for input " + input_name);
-
- // Goals:
- // * Convert values to Tensors if it contains constants.
- // * Verify that values is a list if that matches the input_arg's
- // type.
- // * If the input_arg's type is determined by attrs, either set
- // those attrs and validate those attr values are legal (if
- // they have not yet been set) or validate the input matches
- // the type indicated by the attrs (if they have already been
- // inferred via an earlier input).
- // * If the input_arg has an explicit type, make sure the input
- // conforms.
-
- DataType dtype = DataType.DtInvalid;
- DataType default_dtype = DataType.DtInvalid;
-
- if (_IsListParameter(input_arg))
- {
- if (!_IsListValue(values))
- throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
- if (input_arg.Type != DataType.DtInvalid)
- dtype = input_arg.Type;
- else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
- {
- if (attrs.ContainsKey(input_arg.TypeAttr))
- dtype = (DataType)attrs[input_arg.TypeAttr];
- else
- switch (values)
- {
- case Tensor[] values1:
- dtype = values1[0].dtype.as_datatype_enum();
- break;
- case object[] values1:
- foreach (var t in values1)
- if (t is Tensor tensor)
- {
- dtype = tensor.dtype.as_datatype_enum();
- break;
- }
- break;
- default:
- throw new NotImplementedException($"can't infer the dtype for {values.GetType()}");
- }
-
- if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr))
- default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
- }
-
- if (!input_arg.IsRef && dtype != DataType.DtInvalid)
- dtype = dtype.as_base_dtype();
-
- values = ops.internal_convert_n_to_tensor(values as object[],
- name: input_arg.Name,
- dtype: dtype.as_tf_dtype(),
- preferred_dtype: default_dtype.as_tf_dtype(),
- as_ref: input_arg.IsRef);
- }
- else
+ if (keywords.ContainsKey(input_name))
+ values = keywords[input_name];
+ else if (keywords.ContainsKey(input_name + "_"))
+ {
+ input_name += "_";
+ values = keywords[input_name];
+ }
+ else if (keywords.ContainsKey($"input_{i}"))
+ {
+ values = keywords[$"input_{i}"];
+ }
+ else
+ throw new TypeError("No argument for input " + input_name);
+
+ // Goals:
+ // * Convert values to Tensors if it contains constants.
+ // * Verify that values is a list if that matches the input_arg's
+ // type.
+ // * If the input_arg's type is determined by attrs, either set
+ // those attrs and validate those attr values are legal (if
+ // they have not yet been set) or validate the input matches
+ // the type indicated by the attrs (if they have already been
+ // inferred via an earlier input).
+ // * If the input_arg has an explicit type, make sure the input
+ // conforms.
+
+ DataType dtype = DataType.DtInvalid;
+ DataType default_dtype = DataType.DtInvalid;
+
+ if (_IsListParameter(input_arg))
+ {
+ if (!_IsListValue(values))
+ throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
+ if (input_arg.Type != DataType.DtInvalid)
+ dtype = input_arg.Type;
+ else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
{
- if (input_arg.Type != DataType.DtInvalid)
- dtype = input_arg.Type;
- else if (attrs.ContainsKey(input_arg.TypeAttr))
+ if (attrs.ContainsKey(input_arg.TypeAttr))
dtype = (DataType)attrs[input_arg.TypeAttr];
- else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid)
- dtype = DataType.DtString;
- else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr))
+ else
+ switch (values)
+ {
+ case Tensor[] values1:
+ dtype = values1[0].dtype.as_datatype_enum();
+ break;
+ case object[] values1:
+ foreach (var t in values1)
+ if (t is Tensor tensor)
+ {
+ dtype = tensor.dtype.as_datatype_enum();
+ break;
+ }
+ break;
+ default:
+ throw new NotImplementedException($"can't infer the dtype for {values.GetType()}");
+ }
+
+ if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr))
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
+ }
- var value = ops.convert_to_tensor(values,
- name: input_name,
- dtype: dtype.as_tf_dtype(),
- as_ref: input_arg.IsRef,
- preferred_dtype: default_dtype.as_tf_dtype());
-
- //if (!String.IsNullOrEmpty(input_arg.TypeAttr))
- //attrs[input_arg.TypeAttr] = values.dtype;
+ if (!input_arg.IsRef && dtype != DataType.DtInvalid)
+ dtype = dtype.as_base_dtype();
- values = new Tensor[] { value };
- }
+ values = ops.internal_convert_n_to_tensor(values as object[],
+ name: input_arg.Name,
+ dtype: dtype.as_tf_dtype(),
+ preferred_dtype: default_dtype.as_tf_dtype(),
+ as_ref: input_arg.IsRef);
+ }
+ else
+ {
+ if (input_arg.Type != DataType.DtInvalid)
+ dtype = input_arg.Type;
+ else if (attrs.ContainsKey(input_arg.TypeAttr))
+ dtype = (DataType)attrs[input_arg.TypeAttr];
+ else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid)
+ dtype = DataType.DtString;
+ else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr))
+ default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
+
+ var value = ops.convert_to_tensor(values,
+ name: input_name,
+ dtype: dtype.as_tf_dtype(),
+ as_ref: input_arg.IsRef,
+ preferred_dtype: default_dtype.as_tf_dtype());
+
+ //if (!String.IsNullOrEmpty(input_arg.TypeAttr))
+ //attrs[input_arg.TypeAttr] = values.dtype;
+
+ values = new Tensor[] { value };
+ }
- if (values is Tensor[] values2)
- {
- types = values2.Select(x => x.dtype).ToList();
- inputs.AddRange(values2);
- base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList();
- }
- else throw new NotImplementedException("_IsListParameter");
-
- SetAttrs(op_type_name,
- input_arg,
- op_def,
- attrs,
- inferred_from,
- types,
- base_types,
- input_types,
- values);
+ if (values is Tensor[] values2)
+ {
+ types = values2.Select(x => x.dtype).ToList();
+ inputs.AddRange(values2);
+ base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList();
}
+ else throw new NotImplementedException("_IsListParameter");
+
+ SetAttrs(op_type_name,
+ input_arg,
+ op_def,
+ attrs,
+ inferred_from,
+ types,
+ base_types,
+ input_types,
+ values);
+ }
- // Process remaining attrs
- foreach (var attr in op_def.Attr)
+ // Process remaining attrs
+ foreach (var attr in op_def.Attr)
+ {
+ if (keywords.ContainsKey(attr.Name))
{
- if (keywords.ContainsKey(attr.Name))
- {
- attrs[attr.Name] = keywords[attr.Name];
- }
+ attrs[attr.Name] = keywords[attr.Name];
}
+ }
- // Convert attr values to AttrValue protos.
- var attr_protos = new Dictionary();
- foreach (AttrDef attr_def in op_def.Attr)
+ // Convert attr values to AttrValue protos.
+ var attr_protos = new Dictionary();
+ foreach (AttrDef attr_def in op_def.Attr)
+ {
+ var key = attr_def.Name;
+ if (attrs.ContainsKey(key))
{
- var key = attr_def.Name;
- if (attrs.ContainsKey(key))
- {
- attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]);
- }
- else
+ attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]);
+ }
+ else
+ {
+ if (attr_def.DefaultValue == null)
{
- if (attr_def.DefaultValue == null)
- {
- throw new TypeError("Missing required positional argument " + key);
- }
+ throw new TypeError("Missing required positional argument " + key);
}
}
+ }
- attrs.Clear();
+ attrs.Clear();
- // Determine output types (possibly using attrs)
- var output_types = new List();
+ // Determine output types (possibly using attrs)
+ var output_types = new List();
- foreach (var arg in op_def.OutputArg)
+ foreach (var arg in op_def.OutputArg)
+ {
+ types = new List();
+ if (!string.IsNullOrEmpty(arg.NumberAttr))
{
- types = new List();
- if (!string.IsNullOrEmpty(arg.NumberAttr))
- {
- }
- else if (!string.IsNullOrEmpty(arg.TypeAttr))
- {
- types = new List() { (TF_DataType)attr_protos[arg.TypeAttr].Type };
- }
+ }
+ else if (!string.IsNullOrEmpty(arg.TypeAttr))
+ {
+ types = new List() { (TF_DataType)attr_protos[arg.TypeAttr].Type };
+ }
- if (arg.IsRef)
- types = types.Select(x => x.as_ref()).ToList();
+ if (arg.IsRef)
+ types = types.Select(x => x.as_ref()).ToList();
- output_types.AddRange(types);
- }
+ output_types.AddRange(types);
+ }
+
+ // We add an explicit colocation constraint between
+ // the newly created op and any of its reference-typed inputs.
+ var must_colocate_inputs = zip(op_def.InputArg, inputs)
+ .Where(x => x.Item1.IsRef)
+ .Select(x => x.Item2)
+ .ToArray();
+
+ _MaybeColocateWith(must_colocate_inputs);
+
+ // Add Op to graph
+ var ret_op = g.create_op(op_type_name,
+ inputs.ToArray(),
+ output_types.ToArray(),
+ name: _scope_name,
+ input_types: input_types.ToArray(),
+ attrs: attr_protos,
+ op_def: op_def);
+
+ scope.__exit__();
- // We add an explicit colocation constraint between
- // the newly created op and any of its reference-typed inputs.
- var must_colocate_inputs = zip(op_def.InputArg, inputs)
- .Where(x => x.Item1.IsRef)
- .Select(x => x.Item2)
- .ToArray();
-
- _MaybeColocateWith(must_colocate_inputs);
-
- // Add Op to graph
- var op = g.create_op(op_type_name,
- inputs.ToArray(),
- output_types.ToArray(),
- name: _scope_name,
- input_types: input_types.ToArray(),
- attrs: attr_protos,
- op_def: op_def);
-
- return op;
- });
g.Exit();
+
return ret_op;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
index 3c5e0d5d..218d1369 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
@@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false)
{
- var outputs = _convolution_op.Apply(inputs, kernel);
+ var outputs = _convolution_op.Apply(inputs, kernel.AsTensor());
if (use_bias)
{
if (data_format == "channels_first")