From 4aee841ab39d739be34a8d438c0b249ce8bc3a80 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 28 Aug 2019 20:10:20 -0500 Subject: [PATCH] Finish RandomNormal.call(). Override TensorShape.ToString(). --- .../Operations/Initializers/RandomNormal.cs | 12 ++++++++++-- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 5 +++++ .../ImageProcessing/YOLO/Main.cs | 2 +- .../ImageProcessing/YOLO/YOLOv3.cs | 15 +++++++++++++++ 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 5b1b5713..f553d45b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -40,12 +40,20 @@ namespace Tensorflow.Operations.Initializers public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) { - throw new NotImplementedException(); + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed); } public object get_config() { - throw new NotImplementedException(); + return new + { + mean, + stddev, + seed, + dtype + }; } } } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index cf62ce04..8c9e571e 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -143,6 +143,11 @@ namespace Tensorflow } } + public override string ToString() + { + return shape.ToString(); + } + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index 935b9914..c3201f8c 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.name_scope("define_loss"), scope => { - // model = new YOLOv3(cfg, input_data, trainable); + model = new YOLOv3(cfg, input_data, trainable); }); tf_with(tf.name_scope("define_weight_decay"), scope => diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs index 5125c603..de5f0acc 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs @@ -37,6 +37,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO upsample_method = cfg.YOLO.UPSAMPLE_METHOD; (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); + + tf_with(tf.variable_scope("pred_sbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + + tf_with(tf.variable_scope("pred_mbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); + + tf_with(tf.variable_scope("pred_lbbox"), scope => + { + // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + }); } private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data)