Browse Source

Finish RandomNormal.call().

Override TensorShape.ToString().
tags/v0.12
Oceania2018 6 years ago
parent
commit
4aee841ab3
4 changed files with 31 additions and 3 deletions
  1. +10
    -2
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  3. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  4. +15
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs

+ 10
- 2
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -40,12 +40,20 @@ namespace Tensorflow.Operations.Initializers

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
{
throw new NotImplementedException();
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;
return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed);
}

public object get_config()
{
throw new NotImplementedException();
return new
{
mean,
stddev,
seed,
dtype
};
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -143,6 +143,11 @@ namespace Tensorflow
}
}

public override string ToString()
{
return shape.ToString();
}

public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

tf_with(tf.name_scope("define_loss"), scope =>
{
// model = new YOLOv3(cfg, input_data, trainable);
model = new YOLOv3(cfg, input_data, trainable);
});

tf_with(tf.name_scope("define_weight_decay"), scope =>


+ 15
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs View File

@@ -37,6 +37,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
upsample_method = cfg.YOLO.UPSAMPLE_METHOD;

(conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data);

tf_with(tf.variable_scope("pred_sbbox"), scope =>
{
// pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});

tf_with(tf.variable_scope("pred_mbbox"), scope =>
{
// pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});

tf_with(tf.variable_scope("pred_lbbox"), scope =>
{
// pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});
}

private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data)


Loading…
Cancel
Save