diff --git a/src/TensorFlowNET.Core/APIs/tf.control.cs b/src/TensorFlowNET.Core/APIs/tf.control.cs index dcccb6fe..f0805c25 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control.cs @@ -20,6 +20,13 @@ namespace Tensorflow { public partial class tensorflow { + public Tensor cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + bool strict = false, + string name = null) + => control_flow_ops.cond(pred, true_fn, false_fn, strict: strict, name: name); + public Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, TensorShape shape_invariants = null, int parallel_iterations = 10, diff --git a/src/TensorFlowNET.Core/APIs/tf.state.cs b/src/TensorFlowNET.Core/APIs/tf.state.cs new file mode 100644 index 00000000..c57d03c6 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.state.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor assign_add(RefVariable @ref, T value, + bool use_locking = false, string name = null) + => state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs index b4c77226..4ee9db76 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -29,6 +29,10 @@ namespace Tensorflow public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y); + public static Tensor operator <(RefVariable x, Tensor y) => op_helper("Less", x, y); + + public static Tensor operator >(RefVariable x, Tensor y) => op_helper("Greater", x, y); + private static Tensor op_helper(string default_name, RefVariable x, T y) { var tensor1 = x.value(); diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 5c8744b6..24cb11f5 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -126,7 +126,7 @@ namespace Tensorflow // name: A name for the operation(optional). // Returns: // A mutable `Tensor`. Has the same type as `ref`. - public static Tensor assign_add(RefVariable @ref, Tensor value, bool use_locking = false, string name = null) + public static Tensor assign_add(RefVariable @ref, T value, bool use_locking = false, string name = null) { var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 8f478f2d..cd8d4f3f 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -94,10 +94,15 @@ namespace Tensorflow // Returns: // Same as "ref". Returned as a convenience for operations that want // to use the new value after the variable has been updated. - public static Tensor assign_add(RefVariable @ref, - Tensor value, + public static Tensor assign_add(RefVariable @ref, + T value, bool use_locking = false, - string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + throw new NotImplementedException("assign_add"); + } public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) {