diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 1786f340..653978d1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -20,6 +20,8 @@ namespace Tensorflow { public partial class tensorflow { + public InitializersImpl initializers { get; } = new InitializersImpl(); + public IInitializer constant_initializer(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) => new Constant(value, dtype: dtype, verify_shape: verify_shape); public IInitializer zeros_initializer => new Zeros(); @@ -82,5 +84,20 @@ namespace Tensorflow uniform: uniform, seed: seed, dtype: dtype); + + public class InitializersImpl + { + public IInitializer random_normal_initializer(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) => new RandomNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); + + public IInitializer zeros_initializer(TensorShape shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) => new Zeros(shape: shape, + dtype: dtype); + } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.optimizers.cs b/src/TensorFlowNET.Core/APIs/tf.optimizers.cs index 760154ad..ceccca5b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.optimizers.cs @@ -27,6 +27,18 @@ namespace Tensorflow public class KerasOptimizers { public SGD SGD(float learning_rate) => new SGD(learning_rate); + + public Adam Adam(float learning_rate = 0.001f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + string name = "Adam") => new Adam(learning_rate: learning_rate, + beta_1: beta_1, + beta_2: beta_2, + epsilon: epsilon, + amsgrad: amsgrad, + name: name); } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerOperation.cs b/src/TensorFlowNET.Core/Eager/EagerOperation.cs index 2aa7c04c..13e546c1 100644 --- a/src/TensorFlowNET.Core/Eager/EagerOperation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerOperation.cs @@ -51,28 +51,14 @@ namespace Tensorflow.Eager public override object get_attr(string attr_name) { - object value = null; - byte isList = 0; - var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); - switch (attrType) - { - case TF_AttrType.TF_ATTR_BOOL: - value = get_attr_bool(attr_name); - break; - default: - break; - } - - return value; - } - - public bool get_attr_bool(string attr_name) - { + // var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); for (int i = 0; i < Attrs.Length; i = i + 2) + { if (Attrs[i].Equals(attr_name)) - return Attrs[i + 1].Equals("1"); + return Attrs[i + 1]; + } - throw new ValueError($"Can't find attr: {attr_name}"); + return null; } public override string ToString() diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 52b811de..4e61394c 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -344,6 +344,11 @@ namespace Tensorflow.Eager c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); attr_list_sizes[key] = values2.Length; } + else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4) + { + c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length); + attr_list_sizes[key] = values4.Length; + } else { throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index df68060b..7e69cc13 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -209,6 +209,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrIntList(SafeOpHandle op, string attr_name, long[] values, int num_values); + [DllImport(TensorFlowLibName)] public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index dccb9574..e94dc62a 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -119,7 +119,7 @@ namespace Tensorflow.Gradients return (results[0], results[1]); } - public Tensor[] gradient(Tensor target, List sources) + public Tensor[] gradient(Tensor target, IEnumerable sources) { if (_recording) { diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index b1393e15..6a2df6e9 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -128,12 +128,12 @@ namespace Tensorflow.Gradients [RegisterGradient("Conv2D")] public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) { - var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); - var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); - var padding = op.get_attr("padding"); - var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); - var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); - var data_format = op.get_attr("data_format"); + var dilations = op.get_attr("dilations"); + var strides = op.get_attr("strides"); + var padding = op.get_attr("padding"); + var explicit_paddings = op.get_attr("explicit_paddings"); + var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); + var data_format = op.get_attr("data_format"); var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); return new Tensor[] @@ -287,8 +287,8 @@ namespace Tensorflow.Gradients op.inputs[0], op.outputs[0], grad, - (op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), - (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), + op.get_attr("ksize") as int[], + op.get_attr("strides") as int[], padding: op.get_attr("padding").ToString(), data_format: op.get_attr("data_format").ToString()) }; diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/Adam.cs b/src/TensorFlowNET.Core/Keras/Optimizers/Adam.cs new file mode 100644 index 00000000..bd5c3a96 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/Adam.cs @@ -0,0 +1,91 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// Optimizer that implements the Adam algorithm. + /// Adam optimization is a stochastic gradient descent method that is based on + /// adaptive estimation of first-order and second-order moments. + /// + public class Adam : OptimizerV2 + { + protected override string _name => "Adam"; + float epsilon = 1e-7f; + bool amsgrad = false; + + public Adam(float learning_rate = 0.001f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + string name = "Adam") + { + _set_hyper("learning_rate", learning_rate); + // _set_hyper("decay", _initial_decay); + _set_hyper("beta_1", beta_1); + _set_hyper("beta_2", beta_2); + this.epsilon = epsilon; + this.amsgrad = amsgrad; + } + + protected override void _create_slots(IVariableV1[] var_list) + { + foreach(var var in var_list) + add_slot(var, "m"); + foreach (var var in var_list) + add_slot(var, "v"); + if (amsgrad) + foreach (var var in var_list) + add_slot(var, "vhat"); + } + + protected override void _prepare_local(DeviceDType device_dtype, Dictionary> apply_state) + { + base._prepare_local(device_dtype, apply_state); + var var_dtype = device_dtype.DType; + var var_device = device_dtype.Device; + var local_step = math_ops.cast(iterations + 1, var_dtype); + var beta_1_t = array_ops.identity(_get_hyper("beta_1", var_dtype)); + var beta_2_t = array_ops.identity(_get_hyper("beta_2", var_dtype)); + var beta_1_power = math_ops.pow(beta_1_t, local_step); + var beta_2_power = math_ops.pow(beta_2_t, local_step); + var lr = apply_state[device_dtype]["lr_t"] * (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)); + // update state + apply_state[device_dtype]["lr"] = lr; + apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(epsilon); + apply_state[device_dtype]["beta_1_t"] = beta_1_t; + apply_state[device_dtype]["beta_1_power"] = beta_1_power; + apply_state[device_dtype]["one_minus_beta_1_t"] = 1 - beta_1_t; + apply_state[device_dtype]["beta_2_t"] = beta_2_t; + apply_state[device_dtype]["beta_2_power"] = beta_2_power; + apply_state[device_dtype]["one_minus_beta_2_t"] = 1 - beta_2_t; + } + + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> apply_state) + { + var (var_device, var_dtype) = (var.Device, var.dtype.as_base_dtype()); + var coefficients = apply_state.FirstOrDefault(x => x.Key.Device == var_device && x.Key.DType == var_dtype).Value ?? _fallback_apply_state(var_device, var_dtype); + var m = get_slot(var, "m"); + var v = get_slot(var, "v"); + + if (!amsgrad) + return gen_training_ops.resource_apply_adam(var.Handle, + m.Handle, + v.Handle, + coefficients["beta_1_power"], + coefficients["beta_2_power"], + coefficients["lr_t"], + coefficients["beta_1_t"], + coefficients["beta_2_t"], + coefficients["epsilon"], + grad, + use_locking: _use_locking); + else + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index 6b926622..4f5d2545 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -18,22 +18,25 @@ namespace Tensorflow.Keras.Optimizers protected bool _hypers_created; protected virtual string _name { get; } - ResourceVariable _iterations; - List _weight; + IVariableV1 _iterations; + protected ResourceVariable iterations => _iterations as ResourceVariable; + List _weights; Dictionary _hyper; - Dictionary _hyper_variables; + Dictionary _hyper_variables; protected bool _momentum; protected float _initial_decay = 0.0f; protected bool _use_locking = true; - Dictionary> apply_state; + Dictionary> _slots; + List _slot_names; public OptimizerV2() : base() { - _weight = new List(); + _weights = new List(); _hyper = new Dictionary(); - _hyper_variables = new Dictionary(); - apply_state = new Dictionary>(); + _hyper_variables = new Dictionary(); + _slots = new Dictionary>(); + _slot_names = new List(); } public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, @@ -61,7 +64,7 @@ namespace Tensorflow.Keras.Optimizers if (grads_and_vars == null || grads_and_vars.Count() == 0) return control_flow_ops.no_op(); - apply_state = _prepare(var_list); + var apply_state = _prepare(var_list); if(experimental_aggregate_gradients) { // var reduced_grads = _aggregate_gradients(grads_and_vars); @@ -72,13 +75,13 @@ namespace Tensorflow.Keras.Optimizers }); } - void apply_grad_to_update_var(ResourceVariable var, EagerTensor grad) + void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary> apply_state) { _resource_apply_dense(var, grad, apply_state); } protected virtual Operation _resource_apply_dense(IVariableV1 var, - EagerTensor grad, + Tensor grad, Dictionary> _apply_state) { throw new NotImplementedException("_resource_apply_dense"); @@ -94,7 +97,7 @@ namespace Tensorflow.Keras.Optimizers { tf_with(ops.name_scope("update"), delegate { - apply_grad_to_update_var(var, grad as EagerTensor); + apply_grad_to_update_var(var, grad, _apply_state); }); } @@ -107,6 +110,12 @@ namespace Tensorflow.Keras.Optimizers return grads_and_vars.Select(x => x.Item1).ToArray(); } + protected IVariableV1 get_slot(IVariableV1 var, string slot_name) + { + var slot_dict = _slots[var.UniqueId]; + return slot_dict[slot_name]; + } + Dictionary> _prepare(IVariableV1[] var_list) { var _apply_state = new Dictionary>(); @@ -125,6 +134,11 @@ namespace Tensorflow.Keras.Optimizers return _apply_state; } + protected Dictionary _fallback_apply_state(string var_device, TF_DataType var_dtype) + { + throw new NotImplementedException(""); + } + protected virtual void _prepare_local(DeviceDType device_dtype, Dictionary> _apply_state) { @@ -145,7 +159,7 @@ namespace Tensorflow.Keras.Optimizers return lr_t; } - protected ResourceVariable _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) + protected Tensor _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) { var value = _hyper_variables[name]; return math_ops.cast(value, dtype); @@ -160,7 +174,7 @@ namespace Tensorflow.Keras.Optimizers dtype: TF_DataType.TF_INT64, trainable: false, aggregation: VariableAggregation.OnlyFirstReplica); - _weight.Add(_iterations); + _weights.Add(_iterations); } _create_hypers(); @@ -190,7 +204,7 @@ namespace Tensorflow.Keras.Optimizers _hypers_created = true; } - void _create_slots(IVariableV1[] var_list) + protected virtual void _create_slots(IVariableV1[] var_list) { if(_momentum) { @@ -199,6 +213,35 @@ namespace Tensorflow.Keras.Optimizers } } + protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) + { + if (initializer == null) + initializer = tf.zeros_initializer; + + if (!_slot_names.Contains(slot_name)) + _slot_names.append(slot_name); + + if (!_slots.ContainsKey(var.UniqueId)) + _slots[var.UniqueId] = new Dictionary(); + var slot_dict = _slots[var.UniqueId]; + if (!slot_dict.ContainsKey(slot_name)) + { + var weight = tf.Variable(initializer, + dtype: var.dtype, + trainable: false, + shape: var.shape, + name: $"{var.Name}/{slot_name}"); + + slot_dict[slot_name] = weight; + _weights.append(weight); + return weight; + } + else + { + return slot_dict[slot_name]; + } + } + ResourceVariable add_weight(string name, TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs b/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs index afedb391..8ac1aa5c 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Optimizers _get_hyper("momentum", device_dtype.DType)); } - protected override Operation _resource_apply_dense(IVariableV1 var, EagerTensor grad, Dictionary> _apply_state) + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> _apply_state) { if (_momentum) { diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index e8d16820..de9f479b 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -36,11 +36,7 @@ namespace Tensorflow.Keras.Utils ops.init_scope(); - Func init_val = () => args.Initializer.Apply(new InitializerArgs - { - Shape = args.Shape, - DType = args.DType - }); + Func init_val = () => args.Initializer.Apply(new InitializerArgs(args.Shape, dtype: args.DType)); var variable_dtype = args.DType.as_base_dtype(); var v = tf.Variable(init_val, diff --git a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs index 561664bc..5e0227f9 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs @@ -6,8 +6,20 @@ namespace Tensorflow { public class InitializerArgs { + public string Name { get; set; } public TensorShape Shape { get; set; } public TF_DataType DType { get; set; } public bool? VerifyShape { get; set; } = null; + + public InitializerArgs(TensorShape shape, + TF_DataType dtype = TF_DataType.DtInvalid, + bool? verify_shape = null, + string name = null) + { + Shape = shape; + DType = dtype; + VerifyShape = verify_shape; + Name = name; + } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index 67e5d424..a4de9508 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -18,17 +18,21 @@ namespace Tensorflow.Operations.Initializers { public class Zeros : IInitializer { - private TF_DataType dtype; + TensorShape shape; + TF_DataType dtype; - public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) + public Zeros(TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { + this.shape = shape; this.dtype = dtype; } public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) - args.DType = this.dtype; + args.DType = dtype; + if (args.Shape == null) + args.Shape = shape; return array_ops.zeros(args.Shape, dtype); } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs index dde018df..fa0d5bef 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs @@ -71,7 +71,7 @@ namespace Tensorflow.Operations public bool UseCudnnOnGpu { get; set; } = true; - public int[] Dilations { get; set; } = new [] { 1, 1, 1, 1 }; + public int[] Dilations { get; set; } = new int[] { 1, 1, 1, 1 }; public Conv2dParams() { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 31b06a32..b239cfd8 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -42,6 +42,22 @@ namespace Tensorflow.Operations /// public static Tensor conv2d(Conv2dParams parameters) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Conv2D", parameters.Name, + null, + parameters.Input, parameters.Filter, + "strides", parameters.Strides, + "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, + "padding", parameters.Padding, + "explicit_paddings", parameters.ExplicitPaddings, + "data_format", parameters.DataFormat, + "dilations", parameters.Dilations); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("Conv2D", name: parameters.Name, args: new { input = parameters.Input, @@ -64,6 +80,22 @@ namespace Tensorflow.Operations /// public static Tensor conv2d_backprop_filter(Conv2dParams parameters) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Conv2DBackpropFilter", parameters.Name, + null, + parameters.Input, parameters.FilterSizes, parameters.OutBackProp, + "strides", parameters.Strides, + "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, + "padding", parameters.Padding, + "explicit_paddings", parameters.ExplicitPaddings, + "data_format", parameters.DataFormat, + "dilations", parameters.Dilations); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new { input = parameters.Input, @@ -87,6 +119,22 @@ namespace Tensorflow.Operations /// public static Tensor conv2d_backprop_input(Conv2dParams parameters) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Conv2DBackpropInput", parameters.Name, + null, + parameters.InputSizes, parameters.Filter, parameters.OutBackProp, + "strides", parameters.Strides, + "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, + "padding", parameters.Padding, + "explicit_paddings", parameters.ExplicitPaddings, + "data_format", parameters.DataFormat, + "dilations", parameters.Dilations); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new { input_sizes = parameters.InputSizes, @@ -341,6 +389,20 @@ namespace Tensorflow.Operations string data_format = "NHWC", string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "MaxPool", name, + null, + input, + "ksize", ksize, + "strides", strides, + "padding", padding, + "data_format", data_format); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, args: new { input, @@ -356,6 +418,20 @@ namespace Tensorflow.Operations public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format= "NHWC", string name= null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "MaxPoolGrad", name, + null, + orig_input, orig_output, grad, + "ksize", ksize, + "strides", strides, + "padding", padding, + "data_format", data_format); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, args: new { orig_input, @@ -384,7 +460,7 @@ namespace Tensorflow.Operations public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "ReluGrad", name, diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 37ecdb7f..595c0ce8 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -227,7 +227,7 @@ namespace Tensorflow return grouped_inputs.ToArray(); } - public T get_attr(string name) + public virtual T get_attr(string name) => (T)get_attr(name); public virtual object get_attr(string name) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 4ce25b3c..95d4d078 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -424,6 +424,17 @@ namespace Tensorflow /// public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ShapeN", name, + null, + input, + "out_type", out_type); + + return results; + } + var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, new { input, out_type }); return _op.outputs; } @@ -450,7 +461,7 @@ namespace Tensorflow public static Tensor tile(Tensor input, T multiples, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Tile", name, diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index d88dca8c..603b757e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -320,7 +320,7 @@ namespace Tensorflow /// public static Tensor sigmoid(Tensor x, string name = "Sigmoid") { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Sigmoid", name, @@ -1074,23 +1074,6 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor pow(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Pow", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y }); - - return _op.outputs[0]; - } - public static Tensor _sum(Tx input, Ty axis = default, bool keep_dims = false, string name = null) { if (tf.Context.executing_eagerly()) diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 8a349846..8a93453e 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -681,7 +681,19 @@ namespace Tensorflow var x_tensor = ops.convert_to_tensor(x, name: "x"); var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); - return gen_math_ops.pow(x_tensor, y_tensor, name: name); + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Pow", name, + null, + x_tensor, y_tensor); + + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x_tensor, y_tensor }); + + return _op.output; }); public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") @@ -754,9 +766,6 @@ namespace Tensorflow if (transpose_b && adjoint_b) throw new ValueError("Only one of transpose_b and adjoint_b can be True."); - a = ops.convert_to_tensor(a, name: "a"); - b = ops.convert_to_tensor(b, name: "b"); - result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); }); diff --git a/src/TensorFlowNET.Core/Operations/random_ops.cs b/src/TensorFlowNET.Core/Operations/random_ops.cs index 64530396..8c9186b9 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.cs @@ -30,7 +30,7 @@ namespace Tensorflow /// /// /// - public static Tensor random_normal(int[] shape, + public static Tensor random_normal(TensorShape shape, float mean = 0.0f, float stddev = 1.0f, TF_DataType dtype = TF_DataType.TF_FLOAT, diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index ff7e9ff4..36eca7d6 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -23,6 +23,24 @@ namespace Tensorflow { public class gen_training_ops { + public static Operation resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, + Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, + bool use_locking = false, bool use_nesterov = false, string name = null) + { + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ResourceApplyAdam", name, + null, + var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + "use_locking", use_locking, + "use_nesterov", use_nesterov); + return null; + } + + throw new NotImplementedException(""); + } + public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool use_locking = false, bool use_nesterov = false, string name = null) @@ -56,12 +74,12 @@ namespace Tensorflow use_locking }); - return _op.outputs[0]; + return _op.output; } public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "ResourceApplyGradientDescent", name, diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index c8e24528..9b179b4e 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -18,7 +18,7 @@ namespace Tensorflow protected string handle_name => _handle_name; protected string _unique_id; - public string unique_id => _unique_id; + public string UniqueId => _unique_id; protected bool _in_graph_mode; diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index 6295a1cd..9178d6ad 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -31,6 +31,7 @@ namespace Tensorflow /// public interface IVariableV1 { + public string UniqueId { get; } public string Name { get; } public Tensor Handle { get; } public string Device { get; } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 3fccc04e..34f1d93f 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -25,6 +25,7 @@ namespace Tensorflow public partial class RefVariable : IVariableV1, IProtoBuf { protected string _name; + public string UniqueId => _name; public Tensor GraphElement { get; } public Tensor _variable; public Tensor Handle => _variable; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index ab1854ad..671602d9 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -67,8 +67,6 @@ namespace Tensorflow dtype: dtype, shape: shape); } - - // handle.ResourceVar = this; } private void _init_from_args(object initial_value = null, @@ -79,7 +77,8 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null) { - var init_from_fn = initial_value.GetType().Name == "Func`1"; + var init_from_fn = initial_value.GetType().Name == "Func`1" || + initial_value.GetType().GetInterface("IInitializer") != null; if(collections == null) collections = new List() { tf.GraphKeys.GLOBAL_VARIABLES }; _trainable = trainable; @@ -112,9 +111,12 @@ namespace Tensorflow attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); tf_with(ops.name_scope("Initializer"), delegate { - initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func)() : initial_value, - name: "initial_value", - dtype: dtype); + if (initial_value.GetType().GetInterface("IInitializer") != null) + initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); + else + initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func)() : initial_value, + name: "initial_value", + dtype: dtype); }); _shape = shape ?? (initial_value as Tensor).TensorShape; _initial_value = initial_value as Tensor; diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index fb76188b..a1fd03c9 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -162,11 +162,7 @@ namespace Tensorflow } else { - Func init_val = () => initializer.Apply(new InitializerArgs - { - Shape = shape, - DType = dtype - }); + Func init_val = () => initializer.Apply(new InitializerArgs(shape, dtype: dtype)); var variable_dtype = dtype.as_base_dtype(); v = variable_scope.default_variable_creator(init_val,