Browse Source

fix Incompatible shapes for node define_loss/bigger_box_loss/mul_13 #424

tags/v0.13
Oceania2018 6 years ago
parent
commit
fa4a931d5a
6 changed files with 68 additions and 37 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  3. +6
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
  5. +52
    -32
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 2
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -75,9 +75,9 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// </summary> /// </summary>
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
public partial class Graph : DisposableObject,
public partial class Graph : DisposableObject
#if !SERIALIZABLE #if !SERIALIZABLE
IEnumerable<Operation>
,IEnumerable<Operation>
#endif #endif
{ {
private Dictionary<int, ITensorOrOperation> _nodes_by_id; private Dictionary<int, ITensorOrOperation> _nodes_by_id;


+ 3
- 0
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -60,6 +60,9 @@ namespace Tensorflow
/// <summary> /// <summary>
/// List this operation's output types. /// List this operation's output types.
/// </summary> /// </summary>
#if SERIALIZABLE
[JsonIgnore]
#endif
public TF_DataType[] _output_types public TF_DataType[] _output_types
{ {
get get


+ 6
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -78,7 +78,10 @@ namespace Tensorflow
#if SERIALIZABLE #if SERIALIZABLE
[JsonIgnore] [JsonIgnore]
#endif #endif
bool _is_stateful;
bool _is_stateful;
#if SERIALIZABLE
[JsonIgnore]
#endif
public NodeDef node_def public NodeDef node_def
{ {
get get
@@ -181,8 +184,8 @@ namespace Tensorflow
// This will be set by self.inputs. // This will be set by self.inputs.
if (op_def == null) if (op_def == null)
op_def = g.GetOpDef(node_def.Op);
op_def = g.GetOpDef(node_def.Op);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_is_stateful = op_def.IsStateful; _is_stateful = op_def.IsStateful;


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -79,6 +79,8 @@ namespace Tensorflow
} }


strides.Add(s.Step); strides.Add(s.Step);
if (s.IsIndex)
shrink_axis_mask |= (1 << index);
} }


index += 1; index += 1;


+ 52
- 32
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -16,6 +16,7 @@


using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Numerics; using System.Numerics;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -25,7 +26,7 @@ namespace Tensorflow
public partial class Tensor public partial class Tensor
{ {
#if _REGEN #if _REGEN
#region Compute
#region Compute
%operators = ["add", "sub", "mul", "div", "mod"] %operators = ["add", "sub", "mul", "div", "mod"]
%operators_sign = ["+", "-", "*", "/", "%"] %operators_sign = ["+", "-", "*", "/", "%"]
%operators_comparers = [">", "<", ">=", "<="] %operators_comparers = [">", "<", ">=", "<="]
@@ -49,11 +50,11 @@ namespace Tensorflow
% %
% %
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); public static Tensor operator -(Tensor x) => gen_math_ops.neg(x);
#endregion
#endregion
#else #else
#region Compute
#region Compute



public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs);
@@ -129,31 +130,31 @@ namespace Tensorflow
public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs);
public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs);
public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs);
public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs);
public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs);
public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs);
public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs);
public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs);
@@ -281,24 +282,43 @@ namespace Tensorflow
public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs);
public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs);
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); public static Tensor operator -(Tensor x) => gen_math_ops.neg(x);
#endregion
#endregion
#endif #endif
private static readonly TF_DataType[] _intTfDataTypes = { private static readonly TF_DataType[] _intTfDataTypes = {
TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64,
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
}; };


private static string div_or_truediv<Tx, Ty>(string name, Tx x, Ty y)
{
bool is_floating = false;
var types = new List<bool>();
if (x is Tensor t1)
types.add(t1.dtype.is_floating());

if (y is Tensor t2)
types.add(t2.dtype.is_floating());

is_floating = types.Contains(true);

return is_floating ? "truediv" : name;
}

private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
{ {
TF_DataType dtype = TF_DataType.DtInvalid; TF_DataType dtype = TF_DataType.DtInvalid;
if (x is Tensor tl) if (x is Tensor tl)
dtype = tl.dtype.as_base_dtype(); dtype = tl.dtype.as_base_dtype();
if (y is Tensor tr) if (y is Tensor tr)
dtype = tr.dtype.as_base_dtype(); dtype = tr.dtype.as_base_dtype();


if (name == "div")
name = div_or_truediv(name, x, y);

return tf_with(ops.name_scope(null, name, new { x, y }), scope => return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
{ {
Tensor result; Tensor result;


+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -102,6 +102,9 @@ namespace Tensorflow
[JsonIgnore] [JsonIgnore]
#endif #endif
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
#if SERIALIZABLE
[JsonIgnore]
#endif
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
#if SERIALIZABLE #if SERIALIZABLE


Loading…
Cancel
Save