From c8ddb8792d51a6fafba18241e274b5871a96f99b Mon Sep 17 00:00:00 2001 From: carb0n <58676303+carb0n@users.noreply.github.com> Date: Thu, 20 Aug 2020 19:43:54 -0400 Subject: [PATCH] fix rot90_3d --- .../Operations/control_flow_ops.cs | 43 ++++++++++++++++++- .../Operations/image_ops_impl.cs | 14 +++--- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index ec3824ed..631c97e9 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -1,4 +1,4 @@ -/***************************************************************************** +/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -226,6 +226,47 @@ namespace Tensorflow }); } + internal static Tensor _case_helper(Func cond_fn, Tensor[] pred_fn_pairs, Func callable_default, bool exclusive, string name, + bool allow_python_preds = false) + { + /* + (Tensor[] predicates, Tensor[] actions) = _case_verify_and_canonicalize_args( + pred_fn_pairs, exclusive, name, allow_python_preds); + return tf_with(ops.name_scope(name, "case", new [] {predicates}), delegate + { + if (callable_default == null) + { + (callable_default, predicates, actions) = _case_create_default_action( + predicates, actions); + } + var fn = callable_default; + }); + */ + + throw new NotImplementedException("_case_helper"); + } + + internal static (Func, Tensor[], Tensor[]) _case_create_default_action(Tensor[] predicates, Tensor[] actions) + { + throw new NotImplementedException("_case_create_default_action"); + } + + internal static (Tensor[], Tensor[]) _case_verify_and_canonicalize_args(Tensor[] pred_fn_pairs, bool exclusive, string name, bool allow_python_preds) + { + throw new NotImplementedException("_case_verify_and_canonicalize_args"); + } + + public static Tensor case_v2(Tensor[] pred_fn_pairs, Func callable_default = null, bool exclusive = false, bool strict = false, string name = "case") + => _case_helper( + cond_fn: (Tensor x) => cond(x), + pred_fn_pairs, + default, + exclusive, + name, + allow_python_preds: false//, + //strict: strict + ); + /// /// Produces the content of `output_tensor` only after `dependencies`. /// diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index dc3c0985..11cc6740 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -265,15 +265,13 @@ namespace Tensorflow return gen_array_ops.reverse(array_ops.transpose(image, new [] {1, 0, 2}), new [] {1}); }; - var cases = new [] {new [] {math_ops.equal(k, 1), _rot90()}, - new [] {math_ops.equal(k, 2), _rot180()}, - new [] {math_ops.equal(k, 3), _rot270()}}; + var cases = new [] {math_ops.equal(k, 1), _rot90(), + math_ops.equal(k, 2), _rot180(), + math_ops.equal(k, 3), _rot270()}; - // ! control_flow_ops doesn't have an implementation for case yet ! - // var result = control_flow_ops.case(cases, default: () => image, exclusive: true, name: name_scope); - // result.set_shape(new [] {null, null, image.shape.dims[2]}) - // return result - throw new NotImplementedException(); + var result = control_flow_ops.case_v2(cases, callable_default: () => new Tensor[] {image}, exclusive: true, name: name_scope); + result.set_shape(new [] {-1, -1, image.TensorShape.dims[2]}); + return result; } public static Tensor transpose(Tensor image, string name = null)