diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index bc1ea0b7..48eddf3b 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -28,9 +28,9 @@ namespace Tensorflow { private Func, object> _contraction_fn; - public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) + public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn, Graph graph = null) { - var g = ops.get_default_graph(); + var g = graph ?? ops.get_default_graph(); foreach(var fetch in fetches) { diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index e1a77d90..b7434089 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -34,7 +34,7 @@ namespace Tensorflow public _FetchHandler(Graph graph, object fetches, Dictionary feeds = null, Action feed_handles = null) { - _fetch_mapper = _FetchMapper.for_fetch(fetches); + _fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph); foreach(var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 534cdcd7..e28b76a1 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -25,7 +25,7 @@ namespace Tensorflow { protected List _unique_fetches = new List(); protected List _value_indices = new List(); - public static _FetchMapper for_fetch(object fetch) + public static _FetchMapper for_fetch(object fetch, Graph graph = null) { var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; @@ -34,7 +34,7 @@ namespace Tensorflow if (fetch.GetType().IsArray) return new _ListFetchMapper(fetches); else - return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0], graph: graph); } public virtual NDArray[] build_results(List values)