2: override _prepare() for AdamOptimizer. 3: fix key name if _get_non_slot_variable.tags/v0.9
| @@ -46,7 +46,8 @@ namespace Tensorflow.Train | |||||
| var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); | var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); | ||||
| var m = get_slot(var, "m"); | var m = get_slot(var, "m"); | ||||
| var m_scaled_g_values = grad * (1 - beta1_t); | var m_scaled_g_values = grad * (1 - beta1_t); | ||||
| var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking); | |||||
| var mul = m * beta1_t; | |||||
| var m_t = state_ops.assign(m, mul, use_locking: _use_locking); | |||||
| with(ops.control_dependencies(new[] { m_t }), delegate | with(ops.control_dependencies(new[] { m_t }), delegate | ||||
| { | { | ||||
| m_t = scatter_add(m, indices, m_scaled_g_values); | m_t = scatter_add(m, indices, m_scaled_g_values); | ||||
| @@ -88,9 +89,15 @@ namespace Tensorflow.Train | |||||
| public override void _prepare() | public override void _prepare() | ||||
| { | { | ||||
| //copied from GradientDescentOptimizer | |||||
| LearningRate = _call_if_callable(LearningRate); | |||||
| LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate"); | |||||
| var lr = _call_if_callable(_lr); | |||||
| var beta1 = _call_if_callable(_beta1); | |||||
| var beta2 = _call_if_callable(_beta2); | |||||
| var epsilon = _call_if_callable(_epsilon); | |||||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
| _beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); | |||||
| _beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); | |||||
| _epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,14 +26,13 @@ namespace Tensorflow.Train | |||||
| public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | ||||
| : base(learning_rate, use_locking, name) | : base(learning_rate, use_locking, name) | ||||
| { | { | ||||
| LearningRate = learning_rate; | |||||
| LearningRateTensor = null; | |||||
| _lr = learning_rate; | |||||
| } | } | ||||
| public override void _prepare() | public override void _prepare() | ||||
| { | { | ||||
| LearningRate = _call_if_callable(LearningRate); | |||||
| LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate"); | |||||
| var lr = _call_if_callable(_lr); | |||||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,8 +23,10 @@ namespace Tensorflow | |||||
| string _name; | string _name; | ||||
| public string Name => _name; | public string Name => _name; | ||||
| public float LearningRate { get; set; } | |||||
| public Tensor LearningRateTensor { get; set; } | |||||
| protected float _lr; | |||||
| public float LearningRate => _lr; | |||||
| protected Tensor _lr_t; | |||||
| public Tensor LearningRateTensor => _lr_t; | |||||
| public bool _use_locking; | public bool _use_locking; | ||||
| public Dictionary<string, Dictionary<string, RefVariable>> _slots; | public Dictionary<string, Dictionary<string, RefVariable>> _slots; | ||||
| public Dictionary<string, RefVariable> _non_slot_dict; | public Dictionary<string, RefVariable> _non_slot_dict; | ||||
| @@ -38,7 +40,7 @@ namespace Tensorflow | |||||
| _name = name; | _name = name; | ||||
| _use_locking = use_locking; | _use_locking = use_locking; | ||||
| LearningRate = learning_rate; | |||||
| _lr = learning_rate; | |||||
| // Dictionary of slots. | // Dictionary of slots. | ||||
| _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | ||||
| _non_slot_dict = new Dictionary<string, RefVariable>(); | _non_slot_dict = new Dictionary<string, RefVariable>(); | ||||
| @@ -302,7 +304,7 @@ namespace Tensorflow | |||||
| protected RefVariable _get_non_slot_variable(string name, Graph graph = null) | protected RefVariable _get_non_slot_variable(string name, Graph graph = null) | ||||
| { | { | ||||
| var key = $"{graph.graph_key}.{name}"; | |||||
| var key = $"{name}.{graph.graph_key}"; | |||||
| var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | ||||
| return non_slot; | return non_slot; | ||||
| @@ -36,8 +36,8 @@ namespace Tensorflow | |||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| use_locking: use_locking, | use_locking: use_locking, | ||||
| name: name); | name: name); | ||||
| else | |||||
| throw new NotImplementedException("state_ops.assign"); | |||||
| throw new NotImplementedException("state_ops.assign"); | |||||
| //return @ref.assign(value, name: name); | |||||
| } | } | ||||
| public static Tensor assign_sub(RefVariable @ref, | public static Tensor assign_sub(RefVariable @ref, | ||||