diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index 40da04b1..25d9cfe8 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -21,9 +21,32 @@ namespace Tensorflow { public partial class tensorflow { + public IoApi io { get; } = new IoApi(); + + public class IoApi + { + io_ops ops; + public IoApi() + { + ops = new io_ops(); + } + + public Tensor read_file(string filename, string name = null) + => ops.read_file(filename, name); + + public Tensor read_file(Tensor filename, string name = null) + => ops.read_file(filename, name); + + public Operation save_v2(Tensor prefix, string[] tensor_names, + string[] shape_and_slices, Tensor[] tensors, string name = null) + => ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name); + + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, + string[] shape_and_slices, TF_DataType[] dtypes, string name = null) + => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); + } + public GFile gfile = new GFile(); - public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); - public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name); public ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary input_map = null, diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs index 38d92803..e19136a9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.strings.cs +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -21,12 +21,28 @@ namespace Tensorflow { public partial class tensorflow { - public strings_internal strings = new strings_internal(); - public class strings_internal + public StringsApi strings { get; } = new StringsApi(); + + public class StringsApi { + string_ops ops = new string_ops(); + + /// + /// Return substrings from `Tensor` of strings. + /// + /// + /// + /// + /// + /// + /// public Tensor substr(Tensor input, int pos, int len, string name = null, string @uint = "BYTE") - => string_ops.substr(input, pos, len, name: name, @uint: @uint); + => ops.substr(input, pos, len, @uint: @uint, name: name); + + public Tensor substr(string input, int pos, int len, + string name = null, string @uint = "BYTE") + => ops.substr(input, pos, len, @uint: @uint, name: name); } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index 8ba78f42..3121e354 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -66,11 +66,18 @@ namespace Tensorflow /// A name for the operation (optional) /// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value. - public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split( + public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) + => array_ops.split( value: value, + num_split: num_split, axis: axis, + name: name); + + public Tensor[] split(Tensor value, int num_split, int axis, string name = null) + => array_ops.split( + value: value, num_split: num_split, - name: name - ); + axis: axis, + name: name); } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs index 5fe9986c..78a63b77 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -44,10 +44,11 @@ namespace Tensorflow.Eager break; } c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); + status.Check(true); } } - if (status.ok()) - SetOpAttrs(op, attrs, status.Handle); + if (status.ok() && attrs != null) + SetOpAttrs(op, attrs); var outputs = new IntPtr[num_outputs]; if (status.ok()) diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 936877bc..0385e588 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -64,10 +64,8 @@ namespace Tensorflow.Eager } } - var flattened_inputs = args.Take(op_def.InputArg.Count) - .Select(x => x as Tensor) - .ToArray(); - var flattened_attrs = args.Skip(op_def.InputArg.Count).ToArray(); + var flattened_attrs = new List(op_def.InputArg.Count); + var flattened_inputs = new List(op_def.InputArg.Count); c_api.TFE_OpSetDevice(op, device_name, status.Handle); status.Check(true); @@ -80,31 +78,36 @@ namespace Tensorflow.Eager { int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); + if (op_exec_info.run_callbacks) + { + flattened_attrs.Add(input_arg.NumberAttr); + flattened_attrs.Add(len); + } attr_list_sizes[input_arg.NumberAttr] = len; - + if (len > 0) { var fast_input_array = (object[])args[i]; // First item adds the type attr. - if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status)) + if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) return null; for (var j = 1; j < len; j++) { // Since the list is homogeneous, we don't need to re-add the attr. - if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status)) + if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status)) return null; } } } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { - + throw new NotImplementedException(""); } else { // The item is a single item. - AddInputToOp(args[i], true, input_arg, op, status); + AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); } } @@ -133,7 +136,7 @@ namespace Tensorflow.Eager if (!RunCallbacks( op_exec_info, kFastPathExecuteInputStartIndex + op_def.InputArg.Count(), - flattened_inputs, flattened_attrs, flat_result)) + flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result)) { return null; } @@ -187,6 +190,8 @@ namespace Tensorflow.Eager bool AddInputToOp(object inputs, bool add_type_attr, ArgDef input_arg, + List flattened_attrs, + List flattened_inputs, IntPtr op, Status status) { @@ -197,9 +202,7 @@ namespace Tensorflow.Eager { case EagerTensor input: input_handle = input.EagerTensorHandle; - break; - case EagerTensor[] input_list: - input_handle = input_list[0].EagerTensorHandle; + flattened_inputs.Add(input); break; default: var tensor = tf.convert_to_tensor(inputs); @@ -211,6 +214,8 @@ namespace Tensorflow.Eager { var dtype = c_api.TFE_TensorHandleDataType(input_handle); c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); + flattened_attrs.Add(input_arg.TypeAttr); + flattened_attrs.Add(dtype); } c_api.TFE_OpAddInput(op, input_handle, status.Handle); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 1dea32bf..95f808a8 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Eager public EagerTensor Resolve() { - _id = get_uid(); + _id = ops.uid(); if (_handle == IntPtr.Zero) _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status.Handle); @@ -55,8 +55,5 @@ namespace Tensorflow.Eager //print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); c_api.TFE_DeleteTensorHandle(EagerTensorHandle); } - - static long _uid = 0; - long get_uid() => _uid++; } } diff --git a/src/TensorFlowNET.Core/Eager/Execute.cs b/src/TensorFlowNET.Core/Eager/Execute.cs index 52df5a7d..04c11a1d 100644 --- a/src/TensorFlowNET.Core/Eager/Execute.cs +++ b/src/TensorFlowNET.Core/Eager/Execute.cs @@ -44,27 +44,27 @@ namespace Tensorflow.Eager return results; } - public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) + public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) { if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) return (default_dtype, null); - if (args.Count(x => x is EagerTensor) == args.Length) - return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray()); + if (args.Count(x => x is Tensor) == args.Length) + return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray()); var dtype = TF_DataType.DtInvalid; foreach (var x in args) { - if (x is EagerTensor et) + if (x is Tensor et) dtype = et.dtype; } if (dtype == TF_DataType.DtInvalid) { - var ret = new List(); + var ret = new List(); foreach (var t in args) { - ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor); + ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor); if (dtype == TF_DataType.DtInvalid) dtype = ret.Last().dtype; } diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index 69bf264f..d33d38a2 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -24,8 +24,8 @@ namespace Tensorflow.Gradients /// public class GradientTape : IDisposable { - static bool _recording; - public static bool Recording => _recording; + bool _recording; + public bool Recording => _recording; bool _persistent; bool _watch_accessed_variables; ResourceVariable[] _watched_variables; diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 33c5f7c5..36148287 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -117,7 +117,7 @@ namespace Tensorflow.Gradients new Tensor[] { non_neg_concat_dim, tf.constant(0) }, new Tensor[] { tf.constant(1), tf.constant(-1) }); var squeeze_sizes = array_ops.squeeze(slice); - out_grads = gen_array_ops.split(grad, squeeze_sizes, (int)non_neg_concat_dim).ToList(); + out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split:(int)non_neg_concat_dim).ToList(); } else { diff --git a/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs b/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs index 7e3449b6..c6eab3b3 100644 --- a/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs +++ b/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs @@ -13,12 +13,14 @@ namespace Tensorflow.Gradients "FusedBatchNormGradV3" => new[] { 5 }, "FusedBatchNormV2" => new[] { 2 }, "FusedBatchNormV3" => new[] { 2 }, + "ReadVariableOp" => new int[0], _ => null }; public static int[] OpGradientUnusedOutputIndices(string op_name) => op_name switch { + "ReadVariableOp" => new int[0], "SoftmaxCrossEntropyWithLogits" => new[] { 0 }, "TensorArrayConcat" => new[] { 0 }, "TensorArrayConcatV2" => new[] { 0 }, diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 4e5a5e85..af584658 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -64,6 +64,22 @@ namespace Tensorflow.Gradients return new Tensor[] { r1, r2 }; } + /// + /// Copies the gradient to all inputs. + /// + /// + /// + /// + [RegisterGradient("AddN")] + public static Tensor[] _AddNGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + + return Enumerable.Range(0, len(op.inputs)) + .Select(x => grad) + .ToArray(); + } + [RegisterGradient("Cumsum")] public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index 1cb352ae..d8d1cb3d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -91,7 +91,7 @@ namespace Tensorflow gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); // i = input_gate, j = new_input, f = forget_gate, o = output_gate - var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); + var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 4cb55119..a4335046 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -18,6 +18,7 @@ using NumSharp; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -668,14 +669,29 @@ namespace Tensorflow }); } - public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, + public static Tensor[] split(Tensor value, int num_split, T axis, string name = "split") { - var size_splits = ops.convert_to_tensor(num_or_size_splits); - return gen_array_ops.split(axis: axis, - num_split: num_or_size_splits, - value: value, - name: name); + var size_splits = ops.convert_to_tensor(num_split); + + if (tf.context.executing_eagerly()) + { + return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.context); + } + + var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); + return _op.outputs; + } + + private static Tensor[] split_eager_fallback(Ta axis, Tv value, int num_split, string name, Context ctx = null) + { + var (_attr_T, input) = tf._execute.args_to_matching_eager(ctx, args: new object[] { value }); + var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); + var _inputs_flat = new List { axis_tensor }; + _inputs_flat.AddRange(input); + var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; + + return tf._execute.execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); } public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 2852c05c..5df45e61 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -376,6 +376,16 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "cond", new { pred }), delegate { + if (tf.context.executing_eagerly()) + { + if (pred.ToArray()[0]) + return true_fn() as Tensor; + else + return false_fn() as Tensor; + + return null; + } + // Add the Switch to the graph. var switch_result= @switch(pred, pred); var (p_2, p_1 )= (switch_result[0], switch_result[1]); @@ -450,6 +460,16 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "cond", new { pred }), delegate { + if (tf.context.executing_eagerly()) + { + if (pred.ToArray()[0]) + return true_fn() as Tensor[]; + else + return false_fn() as Tensor[]; + + return null; + } + // Add the Switch to the graph. var switch_result = @switch(pred, pred); var p_2 = switch_result[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 2111564c..575ea46e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -124,6 +124,16 @@ namespace Tensorflow /// public static Tensor diag(Tensor diagonal, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Diag", name, + null, + diagonal); + + return results[0]; + } + var op = tf._op_def_lib._apply_op_helper("Diag", name: name, args: new { diagonal }); return op.output; @@ -131,6 +141,16 @@ namespace Tensorflow public static Tensor expand_dims(Tensor input, int axis, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "ExpandDims", name, + null, + input, tf.convert_to_tensor(axis)); + + return results[0]; + } + var _op = tf._op_def_lib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis }); return _op.outputs[0]; @@ -463,12 +483,6 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) - { - var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); - return _op.outputs; - } - public static Tensor tile(Tensor input, T multiples, string name = null) { if (tf.context.executing_eagerly()) diff --git a/src/TensorFlowNET.Core/Operations/gen_string_ops.cs b/src/TensorFlowNET.Core/Operations/gen_string_ops.cs deleted file mode 100644 index bb407e77..00000000 --- a/src/TensorFlowNET.Core/Operations/gen_string_ops.cs +++ /dev/null @@ -1,40 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; - -namespace Tensorflow -{ - public class gen_string_ops - { - public static Tensor substr(Tensor input, int pos, int len, - string name = null, string @uint = "BYTE") - { - var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new - { - input, - pos, - len, - unit = @uint - }); - - return _op.output; - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index d3de812e..5a06b136 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Index; +using System.Linq; using System.Range; using System.Text; using Tensorflow.Operations; @@ -160,7 +161,7 @@ namespace Tensorflow Func _bmp = () => { int bmp_channels = channels; - var signature = string_ops.substr(contents, 0, 2); + var signature = tf.strings.substr(contents, 0, 2); var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); @@ -195,7 +196,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "decode_image"), scope => { - substr = string_ops.substr(contents, 0, 3); + substr = tf.strings.substr(contents, 0, 3); return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); }); } @@ -225,8 +226,11 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "is_jpeg"), scope => { - var substr = string_ops.substr(contents, 0, 3); - return math_ops.equal(substr, "\xff\xd8\xff", name: name); + var substr = tf.strings.substr(contents, 0, 3); + var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); + var jpg_tensor = tf.constant(jpg); + var result = math_ops.equal(substr, jpg_tensor, name: name); + return result; }); } @@ -234,7 +238,7 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "is_png"), scope => { - var substr = string_ops.substr(contents, 0, 3); + var substr = tf.strings.substr(contents, 0, 3); return math_ops.equal(substr, @"\211PN", name: name); }); } diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs similarity index 60% rename from src/TensorFlowNET.Core/Operations/gen_io_ops.cs rename to src/TensorFlowNET.Core/Operations/io_ops.cs index d7462116..9b8a9889 100644 --- a/src/TensorFlowNET.Core/Operations/gen_io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -14,31 +14,45 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow { - public class gen_io_ops + public class io_ops { - public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) + public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) { var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); return _op; } - public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); return _op.outputs; } - public static Tensor read_file(T filename, string name = null) + public Tensor read_file(T filename, string name = null) { + if (tf.context.executing_eagerly()) + { + return read_file_eager_fallback(filename, name: name, tf.context); + } + var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); return _op.outputs[0]; } + + private Tensor read_file_eager_fallback(T filename, string name = null, Context ctx = null) + { + var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING); + var _inputs_flat = new[] { filename_tensor }; + + return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index ee46cf78..a0b46c48 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow { @@ -31,8 +32,30 @@ namespace Tensorflow /// /// /// - public static Tensor substr(Tensor input, int pos, int len, - string name = null, string @uint = "BYTE") - => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); + public Tensor substr(T input, int pos, int len, + string @uint = "BYTE", string name = null) + { + if (tf.context.executing_eagerly()) + { + var input_tensor = tf.constant(input); + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Substr", name, + null, + input, pos, len, + "unit", @uint); + + return results[0]; + } + + var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new + { + input, + pos, + len, + unit = @uint + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs index aa9e7d90..1845e9fd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs @@ -68,9 +68,9 @@ namespace Tensorflow throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); IntPtr stringStartAddress = IntPtr.Zero; - UIntPtr dstLen = UIntPtr.Zero; + ulong dstLen = 0; - c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle); + c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle); tf.status.Check(true); var dstLenInt = checked((int) dstLen); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 3d7e4cbc..d1a75338 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -453,7 +453,7 @@ namespace Tensorflow { var buffer = Encoding.UTF8.GetBytes(str); var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong))); AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 04b22a68..579ff566 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -235,13 +235,12 @@ namespace Tensorflow var buffer = new byte[size][]; var src = c_api.TF_TensorData(_handle); - var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); src += (int)(size * 8); for (int i = 0; i < buffer.Length; i++) { IntPtr dst = IntPtr.Zero; - UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle); + ulong dstLen = 0; + var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); tf.status.Check(true); buffer[i] = new byte[(int)dstLen]; Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); @@ -254,5 +253,35 @@ namespace Tensorflow return _str; } + + public unsafe byte[][] StringBytes() + { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + + // + // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. + // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] + // + long size = 1; + foreach (var s in TensorShape.dims) + size *= s; + + var buffer = new byte[size][]; + var src = c_api.TF_TensorData(_handle); + src += (int)(size * 8); + for (int i = 0; i < buffer.Length; i++) + { + IntPtr dst = IntPtr.Zero; + ulong dstLen = 0; + var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); + tf.status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + src += (int)read; + } + + return buffer; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index ebc2b192..c9dd5e13 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -207,7 +207,7 @@ namespace Tensorflow public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status); + public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index b97ba1cd..d3c28938 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -132,28 +132,54 @@ namespace Tensorflow switch (value) { + case EagerTensor val: + return val; case NDArray val: return new EagerTensor(val, ctx.device_name); case string val: return new EagerTensor(val, ctx.device_name); + case bool val: + return new EagerTensor(val, ctx.device_name); + case byte val: + return new EagerTensor(val, ctx.device_name); + case byte[] val: + return new EagerTensor(val, ctx.device_name); + case byte[,] val: + return new EagerTensor(val, ctx.device_name); + case byte[,,] val: + return new EagerTensor(val, ctx.device_name); case int val: return new EagerTensor(val, ctx.device_name); case int[] val: return new EagerTensor(val, ctx.device_name); case int[,] val: return new EagerTensor(val, ctx.device_name); + case int[,,] val: + return new EagerTensor(val, ctx.device_name); case long val: return new EagerTensor(val, ctx.device_name); + case long[] val: + return new EagerTensor(val, ctx.device_name); + case long[,] val: + return new EagerTensor(val, ctx.device_name); + case long[,,] val: + return new EagerTensor(val, ctx.device_name); case float val: return new EagerTensor(val, ctx.device_name); + case float[] val: + return new EagerTensor(val, ctx.device_name); case float[,] val: return new EagerTensor(val, ctx.device_name); - case double val: + case float[,,] val: return new EagerTensor(val, ctx.device_name); - case float[] val: + case double val: return new EagerTensor(val, ctx.device_name); case double[] val: return new EagerTensor(val, ctx.device_name); + case double[,] val: + return new EagerTensor(val, ctx.device_name); + case double[,,] val: + return new EagerTensor(val, ctx.device_name); default: throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); } diff --git a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs index 1aae389b..7ebf94d6 100644 --- a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs @@ -55,7 +55,7 @@ namespace Tensorflow if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) { - return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); + return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); } else { @@ -76,7 +76,7 @@ namespace Tensorflow dtypes.Add(spec.dtype); } - return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); + return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); } public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, diff --git a/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs b/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs index abfb3446..a3f5dfed 100644 --- a/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs +++ b/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs @@ -39,5 +39,8 @@ namespace Tensorflow.Util } public override bool IsInvalid => handle == IntPtr.Zero; + + public override string ToString() + => $"0x{handle.ToString("x16")}"; } } diff --git a/src/TensorFlowNET.Core/Util/UnorderedMap.cs b/src/TensorFlowNET.Core/Util/UnorderedMap.cs index 397d3719..51bbecae 100644 --- a/src/TensorFlowNET.Core/Util/UnorderedMap.cs +++ b/src/TensorFlowNET.Core/Util/UnorderedMap.cs @@ -28,10 +28,10 @@ namespace Tensorflow.Util } public void push_back(Tk key, Tv value) - => Add(key, value); + => this[key] = value; public void emplace(Tk key, Tv value) - => Add(key, value); + => this[key] = value; public bool find(Tk key) => ContainsKey(key); diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs index 03ab556f..8ff760b9 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs @@ -22,56 +22,21 @@ namespace Tensorflow { public partial class ResourceVariable { - public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); - public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); - public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); - public static Tensor operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y); - public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); - public static Tensor operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y); - - public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); - public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); - - public static Tensor operator <(ResourceVariable x, Tensor y) => op_helper("less", x, y); - - public static Tensor operator >(ResourceVariable x, Tensor y) => op_helper("greater", x, y); - - private static Tensor op_helper(string default_name, ResourceVariable x, T y) - => tf_with(ops.name_scope(null, default_name, new { x, y }), scope => - { - string name = scope; - var xVal = x.value(); - var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y"); - Tensor result = null; - switch (default_name) - { - case "add": - result = x.dtype == TF_DataType.TF_STRING ? - gen_math_ops.add(xVal, yTensor, name) : - gen_math_ops.add_v2(xVal, yTensor, name); - break; - case "sub": - result = gen_math_ops.sub(xVal, yTensor, name); - break; - case "mul": - result = gen_math_ops.mul(xVal, yTensor, name: name); - break; - case "less": - result = gen_math_ops.less(xVal, yTensor, name); - break; - case "greater": - result = gen_math_ops.greater(xVal, yTensor, name); - break; - default: - throw new NotImplementedException(""); - } - - // x.assign(result); - // result.ResourceVar = x; - return result; - }); + public static Tensor operator +(ResourceVariable x, int y) => x.value() + y; + public static Tensor operator +(ResourceVariable x, float y) => x.value() + y; + public static Tensor operator +(ResourceVariable x, double y) => x.value() + y; + public static Tensor operator +(ResourceVariable x, ResourceVariable y) => x.value() + y.value(); + public static Tensor operator -(ResourceVariable x, int y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, float y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, double y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, Tensor y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, ResourceVariable y) => x.value() - y.value(); + + public static Tensor operator *(ResourceVariable x, ResourceVariable y) => x.value() * y.value(); + public static Tensor operator *(ResourceVariable x, NDArray y) => x.value() * y; + + public static Tensor operator <(ResourceVariable x, Tensor y) => x.value() < y; + + public static Tensor operator >(ResourceVariable x, Tensor y) => x.value() > y; } } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 697d3c04..ecc00f55 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -277,7 +277,7 @@ namespace Tensorflow return ops.control_dependencies(null); } - private static int uid_number = 0; + private static int uid_number = -1; /// /// A unique (within this program execution) integer. diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index 344e4374..cb3ea87a 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics Assert.AreEqual(6.0, (double)c); } - [Ignore] [TestMethod] public void StringEncode() { @@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); Assert.AreEqual(encoded_str, str); Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); - //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); + // c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index d94101cc..02ae5e43 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -2,8 +2,10 @@ using System; using System.Collections.Generic; using System.IO; +using System.Reflection; using System.Text; using Tensorflow; +using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics @@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics [TestInitialize] public void Initialize() { - imgPath = Path.GetFullPath(imgPath); - contents = tf.read_file(imgPath); + imgPath = TestHelper.GetFullPathFromDataDir(imgPath); + contents = tf.io.read_file(imgPath); } - [Ignore("")] [TestMethod] public void decode_image() { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs index 7c24ee26..1dac4e39 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs @@ -24,6 +24,25 @@ namespace TensorFlowNET.UnitTest.Gradient Assert.AreEqual((float)grad, 3.0f); } + /// + /// Calcute the gradient of w * w * w + /// 高阶梯度 + /// + [TestMethod] + public void HighGradient() + { + var x = tf.Variable(1.0f); + using var tape1 = tf.GradientTape(); + using var tape2 = tf.GradientTape(); + var y = x * x * x; + tape2.Dispose(); + var dy_dx = tape2.gradient(y, x); + Assert.AreEqual((float)dy_dx, 3.0f); + tape1.Dispose(); + var d2y_d2x = tape1.gradient(dy_dx, x); + Assert.AreEqual((float)d2y_d2x, 6.0f); + } + [TestMethod] public void ConstantMultiply() { @@ -56,5 +75,33 @@ namespace TensorFlowNET.UnitTest.Gradient var dz_dy = tape.gradient(z, y); Assert.AreEqual((float)dz_dy, 8.0f); } + + [TestMethod] + public void ConditionalMultiply() + { + Func func = (x, y) => + { + Tensor output = tf.constant(1.0f); + foreach (var i in range(y)) + { + if (i > 1) + output = tf.multiply(output, x); + } + return output; + }; + + Func grad = (x, y) => + { + using var tape = tf.GradientTape(); + tape.watch(x); + var output = func(x, y); + var grad = tape.gradient(output, x); + return grad; + }; + + var x = tf.constant(2.0f); + var result = grad(x, 4); + Assert.AreEqual((float)result, 4.0f); + } } } diff --git a/test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs b/test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs diff --git a/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs b/test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs similarity index 91% rename from test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs index ccc9c2d9..12023bd4 100644 --- a/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs @@ -6,10 +6,10 @@ using System.Text; using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.math_test +namespace TensorFlowNET.UnitTest.TF_API { [TestClass] - public class MathOperationTest : TFNetApiTest + public class MathApiTest : TFNetApiTest { // A constant vector of size 6 Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); diff --git a/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs new file mode 100644 index 00000000..3049505b --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs @@ -0,0 +1,43 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.UnitTest.TF_API +{ + [TestClass] + public class StringsApiTest + { + [TestMethod] + public void StringEqual() + { + var str1 = tf.constant("Hello1"); + var str2 = tf.constant("Hello2"); + var result = tf.equal(str1, str2); + Assert.IsFalse(result.ToScalar()); + + var str3 = tf.constant("Hello1"); + result = tf.equal(str1, str3); + Assert.IsTrue(result.ToScalar()); + + var str4 = tf.strings.substr(str1, 0, 5); + var str5 = tf.strings.substr(str2, 0, 5); + result = tf.equal(str4, str5); + Assert.IsTrue(result.ToScalar()); + } + + [TestMethod] + public void ImageType() + { + var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg"); + var contents = tf.io.read_file(imgPath); + + var substr = tf.strings.substr(contents, 0, 3); + var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); + var jpg_tensor = tf.constant(jpg); + + var result = math_ops.equal(substr, jpg_tensor); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/TFNetApiTest.cs b/test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/TFNetApiTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs diff --git a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs index b3ce7a4a..39efc8e6 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs @@ -23,6 +23,58 @@ namespace Tensorflow.UnitTest.TF_API Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray(), b.numpy().ToArray())); } + [TestMethod] + public void InitTensorTest() + { + var a = tf.constant(np.array(new[, ,] + { + { { 1 }, { 2 }, { 3 } }, + { { 4 }, { 5 }, { 6 } } + })); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); + + var b = tf.constant(new[, ,] + { + { { 1 }, { 2 }, { 3 } }, + { { 4 }, { 5 }, { 6 } } + }); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); + } + + [TestMethod] + public void ConcatTest() + { + var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); + var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); + var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); + var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); + } + + [TestMethod] + public void ConcatDoubleTest() + { + var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); + var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } }); + var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); + + var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); + } + + [TestMethod] + public void ConcatAndSplitTest() + { + var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); + var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); + var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); + + var value = tf.concat(new[] { a, b, c }, axis: 0); + + var splitValue = tf.split(value, 3, axis: 0); + Assert.AreEqual(3, splitValue.Length); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); + } } } diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs diff --git a/test/TensorFlowNET.UnitTest/nn_test/nn_test.py b/test/TensorFlowNET.UnitTest/TF_API/nn_test.py similarity index 100% rename from test/TensorFlowNET.UnitTest/nn_test/nn_test.py rename to test/TensorFlowNET.UnitTest/TF_API/nn_test.py diff --git a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs new file mode 100644 index 00000000..dbc0d3a6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Tensorflow.UnitTest +{ + public class TestHelper + { + public static string GetFullPathFromDataDir(string fileName) + { + var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); + return Path.GetFullPath(Path.Combine(dir, fileName)); + } + } +}