| @@ -1,5 +1,5 @@ | |||||
| /***************************************************************************** | /***************************************************************************** | ||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | you may not use this file except in compliance with the License. | ||||
| @@ -20,9 +20,15 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Record operations for automatic differentiation. | |||||
| /// </summary> | |||||
| /// <param name="persistent"></param> | |||||
| /// <param name="watch_accessed_variables"></param> | |||||
| /// <returns></returns> | |||||
| public GradientTape GradientTape(bool persistent = false, | public GradientTape GradientTape(bool persistent = false, | ||||
| bool watch_accessed_variables = true) | |||||
| => new GradientTape(persistent: persistent, | |||||
| bool watch_accessed_variables = true) | |||||
| => new GradientTape(persistent: persistent, | |||||
| watch_accessed_variables: watch_accessed_variables); | watch_accessed_variables: watch_accessed_variables); | ||||
| public Tensor[] gradients(Tensor[] ys, | public Tensor[] gradients(Tensor[] ys, | ||||
| @@ -389,42 +389,6 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="device_name"></param> | |||||
| /// <param name="op_name"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="args"></param> | |||||
| /// <param name="input_size"></param> | |||||
| /// <param name="set_op_attrs"></param> | |||||
| /// <param name="status"></param> | |||||
| /// <returns>EagerTensorHandle</returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, | |||||
| string device_name, | |||||
| string op_name, | |||||
| string name, | |||||
| IntPtr[] inputs, | |||||
| int input_size, | |||||
| string attrs_string, | |||||
| TFE_FastPathExecute_SetOpAttrs set_op_attrs, | |||||
| IntPtr[] outputs, | |||||
| int output_size); | |||||
| [UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||||
| public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, | |||||
| string device_name, | |||||
| string op_name, | |||||
| IntPtr[] inputs, | |||||
| int input_size, | |||||
| TFE_FastPathExecute_SetOpAttrs set_op_attrs, | |||||
| IntPtr[] outputs, | |||||
| int output_size); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | ||||
| @@ -142,6 +142,14 @@ namespace Tensorflow.Gradients | |||||
| return results; | return results; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Temporarily stops recording operations on this tape. | |||||
| /// </summary> | |||||
| public void stop_recording() | |||||
| { | |||||
| _pop_tape(); | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| if (_recording) | if (_recording) | ||||
| @@ -328,7 +328,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
| name = scope; | name = scope; | ||||
| var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| Tensor ones = null; | Tensor ones = null; | ||||
| switch (dtype) | switch (dtype) | ||||
| { | { | ||||
| @@ -342,6 +342,11 @@ namespace Tensorflow | |||||
| ones = constant(1); | ones = constant(1); | ||||
| break; | break; | ||||
| } | } | ||||
| if (shape.ndim == 0) | |||||
| return ones; | |||||
| var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| return fill(shape_tensor, ones, name: name); | return fill(shape_tensor, ones, name: name); | ||||
| }); | }); | ||||
| @@ -46,15 +46,15 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
| tape.watch(x); | tape.watch(x); | ||||
| var y = tf.reduce_sum(x); | var y = tf.reduce_sum(x); | ||||
| var z = tf.multiply(y, y); | var z = tf.multiply(y, y); | ||||
| tape.Dispose(); | |||||
| var dz_dx = tape.gradient(z, x); | var dz_dx = tape.gradient(z, x); | ||||
| var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | ||||
| Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | ||||
| var dz_dy = tape.gradient(z, y); | var dz_dy = tape.gradient(z, y); | ||||
| expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||||
| Assert.AreEqual((float)dz_dy, 8.0f); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||