diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index c9ad6402..f8746f7f 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -75,9 +75,9 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
///
/// https://www.tensorflow.org/guide/graphs
https://www.tensorflow.org/api_docs/python/tf/Graph
- public partial class Graph : DisposableObject,
+ public partial class Graph : DisposableObject
#if !SERIALIZABLE
- IEnumerable
+ ,IEnumerable
#endif
{
private Dictionary _nodes_by_id;
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
index 77bf68a1..abe8e9c1 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
@@ -60,6 +60,9 @@ namespace Tensorflow
///
/// List this operation's output types.
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public TF_DataType[] _output_types
{
get
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 0f9ed2eb..359dc870 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -78,7 +78,10 @@ namespace Tensorflow
#if SERIALIZABLE
[JsonIgnore]
#endif
- bool _is_stateful;
+ bool _is_stateful;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public NodeDef node_def
{
get
@@ -181,8 +184,8 @@ namespace Tensorflow
// This will be set by self.inputs.
if (op_def == null)
- op_def = g.GetOpDef(node_def.Op);
-
+ op_def = g.GetOpDef(node_def.Op);
+
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_is_stateful = op_def.IsStateful;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
index d916f624..26c251b0 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
@@ -79,6 +79,8 @@ namespace Tensorflow
}
strides.Add(s.Step);
+ if (s.IsIndex)
+ shrink_axis_mask |= (1 << index);
}
index += 1;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
index 5a3d6f79..2a901e07 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
@@ -16,6 +16,7 @@
using NumSharp;
using System;
+using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using static Tensorflow.Binding;
@@ -25,7 +26,7 @@ namespace Tensorflow
public partial class Tensor
{
#if _REGEN
- #region Compute
+ #region Compute
%operators = ["add", "sub", "mul", "div", "mod"]
%operators_sign = ["+", "-", "*", "/", "%"]
%operators_comparers = [">", "<", ">=", "<="]
@@ -49,11 +50,11 @@ namespace Tensorflow
%
%
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x);
- #endregion
+ #endregion
#else
- #region Compute
+ #region Compute
+
-
public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs);
@@ -129,31 +130,31 @@ namespace Tensorflow
public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs);
public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs);
public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs);
- public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("truediv", lhs, rhs);
- public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs);
+ public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs);
public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs);
public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs);
@@ -281,24 +282,43 @@ namespace Tensorflow
public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs);
public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs);
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x);
- #endregion
+ #endregion
#endif
-
+
private static readonly TF_DataType[] _intTfDataTypes = {
TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64,
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
};
+ private static string div_or_truediv(string name, Tx x, Ty y)
+ {
+ bool is_floating = false;
+ var types = new List();
+
+ if (x is Tensor t1)
+ types.add(t1.dtype.is_floating());
+
+ if (y is Tensor t2)
+ types.add(t2.dtype.is_floating());
+
+ is_floating = types.Contains(true);
+
+ return is_floating ? "truediv" : name;
+ }
+
private static Tensor BinaryOpWrapper(string name, Tx x, Ty y)
{
TF_DataType dtype = TF_DataType.DtInvalid;
-
+
if (x is Tensor tl)
dtype = tl.dtype.as_base_dtype();
if (y is Tensor tr)
dtype = tr.dtype.as_base_dtype();
+ if (name == "div")
+ name = div_or_truediv(name, x, y);
+
return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
{
Tensor result;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 9f505419..67474eb9 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -102,6 +102,9 @@ namespace Tensorflow
[JsonIgnore]
#endif
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
#if SERIALIZABLE