From 67949251b2935f6cb9a24d197154efa6733aa251 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 11:17:53 -0500 Subject: [PATCH] override graph --- src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs | 4 ++-- src/TensorFlowNET.Core/Sessions/_FetchHandler.cs | 2 +- src/TensorFlowNET.Core/Sessions/_FetchMapper.cs | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index bc1ea0b7..48eddf3b 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -28,9 +28,9 @@ namespace Tensorflow { private Func, object> _contraction_fn; - public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) + public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn, Graph graph = null) { - var g = ops.get_default_graph(); + var g = graph ?? ops.get_default_graph(); foreach(var fetch in fetches) { diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index e1a77d90..b7434089 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -34,7 +34,7 @@ namespace Tensorflow public _FetchHandler(Graph graph, object fetches, Dictionary feeds = null, Action feed_handles = null) { - _fetch_mapper = _FetchMapper.for_fetch(fetches); + _fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph); foreach(var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 534cdcd7..e28b76a1 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -25,7 +25,7 @@ namespace Tensorflow { protected List _unique_fetches = new List(); protected List _value_indices = new List(); - public static _FetchMapper for_fetch(object fetch) + public static _FetchMapper for_fetch(object fetch, Graph graph = null) { var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; @@ -34,7 +34,7 @@ namespace Tensorflow if (fetch.GetType().IsArray) return new _ListFetchMapper(fetches); else - return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0], graph: graph); } public virtual NDArray[] build_results(List values)