diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 76ff6e54..8452b81a 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -78,15 +78,15 @@ namespace Tensorflow /// /// /// A `Tensor` resulting from concatenation of the input tensors. - public Tensor concat(IList values, int axis, string name = "concat") + public Tensor concat(IEnumerable values, int axis, string name = "concat") { - if (values.Count == 1) + if (values.Count() == 1) { return tf_with(ops.name_scope(name), scope => { var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); Debug.Assert(tensor.TensorShape.ndim == 0); - return identity(values[0], name: scope); + return identity(values.First(), name: scope); }); } diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs index 3952b82c..9702e1dd 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -19,15 +19,18 @@ namespace Tensorflow public partial class tensorflow { public Tensor reshape(Tensor tensor, - TensorShape shape, - string name = null) => gen_array_ops.reshape(tensor, shape, name); + TensorShape shape, + string name = null) + => gen_array_ops.reshape(tensor, shape, name); public Tensor reshape(Tensor tensor, - Tensor[] shape, - string name = null) => gen_array_ops.reshape(tensor, shape, name); + Tensor shape, + string name = null) + => gen_array_ops.reshape(tensor, shape, name); public Tensor reshape(Tensor tensor, - Tensor shape, - string name = null) => gen_array_ops.reshape(tensor, shape, name); + object[] shape, + string name = null) + => gen_array_ops.reshape(tensor, shape, name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs index 71717e9c..7066ff82 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tile.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs @@ -13,13 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; namespace Tensorflow { public partial class tensorflow { - public Tensor tile(Tensor input, - T multiples, - string name = null) => gen_array_ops.tile(input, multiples, name); + public Tensor tile(Tensor input, Tensor multiples, string name = null) + => gen_array_ops.tile(input, multiples, name); + + public Tensor tile(Tensor input, object[] multiples, string name = null) + => gen_array_ops.tile(input, multiples, name); + + public Tensor tile(Tensor input, TensorShape multiples, string name = null) + { + var multiples_tensor = constant_op.constant(multiples); + return gen_array_ops.tile(input, multiples_tensor, name); + } } } diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index 1a5c00d2..a42b79f0 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -28,16 +28,27 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { - // [DebuggerStepThrough] - public T RunInAutoMode(Func graphAction, Func eagerAction, params Tensor[] tensors) + public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args) { - var shouldRunInEager = executing_eagerly() - && tensors.Count(x => x.IsEagerTensor) == tensors.Length; - - if (shouldRunInEager) - return eagerAction(); - else + if (tf.Context.has_graph_arg(args)) + { return graphAction(); + } + else + { + try + { + return eagerAction(); + } + catch (InvalidArgumentError ex) + { + throw ex; + } + catch (Exception ex) + { + return graphAction(); + } + } } // [DebuggerStepThrough] @@ -46,12 +57,7 @@ namespace Tensorflow.Contexts Action recordGradient, Tensors tensors) { - var shouldRunInEager = executing_eagerly() - && tensors.Count(x => x.IsEagerTensor) == tensors.Length; - - if (shouldRunInEager) - return eagerAction(); - else + if (tf.Context.has_graph_arg(tensors)) { if (executing_eagerly()) { @@ -68,6 +74,10 @@ namespace Tensorflow.Contexts return result; } } + else + { + return eagerAction(); + } } } } diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 43564fdb..5a9f15c9 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -20,6 +20,7 @@ using System.Linq; using Tensorflow.Eager; using static Tensorflow.Binding; using Google.Protobuf; +using Tensorflow.Util; namespace Tensorflow.Contexts { @@ -103,6 +104,29 @@ namespace Tensorflow.Contexts public void eager_mode(bool isFunc = false) => context_switches.Push(true, isFunc); + public bool switched_to_graph(params object[] args) + { + var switching_to_graph = has_graph_arg(args) && tf.Context.executing_eagerly(); + if (switching_to_graph) + tf.Context.graph_mode(tf.Context.is_build_function()); + return switching_to_graph; + } + + public bool has_graph_arg(params object[] args) + { + var flatten_args = nest.flatten(args); + bool has_graph_arg = false; + foreach (var el in flatten_args) + { + if (el is Tensor tensor && !tensor.IsEagerTensor) + { + has_graph_arg = true; + break; + } + } + return has_graph_arg; + } + public void restore_mode() { context_switches.Pop(); diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index ad3bd244..d072306a 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -38,9 +38,9 @@ namespace Tensorflow.Eager } }*/ } - - tf.Logger.Debug($"RecordGradient: should_record={should_record}, op_name={op_name}"); + if (!should_record) return should_record; + tf.Logger.Debug($"RecordGradient: op_name={op_name}"); Tensor[] op_outputs; #pragma warning disable CS0219 // Variable is assigned but its value is never used diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index aa56ede5..4264c929 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -50,7 +50,7 @@ namespace Tensorflow.Eager var op_def = tf.get_default_graph().GetOpDef(opName); - var flattened_attrs = new List(op_def.InputArg.Count); + var flattened_attrs = new List(op_def.Attr.Count * 2); var flattened_inputs = new List(op_def.InputArg.Count); // Set non-inferred attrs, including setting defaults if the attr is passed in @@ -221,23 +221,9 @@ namespace Tensorflow.Eager SafeTensorHandleHandle input_handle; // ConvertToTensor(); - switch (inputs) - { - case EagerTensor input: - input_handle = input.EagerTensorHandle; - flattened_inputs.Add(input); - break; - case ResourceVariable variable: - var var_tensor = variable.AsTensor(); - input_handle = var_tensor.EagerTensorHandle; - flattened_inputs.Add(var_tensor); - break; - default: - var tensor = tf.convert_to_tensor(inputs); - input_handle = tensor.EagerTensorHandle; - flattened_inputs.Add(tensor); - break; - } + var tensor = tf.convert_to_tensor(inputs); + input_handle = tensor.EagerTensorHandle; + flattened_inputs.Add(tensor); if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) { diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs index 35557701..0cdb633a 100644 --- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs +++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs @@ -2,6 +2,7 @@ using System; using System.Linq; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Framework { @@ -65,5 +66,17 @@ namespace Tensorflow.Framework public static TensorShape as_shape(this Shape shape) => new TensorShape(shape.Dimensions); + + public static TensorShape most_specific_compatible_shape(this TensorShape self, TensorShape other) + { + var dims = range(self.rank).Select(x => -1).ToArray(); + foreach(var (i, (d1, d2)) in enumerate(zip(self.dims, other.dims))) + { + if (d1 == d2) + dims[i] = d1; + } + + return new TensorShape(dims); + } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 18cd74a9..dff2db88 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -134,7 +134,7 @@ namespace Tensorflow.Functions /// /// /// - public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) + public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs) { var executing_eagerly = tf.Context.executing_eagerly(); var default_graph = ops.get_default_graph(); diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 78f8e794..b4356107 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -99,8 +99,18 @@ namespace Tensorflow.Functions if (input_index >= backward_function_inputs) break; } + tf.Logger.Debug($"Invoke backward function: {backward.Name}"); - return backward.CallFlat(processed_args, remapped_captures); + var gradients = backward.CallFlat(processed_args, remapped_captures); + + foreach (var unneeded_gradient_index in unneeded_gradients) + { + var index = Convert.ToInt32(unneeded_gradient_index); + if (gradients.Length <= index) + gradients.Insert(index, null); + } + + return gradients; }; return (_backward_function_wrapper, recorded_outputs); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index e3f14336..2ff98103 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -151,6 +151,7 @@ namespace Tensorflow /// public virtual Graph as_default() { + tf.Context.graph_mode(isFunc: false); return ops.set_default_graph(this); } @@ -532,6 +533,7 @@ namespace Tensorflow public virtual void Exit() { + tf.Context.restore_mode(); ops.pop_graph(); } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs index 83cdb28a..0a260b74 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs @@ -3,5 +3,6 @@ public class ReshapeArgs : LayerArgs { public TensorShape TargetShape { get; set; } + public object[] TargetShapeObjects { get; set; } } } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 1b2dcd24..752b1d51 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -132,7 +132,7 @@ namespace Tensorflow if (!input_arg.IsRef && dtype != DataType.DtInvalid) dtype = dtype.as_base_dtype(); - values = ops.internal_convert_n_to_tensor(values, + values = ops.internal_convert_n_to_tensor(values as object[], name: input_arg.Name, dtype: dtype.as_tf_dtype(), preferred_dtype: default_dtype.as_tf_dtype(), diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 34670070..bf5324dd 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -215,7 +216,7 @@ namespace Tensorflow } } - public static Tensor _autopacking_conversion_function(object[] v, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + public static Tensor _autopacking_conversion_function(IEnumerable v, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) { var inferred_dtype = _get_dtype_from_nested_lists(v); if (dtype == TF_DataType.DtInvalid) @@ -224,7 +225,7 @@ namespace Tensorflow return _autopacking_helper(v, dtype, name == null ? "packed" : name); } - private static TF_DataType _get_dtype_from_nested_lists(object[] list_or_tuple) + private static TF_DataType _get_dtype_from_nested_lists(IEnumerable list_or_tuple) { TF_DataType dtype = TF_DataType.DtInvalid; @@ -251,11 +252,14 @@ namespace Tensorflow /// /// /// A `tf.Tensor` with value equivalent to `list_or_tuple`. - public static Tensor _autopacking_helper(object[] list_or_tuple, TF_DataType dtype, string name) + public static Tensor _autopacking_helper(IEnumerable list_or_tuple, TF_DataType dtype, string name) { var must_pack = false; var converted_elems = new List(); - return tf_with(ops.name_scope(name), scope => + + bool switch_to_graph = tf.Context.switched_to_graph(list_or_tuple.ToArray()); + + var result = tf_with(ops.name_scope(name), scope => { foreach (var (i, elem) in enumerate(list_or_tuple)) { @@ -268,8 +272,17 @@ namespace Tensorflow var elems_as_tensors = new List(); foreach (var (i, elem) in enumerate(converted_elems)) { - if (elem is Tensor tensor) + if (elem is EagerTensor eager_tensor) + { + if(switch_to_graph) + elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString())); + else + elems_as_tensors.Add(eager_tensor); + } + else if (elem is Tensor tensor) + { elems_as_tensors.Add(tensor); + } else { var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); @@ -284,6 +297,11 @@ namespace Tensorflow return tf.constant(np.array(new float[0])); } }); + + if (switch_to_graph) + tf.Context.restore_mode(); + + return result; } public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) @@ -351,8 +369,14 @@ namespace Tensorflow public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => ones_like_impl(tensor, dtype, name, optimize); - public static Tensor reshape(Tensor tensor, T2 shape, string name = null) - => gen_array_ops.reshape(tensor, shape, null); + public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) + => gen_array_ops.reshape(tensor, shape, name: name); + + public static Tensor reshape(Tensor tensor, TensorShape shape, string name = null) + => gen_array_ops.reshape(tensor, shape, name: name); + + public static Tensor reshape(Tensor tensor, object[] shape, string name = null) + => gen_array_ops.reshape(tensor, shape, name: name); private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) { diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 2ccff1c0..3a8d70b4 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -22,7 +22,11 @@ namespace Tensorflow return results[0]; } - throw new NotImplementedException(""); + var _op = tf.OpDefLib._apply_op_helper("TensorDataset", + name: name, + args: new { components, output_shapes }); + + return _op.output; } /// @@ -180,6 +184,28 @@ namespace Tensorflow throw new NotImplementedException(""); } + public Tensor concatenate_dataset(Tensor input_dataset, Tensor another_dataset, + TF_DataType[] output_types, TensorShape[] output_shapes, + string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ConcatenateDataset", name, + null, + input_dataset, another_dataset, + "output_types", output_types, + "output_shapes", output_shapes); + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("ConcatenateDataset", + name: name, + args: new { input_dataset, another_dataset, output_types, output_shapes }); + + return _op.outputs[0]; + } + public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index d56813f4..a2db25d9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -202,20 +202,14 @@ namespace Tensorflow } public static Tensor pack(Tensor[] values, int axis = 0, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Pack", name, null, values, - "axis", axis); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Pack", name: name, args: new { values, axis }); - return _op.output; - } + "axis", axis).FirstOrDefault(), + values, axis); /// /// Return a tensor with the same shape and contents as the input tensor or value. @@ -338,12 +332,39 @@ namespace Tensorflow "Reshape", name, null, tensor, shape).FirstOrDefault(), - tensor); + tensor, shape); - public static Tensor reshape(Tensor tensor, int[] shape, string name = null) + public static Tensor reshape(Tensor tensor, object[] shape, string name = null) { - var _op = tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }); - return _op.outputs[0]; + try + { + return tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Reshape", name, + null, + tensor, shape).FirstOrDefault(), + tensor, shape); + } + catch (InvalidArgumentError ex) + { + return reshape_eager_fallback(tensor, shape, name, tf.Context); + } + } + + private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx) + { + var (_attr_T, _input) = tf.Runner.ArgsToMatchingEager(ctx, args: new[] { tensor }); + var (_attr_Tshape, _input_shape) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { shape }, default_dtype: TF_DataType.TF_INT32); + var _inputs_flat = new[] { _input[0], _input_shape[0] }; + var _attrs = new object[] { "T", _attr_T, "Tshape", _attr_Tshape }; + + var results = tf.Runner.Execute(ctx, "Reshape", 1, _inputs_flat, _attrs, name: name); + if (tf.Runner.MustRecordGradient()) + { + tf.Runner.RecordGradient("Reshape", _inputs_flat, _attrs, results); + } + return results[0]; } /// @@ -537,14 +558,23 @@ namespace Tensorflow return _op.outputs; } - public static Tensor tile(Tensor input, T multiples, string name = null) + public static Tensor tile(Tensor input, Tensor multiples, string name = null) => tf.Context.RunInAutoMode(() => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Tile", name, null, input, multiples).FirstOrDefault(), - input); + input, multiples); + + public static Tensor tile(Tensor input, object[] multiples, string name = null) + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Tile", name, + null, + input, multiples).FirstOrDefault(), + input, multiples); public static Tensor transpose(Tensor x, T1 perm, string name = null) { diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index e48cb031..5cfa7664 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -1874,7 +1874,7 @@ new_height, new_width"); { using (ops.name_scope("suppression_loop_body")) { - var num_tiles = Math.Floor((double)array_ops.shape(boxes).dims[1] / tile_size); + var num_tiles = array_ops.shape(boxes).dims[1] / tile_size; var batch_size = array_ops.shape(boxes).dims[0]; (Tensor, Tensor, Tensor, Tensor) cross_suppression_func(Tensor boxes, Tensor box_slice, Tensor iou_threshold, Tensor inner_idx, int tile_size) diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index edccbc17..635ef5d0 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 2.2.0 - 0.32.0 + 0.33.0 8.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK @@ -19,7 +19,7 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.32.0.0 + 0.33.0.0 tf.net 0.20.x and above are based on tensorflow native 2.x. * Eager Mode is added finally. @@ -28,7 +28,7 @@ https://tensorflownet.readthedocs.io * autograph works partially. TensorFlow .NET v0.3x is focused on making more Keras API works - 0.32.0.0 + 0.33.0.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs index 9d6c4af6..bbcb5f28 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs @@ -4,6 +4,12 @@ namespace Tensorflow { public partial class TensorShape { + public void Deconstruct(out int h, out int w) + { + h = dims[0]; + w = dims[1]; + } + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[])shape.Dimensions.Clone()); public static implicit operator Shape(TensorShape shape) => new Shape((int[])shape.dims.Clone()); diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 8e0315ef..aed72222 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -60,6 +60,9 @@ namespace Tensorflow public void AddRange(Tensor[] tensors) => items.AddRange(tensors); + public void Insert(int index, Tensor tensor) + => items.Insert(index, tensor); + IEnumerator IEnumerable.GetEnumerator() { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 45a809ca..d35ed34e 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -156,8 +156,8 @@ namespace Tensorflow return val; case NDArray val: return new EagerTensor(val, ctx.DeviceName); - //case TensorShape val: - //return new EagerTensor(val.dims, ctx.DeviceName); + case TensorShape val: + return new EagerTensor(val.dims, ctx.DeviceName); case string val: return new EagerTensor(val, ctx.DeviceName); case string[] val: diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index c5e964f6..140a1ca6 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -123,6 +123,17 @@ namespace Tensorflow { nparray = nd; } + else if(values is string str) + { + // scalar string + nparray = convert_to_numpy_ndarray(values); + shape = new int[0]; + } + else if(values is string[] strings) + { + nparray = convert_to_numpy_ndarray(values); + shape = new[] { strings.Length }; + } else { if (values == null) @@ -151,9 +162,14 @@ namespace Tensorflow { if (numpy_dtype == TF_DataType.TF_STRING) { - // scalar string - shape = new int[0]; - shape_size = 0; + if (nparray.ndim == 0) + { + // scalar string + shape = new int[0]; + shape_size = 0; + } + else + throw new NotImplementedException($"Not implemented for {nparray.ndim} dims string array."); } else { @@ -428,6 +444,9 @@ would not be rank 1.", tensor.op.get_attr("axis"))); case NDArray val: nd = val; break; + case TensorShape val: + nd = val.dims; + break; case bool boolVal: nd = boolVal; break; @@ -471,7 +490,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); nd = new NDArray(Encoding.ASCII.GetBytes(strVal)); break; case string[] strVals: - nd = strVals; + nd = np.array(strVals); break; case byte[] byteValues: nd = byteValues; diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index eaa78399..3c651310 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -150,7 +150,8 @@ namespace Tensorflow TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), int[] dims => constant_op.constant(dims, dtype: dtype, name: name), string str => constant_op.constant(str, dtype: tf.@string, name: name), - object[] objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), + string[] str => constant_op.constant(str, dtype: tf.@string, name: name), + IEnumerable objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), _ => constant_op.constant(value, dtype: dtype, name: name) }; @@ -500,18 +501,16 @@ namespace Tensorflow return ret.ToArray(); } - public static Tensor[] internal_convert_n_to_tensor(object values, TF_DataType dtype = TF_DataType.DtInvalid, + public static Tensor[] internal_convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid, bool as_ref = false) { var ret = new List(); - - foreach ((int i, object value) in enumerate(values as object[])) + foreach ((int i, object value) in enumerate(values)) { string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; ret.Add(convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype)); } - return ret.ToArray(); } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index f7349dc5..60b22f71 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -48,7 +48,7 @@ namespace Tensorflow public tensorflow() { Logger = new LoggerConfiguration() - .MinimumLevel.Warning() + .MinimumLevel.Error() .WriteTo.Console() .CreateLogger();