diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 5b1b5713..f553d45b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -40,12 +40,20 @@ namespace Tensorflow.Operations.Initializers public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) { - throw new NotImplementedException(); + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed); } public object get_config() { - throw new NotImplementedException(); + return new + { + mean, + stddev, + seed, + dtype + }; } } } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index cf62ce04..8c9e571e 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -143,6 +143,11 @@ namespace Tensorflow } } + public override string ToString() + { + return shape.ToString(); + } + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index 935b9914..c3201f8c 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.name_scope("define_loss"), scope => { - // model = new YOLOv3(cfg, input_data, trainable); + model = new YOLOv3(cfg, input_data, trainable); }); tf_with(tf.name_scope("define_weight_decay"), scope => diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs index 5125c603..de5f0acc 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs @@ -37,6 +37,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO upsample_method = cfg.YOLO.UPSAMPLE_METHOD; (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); + + tf_with(tf.variable_scope("pred_sbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + + tf_with(tf.variable_scope("pred_mbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + + tf_with(tf.variable_scope("pred_lbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); } private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data)