diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 76ff6e54..8452b81a 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -78,15 +78,15 @@ namespace Tensorflow
///
///
/// A `Tensor` resulting from concatenation of the input tensors.
- public Tensor concat(IList values, int axis, string name = "concat")
+ public Tensor concat(IEnumerable values, int axis, string name = "concat")
{
- if (values.Count == 1)
+ if (values.Count() == 1)
{
return tf_with(ops.name_scope(name), scope =>
{
var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32);
Debug.Assert(tensor.TensorShape.ndim == 0);
- return identity(values[0], name: scope);
+ return identity(values.First(), name: scope);
});
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
index 3952b82c..9702e1dd 100644
--- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
@@ -19,15 +19,18 @@ namespace Tensorflow
public partial class tensorflow
{
public Tensor reshape(Tensor tensor,
- TensorShape shape,
- string name = null) => gen_array_ops.reshape(tensor, shape, name);
+ TensorShape shape,
+ string name = null)
+ => gen_array_ops.reshape(tensor, shape, name);
public Tensor reshape(Tensor tensor,
- Tensor[] shape,
- string name = null) => gen_array_ops.reshape(tensor, shape, name);
+ Tensor shape,
+ string name = null)
+ => gen_array_ops.reshape(tensor, shape, name);
public Tensor reshape(Tensor tensor,
- Tensor shape,
- string name = null) => gen_array_ops.reshape(tensor, shape, name);
+ object[] shape,
+ string name = null)
+ => gen_array_ops.reshape(tensor, shape, name);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs
index 71717e9c..7066ff82 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tile.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs
@@ -13,13 +13,22 @@
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
+using static Tensorflow.Binding;
namespace Tensorflow
{
public partial class tensorflow
{
- public Tensor tile(Tensor input,
- T multiples,
- string name = null) => gen_array_ops.tile(input, multiples, name);
+ public Tensor tile(Tensor input, Tensor multiples, string name = null)
+ => gen_array_ops.tile(input, multiples, name);
+
+ public Tensor tile(Tensor input, object[] multiples, string name = null)
+ => gen_array_ops.tile(input, multiples, name);
+
+ public Tensor tile(Tensor input, TensorShape multiples, string name = null)
+ {
+ var multiples_tensor = constant_op.constant(multiples);
+ return gen_array_ops.tile(input, multiples_tensor, name);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
index 1a5c00d2..a42b79f0 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
@@ -28,16 +28,27 @@ namespace Tensorflow.Contexts
///
public sealed partial class Context
{
- // [DebuggerStepThrough]
- public T RunInAutoMode(Func graphAction, Func eagerAction, params Tensor[] tensors)
+ public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args)
{
- var shouldRunInEager = executing_eagerly()
- && tensors.Count(x => x.IsEagerTensor) == tensors.Length;
-
- if (shouldRunInEager)
- return eagerAction();
- else
+ if (tf.Context.has_graph_arg(args))
+ {
return graphAction();
+ }
+ else
+ {
+ try
+ {
+ return eagerAction();
+ }
+ catch (InvalidArgumentError ex)
+ {
+ throw ex;
+ }
+ catch (Exception ex)
+ {
+ return graphAction();
+ }
+ }
}
// [DebuggerStepThrough]
@@ -46,12 +57,7 @@ namespace Tensorflow.Contexts
Action recordGradient,
Tensors tensors)
{
- var shouldRunInEager = executing_eagerly()
- && tensors.Count(x => x.IsEagerTensor) == tensors.Length;
-
- if (shouldRunInEager)
- return eagerAction();
- else
+ if (tf.Context.has_graph_arg(tensors))
{
if (executing_eagerly())
{
@@ -68,6 +74,10 @@ namespace Tensorflow.Contexts
return result;
}
}
+ else
+ {
+ return eagerAction();
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs
index 43564fdb..5a9f15c9 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.cs
@@ -20,6 +20,7 @@ using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;
using Google.Protobuf;
+using Tensorflow.Util;
namespace Tensorflow.Contexts
{
@@ -103,6 +104,29 @@ namespace Tensorflow.Contexts
public void eager_mode(bool isFunc = false)
=> context_switches.Push(true, isFunc);
+ public bool switched_to_graph(params object[] args)
+ {
+ var switching_to_graph = has_graph_arg(args) && tf.Context.executing_eagerly();
+ if (switching_to_graph)
+ tf.Context.graph_mode(tf.Context.is_build_function());
+ return switching_to_graph;
+ }
+
+ public bool has_graph_arg(params object[] args)
+ {
+ var flatten_args = nest.flatten