| @@ -18,22 +18,27 @@ namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Outputs random values from a normal distribution. | |||
| /// </summary> | |||
| /// <param name="shape"></param> | |||
| /// <param name="mean"></param> | |||
| /// <param name="stddev"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="seed"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor random_normal(TensorShape shape, | |||
| float mean = 0.0f, | |||
| float stddev = 1.0f, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| int? seed = null, | |||
| string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||
| public Random random => new Random(); | |||
| public class Random | |||
| { | |||
| /// <summary> | |||
| /// Outputs random values from a normal distribution. | |||
| /// </summary> | |||
| /// <param name="shape"></param> | |||
| /// <param name="mean"></param> | |||
| /// <param name="stddev"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="seed"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor normal(TensorShape shape, | |||
| float mean = 0.0f, | |||
| float stddev = 1.0f, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| int? seed = null, | |||
| string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||
| } | |||
| public Tensor random_uniform(TensorShape shape, | |||
| float minval = 0, | |||
| @@ -45,7 +45,7 @@ namespace Tensorflow.Eager | |||
| op_name, | |||
| inputs.Select(x => (x as EagerTensor).GetTfeTensorHandle()).ToArray(), | |||
| inputs.Length, | |||
| op => wrap_tfe_src.SetOpAttrs(ctx, op, attrs, status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||
| outputs, | |||
| num_outputs, | |||
| status); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Linq; | |||
| using System; | |||
| using static Tensorflow.OpDef.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| @@ -10,8 +11,9 @@ namespace Tensorflow.Eager | |||
| /// </summary> | |||
| public partial class wrap_tfe_src | |||
| { | |||
| public static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, Status out_status) | |||
| public static void SetOpAttrs(TFE_Op op, params object[] attrs) | |||
| { | |||
| using var status = new Status(); | |||
| var len = attrs.Length; | |||
| for (int i = 0; i < len; i += 2) | |||
| { | |||
| @@ -19,13 +21,13 @@ namespace Tensorflow.Eager | |||
| var value = attrs[i + 1]; | |||
| byte is_list = 0; | |||
| var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status); | |||
| if (!out_status.ok()) return; | |||
| var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status); | |||
| if (!status.ok()) return; | |||
| if (is_list != 0) | |||
| SetOpAttrList(ctx, op, key, value, type, null, out_status); | |||
| SetOpAttrList(tf.context, op, key, value, type, null, status); | |||
| else | |||
| SetOpAttrScalar(ctx, op, key, value, type, null, out_status); | |||
| out_status.Check(true); | |||
| SetOpAttrScalar(tf.context, op, key, value, type, null, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| @@ -165,7 +165,7 @@ namespace Tensorflow | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Pack", name, | |||
| values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "axis", axis } , status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, "axis", axis), | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| @@ -421,11 +421,8 @@ namespace Tensorflow | |||
| "Shape", name, new IntPtr[] | |||
| { | |||
| input as EagerTensor, | |||
| }, 1, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||
| { | |||
| "out_type", out_type | |||
| }, status), | |||
| }, 1, | |||
| op => wrap_tfe_src.SetOpAttrs(op, "out_type", out_type), | |||
| status); | |||
| status.Check(true); | |||
| return tensor; | |||
| @@ -531,14 +528,12 @@ namespace Tensorflow | |||
| end as EagerTensor, | |||
| strides as EagerTensor, | |||
| }, 4, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||
| { | |||
| op => wrap_tfe_src.SetOpAttrs(op, | |||
| "begin_mask", begin_mask, | |||
| "end_mask", end_mask, | |||
| "ellipsis_mask", ellipsis_mask, | |||
| "new_axis_mask", new_axis_mask, | |||
| "shrink_axis_mask", shrink_axis_mask | |||
| }, status), | |||
| "shrink_axis_mask", shrink_axis_mask), | |||
| status); | |||
| status.Check(true); | |||
| return tensor; | |||
| @@ -44,13 +44,13 @@ namespace Tensorflow | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| EagerTensorHandle _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "AddN", name, | |||
| inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, | |||
| null, | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(_result); | |||
| return _result; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); | |||
| @@ -132,17 +132,17 @@ namespace Tensorflow | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Mean", name, | |||
| new IntPtr[] | |||
| { | |||
| input as EagerTensor, | |||
| axis as EagerTensor | |||
| }, 2, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| return tensor; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | |||
| @@ -185,10 +185,7 @@ namespace Tensorflow | |||
| input as EagerTensor, | |||
| axis as EagerTensor | |||
| }, 2, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||
| { | |||
| "keep_dims", keep_dims | |||
| }, status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||
| status); | |||
| status.Check(true); | |||
| return tensor; | |||
| @@ -232,14 +229,14 @@ namespace Tensorflow | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| EagerTensorHandle _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Add", name, new IntPtr[] | |||
| { | |||
| x as EagerTensor, | |||
| y as EagerTensor | |||
| }, 2, null, status); | |||
| status.Check(true); | |||
| return new EagerTensor(_result); | |||
| return _result; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | |||
| @@ -273,14 +270,14 @@ namespace Tensorflow | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "AddV2", name, new IntPtr[] | |||
| { | |||
| x as EagerTensor, | |||
| y as EagerTensor | |||
| }, 2, null, status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| return tensor; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("AddV2", name, args: new { x, y }); | |||
| @@ -574,6 +571,18 @@ namespace Tensorflow | |||
| /// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||
| public static Tensor square(Tensor x, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Square", name, new IntPtr[] | |||
| { | |||
| x as EagerTensor, | |||
| }, 1, null, status); | |||
| status.Check(true); | |||
| return tensor; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); | |||
| return _op.outputs[0]; | |||
| @@ -633,7 +642,7 @@ namespace Tensorflow | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Cast", name, | |||
| new IntPtr[] { x as EagerTensor }, 1, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "DstT", DstT, "Truncate", Truncate }, status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, "DstT", DstT, "Truncate", Truncate), | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| @@ -918,11 +927,9 @@ namespace Tensorflow | |||
| a as EagerTensor, | |||
| b as EagerTensor | |||
| }, 2, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||
| { | |||
| op => wrap_tfe_src.SetOpAttrs(op, | |||
| "transpose_a", transpose_a, | |||
| "transpose_b", transpose_b | |||
| }, status), | |||
| "transpose_b", transpose_b), | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| @@ -1049,7 +1056,7 @@ namespace Tensorflow | |||
| input as EagerTensor, | |||
| axis as EagerTensor | |||
| }, 2, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| @@ -13,6 +13,9 @@ | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -36,6 +39,23 @@ namespace Tensorflow | |||
| if (!seed2.HasValue) | |||
| seed2 = 0; | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "RandomStandardNormal", name, new IntPtr[] | |||
| { | |||
| shape as EagerTensor, | |||
| }, 1, | |||
| op => wrap_tfe_src.SetOpAttrs(op, | |||
| "seed", seed, | |||
| "seed2", seed2, | |||
| "dtype", dtype), | |||
| status); | |||
| status.Check(true); | |||
| return tensor; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", | |||
| name: name, | |||
| args: new { shape, dtype, seed, seed2 }); | |||
| @@ -25,6 +25,25 @@ namespace Tensorflow | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public static Operation assign_sub_variable_op(Tensor resource, Tensor value, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "AssignSubVariableOp", name, | |||
| new IntPtr[] | |||
| { | |||
| resource as EagerTensor, | |||
| value as EagerTensor | |||
| }, 2, null, status); | |||
| status.Check(true); | |||
| return tensor; | |||
| } | |||
| return null; | |||
| } | |||
| public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| @@ -51,12 +70,12 @@ namespace Tensorflow | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "VarIsInitializedOp", name, | |||
| new IntPtr[] { resource as EagerTensor }, | |||
| 1, null, status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| return tensor; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | |||
| @@ -79,18 +98,16 @@ namespace Tensorflow | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| using var status = new Status(); | |||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "VarHandleOp", name, null, 0, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||
| { | |||
| op => wrap_tfe_src.SetOpAttrs(op, | |||
| "container", container, | |||
| "shared_name", shared_name, | |||
| "dtype", dtype, | |||
| "shape", shape.dims | |||
| }, status), | |||
| "shape", shape.dims), | |||
| status); | |||
| status.Check(true); | |||
| return new EagerTensor(tensor); | |||
| return tensor; | |||
| } | |||
| var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new { | |||
| @@ -118,7 +135,7 @@ namespace Tensorflow | |||
| EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "ReadVariableOp", name, | |||
| new IntPtr[] { resource as EagerTensor }, 1, | |||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, status), | |||
| op => wrap_tfe_src.SetOpAttrs(op, "dtype", dtype), | |||
| status); | |||
| status.Check(true); | |||
| return tensor; | |||
| @@ -47,6 +47,7 @@ namespace Tensorflow | |||
| var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | |||
| var mul = rnd * stddev_tensor; | |||
| var value = math_ops.add(mul, mean_tensor, name: name); | |||
| // tensor_util.maybe_set_static_shape(value, shape) | |||
| return value; | |||
| }); | |||
| } | |||
| @@ -0,0 +1,37 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class ResourceVariable | |||
| { | |||
| /// <summary> | |||
| /// Subtracts a value from this variable. | |||
| /// </summary> | |||
| /// <param name="delta"></param> | |||
| /// <param name="use_locking"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="read_value"></param> | |||
| public void assign_sub(Tensor delta, bool use_locking = false, string name = null, bool read_value = true) | |||
| { | |||
| gen_resource_variable_ops.assign_sub_variable_op(handle, delta, name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -11,22 +11,52 @@ namespace TensorFlowNET.UnitTest.Training | |||
| [TestClass] | |||
| public class BasicLinearModel | |||
| { | |||
| int NUM_EXAMPLES = 1000; | |||
| /// <summary> | |||
| /// Linear Regression without tf.train.Optimizer | |||
| /// https://www.tensorflow.org/tutorials/customization/custom_training | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void FitLinear() | |||
| public void LinearRegression() | |||
| { | |||
| // Initialize the weights to `5.0` and the bias to `0.0` | |||
| // In practice, these should be initialized to random values (for example, with `tf.random.normal`) | |||
| var W = tf.Variable(5.0f); | |||
| var b = tf.Variable(0.0); | |||
| var b = tf.Variable(0.0f); | |||
| // Define linear model | |||
| Func<Tensor, Tensor> model = (x) => W * x + b; | |||
| // Define the loss function | |||
| Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y) | |||
| => tf.reduce_mean(tf.square(target_y - predicted_y)); | |||
| int NUM_EXAMPLES = 1000; | |||
| float TRUE_W = 3.0f; | |||
| float TRUE_b = 2.0f; | |||
| var inputs = tf.random.normal(shape: NUM_EXAMPLES); | |||
| var noise = tf.random.normal(shape: NUM_EXAMPLES); | |||
| var outputs = inputs * TRUE_W + TRUE_b + noise; | |||
| print($"Current loss: {loss(model(inputs), outputs).numpy()}"); | |||
| // define linear model | |||
| Func<NDArray, Tensor> model = (x) => W * x + b; | |||
| // Define a training loop | |||
| Action<Tensor, Tensor, float> train = (inputs, outputs, learning_rate) | |||
| => | |||
| { | |||
| using var t = tf.GradientTape(); | |||
| var current_loss = loss(outputs, model(inputs)); | |||
| var (dW, db) = t.gradient(current_loss, (W, b)); | |||
| W.assign_sub(learning_rate * dW); | |||
| b.assign_sub(learning_rate * db); | |||
| }; | |||
| // var inputs = tf.random.normal(shape =[NUM_EXAMPLES]); | |||
| // noise = tf.random.normal(shape =[NUM_EXAMPLES]) | |||
| // outputs = inputs * TRUE_W + TRUE_b + noise | |||
| var epochs = range(10); | |||
| foreach(var epoch in epochs) | |||
| { | |||
| train(inputs, outputs, 0.1f); | |||
| print($"Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f"); | |||
| } | |||
| } | |||
| } | |||
| } | |||