| @@ -21,9 +21,32 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public IoApi io { get; } = new IoApi(); | |||||
| public class IoApi | |||||
| { | |||||
| io_ops ops; | |||||
| public IoApi() | |||||
| { | |||||
| ops = new io_ops(); | |||||
| } | |||||
| public Tensor read_file(string filename, string name = null) | |||||
| => ops.read_file(filename, name); | |||||
| public Tensor read_file(Tensor filename, string name = null) | |||||
| => ops.read_file(filename, name); | |||||
| public Operation save_v2(Tensor prefix, string[] tensor_names, | |||||
| string[] shape_and_slices, Tensor[] tensors, string name = null) | |||||
| => ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name); | |||||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, | |||||
| string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||||
| => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); | |||||
| } | |||||
| public GFile gfile = new GFile(); | public GFile gfile = new GFile(); | ||||
| public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | |||||
| public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name); | |||||
| public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | ||||
| Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
| @@ -21,12 +21,28 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public strings_internal strings = new strings_internal(); | |||||
| public class strings_internal | |||||
| public StringsApi strings { get; } = new StringsApi(); | |||||
| public class StringsApi | |||||
| { | { | ||||
| string_ops ops = new string_ops(); | |||||
| /// <summary> | |||||
| /// Return substrings from `Tensor` of strings. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="pos"></param> | |||||
| /// <param name="len"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="uint"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor substr(Tensor input, int pos, int len, | public Tensor substr(Tensor input, int pos, int len, | ||||
| string name = null, string @uint = "BYTE") | string name = null, string @uint = "BYTE") | ||||
| => string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||||
| => ops.substr(input, pos, len, @uint: @uint, name: name); | |||||
| public Tensor substr(string input, int pos, int len, | |||||
| string name = null, string @uint = "BYTE") | |||||
| => ops.substr(input, pos, len, @uint: @uint, name: name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -66,11 +66,18 @@ namespace Tensorflow | |||||
| /// <param name="name">A name for the operation (optional)</param> | /// <param name="name">A name for the operation (optional)</param> | ||||
| /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | ||||
| /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | ||||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split( | |||||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||||
| => array_ops.split( | |||||
| value: value, | value: value, | ||||
| num_split: num_split, | |||||
| axis: axis, | axis: axis, | ||||
| name: name); | |||||
| public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||||
| => array_ops.split( | |||||
| value: value, | |||||
| num_split: num_split, | num_split: num_split, | ||||
| name: name | |||||
| ); | |||||
| axis: axis, | |||||
| name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -44,10 +44,11 @@ namespace Tensorflow.Eager | |||||
| break; | break; | ||||
| } | } | ||||
| c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); | c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); | ||||
| status.Check(true); | |||||
| } | } | ||||
| } | } | ||||
| if (status.ok()) | |||||
| SetOpAttrs(op, attrs, status.Handle); | |||||
| if (status.ok() && attrs != null) | |||||
| SetOpAttrs(op, attrs); | |||||
| var outputs = new IntPtr[num_outputs]; | var outputs = new IntPtr[num_outputs]; | ||||
| if (status.ok()) | if (status.ok()) | ||||
| @@ -64,10 +64,8 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| } | } | ||||
| var flattened_inputs = args.Take(op_def.InputArg.Count) | |||||
| .Select(x => x as Tensor) | |||||
| .ToArray(); | |||||
| var flattened_attrs = args.Skip(op_def.InputArg.Count).ToArray(); | |||||
| var flattened_attrs = new List<object>(op_def.InputArg.Count); | |||||
| var flattened_inputs = new List<Tensor>(op_def.InputArg.Count); | |||||
| c_api.TFE_OpSetDevice(op, device_name, status.Handle); | c_api.TFE_OpSetDevice(op, device_name, status.Handle); | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -80,31 +78,36 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; | int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; | ||||
| c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | ||||
| if (op_exec_info.run_callbacks) | |||||
| { | |||||
| flattened_attrs.Add(input_arg.NumberAttr); | |||||
| flattened_attrs.Add(len); | |||||
| } | |||||
| attr_list_sizes[input_arg.NumberAttr] = len; | attr_list_sizes[input_arg.NumberAttr] = len; | ||||
| if (len > 0) | if (len > 0) | ||||
| { | { | ||||
| var fast_input_array = (object[])args[i]; | var fast_input_array = (object[])args[i]; | ||||
| // First item adds the type attr. | // First item adds the type attr. | ||||
| if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status)) | |||||
| if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) | |||||
| return null; | return null; | ||||
| for (var j = 1; j < len; j++) | for (var j = 1; j < len; j++) | ||||
| { | { | ||||
| // Since the list is homogeneous, we don't need to re-add the attr. | // Since the list is homogeneous, we don't need to re-add the attr. | ||||
| if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status)) | |||||
| if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status)) | |||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // The item is a single item. | // The item is a single item. | ||||
| AddInputToOp(args[i], true, input_arg, op, status); | |||||
| AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -133,7 +136,7 @@ namespace Tensorflow.Eager | |||||
| if (!RunCallbacks( | if (!RunCallbacks( | ||||
| op_exec_info, | op_exec_info, | ||||
| kFastPathExecuteInputStartIndex + op_def.InputArg.Count(), | kFastPathExecuteInputStartIndex + op_def.InputArg.Count(), | ||||
| flattened_inputs, flattened_attrs, flat_result)) | |||||
| flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result)) | |||||
| { | { | ||||
| return null; | return null; | ||||
| } | } | ||||
| @@ -187,6 +190,8 @@ namespace Tensorflow.Eager | |||||
| bool AddInputToOp(object inputs, | bool AddInputToOp(object inputs, | ||||
| bool add_type_attr, | bool add_type_attr, | ||||
| ArgDef input_arg, | ArgDef input_arg, | ||||
| List<object> flattened_attrs, | |||||
| List<Tensor> flattened_inputs, | |||||
| IntPtr op, | IntPtr op, | ||||
| Status status) | Status status) | ||||
| { | { | ||||
| @@ -197,9 +202,7 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| case EagerTensor input: | case EagerTensor input: | ||||
| input_handle = input.EagerTensorHandle; | input_handle = input.EagerTensorHandle; | ||||
| break; | |||||
| case EagerTensor[] input_list: | |||||
| input_handle = input_list[0].EagerTensorHandle; | |||||
| flattened_inputs.Add(input); | |||||
| break; | break; | ||||
| default: | default: | ||||
| var tensor = tf.convert_to_tensor(inputs); | var tensor = tf.convert_to_tensor(inputs); | ||||
| @@ -211,6 +214,8 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| var dtype = c_api.TFE_TensorHandleDataType(input_handle); | var dtype = c_api.TFE_TensorHandleDataType(input_handle); | ||||
| c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | ||||
| flattened_attrs.Add(input_arg.TypeAttr); | |||||
| flattened_attrs.Add(dtype); | |||||
| } | } | ||||
| c_api.TFE_OpAddInput(op, input_handle, status.Handle); | c_api.TFE_OpAddInput(op, input_handle, status.Handle); | ||||
| @@ -34,7 +34,7 @@ namespace Tensorflow.Eager | |||||
| public EagerTensor Resolve() | public EagerTensor Resolve() | ||||
| { | { | ||||
| _id = get_uid(); | |||||
| _id = ops.uid(); | |||||
| if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
| _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status.Handle); | _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status.Handle); | ||||
| @@ -55,8 +55,5 @@ namespace Tensorflow.Eager | |||||
| //print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | //print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | ||||
| c_api.TFE_DeleteTensorHandle(EagerTensorHandle); | c_api.TFE_DeleteTensorHandle(EagerTensorHandle); | ||||
| } | } | ||||
| static long _uid = 0; | |||||
| long get_uid() => _uid++; | |||||
| } | } | ||||
| } | } | ||||
| @@ -44,27 +44,27 @@ namespace Tensorflow.Eager | |||||
| return results; | return results; | ||||
| } | } | ||||
| public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) | |||||
| public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) | |||||
| { | { | ||||
| if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) | if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) | ||||
| return (default_dtype, null); | return (default_dtype, null); | ||||
| if (args.Count(x => x is EagerTensor) == args.Length) | |||||
| return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray()); | |||||
| if (args.Count(x => x is Tensor) == args.Length) | |||||
| return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray()); | |||||
| var dtype = TF_DataType.DtInvalid; | var dtype = TF_DataType.DtInvalid; | ||||
| foreach (var x in args) | foreach (var x in args) | ||||
| { | { | ||||
| if (x is EagerTensor et) | |||||
| if (x is Tensor et) | |||||
| dtype = et.dtype; | dtype = et.dtype; | ||||
| } | } | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| { | { | ||||
| var ret = new List<EagerTensor>(); | |||||
| var ret = new List<Tensor>(); | |||||
| foreach (var t in args) | foreach (var t in args) | ||||
| { | { | ||||
| ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor); | |||||
| ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor); | |||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = ret.Last().dtype; | dtype = ret.Last().dtype; | ||||
| } | } | ||||
| @@ -24,8 +24,8 @@ namespace Tensorflow.Gradients | |||||
| /// </summary> | /// </summary> | ||||
| public class GradientTape : IDisposable | public class GradientTape : IDisposable | ||||
| { | { | ||||
| static bool _recording; | |||||
| public static bool Recording => _recording; | |||||
| bool _recording; | |||||
| public bool Recording => _recording; | |||||
| bool _persistent; | bool _persistent; | ||||
| bool _watch_accessed_variables; | bool _watch_accessed_variables; | ||||
| ResourceVariable[] _watched_variables; | ResourceVariable[] _watched_variables; | ||||
| @@ -117,7 +117,7 @@ namespace Tensorflow.Gradients | |||||
| new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | ||||
| new Tensor[] { tf.constant(1), tf.constant(-1) }); | new Tensor[] { tf.constant(1), tf.constant(-1) }); | ||||
| var squeeze_sizes = array_ops.squeeze(slice); | var squeeze_sizes = array_ops.squeeze(slice); | ||||
| out_grads = gen_array_ops.split(grad, squeeze_sizes, (int)non_neg_concat_dim).ToList(); | |||||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split:(int)non_neg_concat_dim).ToList(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -13,12 +13,14 @@ namespace Tensorflow.Gradients | |||||
| "FusedBatchNormGradV3" => new[] { 5 }, | "FusedBatchNormGradV3" => new[] { 5 }, | ||||
| "FusedBatchNormV2" => new[] { 2 }, | "FusedBatchNormV2" => new[] { 2 }, | ||||
| "FusedBatchNormV3" => new[] { 2 }, | "FusedBatchNormV3" => new[] { 2 }, | ||||
| "ReadVariableOp" => new int[0], | |||||
| _ => null | _ => null | ||||
| }; | }; | ||||
| public static int[] OpGradientUnusedOutputIndices(string op_name) | public static int[] OpGradientUnusedOutputIndices(string op_name) | ||||
| => op_name switch | => op_name switch | ||||
| { | { | ||||
| "ReadVariableOp" => new int[0], | |||||
| "SoftmaxCrossEntropyWithLogits" => new[] { 0 }, | "SoftmaxCrossEntropyWithLogits" => new[] { 0 }, | ||||
| "TensorArrayConcat" => new[] { 0 }, | "TensorArrayConcat" => new[] { 0 }, | ||||
| "TensorArrayConcatV2" => new[] { 0 }, | "TensorArrayConcatV2" => new[] { 0 }, | ||||
| @@ -64,6 +64,22 @@ namespace Tensorflow.Gradients | |||||
| return new Tensor[] { r1, r2 }; | return new Tensor[] { r1, r2 }; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Copies the gradient to all inputs. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="grads"></param> | |||||
| /// <returns></returns> | |||||
| [RegisterGradient("AddN")] | |||||
| public static Tensor[] _AddNGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var grad = grads[0]; | |||||
| return Enumerable.Range(0, len(op.inputs)) | |||||
| .Select(x => grad) | |||||
| .ToArray(); | |||||
| } | |||||
| [RegisterGradient("Cumsum")] | [RegisterGradient("Cumsum")] | ||||
| public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads) | public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads) | ||||
| { | { | ||||
| @@ -91,7 +91,7 @@ namespace Tensorflow | |||||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | ||||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | // i = input_gate, j = new_input, f = forget_gate, o = output_gate | ||||
| var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||||
| var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||||
| var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | ||||
| var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | ||||
| @@ -18,6 +18,7 @@ using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -668,14 +669,29 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, | |||||
| public static Tensor[] split<T>(Tensor value, int num_split, T axis, | |||||
| string name = "split") | string name = "split") | ||||
| { | { | ||||
| var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||||
| return gen_array_ops.split(axis: axis, | |||||
| num_split: num_or_size_splits, | |||||
| value: value, | |||||
| name: name); | |||||
| var size_splits = ops.convert_to_tensor(num_split); | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.context); | |||||
| } | |||||
| var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||||
| return _op.outputs; | |||||
| } | |||||
| private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null) | |||||
| { | |||||
| var (_attr_T, input) = tf._execute.args_to_matching_eager(ctx, args: new object[] { value }); | |||||
| var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); | |||||
| var _inputs_flat = new List<Tensor> { axis_tensor }; | |||||
| _inputs_flat.AddRange(input); | |||||
| var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; | |||||
| return tf._execute.execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); | |||||
| } | } | ||||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | ||||
| @@ -376,6 +376,16 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| if (pred.ToArray<bool>()[0]) | |||||
| return true_fn() as Tensor; | |||||
| else | |||||
| return false_fn() as Tensor; | |||||
| return null; | |||||
| } | |||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var switch_result= @switch(pred, pred); | var switch_result= @switch(pred, pred); | ||||
| var (p_2, p_1 )= (switch_result[0], switch_result[1]); | var (p_2, p_1 )= (switch_result[0], switch_result[1]); | ||||
| @@ -450,6 +460,16 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| if (pred.ToArray<bool>()[0]) | |||||
| return true_fn() as Tensor[]; | |||||
| else | |||||
| return false_fn() as Tensor[]; | |||||
| return null; | |||||
| } | |||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var switch_result = @switch(pred, pred); | var switch_result = @switch(pred, pred); | ||||
| var p_2 = switch_result[0]; | var p_2 = switch_result[0]; | ||||
| @@ -124,6 +124,16 @@ namespace Tensorflow | |||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor diag(Tensor diagonal, string name = null) | public static Tensor diag(Tensor diagonal, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Diag", name, | |||||
| null, | |||||
| diagonal); | |||||
| return results[0]; | |||||
| } | |||||
| var op = tf._op_def_lib._apply_op_helper("Diag", name: name, args: new { diagonal }); | var op = tf._op_def_lib._apply_op_helper("Diag", name: name, args: new { diagonal }); | ||||
| return op.output; | return op.output; | ||||
| @@ -131,6 +141,16 @@ namespace Tensorflow | |||||
| public static Tensor expand_dims(Tensor input, int axis, string name = null) | public static Tensor expand_dims(Tensor input, int axis, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "ExpandDims", name, | |||||
| null, | |||||
| input, tf.convert_to_tensor(axis)); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf._op_def_lib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis }); | var _op = tf._op_def_lib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -463,12 +483,6 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) | |||||
| { | |||||
| var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||||
| return _op.outputs; | |||||
| } | |||||
| public static Tensor tile<T>(Tensor input, T multiples, string name = null) | public static Tensor tile<T>(Tensor input, T multiples, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| @@ -1,40 +0,0 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class gen_string_ops | |||||
| { | |||||
| public static Tensor substr(Tensor input, int pos, int len, | |||||
| string name = null, string @uint = "BYTE") | |||||
| { | |||||
| var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new | |||||
| { | |||||
| input, | |||||
| pos, | |||||
| len, | |||||
| unit = @uint | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Index; | using System.Index; | ||||
| using System.Linq; | |||||
| using System.Range; | using System.Range; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| @@ -160,7 +161,7 @@ namespace Tensorflow | |||||
| Func<ITensorOrOperation> _bmp = () => | Func<ITensorOrOperation> _bmp = () => | ||||
| { | { | ||||
| int bmp_channels = channels; | int bmp_channels = channels; | ||||
| var signature = string_ops.substr(contents, 0, 2); | |||||
| var signature = tf.strings.substr(contents, 0, 2); | |||||
| var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | ||||
| string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | ||||
| var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | ||||
| @@ -195,7 +196,7 @@ namespace Tensorflow | |||||
| return tf_with(ops.name_scope(name, "decode_image"), scope => | return tf_with(ops.name_scope(name, "decode_image"), scope => | ||||
| { | { | ||||
| substr = string_ops.substr(contents, 0, 3); | |||||
| substr = tf.strings.substr(contents, 0, 3); | |||||
| return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -225,8 +226,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "is_jpeg"), scope => | return tf_with(ops.name_scope(name, "is_jpeg"), scope => | ||||
| { | { | ||||
| var substr = string_ops.substr(contents, 0, 3); | |||||
| return math_ops.equal(substr, "\xff\xd8\xff", name: name); | |||||
| var substr = tf.strings.substr(contents, 0, 3); | |||||
| var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||||
| var jpg_tensor = tf.constant(jpg); | |||||
| var result = math_ops.equal(substr, jpg_tensor, name: name); | |||||
| return result; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -234,7 +238,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "is_png"), scope => | return tf_with(ops.name_scope(name, "is_png"), scope => | ||||
| { | { | ||||
| var substr = string_ops.substr(contents, 0, 3); | |||||
| var substr = tf.strings.substr(contents, 0, 3); | |||||
| return math_ops.equal(substr, @"\211PN", name: name); | return math_ops.equal(substr, @"\211PN", name: name); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -14,31 +14,45 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class gen_io_ops | |||||
| public class io_ops | |||||
| { | { | ||||
| public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||||
| public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||||
| { | { | ||||
| var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | ||||
| return _op; | return _op; | ||||
| } | } | ||||
| public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||||
| { | { | ||||
| var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | ||||
| return _op.outputs; | return _op.outputs; | ||||
| } | } | ||||
| public static Tensor read_file<T>(T filename, string name = null) | |||||
| public Tensor read_file<T>(T filename, string name = null) | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| return read_file_eager_fallback(filename, name: name, tf.context); | |||||
| } | |||||
| var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); | var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| private Tensor read_file_eager_fallback<T>(T filename, string name = null, Context ctx = null) | |||||
| { | |||||
| var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING); | |||||
| var _inputs_flat = new[] { filename_tensor }; | |||||
| return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0]; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -31,8 +32,30 @@ namespace Tensorflow | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="uint"></param> | /// <param name="uint"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor substr(Tensor input, int pos, int len, | |||||
| string name = null, string @uint = "BYTE") | |||||
| => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||||
| public Tensor substr<T>(T input, int pos, int len, | |||||
| string @uint = "BYTE", string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var input_tensor = tf.constant(input); | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Substr", name, | |||||
| null, | |||||
| input, pos, len, | |||||
| "unit", @uint); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new | |||||
| { | |||||
| input, | |||||
| pos, | |||||
| len, | |||||
| unit = @uint | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -68,9 +68,9 @@ namespace Tensorflow | |||||
| throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); | throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); | ||||
| IntPtr stringStartAddress = IntPtr.Zero; | IntPtr stringStartAddress = IntPtr.Zero; | ||||
| UIntPtr dstLen = UIntPtr.Zero; | |||||
| ulong dstLen = 0; | |||||
| c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle); | |||||
| c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle); | |||||
| tf.status.Check(true); | tf.status.Check(true); | ||||
| var dstLenInt = checked((int) dstLen); | var dstLenInt = checked((int) dstLen); | ||||
| @@ -453,7 +453,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var buffer = Encoding.UTF8.GetBytes(str); | var buffer = Encoding.UTF8.GetBytes(str); | ||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | ||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong))); | |||||
| AllocationType = AllocationType.Tensorflow; | AllocationType = AllocationType.Tensorflow; | ||||
| IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
| @@ -235,13 +235,12 @@ namespace Tensorflow | |||||
| var buffer = new byte[size][]; | var buffer = new byte[size][]; | ||||
| var src = c_api.TF_TensorData(_handle); | var src = c_api.TF_TensorData(_handle); | ||||
| var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||||
| src += (int)(size * 8); | src += (int)(size * 8); | ||||
| for (int i = 0; i < buffer.Length; i++) | for (int i = 0; i < buffer.Length; i++) | ||||
| { | { | ||||
| IntPtr dst = IntPtr.Zero; | IntPtr dst = IntPtr.Zero; | ||||
| UIntPtr dstLen = UIntPtr.Zero; | |||||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle); | |||||
| ulong dstLen = 0; | |||||
| var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); | |||||
| tf.status.Check(true); | tf.status.Check(true); | ||||
| buffer[i] = new byte[(int)dstLen]; | buffer[i] = new byte[(int)dstLen]; | ||||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | ||||
| @@ -254,5 +253,35 @@ namespace Tensorflow | |||||
| return _str; | return _str; | ||||
| } | } | ||||
| public unsafe byte[][] StringBytes() | |||||
| { | |||||
| if (dtype != TF_DataType.TF_STRING) | |||||
| throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||||
| // | |||||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||||
| // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | |||||
| // | |||||
| long size = 1; | |||||
| foreach (var s in TensorShape.dims) | |||||
| size *= s; | |||||
| var buffer = new byte[size][]; | |||||
| var src = c_api.TF_TensorData(_handle); | |||||
| src += (int)(size * 8); | |||||
| for (int i = 0; i < buffer.Length; i++) | |||||
| { | |||||
| IntPtr dst = IntPtr.Zero; | |||||
| ulong dstLen = 0; | |||||
| var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); | |||||
| tf.status.Check(true); | |||||
| buffer[i] = new byte[(int)dstLen]; | |||||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||||
| src += (int)read; | |||||
| } | |||||
| return buffer; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -207,7 +207,7 @@ namespace Tensorflow | |||||
| public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status); | |||||
| public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | |||||
| public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; | public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; | ||||
| @@ -132,28 +132,54 @@ namespace Tensorflow | |||||
| switch (value) | switch (value) | ||||
| { | { | ||||
| case EagerTensor val: | |||||
| return val; | |||||
| case NDArray val: | case NDArray val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case string val: | case string val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case bool val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case byte val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case byte[] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case byte[,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case byte[,,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case int val: | case int val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case int[] val: | case int[] val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case int[,] val: | case int[,] val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case int[,,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case long val: | case long val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case long[] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case long[,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case long[,,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case float val: | case float val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case float[] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case float[,] val: | case float[,] val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case double val: | |||||
| case float[,,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case float[] val: | |||||
| case double val: | |||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case double[] val: | case double[] val: | ||||
| return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
| case double[,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case double[,,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| default: | default: | ||||
| throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); | throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); | ||||
| } | } | ||||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||||
| if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | ||||
| { | { | ||||
| return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||||
| return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -76,7 +76,7 @@ namespace Tensorflow | |||||
| dtypes.Add(spec.dtype); | dtypes.Add(spec.dtype); | ||||
| } | } | ||||
| return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||||
| return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||||
| } | } | ||||
| public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, | public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, | ||||
| @@ -39,5 +39,8 @@ namespace Tensorflow.Util | |||||
| } | } | ||||
| public override bool IsInvalid => handle == IntPtr.Zero; | public override bool IsInvalid => handle == IntPtr.Zero; | ||||
| public override string ToString() | |||||
| => $"0x{handle.ToString("x16")}"; | |||||
| } | } | ||||
| } | } | ||||
| @@ -28,10 +28,10 @@ namespace Tensorflow.Util | |||||
| } | } | ||||
| public void push_back(Tk key, Tv value) | public void push_back(Tk key, Tv value) | ||||
| => Add(key, value); | |||||
| => this[key] = value; | |||||
| public void emplace(Tk key, Tv value) | public void emplace(Tk key, Tv value) | ||||
| => Add(key, value); | |||||
| => this[key] = value; | |||||
| public bool find(Tk key) | public bool find(Tk key) | ||||
| => ContainsKey(key); | => ContainsKey(key); | ||||
| @@ -22,56 +22,21 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class ResourceVariable | public partial class ResourceVariable | ||||
| { | { | ||||
| public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); | |||||
| public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); | |||||
| public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); | |||||
| public static Tensor operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y); | |||||
| public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); | |||||
| public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); | |||||
| public static Tensor operator <(ResourceVariable x, Tensor y) => op_helper("less", x, y); | |||||
| public static Tensor operator >(ResourceVariable x, Tensor y) => op_helper("greater", x, y); | |||||
| private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y) | |||||
| => tf_with(ops.name_scope(null, default_name, new { x, y }), scope => | |||||
| { | |||||
| string name = scope; | |||||
| var xVal = x.value(); | |||||
| var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y"); | |||||
| Tensor result = null; | |||||
| switch (default_name) | |||||
| { | |||||
| case "add": | |||||
| result = x.dtype == TF_DataType.TF_STRING ? | |||||
| gen_math_ops.add(xVal, yTensor, name) : | |||||
| gen_math_ops.add_v2(xVal, yTensor, name); | |||||
| break; | |||||
| case "sub": | |||||
| result = gen_math_ops.sub(xVal, yTensor, name); | |||||
| break; | |||||
| case "mul": | |||||
| result = gen_math_ops.mul(xVal, yTensor, name: name); | |||||
| break; | |||||
| case "less": | |||||
| result = gen_math_ops.less(xVal, yTensor, name); | |||||
| break; | |||||
| case "greater": | |||||
| result = gen_math_ops.greater(xVal, yTensor, name); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| // x.assign(result); | |||||
| // result.ResourceVar = x; | |||||
| return result; | |||||
| }); | |||||
| public static Tensor operator +(ResourceVariable x, int y) => x.value() + y; | |||||
| public static Tensor operator +(ResourceVariable x, float y) => x.value() + y; | |||||
| public static Tensor operator +(ResourceVariable x, double y) => x.value() + y; | |||||
| public static Tensor operator +(ResourceVariable x, ResourceVariable y) => x.value() + y.value(); | |||||
| public static Tensor operator -(ResourceVariable x, int y) => x.value() - y; | |||||
| public static Tensor operator -(ResourceVariable x, float y) => x.value() - y; | |||||
| public static Tensor operator -(ResourceVariable x, double y) => x.value() - y; | |||||
| public static Tensor operator -(ResourceVariable x, Tensor y) => x.value() - y; | |||||
| public static Tensor operator -(ResourceVariable x, ResourceVariable y) => x.value() - y.value(); | |||||
| public static Tensor operator *(ResourceVariable x, ResourceVariable y) => x.value() * y.value(); | |||||
| public static Tensor operator *(ResourceVariable x, NDArray y) => x.value() * y; | |||||
| public static Tensor operator <(ResourceVariable x, Tensor y) => x.value() < y; | |||||
| public static Tensor operator >(ResourceVariable x, Tensor y) => x.value() > y; | |||||
| } | } | ||||
| } | } | ||||
| @@ -277,7 +277,7 @@ namespace Tensorflow | |||||
| return ops.control_dependencies(null); | return ops.control_dependencies(null); | ||||
| } | } | ||||
| private static int uid_number = 0; | |||||
| private static int uid_number = -1; | |||||
| /// <summary> | /// <summary> | ||||
| /// A unique (within this program execution) integer. | /// A unique (within this program execution) integer. | ||||
| @@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| Assert.AreEqual(6.0, (double)c); | Assert.AreEqual(6.0, (double)c); | ||||
| } | } | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void StringEncode() | public void StringEncode() | ||||
| { | { | ||||
| @@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | ||||
| Assert.AreEqual(encoded_str, str); | Assert.AreEqual(encoded_str, str); | ||||
| Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | ||||
| //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||||
| // c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -2,8 +2,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Reflection; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
| @@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| [TestInitialize] | [TestInitialize] | ||||
| public void Initialize() | public void Initialize() | ||||
| { | { | ||||
| imgPath = Path.GetFullPath(imgPath); | |||||
| contents = tf.read_file(imgPath); | |||||
| imgPath = TestHelper.GetFullPathFromDataDir(imgPath); | |||||
| contents = tf.io.read_file(imgPath); | |||||
| } | } | ||||
| [Ignore("")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void decode_image() | public void decode_image() | ||||
| { | { | ||||
| @@ -24,6 +24,25 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
| Assert.AreEqual((float)grad, 3.0f); | Assert.AreEqual((float)grad, 3.0f); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Calcute the gradient of w * w * w | |||||
| /// 高阶梯度 | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void HighGradient() | |||||
| { | |||||
| var x = tf.Variable(1.0f); | |||||
| using var tape1 = tf.GradientTape(); | |||||
| using var tape2 = tf.GradientTape(); | |||||
| var y = x * x * x; | |||||
| tape2.Dispose(); | |||||
| var dy_dx = tape2.gradient(y, x); | |||||
| Assert.AreEqual((float)dy_dx, 3.0f); | |||||
| tape1.Dispose(); | |||||
| var d2y_d2x = tape1.gradient(dy_dx, x); | |||||
| Assert.AreEqual((float)d2y_d2x, 6.0f); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void ConstantMultiply() | public void ConstantMultiply() | ||||
| { | { | ||||
| @@ -56,5 +75,33 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
| var dz_dy = tape.gradient(z, y); | var dz_dy = tape.gradient(z, y); | ||||
| Assert.AreEqual((float)dz_dy, 8.0f); | Assert.AreEqual((float)dz_dy, 8.0f); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void ConditionalMultiply() | |||||
| { | |||||
| Func<Tensor, int, Tensor> func = (x, y) => | |||||
| { | |||||
| Tensor output = tf.constant(1.0f); | |||||
| foreach (var i in range(y)) | |||||
| { | |||||
| if (i > 1) | |||||
| output = tf.multiply(output, x); | |||||
| } | |||||
| return output; | |||||
| }; | |||||
| Func<Tensor, int, Tensor> grad = (x, y) => | |||||
| { | |||||
| using var tape = tf.GradientTape(); | |||||
| tape.watch(x); | |||||
| var output = func(x, y); | |||||
| var grad = tape.gradient(output, x); | |||||
| return grad; | |||||
| }; | |||||
| var x = tf.constant(2.0f); | |||||
| var result = grad(x, 4); | |||||
| Assert.AreEqual((float)result, 4.0f); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,10 +6,10 @@ using System.Text; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.math_test | |||||
| namespace TensorFlowNET.UnitTest.TF_API | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class MathOperationTest : TFNetApiTest | |||||
| public class MathApiTest : TFNetApiTest | |||||
| { | { | ||||
| // A constant vector of size 6 | // A constant vector of size 6 | ||||
| Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | ||||
| @@ -0,0 +1,43 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.UnitTest.TF_API | |||||
| { | |||||
| [TestClass] | |||||
| public class StringsApiTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void StringEqual() | |||||
| { | |||||
| var str1 = tf.constant("Hello1"); | |||||
| var str2 = tf.constant("Hello2"); | |||||
| var result = tf.equal(str1, str2); | |||||
| Assert.IsFalse(result.ToScalar<bool>()); | |||||
| var str3 = tf.constant("Hello1"); | |||||
| result = tf.equal(str1, str3); | |||||
| Assert.IsTrue(result.ToScalar<bool>()); | |||||
| var str4 = tf.strings.substr(str1, 0, 5); | |||||
| var str5 = tf.strings.substr(str2, 0, 5); | |||||
| result = tf.equal(str4, str5); | |||||
| Assert.IsTrue(result.ToScalar<bool>()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ImageType() | |||||
| { | |||||
| var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg"); | |||||
| var contents = tf.io.read_file(imgPath); | |||||
| var substr = tf.strings.substr(contents, 0, 3); | |||||
| var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||||
| var jpg_tensor = tf.constant(jpg); | |||||
| var result = math_ops.equal(substr, jpg_tensor); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -23,6 +23,58 @@ namespace Tensorflow.UnitTest.TF_API | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>())); | Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>())); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void InitTensorTest() | |||||
| { | |||||
| var a = tf.constant(np.array(new[, ,] | |||||
| { | |||||
| { { 1 }, { 2 }, { 3 } }, | |||||
| { { 4 }, { 5 }, { 6 } } | |||||
| })); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); | |||||
| var b = tf.constant(new[, ,] | |||||
| { | |||||
| { { 1 }, { 2 }, { 3 } }, | |||||
| { { 4 }, { 5 }, { 6 } } | |||||
| }); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ConcatTest() | |||||
| { | |||||
| var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); | |||||
| var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); | |||||
| var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); | |||||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ConcatDoubleTest() | |||||
| { | |||||
| var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); | |||||
| var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } }); | |||||
| var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); | |||||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ConcatAndSplitTest() | |||||
| { | |||||
| var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); | |||||
| var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); | |||||
| var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); | |||||
| var value = tf.concat(new[] { a, b, c }, axis: 0); | |||||
| var splitValue = tf.split(value, 3, axis: 0); | |||||
| Assert.AreEqual(3, splitValue.Length); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| namespace Tensorflow.UnitTest | |||||
| { | |||||
| public class TestHelper | |||||
| { | |||||
| public static string GetFullPathFromDataDir(string fileName) | |||||
| { | |||||
| var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); | |||||
| return Path.GetFullPath(Path.Combine(dir, fileName)); | |||||
| } | |||||
| } | |||||
| } | |||||