Browse Source

Shape error for gradients/Sum_grad/Tile #193

tags/v0.8.0
haiping008 6 years ago
parent
commit
7753183941
9 changed files with 27 additions and 22 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/Framework/common_shapes.py.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  4. +1
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  6. +11
    -8
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  7. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  9. +0
    -7
      test/TensorFlowNET.Examples/LogisticRegression.cs

+ 5
- 0
src/TensorFlowNET.Core/Framework/common_shapes.py.cs View File

@@ -34,5 +34,10 @@ namespace Tensorflow.Framework
{ {
return tensor.rank; return tensor.rank;
} }

public static bool has_fully_defined_shape(Tensor tensor)
{
return tensor.getShape().is_fully_defined();
}
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -1,4 +1,5 @@
using System;
//using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;


+ 2
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -1,4 +1,5 @@
using System;
//using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;


+ 1
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,4 +1,5 @@
using Google.Protobuf.Collections; using Google.Protobuf.Collections;
//using Newtonsoft.Json;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;


+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -207,14 +207,14 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null)
public static Tensor _sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });
return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
public static Tensor _sum(Tensor input, int axis, bool keep_dims = false, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });


+ 11
- 8
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -212,26 +212,29 @@ namespace Tensorflow
throw new NotImplementedException(); throw new NotImplementedException();
} }


public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false)
public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false, string name = null)
{ {
var r = _ReductionDims(input_tensor, axis); var r = _ReductionDims(input_tensor, axis);
var m = gen_math_ops.sum(input_tensor, r);
return _may_reduce_to_scalar(keepdims, m);
var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name);
return _may_reduce_to_scalar(keepdims, axis, m);
} }


public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false) public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false)
{ {
var m = gen_math_ops.sum(input_tensor, axis);
return _may_reduce_to_scalar(keepdims, m);
var m = gen_math_ops._sum(input_tensor, axis);
return _may_reduce_to_scalar(keepdims, new int[] { axis }, m);
} }


private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor output)
private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor output)
{ {
output.shape = new long[0];
if (!common_shapes.has_fully_defined_shape(output) &&
!keepdims &&
axis == null)
output.shape = new long[0];
return output; return output;
} }


private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axos, Tensor output)
private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axis, Tensor output)
{ {
output.shape = new long[0]; output.shape = new long[0];
return output; return output;


+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -1,4 +1,5 @@
using NumSharp.Core;
//using Newtonsoft.Json;
using NumSharp.Core;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;


+ 2
- 2
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -302,7 +302,7 @@ namespace Tensorflow
default: default:
throw new NotImplementedException("as_shape Not Implemented"); throw new NotImplementedException("as_shape Not Implemented");
} }
dim.Name = $"dim_{i}";
// dim.Name = $"dim_{i}";


shape.Dim.Add(dim); shape.Dim.Add(dim);
} }
@@ -333,7 +333,7 @@ namespace Tensorflow
{ {
var dim = new TensorShapeProto.Types.Dim(); var dim = new TensorShapeProto.Types.Dim();
dim.Size = tshape.Dimensions[i]; dim.Size = tshape.Dimensions[i];
dim.Name = $"dim_{i}";
//dim.Name = $"dim_{i}";


shape.Dim.Add(dim); shape.Dim.Add(dim);
} }


+ 0
- 7
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -49,13 +49,6 @@ namespace TensorFlowNET.Examples
// Gradient Descent // Gradient Descent
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);


//var new_saver = tf.train.import_meta_graph("logistic_regression.meta.bin");

/*var text = JsonConvert.SerializeObject(tf.get_default_graph(), new JsonSerializerSettings
{
Formatting = Formatting.Indented
});*/

// Initialize the variables (i.e. assign their default value) // Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();




Loading…
Cancel
Save