diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index c9ad6402..f8746f7f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -75,9 +75,9 @@ namespace Tensorflow /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// /// https://www.tensorflow.org/guide/graphs

https://www.tensorflow.org/api_docs/python/tf/Graph
- public partial class Graph : DisposableObject, + public partial class Graph : DisposableObject #if !SERIALIZABLE - IEnumerable + ,IEnumerable #endif { private Dictionary _nodes_by_id; diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 77bf68a1..abe8e9c1 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -60,6 +60,9 @@ namespace Tensorflow /// /// List this operation's output types. /// +#if SERIALIZABLE + [JsonIgnore] +#endif public TF_DataType[] _output_types { get diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 0f9ed2eb..359dc870 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -78,7 +78,10 @@ namespace Tensorflow #if SERIALIZABLE [JsonIgnore] #endif - bool _is_stateful; + bool _is_stateful; +#if SERIALIZABLE + [JsonIgnore] +#endif public NodeDef node_def { get @@ -181,8 +184,8 @@ namespace Tensorflow // This will be set by self.inputs. if (op_def == null) - op_def = g.GetOpDef(node_def.Op); - + op_def = g.GetOpDef(node_def.Op); + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); _is_stateful = op_def.IsStateful; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs index d916f624..26c251b0 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -79,6 +79,8 @@ namespace Tensorflow } strides.Add(s.Step); + if (s.IsIndex) + shrink_axis_mask |= (1 << index); } index += 1; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 5a3d6f79..2a901e07 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -16,6 +16,7 @@ using NumSharp; using System; +using System.Collections.Generic; using System.Linq; using System.Numerics; using static Tensorflow.Binding; @@ -25,7 +26,7 @@ namespace Tensorflow public partial class Tensor { #if _REGEN - #region Compute + #region Compute %operators = ["add", "sub", "mul", "div", "mod"] %operators_sign = ["+", "-", "*", "/", "%"] %operators_comparers = [">", "<", ">=", "<="] @@ -49,11 +50,11 @@ namespace Tensorflow % % public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); - #endregion + #endregion #else - #region Compute + #region Compute + - public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); @@ -129,31 +130,31 @@ namespace Tensorflow public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); - public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("truediv", lhs, rhs); - public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); + public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); @@ -281,24 +282,43 @@ namespace Tensorflow public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); - #endregion + #endregion #endif - + private static readonly TF_DataType[] _intTfDataTypes = { TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; + private static string div_or_truediv(string name, Tx x, Ty y) + { + bool is_floating = false; + var types = new List(); + + if (x is Tensor t1) + types.add(t1.dtype.is_floating()); + + if (y is Tensor t2) + types.add(t2.dtype.is_floating()); + + is_floating = types.Contains(true); + + return is_floating ? "truediv" : name; + } + private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; - + if (x is Tensor tl) dtype = tl.dtype.as_base_dtype(); if (y is Tensor tr) dtype = tr.dtype.as_base_dtype(); + if (name == "div") + name = div_or_truediv(name, x, y); + return tf_with(ops.name_scope(null, name, new { x, y }), scope => { Tensor result; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 9f505419..67474eb9 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -102,6 +102,9 @@ namespace Tensorflow [JsonIgnore] #endif public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; +#if SERIALIZABLE + [JsonIgnore] +#endif public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); #if SERIALIZABLE