|
|
@@ -1,26 +1,26 @@ |
|
|
/*****************************************************************************
|
|
|
|
|
|
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
|
limitations under the License.
|
|
|
|
|
|
|
|
|
/***************************************************************************** |
|
|
|
|
|
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. |
|
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
|
|
you may not use this file except in compliance with the License. |
|
|
|
|
|
You may obtain a copy of the License at |
|
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
See the License for the specific language governing permissions and |
|
|
|
|
|
limitations under the License. |
|
|
******************************************************************************/ |
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
|
namespace Tensorflow.Train |
|
|
namespace Tensorflow.Train |
|
|
{ |
|
|
{ |
|
|
/// <summary>
|
|
|
|
|
|
/// Optimizer that implements the gradient descent algorithm.
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
|
/// Optimizer that implements the gradient descent algorithm. |
|
|
/// </summary> |
|
|
/// </summary> |
|
|
public class GradientDescentOptimizer : Optimizer |
|
|
public class GradientDescentOptimizer : Optimizer |
|
|
{
|
|
|
|
|
|
|
|
|
{ |
|
|
/// <summary> |
|
|
/// <summary> |
|
|
/// Construct a new gradient descent optimizer. |
|
|
/// Construct a new gradient descent optimizer. |
|
|
/// </summary> |
|
|
/// </summary> |
|
|
@@ -41,9 +41,9 @@ namespace Tensorflow.Train |
|
|
{ |
|
|
{ |
|
|
_lr = learning_rate; |
|
|
_lr = learning_rate; |
|
|
_useTensor = false; |
|
|
_useTensor = false; |
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") |
|
|
: base(learning_rate, use_locking, name) |
|
|
: base(learning_rate, use_locking, name) |
|
|
{ |
|
|
{ |
|
|
_lr_t = learning_rate; |
|
|
_lr_t = learning_rate; |
|
|
@@ -52,10 +52,10 @@ namespace Tensorflow.Train |
|
|
|
|
|
|
|
|
public override void _prepare() |
|
|
public override void _prepare() |
|
|
{ |
|
|
{ |
|
|
if(!_useTensor)
|
|
|
|
|
|
{
|
|
|
|
|
|
var lr = _call_if_callable(_lr);
|
|
|
|
|
|
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
|
|
|
|
|
|
|
|
|
if(!_useTensor) |
|
|
|
|
|
{ |
|
|
|
|
|
var lr = _call_if_callable(_lr); |
|
|
|
|
|
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
} |