| @@ -20,6 +20,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public InitializersImpl initializers { get; } = new InitializersImpl(); | |||||
| public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | ||||
| => new Constant<T>(value, dtype: dtype, verify_shape: verify_shape); | => new Constant<T>(value, dtype: dtype, verify_shape: verify_shape); | ||||
| public IInitializer zeros_initializer => new Zeros(); | public IInitializer zeros_initializer => new Zeros(); | ||||
| @@ -82,5 +84,20 @@ namespace Tensorflow | |||||
| uniform: uniform, | uniform: uniform, | ||||
| seed: seed, | seed: seed, | ||||
| dtype: dtype); | dtype: dtype); | ||||
| public class InitializersImpl | |||||
| { | |||||
| public IInitializer random_normal_initializer(float mean = 0.0f, | |||||
| float stddev = 1.0f, | |||||
| int? seed = null, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) => new RandomNormal(mean: mean, | |||||
| stddev: stddev, | |||||
| seed: seed, | |||||
| dtype: dtype); | |||||
| public IInitializer zeros_initializer(TensorShape shape = null, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) => new Zeros(shape: shape, | |||||
| dtype: dtype); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -27,6 +27,18 @@ namespace Tensorflow | |||||
| public class KerasOptimizers | public class KerasOptimizers | ||||
| { | { | ||||
| public SGD SGD(float learning_rate) => new SGD(learning_rate); | public SGD SGD(float learning_rate) => new SGD(learning_rate); | ||||
| public Adam Adam(float learning_rate = 0.001f, | |||||
| float beta_1 = 0.9f, | |||||
| float beta_2 = 0.999f, | |||||
| float epsilon = 1e-7f, | |||||
| bool amsgrad = false, | |||||
| string name = "Adam") => new Adam(learning_rate: learning_rate, | |||||
| beta_1: beta_1, | |||||
| beta_2: beta_2, | |||||
| epsilon: epsilon, | |||||
| amsgrad: amsgrad, | |||||
| name: name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -51,28 +51,14 @@ namespace Tensorflow.Eager | |||||
| public override object get_attr(string attr_name) | public override object get_attr(string attr_name) | ||||
| { | { | ||||
| object value = null; | |||||
| byte isList = 0; | |||||
| var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); | |||||
| switch (attrType) | |||||
| { | |||||
| case TF_AttrType.TF_ATTR_BOOL: | |||||
| value = get_attr_bool(attr_name); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| return value; | |||||
| } | |||||
| public bool get_attr_bool(string attr_name) | |||||
| { | |||||
| // var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); | |||||
| for (int i = 0; i < Attrs.Length; i = i + 2) | for (int i = 0; i < Attrs.Length; i = i + 2) | ||||
| { | |||||
| if (Attrs[i].Equals(attr_name)) | if (Attrs[i].Equals(attr_name)) | ||||
| return Attrs[i + 1].Equals("1"); | |||||
| return Attrs[i + 1]; | |||||
| } | |||||
| throw new ValueError($"Can't find attr: {attr_name}"); | |||||
| return null; | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -344,6 +344,11 @@ namespace Tensorflow.Eager | |||||
| c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); | c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); | ||||
| attr_list_sizes[key] = values2.Length; | attr_list_sizes[key] = values2.Length; | ||||
| } | } | ||||
| else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4) | |||||
| { | |||||
| c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length); | |||||
| attr_list_sizes[key] = values4.Length; | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -209,6 +209,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrIntList(SafeOpHandle op, string attr_name, long[] values, int num_values); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | ||||
| @@ -119,7 +119,7 @@ namespace Tensorflow.Gradients | |||||
| return (results[0], results[1]); | return (results[0], results[1]); | ||||
| } | } | ||||
| public Tensor[] gradient(Tensor target, List<IVariableV1> sources) | |||||
| public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||||
| { | { | ||||
| if (_recording) | if (_recording) | ||||
| { | { | ||||
| @@ -128,12 +128,12 @@ namespace Tensorflow.Gradients | |||||
| [RegisterGradient("Conv2D")] | [RegisterGradient("Conv2D")] | ||||
| public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | ||||
| { | { | ||||
| var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
| var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
| var padding = op.get_attr("padding"); | |||||
| var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
| var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); | |||||
| var data_format = op.get_attr("data_format"); | |||||
| var dilations = op.get_attr<int[]>("dilations"); | |||||
| var strides = op.get_attr<int[]>("strides"); | |||||
| var padding = op.get_attr<string>("padding"); | |||||
| var explicit_paddings = op.get_attr<int[]>("explicit_paddings"); | |||||
| var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu"); | |||||
| var data_format = op.get_attr<string>("data_format"); | |||||
| var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | ||||
| return new Tensor[] | return new Tensor[] | ||||
| @@ -287,8 +287,8 @@ namespace Tensorflow.Gradients | |||||
| op.inputs[0], | op.inputs[0], | ||||
| op.outputs[0], | op.outputs[0], | ||||
| grad, | grad, | ||||
| (op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||||
| (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||||
| op.get_attr("ksize") as int[], | |||||
| op.get_attr("strides") as int[], | |||||
| padding: op.get_attr("padding").ToString(), | padding: op.get_attr("padding").ToString(), | ||||
| data_format: op.get_attr("data_format").ToString()) | data_format: op.get_attr("data_format").ToString()) | ||||
| }; | }; | ||||
| @@ -0,0 +1,91 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Keras.Optimizers | |||||
| { | |||||
| /// <summary> | |||||
| /// Optimizer that implements the Adam algorithm. | |||||
| /// Adam optimization is a stochastic gradient descent method that is based on | |||||
| /// adaptive estimation of first-order and second-order moments. | |||||
| /// </summary> | |||||
| public class Adam : OptimizerV2 | |||||
| { | |||||
| protected override string _name => "Adam"; | |||||
| float epsilon = 1e-7f; | |||||
| bool amsgrad = false; | |||||
| public Adam(float learning_rate = 0.001f, | |||||
| float beta_1 = 0.9f, | |||||
| float beta_2 = 0.999f, | |||||
| float epsilon = 1e-7f, | |||||
| bool amsgrad = false, | |||||
| string name = "Adam") | |||||
| { | |||||
| _set_hyper("learning_rate", learning_rate); | |||||
| // _set_hyper("decay", _initial_decay); | |||||
| _set_hyper("beta_1", beta_1); | |||||
| _set_hyper("beta_2", beta_2); | |||||
| this.epsilon = epsilon; | |||||
| this.amsgrad = amsgrad; | |||||
| } | |||||
| protected override void _create_slots(IVariableV1[] var_list) | |||||
| { | |||||
| foreach(var var in var_list) | |||||
| add_slot(var, "m"); | |||||
| foreach (var var in var_list) | |||||
| add_slot(var, "v"); | |||||
| if (amsgrad) | |||||
| foreach (var var in var_list) | |||||
| add_slot(var, "vhat"); | |||||
| } | |||||
| protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||||
| { | |||||
| base._prepare_local(device_dtype, apply_state); | |||||
| var var_dtype = device_dtype.DType; | |||||
| var var_device = device_dtype.Device; | |||||
| var local_step = math_ops.cast(iterations + 1, var_dtype); | |||||
| var beta_1_t = array_ops.identity(_get_hyper("beta_1", var_dtype)); | |||||
| var beta_2_t = array_ops.identity(_get_hyper("beta_2", var_dtype)); | |||||
| var beta_1_power = math_ops.pow(beta_1_t, local_step); | |||||
| var beta_2_power = math_ops.pow(beta_2_t, local_step); | |||||
| var lr = apply_state[device_dtype]["lr_t"] * (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)); | |||||
| // update state | |||||
| apply_state[device_dtype]["lr"] = lr; | |||||
| apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(epsilon); | |||||
| apply_state[device_dtype]["beta_1_t"] = beta_1_t; | |||||
| apply_state[device_dtype]["beta_1_power"] = beta_1_power; | |||||
| apply_state[device_dtype]["one_minus_beta_1_t"] = 1 - beta_1_t; | |||||
| apply_state[device_dtype]["beta_2_t"] = beta_2_t; | |||||
| apply_state[device_dtype]["beta_2_power"] = beta_2_power; | |||||
| apply_state[device_dtype]["one_minus_beta_2_t"] = 1 - beta_2_t; | |||||
| } | |||||
| protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||||
| { | |||||
| var (var_device, var_dtype) = (var.Device, var.dtype.as_base_dtype()); | |||||
| var coefficients = apply_state.FirstOrDefault(x => x.Key.Device == var_device && x.Key.DType == var_dtype).Value ?? _fallback_apply_state(var_device, var_dtype); | |||||
| var m = get_slot(var, "m"); | |||||
| var v = get_slot(var, "v"); | |||||
| if (!amsgrad) | |||||
| return gen_training_ops.resource_apply_adam(var.Handle, | |||||
| m.Handle, | |||||
| v.Handle, | |||||
| coefficients["beta_1_power"], | |||||
| coefficients["beta_2_power"], | |||||
| coefficients["lr_t"], | |||||
| coefficients["beta_1_t"], | |||||
| coefficients["beta_2_t"], | |||||
| coefficients["epsilon"], | |||||
| grad, | |||||
| use_locking: _use_locking); | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -18,22 +18,25 @@ namespace Tensorflow.Keras.Optimizers | |||||
| protected bool _hypers_created; | protected bool _hypers_created; | ||||
| protected virtual string _name { get; } | protected virtual string _name { get; } | ||||
| ResourceVariable _iterations; | |||||
| List<ResourceVariable> _weight; | |||||
| IVariableV1 _iterations; | |||||
| protected ResourceVariable iterations => _iterations as ResourceVariable; | |||||
| List<IVariableV1> _weights; | |||||
| Dictionary<string, float> _hyper; | Dictionary<string, float> _hyper; | ||||
| Dictionary<string, ResourceVariable> _hyper_variables; | |||||
| Dictionary<string, IVariableV1> _hyper_variables; | |||||
| protected bool _momentum; | protected bool _momentum; | ||||
| protected float _initial_decay = 0.0f; | protected float _initial_decay = 0.0f; | ||||
| protected bool _use_locking = true; | protected bool _use_locking = true; | ||||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state; | |||||
| Dictionary<string, Dictionary<string, IVariableV1>> _slots; | |||||
| List<string> _slot_names; | |||||
| public OptimizerV2() : base() | public OptimizerV2() : base() | ||||
| { | { | ||||
| _weight = new List<ResourceVariable>(); | |||||
| _weights = new List<IVariableV1>(); | |||||
| _hyper = new Dictionary<string, float>(); | _hyper = new Dictionary<string, float>(); | ||||
| _hyper_variables = new Dictionary<string, ResourceVariable>(); | |||||
| apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | |||||
| _hyper_variables = new Dictionary<string, IVariableV1>(); | |||||
| _slots = new Dictionary<string, Dictionary<string, IVariableV1>>(); | |||||
| _slot_names = new List<string>(); | |||||
| } | } | ||||
| public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | ||||
| @@ -61,7 +64,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| if (grads_and_vars == null || grads_and_vars.Count() == 0) | if (grads_and_vars == null || grads_and_vars.Count() == 0) | ||||
| return control_flow_ops.no_op(); | return control_flow_ops.no_op(); | ||||
| apply_state = _prepare(var_list); | |||||
| var apply_state = _prepare(var_list); | |||||
| if(experimental_aggregate_gradients) | if(experimental_aggregate_gradients) | ||||
| { | { | ||||
| // var reduced_grads = _aggregate_gradients(grads_and_vars); | // var reduced_grads = _aggregate_gradients(grads_and_vars); | ||||
| @@ -72,13 +75,13 @@ namespace Tensorflow.Keras.Optimizers | |||||
| }); | }); | ||||
| } | } | ||||
| void apply_grad_to_update_var(ResourceVariable var, EagerTensor grad) | |||||
| void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||||
| { | { | ||||
| _resource_apply_dense(var, grad, apply_state); | _resource_apply_dense(var, grad, apply_state); | ||||
| } | } | ||||
| protected virtual Operation _resource_apply_dense(IVariableV1 var, | protected virtual Operation _resource_apply_dense(IVariableV1 var, | ||||
| EagerTensor grad, | |||||
| Tensor grad, | |||||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | ||||
| { | { | ||||
| throw new NotImplementedException("_resource_apply_dense"); | throw new NotImplementedException("_resource_apply_dense"); | ||||
| @@ -94,7 +97,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| { | { | ||||
| tf_with(ops.name_scope("update"), delegate | tf_with(ops.name_scope("update"), delegate | ||||
| { | { | ||||
| apply_grad_to_update_var(var, grad as EagerTensor); | |||||
| apply_grad_to_update_var(var, grad, _apply_state); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -107,6 +110,12 @@ namespace Tensorflow.Keras.Optimizers | |||||
| return grads_and_vars.Select(x => x.Item1).ToArray(); | return grads_and_vars.Select(x => x.Item1).ToArray(); | ||||
| } | } | ||||
| protected IVariableV1 get_slot(IVariableV1 var, string slot_name) | |||||
| { | |||||
| var slot_dict = _slots[var.UniqueId]; | |||||
| return slot_dict[slot_name]; | |||||
| } | |||||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(IVariableV1[] var_list) | Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(IVariableV1[] var_list) | ||||
| { | { | ||||
| var _apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | var _apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | ||||
| @@ -125,6 +134,11 @@ namespace Tensorflow.Keras.Optimizers | |||||
| return _apply_state; | return _apply_state; | ||||
| } | } | ||||
| protected Dictionary<string, Tensor> _fallback_apply_state(string var_device, TF_DataType var_dtype) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| protected virtual void _prepare_local(DeviceDType device_dtype, | protected virtual void _prepare_local(DeviceDType device_dtype, | ||||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | ||||
| { | { | ||||
| @@ -145,7 +159,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| return lr_t; | return lr_t; | ||||
| } | } | ||||
| protected ResourceVariable _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| protected Tensor _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | { | ||||
| var value = _hyper_variables[name]; | var value = _hyper_variables[name]; | ||||
| return math_ops.cast(value, dtype); | return math_ops.cast(value, dtype); | ||||
| @@ -160,7 +174,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| dtype: TF_DataType.TF_INT64, | dtype: TF_DataType.TF_INT64, | ||||
| trainable: false, | trainable: false, | ||||
| aggregation: VariableAggregation.OnlyFirstReplica); | aggregation: VariableAggregation.OnlyFirstReplica); | ||||
| _weight.Add(_iterations); | |||||
| _weights.Add(_iterations); | |||||
| } | } | ||||
| _create_hypers(); | _create_hypers(); | ||||
| @@ -190,7 +204,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| _hypers_created = true; | _hypers_created = true; | ||||
| } | } | ||||
| void _create_slots(IVariableV1[] var_list) | |||||
| protected virtual void _create_slots(IVariableV1[] var_list) | |||||
| { | { | ||||
| if(_momentum) | if(_momentum) | ||||
| { | { | ||||
| @@ -199,6 +213,35 @@ namespace Tensorflow.Keras.Optimizers | |||||
| } | } | ||||
| } | } | ||||
| protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||||
| { | |||||
| if (initializer == null) | |||||
| initializer = tf.zeros_initializer; | |||||
| if (!_slot_names.Contains(slot_name)) | |||||
| _slot_names.append(slot_name); | |||||
| if (!_slots.ContainsKey(var.UniqueId)) | |||||
| _slots[var.UniqueId] = new Dictionary<string, IVariableV1>(); | |||||
| var slot_dict = _slots[var.UniqueId]; | |||||
| if (!slot_dict.ContainsKey(slot_name)) | |||||
| { | |||||
| var weight = tf.Variable(initializer, | |||||
| dtype: var.dtype, | |||||
| trainable: false, | |||||
| shape: var.shape, | |||||
| name: $"{var.Name}/{slot_name}"); | |||||
| slot_dict[slot_name] = weight; | |||||
| _weights.append(weight); | |||||
| return weight; | |||||
| } | |||||
| else | |||||
| { | |||||
| return slot_dict[slot_name]; | |||||
| } | |||||
| } | |||||
| ResourceVariable add_weight(string name, | ResourceVariable add_weight(string name, | ||||
| TensorShape shape, | TensorShape shape, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| _get_hyper("momentum", device_dtype.DType)); | _get_hyper("momentum", device_dtype.DType)); | ||||
| } | } | ||||
| protected override Operation _resource_apply_dense(IVariableV1 var, EagerTensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||||
| protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||||
| { | { | ||||
| if (_momentum) | if (_momentum) | ||||
| { | { | ||||
| @@ -36,11 +36,7 @@ namespace Tensorflow.Keras.Utils | |||||
| ops.init_scope(); | ops.init_scope(); | ||||
| Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs | |||||
| { | |||||
| Shape = args.Shape, | |||||
| DType = args.DType | |||||
| }); | |||||
| Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs(args.Shape, dtype: args.DType)); | |||||
| var variable_dtype = args.DType.as_base_dtype(); | var variable_dtype = args.DType.as_base_dtype(); | ||||
| var v = tf.Variable(init_val, | var v = tf.Variable(init_val, | ||||
| @@ -6,8 +6,20 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class InitializerArgs | public class InitializerArgs | ||||
| { | { | ||||
| public string Name { get; set; } | |||||
| public TensorShape Shape { get; set; } | public TensorShape Shape { get; set; } | ||||
| public TF_DataType DType { get; set; } | public TF_DataType DType { get; set; } | ||||
| public bool? VerifyShape { get; set; } = null; | public bool? VerifyShape { get; set; } = null; | ||||
| public InitializerArgs(TensorShape shape, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| bool? verify_shape = null, | |||||
| string name = null) | |||||
| { | |||||
| Shape = shape; | |||||
| DType = dtype; | |||||
| VerifyShape = verify_shape; | |||||
| Name = name; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -18,17 +18,21 @@ namespace Tensorflow.Operations.Initializers | |||||
| { | { | ||||
| public class Zeros : IInitializer | public class Zeros : IInitializer | ||||
| { | { | ||||
| private TF_DataType dtype; | |||||
| TensorShape shape; | |||||
| TF_DataType dtype; | |||||
| public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| public Zeros(TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| { | { | ||||
| this.shape = shape; | |||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| { | { | ||||
| if (args.DType == TF_DataType.DtInvalid) | if (args.DType == TF_DataType.DtInvalid) | ||||
| args.DType = this.dtype; | |||||
| args.DType = dtype; | |||||
| if (args.Shape == null) | |||||
| args.Shape = shape; | |||||
| return array_ops.zeros(args.Shape, dtype); | return array_ops.zeros(args.Shape, dtype); | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Operations | |||||
| public bool UseCudnnOnGpu { get; set; } = true; | public bool UseCudnnOnGpu { get; set; } = true; | ||||
| public int[] Dilations { get; set; } = new [] { 1, 1, 1, 1 }; | |||||
| public int[] Dilations { get; set; } = new int[] { 1, 1, 1, 1 }; | |||||
| public Conv2dParams() | public Conv2dParams() | ||||
| { | { | ||||
| @@ -42,6 +42,22 @@ namespace Tensorflow.Operations | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor conv2d(Conv2dParams parameters) | public static Tensor conv2d(Conv2dParams parameters) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Conv2D", parameters.Name, | |||||
| null, | |||||
| parameters.Input, parameters.Filter, | |||||
| "strides", parameters.Strides, | |||||
| "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||||
| "padding", parameters.Padding, | |||||
| "explicit_paddings", parameters.ExplicitPaddings, | |||||
| "data_format", parameters.DataFormat, | |||||
| "dilations", parameters.Dilations); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Conv2D", name: parameters.Name, args: new | var _op = tf.OpDefLib._apply_op_helper("Conv2D", name: parameters.Name, args: new | ||||
| { | { | ||||
| input = parameters.Input, | input = parameters.Input, | ||||
| @@ -64,6 +80,22 @@ namespace Tensorflow.Operations | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor conv2d_backprop_filter(Conv2dParams parameters) | public static Tensor conv2d_backprop_filter(Conv2dParams parameters) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Conv2DBackpropFilter", parameters.Name, | |||||
| null, | |||||
| parameters.Input, parameters.FilterSizes, parameters.OutBackProp, | |||||
| "strides", parameters.Strides, | |||||
| "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||||
| "padding", parameters.Padding, | |||||
| "explicit_paddings", parameters.ExplicitPaddings, | |||||
| "data_format", parameters.DataFormat, | |||||
| "dilations", parameters.Dilations); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new | var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new | ||||
| { | { | ||||
| input = parameters.Input, | input = parameters.Input, | ||||
| @@ -87,6 +119,22 @@ namespace Tensorflow.Operations | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor conv2d_backprop_input(Conv2dParams parameters) | public static Tensor conv2d_backprop_input(Conv2dParams parameters) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Conv2DBackpropInput", parameters.Name, | |||||
| null, | |||||
| parameters.InputSizes, parameters.Filter, parameters.OutBackProp, | |||||
| "strides", parameters.Strides, | |||||
| "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||||
| "padding", parameters.Padding, | |||||
| "explicit_paddings", parameters.ExplicitPaddings, | |||||
| "data_format", parameters.DataFormat, | |||||
| "dilations", parameters.Dilations); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new | var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new | ||||
| { | { | ||||
| input_sizes = parameters.InputSizes, | input_sizes = parameters.InputSizes, | ||||
| @@ -341,6 +389,20 @@ namespace Tensorflow.Operations | |||||
| string data_format = "NHWC", | string data_format = "NHWC", | ||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "MaxPool", name, | |||||
| null, | |||||
| input, | |||||
| "ksize", ksize, | |||||
| "strides", strides, | |||||
| "padding", padding, | |||||
| "data_format", data_format); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, args: new | var _op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, args: new | ||||
| { | { | ||||
| input, | input, | ||||
| @@ -356,6 +418,20 @@ namespace Tensorflow.Operations | |||||
| public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, | public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, | ||||
| string data_format= "NHWC", string name= null) | string data_format= "NHWC", string name= null) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "MaxPoolGrad", name, | |||||
| null, | |||||
| orig_input, orig_output, grad, | |||||
| "ksize", ksize, | |||||
| "strides", strides, | |||||
| "padding", padding, | |||||
| "data_format", data_format); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, args: new | var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, args: new | ||||
| { | { | ||||
| orig_input, | orig_input, | ||||
| @@ -384,7 +460,7 @@ namespace Tensorflow.Operations | |||||
| public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) | public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "ReluGrad", name, | "ReluGrad", name, | ||||
| @@ -227,7 +227,7 @@ namespace Tensorflow | |||||
| return grouped_inputs.ToArray(); | return grouped_inputs.ToArray(); | ||||
| } | } | ||||
| public T get_attr<T>(string name) | |||||
| public virtual T get_attr<T>(string name) | |||||
| => (T)get_attr(name); | => (T)get_attr(name); | ||||
| public virtual object get_attr(string name) | public virtual object get_attr(string name) | ||||
| @@ -424,6 +424,17 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "ShapeN", name, | |||||
| null, | |||||
| input, | |||||
| "out_type", out_type); | |||||
| return results; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, new { input, out_type }); | var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, new { input, out_type }); | ||||
| return _op.outputs; | return _op.outputs; | ||||
| } | } | ||||
| @@ -450,7 +461,7 @@ namespace Tensorflow | |||||
| public static Tensor tile<T>(Tensor input, T multiples, string name = null) | public static Tensor tile<T>(Tensor input, T multiples, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "Tile", name, | "Tile", name, | ||||
| @@ -320,7 +320,7 @@ namespace Tensorflow | |||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "Sigmoid", name, | "Sigmoid", name, | ||||
| @@ -1074,23 +1074,6 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Pow", name, | |||||
| null, | |||||
| x, y); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims = false, string name = null) | public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims = false, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| @@ -681,7 +681,19 @@ namespace Tensorflow | |||||
| var x_tensor = ops.convert_to_tensor(x, name: "x"); | var x_tensor = ops.convert_to_tensor(x, name: "x"); | ||||
| var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); | var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); | ||||
| return gen_math_ops.pow(x_tensor, y_tensor, name: name); | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Pow", name, | |||||
| null, | |||||
| x_tensor, y_tensor); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x_tensor, y_tensor }); | |||||
| return _op.output; | |||||
| }); | }); | ||||
| public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") | public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") | ||||
| @@ -754,9 +766,6 @@ namespace Tensorflow | |||||
| if (transpose_b && adjoint_b) | if (transpose_b && adjoint_b) | ||||
| throw new ValueError("Only one of transpose_b and adjoint_b can be True."); | throw new ValueError("Only one of transpose_b and adjoint_b can be True."); | ||||
| a = ops.convert_to_tensor(a, name: "a"); | |||||
| b = ops.convert_to_tensor(b, name: "b"); | |||||
| result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); | result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); | ||||
| }); | }); | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||||
| /// <param name="seed"></param> | /// <param name="seed"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor random_normal(int[] shape, | |||||
| public static Tensor random_normal(TensorShape shape, | |||||
| float mean = 0.0f, | float mean = 0.0f, | ||||
| float stddev = 1.0f, | float stddev = 1.0f, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| @@ -23,6 +23,24 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class gen_training_ops | public class gen_training_ops | ||||
| { | { | ||||
| public static Operation resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, | |||||
| Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | |||||
| bool use_locking = false, bool use_nesterov = false, string name = null) | |||||
| { | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "ResourceApplyAdam", name, | |||||
| null, | |||||
| var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, | |||||
| "use_locking", use_locking, | |||||
| "use_nesterov", use_nesterov); | |||||
| return null; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, | public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, | ||||
| Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | ||||
| bool use_locking = false, bool use_nesterov = false, string name = null) | bool use_locking = false, bool use_nesterov = false, string name = null) | ||||
| @@ -56,12 +74,12 @@ namespace Tensorflow | |||||
| use_locking | use_locking | ||||
| }); | }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "ResourceApplyGradientDescent", name, | "ResourceApplyGradientDescent", name, | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| protected string handle_name => _handle_name; | protected string handle_name => _handle_name; | ||||
| protected string _unique_id; | protected string _unique_id; | ||||
| public string unique_id => _unique_id; | |||||
| public string UniqueId => _unique_id; | |||||
| protected bool _in_graph_mode; | protected bool _in_graph_mode; | ||||
| @@ -31,6 +31,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public interface IVariableV1 | public interface IVariableV1 | ||||
| { | { | ||||
| public string UniqueId { get; } | |||||
| public string Name { get; } | public string Name { get; } | ||||
| public Tensor Handle { get; } | public Tensor Handle { get; } | ||||
| public string Device { get; } | public string Device { get; } | ||||
| @@ -25,6 +25,7 @@ namespace Tensorflow | |||||
| public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | ||||
| { | { | ||||
| protected string _name; | protected string _name; | ||||
| public string UniqueId => _name; | |||||
| public Tensor GraphElement { get; } | public Tensor GraphElement { get; } | ||||
| public Tensor _variable; | public Tensor _variable; | ||||
| public Tensor Handle => _variable; | public Tensor Handle => _variable; | ||||
| @@ -67,8 +67,6 @@ namespace Tensorflow | |||||
| dtype: dtype, | dtype: dtype, | ||||
| shape: shape); | shape: shape); | ||||
| } | } | ||||
| // handle.ResourceVar = this; | |||||
| } | } | ||||
| private void _init_from_args(object initial_value = null, | private void _init_from_args(object initial_value = null, | ||||
| @@ -79,7 +77,8 @@ namespace Tensorflow | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| TensorShape shape = null) | TensorShape shape = null) | ||||
| { | { | ||||
| var init_from_fn = initial_value.GetType().Name == "Func`1"; | |||||
| var init_from_fn = initial_value.GetType().Name == "Func`1" || | |||||
| initial_value.GetType().GetInterface("IInitializer") != null; | |||||
| if(collections == null) | if(collections == null) | ||||
| collections = new List<string>() { tf.GraphKeys.GLOBAL_VARIABLES }; | collections = new List<string>() { tf.GraphKeys.GLOBAL_VARIABLES }; | ||||
| _trainable = trainable; | _trainable = trainable; | ||||
| @@ -112,9 +111,12 @@ namespace Tensorflow | |||||
| attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | ||||
| tf_with(ops.name_scope("Initializer"), delegate | tf_with(ops.name_scope("Initializer"), delegate | ||||
| { | { | ||||
| initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func<Tensor>)() : initial_value, | |||||
| name: "initial_value", | |||||
| dtype: dtype); | |||||
| if (initial_value.GetType().GetInterface("IInitializer") != null) | |||||
| initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||||
| else | |||||
| initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func<Tensor>)() : initial_value, | |||||
| name: "initial_value", | |||||
| dtype: dtype); | |||||
| }); | }); | ||||
| _shape = shape ?? (initial_value as Tensor).TensorShape; | _shape = shape ?? (initial_value as Tensor).TensorShape; | ||||
| _initial_value = initial_value as Tensor; | _initial_value = initial_value as Tensor; | ||||
| @@ -162,11 +162,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| Func<Tensor> init_val = () => initializer.Apply(new InitializerArgs | |||||
| { | |||||
| Shape = shape, | |||||
| DType = dtype | |||||
| }); | |||||
| Func<Tensor> init_val = () => initializer.Apply(new InitializerArgs(shape, dtype: dtype)); | |||||
| var variable_dtype = dtype.as_base_dtype(); | var variable_dtype = dtype.as_base_dtype(); | ||||
| v = variable_scope.default_variable_creator(init_val, | v = variable_scope.default_variable_creator(init_val, | ||||