| @@ -20,6 +20,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | |||||
| => new Constant<T>(value, dtype: dtype, verify_shape: verify_shape); | |||||
| public IInitializer zeros_initializer => new Zeros(); | public IInitializer zeros_initializer => new Zeros(); | ||||
| public IInitializer ones_initializer => new Ones(); | public IInitializer ones_initializer => new Ones(); | ||||
| public IInitializer glorot_uniform_initializer => new GlorotUniform(); | public IInitializer glorot_uniform_initializer => new GlorotUniform(); | ||||
| @@ -0,0 +1,55 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| namespace Tensorflow.Operations.Initializers | |||||
| { | |||||
| public class Constant<T> : IInitializer | |||||
| { | |||||
| TF_DataType dtype; | |||||
| T value; | |||||
| bool _verify_shape; | |||||
| public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | |||||
| { | |||||
| this.value = value; | |||||
| this.dtype = dtype; | |||||
| _verify_shape = verify_shape; | |||||
| } | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
| { | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = this.dtype; | |||||
| if (!verify_shape.HasValue) | |||||
| verify_shape = _verify_shape; | |||||
| return constant_op._constant_impl(value, dtype, shape, | |||||
| name: "Const", | |||||
| verify_shape: verify_shape.Value, | |||||
| allow_broadcast: false); | |||||
| } | |||||
| public object get_config() | |||||
| { | |||||
| return new | |||||
| { | |||||
| value, | |||||
| dtype = dtype.name() | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public interface IInitializer | public interface IInitializer | ||||
| { | { | ||||
| Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); | |||||
| Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null); | |||||
| object get_config(); | object get_config(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
| { | { | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = this.dtype; | dtype = this.dtype; | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
| { | { | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = this.dtype; | dtype = this.dtype; | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
| { | { | ||||
| return random_ops.random_uniform(shape, | return random_ops.random_uniform(shape, | ||||
| minval: minval, | minval: minval, | ||||
| @@ -34,7 +34,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | |||||
| { | { | ||||
| return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); | return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| _dtype = dtype; | _dtype = dtype; | ||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | |||||
| { | { | ||||
| var (fan_in, fan_out) = _compute_fans(shape); | var (fan_in, fan_out) = _compute_fans(shape); | ||||
| if (_mode == "fan_in") | if (_mode == "fan_in") | ||||
| @@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
| { | { | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = this.dtype; | dtype = this.dtype; | ||||
| @@ -155,7 +155,7 @@ namespace Tensorflow | |||||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | ||||
| public static explicit operator int(TensorShape shape) => shape.size; | public static explicit operator int(TensorShape shape) => shape.size; | ||||
| public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
| public static implicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
| public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | ||||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | ||||