Browse Source

Check and refine the code.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
3a6a59e18c
31 changed files with 194 additions and 241 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  5. +10
    -9
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  7. +71
    -90
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  8. +1
    -1
      src/TensorFlowNET.Core/DisposableObject.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Exceptions/AssertionError.cs
  10. +3
    -3
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  11. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs
  13. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs
  17. +8
    -8
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs
  19. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs
  20. +3
    -3
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs
  22. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs
  23. +10
    -11
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs
  25. +0
    -11
      src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs
  26. +58
    -8
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  27. +0
    -66
      src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs
  28. +1
    -2
      src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs
  29. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  30. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs
  31. +4
    -4
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs

+ 3
- 3
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -12,9 +12,9 @@ namespace Tensorflow.Checkpoint;
public static class CheckPointUtils public static class CheckPointUtils
{ {
private static string _ESCAPE_CHAR = "."; private static string _ESCAPE_CHAR = ".";
public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>,
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
{ {
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new(); Dictionary<Trackable, string> object_names = new();
@@ -149,4 +149,4 @@ public static class CheckPointUtils
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// } // }
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs View File

@@ -2,4 +2,4 @@


public record class CheckpointOptions( public record class CheckpointOptions(
string? experimental_io_device = null, string? experimental_io_device = null,
bool experimental_enable_async_checkpoint = false);
bool experimental_enable_async_checkpoint = false);

+ 2
- 2
src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs View File

@@ -45,7 +45,7 @@ public class ObjectGraphView: TrackableView, ICloneable
get => _attached_dependencies; get => _attached_dependencies;
} }


public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{ {
return base._descendants_with_paths(); return base._descendants_with_paths();
} }
@@ -61,4 +61,4 @@ public class ObjectGraphView: TrackableView, ICloneable
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
}
}

+ 2
- 2
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -58,7 +58,7 @@ namespace Tensorflow.Checkpoint
return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); return (serialized_tensors, feed_additions, registered_savers, object_graph_proto);
} }


private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
private static (IList<TrackableData>, IDictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
{ {
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new(); Dictionary<Trackable, string> object_names = new();
@@ -173,7 +173,7 @@ namespace Tensorflow.Checkpoint


tensor_dict[checkpoint_key] = maybe_tensor; tensor_dict[checkpoint_key] = maybe_tensor;


if(maybe_tensor.GetValueA() is SaveSpec)
if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>())
{ {
throw new NotImplementedException(); throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;


+ 10
- 9
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Checkpoint;


public static class SaveUtilV1 public static class SaveUtilV1
{ {
public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names,
public static (IDictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names,
IDictionary<Trackable, Trackable>? object_map = null) IDictionary<Trackable, Trackable>? object_map = null)
{ {
// According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md,
@@ -44,7 +44,7 @@ public static class SaveUtilV1
return (checkpoint_factory_map, null); return (checkpoint_factory_map, null);
} }


public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view,
public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
object? saveables_cache = null) object? saveables_cache = null)
{ {
@@ -73,7 +73,7 @@ public static class SaveUtilV1
} }
} }


public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view,
public static (IList<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{ {
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
@@ -129,7 +129,8 @@ public static class SaveUtilV1
return object_graph_proto; return object_graph_proto;
} }


private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects,
private static (IList<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(
IList<Trackable> trackable_objects,
TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map,
bool call_with_mapped_captures, object? saveables_cache = null) bool call_with_mapped_captures, object? saveables_cache = null)
@@ -150,7 +151,7 @@ public static class SaveUtilV1
return (named_saveable_objects, feed_additions, null); return (named_saveable_objects, feed_additions, null);
} }


public static (List<MySaveableObject>, object?) generate_saveable_objects(
public static (IList<MySaveableObject>, object?) generate_saveable_objects(
IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map,
TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
@@ -178,13 +179,13 @@ public static class SaveUtilV1


// TODO: oneflow python has a process with callable `saveable_factory`. // TODO: oneflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new(); List<MySaveableObject> saveables = new();
if (maybe_saveable.DataType == typeof(MySaveableObject))
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
{ {
saveables.Add(maybe_saveable.GetValueB());
saveables.Add(s);
} }
else else
{ {
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key));
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key));
} }


foreach (var saveable in saveables) foreach (var saveable in saveables)
@@ -219,4 +220,4 @@ public record class CheckpointFactoryData
Maybe<BaseResourceVariable, MySaveableObject> factory, Maybe<BaseResourceVariable, MySaveableObject> factory,
string name, string name,
string checkpoint_key string checkpoint_key
);
);

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -52,7 +52,7 @@ public class TrackableView
/// Returns a list of all nodes and its paths from self.root using a breadth first traversal. /// Returns a list of all nodes and its paths from self.root using a breadth first traversal.
/// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths
/// </summary> /// </summary>
protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths()
protected (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths()
{ {
List<Trackable> bfs_sorted = new(); List<Trackable> bfs_sorted = new();
Queue<Trackable> to_visit = new(); Queue<Trackable> to_visit = new();


+ 71
- 90
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -14,112 +14,91 @@ using Tensorflow.Training;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using System.Xml.Linq; using System.Xml.Linq;
using System.Diagnostics; using System.Diagnostics;
using RestoreFunc = System.Func<object, object>;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
{ {
/// <summary>
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions.
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users.
/// </summary>
public interface IFunctionHolder
{
int ArgCount { get; }
object DynamicInvoke(params object[] args);
}
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder
{
public int ArgCount => 0;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
public TR Invoke()
{
return Func.Invoke();
}
}
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder
{
public int ArgCount => 1;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder
{
public int ArgCount => 2;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder
{
public int ArgCount => 3;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
public class Maybe<TA, TB> public class Maybe<TA, TB>
{ {
private TA? _valueA = default(TA); private TA? _valueA = default(TA);
private TB? _valueB = default(TB); private TB? _valueB = default(TB);
private Type _type; private Type _type;
private bool _assigned = false;
private bool _assignedTA;
public Maybe(TA value) public Maybe(TA value)
{ {
_valueA = value; _valueA = value;
_type= typeof(TA); _type= typeof(TA);
_assigned = true;
_assignedTA = true;
} }
public Maybe(TB value) public Maybe(TB value)
{ {
_valueB = value; _valueB = value;
_type = typeof(TB); _type = typeof(TB);
_assigned = true;
_assignedTA = false;
} }


public Type DataType => _type; public Type DataType => _type;


public TA GetValueA()
/// <summary>
/// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned.
/// It returns
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="res"></param>
/// <returns></returns>
public bool TryGet<T>(out T? res)
{ {
if(!_assigned || DataType != typeof(TA))
if(_valueA is T && _valueB is not T)
{ {
throw new TypeError("Cannot get the data because of wrong specified type.");
res = (T)(object)_valueA;
return _assignedTA;
} }
return _valueA;
}
public TB GetValueB()
{
if (!_assigned || DataType != typeof(TB))
else if(_valueA is not T && _valueB is T)
{ {
throw new TypeError("Cannot get the data because of wrong specified type.");
res = (T)(object)_valueB;
return !_assignedTA;
} }
return _valueB;
res = default(T);
return false;
} }
public object GetValue()

public bool IsTypeOrDeriveFrom<T>()
{ {
if (!_assigned)
if (_valueA is T && _valueB is not T)
{ {
throw new TypeError("Cannot get the data because of wrong specified type.");
return _assignedTA;
} }
if(DataType == typeof(TA) && _valueA is not null)
else if (_valueA is not T && _valueB is T)
{ {
return _valueA;
return !_assignedTA;
} }
else if(DataType == typeof(TB) && _valueB is not null)
else if (_valueA is T && _valueB is T)
{ {
return _valueB;
return true;
} }
else if(DataType == typeof(TA))
else
{ {
return _valueA;
return false;
}
}

public T GetValue<T>()
{
if (_valueA is T && _valueB is not T)
{
return (T)(object)_valueA;
}
else if (_valueA is not T && _valueB is T)
{
return (T)(object)_valueB;
}
else if (_valueA is T && _valueB is T)
{
throw new TypeError("The type is vague, this is always because TA and TB both derive from T.");
} }
else else
{ {
return _valueB;
throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}.");
} }
} }


@@ -170,9 +149,8 @@ namespace Tensorflow.Checkpoint
{ {
var slice_spec = slice.Key; var slice_spec = slice.Key;
var maybe_tensor = slice.Value; var maybe_tensor = slice.Value;
if(maybe_tensor.DataType == typeof(SaveSpec))
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
{ {
var spec = maybe_tensor.GetValueB();
var tensor_value = spec.tensor; var tensor_value = spec.tensor;
if (tensor_value is not null) if (tensor_value is not null)
{ {
@@ -183,7 +161,7 @@ namespace Tensorflow.Checkpoint
} }
else else
{ {
var tensor = maybe_tensor.GetValueA();
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_names.Add(checkpoint_key); tensor_names.Add(checkpoint_key);
tensors.Add(tensor); tensors.Add(tensor);
slice_specs.Add(slice_spec); slice_specs.Add(slice_spec);
@@ -215,16 +193,15 @@ namespace Tensorflow.Checkpoint
var slice_spec = slice.Key; var slice_spec = slice.Key;
var maybe_tensor = slice.Value; var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed. // TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
{ {
var spec = maybe_tensor.GetValueB();
tensor_dtypes.Add(spec.dtype); tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec); slice_specs.Add(spec.slice_spec);
tensor_names.Add(spec.name); tensor_names.Add(spec.name);
} }
else else
{ {
var tensor = maybe_tensor.GetValueA();
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_dtypes.Add(tensor.dtype); tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec); slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key); tensor_names.Add(checkpoint_key);
@@ -268,9 +245,9 @@ namespace Tensorflow.Checkpoint
public class MultiDeviceSaver public class MultiDeviceSaver
{ {
private Dictionary<string, SingleDeviceSaver> _single_device_savers; private Dictionary<string, SingleDeviceSaver> _single_device_savers;
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers;
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn;
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys;
private IDictionary<string, (RestoreFunc, RestoreFunc)> _registered_savers;
private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn;
private Dictionary<RestoreFunc, IList<(string, string)>> _restore_fn_to_keys;
/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
@@ -280,24 +257,28 @@ namespace Tensorflow.Checkpoint
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{ {
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>();
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>();
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>();
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
foreach(var pair in serialized_tensors) foreach(var pair in serialized_tensors)
{ {
var obj = pair.Key; var obj = pair.Key;
var tensor_dict = pair.Value; var tensor_dict = pair.Value;
IFunctionHolder restore_fn;
RestoreFunc restore_fn;
if(obj == Trackable.None) if(obj == Trackable.None)
{ {
restore_fn = new FunctionHolder<object?>(() => null);
restore_fn = new RestoreFunc(x => null);
} }
else else
{ {
restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x =>
restore_fn = new RestoreFunc(x =>
{ {
return obj._restore_from_tensors(x);
if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>)
{
return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>);
}
throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}.");
}); });
} }


@@ -305,14 +286,14 @@ namespace Tensorflow.Checkpoint
{ {
var checkpoint_key = item.Key; var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor; IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.DataType != typeof(IDictionary<string, Tensor>))
if(item.Value.TryGet<Tensor>(out var t))
{ {
spec_to_tensor = new Dictionary<string, Tensor>(); spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = item.Value.GetValueA();
spec_to_tensor[""] = t;
} }
else else
{ {
spec_to_tensor = item.Value.GetValueB();
spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>();
} }


foreach(var spec in spec_to_tensor) foreach(var spec in spec_to_tensor)
@@ -342,7 +323,7 @@ namespace Tensorflow.Checkpoint


_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value));


_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>();
_registered_savers = new Dictionary<string, (RestoreFunc, RestoreFunc)>();
if(registered_savers is not null && registered_savers.Count > 0) if(registered_savers is not null && registered_savers.Count > 0)
{ {
// TODO: complete the implementation. // TODO: complete the implementation.
@@ -418,8 +399,8 @@ namespace Tensorflow.Checkpoint


IDictionary<string, Operation> restore_func() IDictionary<string, Operation> restore_func()
{ {
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new(); Dictionary<string, Operation> restore_ops = new();


foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key))
@@ -449,7 +430,7 @@ namespace Tensorflow.Checkpoint
} }
else else
{ {
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor;
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
} }
} }
else else


+ 1
- 1
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -158,4 +158,4 @@ namespace Tensorflow
Dispose(false); Dispose(false);
} }
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Exceptions/AssertionError.cs View File

@@ -11,4 +11,4 @@ public class AssertionError : TensorflowException
{ {


} }
}
}

+ 3
- 3
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -37,10 +37,10 @@ namespace Tensorflow.Train
var properties = this.GetType().GetProperties(); var properties = this.GetType().GetProperties();
foreach ( var property in properties ) foreach ( var property in properties )
{ {
string name = property.Name;
object value = property.GetValue(this, null);
if(value is Function || value is ConcreteFunction)
if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction))
{ {
string name = property.Name;
object value = property.GetValue(this, null);
functions[name] = (Trackable)value; functions[name] = (Trackable)value;
} }
} }


+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -25,9 +25,9 @@ namespace Tensorflow
{ {
get get
{ {
if(_op.DataType == typeof(Tensor))
if(_op.TryGet<Tensor>(out var tensor))
{ {
return _op.GetValueA();
return tensor;
} }
else else
{ {


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs View File

@@ -8,4 +8,4 @@ public record class AssetInfo
Dictionary<object, object> asset_initializers_by_resource, Dictionary<object, object> asset_initializers_by_resource,
Dictionary<AssetInfo, string> asset_filename_map, Dictionary<AssetInfo, string> asset_filename_map,
Dictionary<object, object> asset_index Dictionary<object, object> asset_index
);
);

+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs View File

@@ -86,7 +86,7 @@ public class AugmentedGraphView: ObjectGraphView
return concrete_function; return concrete_function;
} }


public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{ {
Trackable get_merged_trackable(Trackable x) Trackable get_merged_trackable(Trackable x)
{ {
@@ -130,4 +130,4 @@ public class AugmentedGraphView: ObjectGraphView
{ {
return _children_cache[obj][name]; return _children_cache[obj][name];
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs View File

@@ -30,4 +30,4 @@ public static class Constants


public static readonly string VARIABLES_DIRECTORY = "variables"; public static readonly string VARIABLES_DIRECTORY = "variables";
public static readonly string VARIABLES_FILENAME = "variables"; public static readonly string VARIABLES_FILENAME = "variables";
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -14,4 +14,4 @@ public class RevivedTypes
// TODO: complete the implementation. // TODO: complete the implementation.
return null; return null;
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs View File

@@ -6,4 +6,4 @@ public enum SaveType
{ {
SAVEDMODEL, SAVEDMODEL,
CHECKPOINT CHECKPOINT
}
}

+ 8
- 8
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -18,13 +18,13 @@ public class SaveableView
{ {
private AugmentedGraphView _augmented_graph_view; private AugmentedGraphView _augmented_graph_view;
private SaveOptions _options; private SaveOptions _options;
private List<Trackable> _trackable_objects;
private IList<Trackable> _trackable_objects;
private List<Trackable> _nodes; private List<Trackable> _nodes;
private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths;
private Dictionary<Trackable, int> _node_ids;
private IDictionary<Trackable, IEnumerable<TrackableReference>> _node_paths;
private IDictionary<Trackable, int> _node_ids;
private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
_slot_variables; _slot_variables;
private Dictionary<Trackable, string> _object_names;
private IDictionary<Trackable, string> _object_names;
private List<object> _gradient_functions; // to be completed private List<object> _gradient_functions; // to be completed
private List<RegisteredGradient> _gradient_defs; // to be completed private List<RegisteredGradient> _gradient_defs; // to be completed
private List<ConcreteFunction> _concrete_functions; private List<ConcreteFunction> _concrete_functions;
@@ -45,7 +45,7 @@ public class SaveableView
{ {
get => _nodes; get => _nodes;
} }
public Dictionary<Trackable, int> NodeIds
public IDictionary<Trackable, int> NodeIds
{ {
get => _node_ids; get => _node_ids;
} }
@@ -53,7 +53,7 @@ public class SaveableView
{ {
get => _gradient_defs; get => _gradient_defs;
} }
public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths
public IDictionary<Trackable, IEnumerable<TrackableReference>> NodePaths
{ {
get => _node_paths; get => _node_paths;
} }
@@ -84,7 +84,7 @@ public class SaveableView


private void initialize_nodes_and_concrete_functions() private void initialize_nodes_and_concrete_functions()
{ {
_nodes = _trackable_objects.ConvertAll(x => x); // deep copy
_nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy
_gradient_functions = new(); _gradient_functions = new();
_gradient_defs = new(); _gradient_defs = new();


@@ -296,4 +296,4 @@ public class SaveableView
proto.Nodes.Add(object_proto); proto.Nodes.Add(object_proto);
} }
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs View File

@@ -7,4 +7,4 @@ public static class TagConstants
public static readonly string EVAL = "eval"; public static readonly string EVAL = "eval";
public static readonly string GPU = "gpu"; public static readonly string GPU = "gpu";
public static readonly string TPU = "tpu"; public static readonly string TPU = "tpu";
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs View File

@@ -19,4 +19,4 @@ public class BuilderUtils
throw new NotImplementedException(); throw new NotImplementedException();
} }
} }
}
}

+ 3
- 3
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -81,8 +81,8 @@ public static partial class SavedModelUtils
return (saved_nodes, node_paths); return (saved_nodes, node_paths);
} }


private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>,
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList<Trackable>,
IDictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
{ {
using (SaveContext.save_context(options)) using (SaveContext.save_context(options))
@@ -266,4 +266,4 @@ public static partial class SavedModelUtils
} }
} }
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs View File

@@ -104,4 +104,4 @@ public class SignatureMap: Trackable


return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value);
} }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs View File

@@ -54,4 +54,4 @@ public static partial class SavedModelUtils
{ {
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY));
} }
}
}

+ 10
- 11
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -136,9 +136,8 @@ namespace Tensorflow
{ {
full_name = name + "_" + attr; full_name = name + "_" + attr;
} }
if(factory.DataType == typeof(ResourceVariable))
if(factory.TryGet<BaseResourceVariable>(out var variable))
{ {
var variable = factory.GetValueA();
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name))
{ {
yield return op; yield return op;
@@ -146,8 +145,8 @@ namespace Tensorflow
} }
else else
{ {
var variable = factory.GetValueB();
foreach (var op in saveable_objects_for_op(variable, variable.name))
var saveable = factory.GetValue<MySaveableObject>();
foreach (var op in saveable_objects_for_op(saveable, saveable.name))
{ {
yield return op; yield return op;
} }
@@ -236,14 +235,14 @@ namespace Tensorflow
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); string spec_name = name + TrackableUtils.escape_local_name(tensor_name);


IDictionary<string, Tensor> internal_dict; IDictionary<string, Tensor> internal_dict;
if(maybe_tensor.DataType == typeof(Tensor))
if(maybe_tensor.TryGet<Tensor>(out var tensor))
{ {
internal_dict= new Dictionary<string, Tensor>(); internal_dict= new Dictionary<string, Tensor>();
internal_dict[""] = maybe_tensor.GetValueA();
internal_dict[""] = tensor;
} }
else else
{ {
internal_dict = maybe_tensor.GetValueB();
internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>();
} }


foreach(var item in internal_dict) foreach(var item in internal_dict)
@@ -287,7 +286,7 @@ namespace Tensorflow
var slice_spec = convert_to_string(spec.slice_spec); var slice_spec = convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec)) if (!string.IsNullOrEmpty(slice_spec))
{ {
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor;
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValue<IDictionary<string, Tensor>>()[slice_spec] = spec.tensor;
} }
else else
{ {
@@ -318,14 +317,14 @@ namespace Tensorflow


var maybe_tensor = restored_tensors[name]; var maybe_tensor = restored_tensors[name];
IDictionary<string, Tensor> dict; IDictionary<string, Tensor> dict;
if(maybe_tensor.DataType == typeof(Tensor))
if(maybe_tensor.TryGet<Tensor>(out var tensor))
{ {
dict = new Dictionary<string, Tensor>(); dict = new Dictionary<string, Tensor>();
dict[""] = maybe_tensor.GetValueA();
dict[""] = tensor;
} }
else else
{ {
dict = maybe_tensor.GetValueB();
dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>();
} }
saveable_restored_tensors.Add(dict[slice_spec]); saveable_restored_tensors.Add(dict[slice_spec]);
} }


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs View File

@@ -38,4 +38,4 @@ public static class Constants
RNN_LAYER_IDENTIFIER, RNN_LAYER_IDENTIFIER,
SEQUENTIAL_IDENTIFIER SEQUENTIAL_IDENTIFIER
}; };
}
}

+ 0
- 11
src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs View File

@@ -1,11 +0,0 @@
namespace Tensorflow.Keras.Saving.SavedModel;

public class KerasObjectWrapper
{
}

public class KerasObjectWrapper<T>
{
public T Item { get; set; }
}

+ 58
- 8
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -3,19 +3,15 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using Google.Protobuf; using Google.Protobuf;
using ICSharpCode.SharpZipLib.Zip;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.ModelSaving; using Tensorflow.ModelSaving;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Exceptions;
using Tensorflow.IO;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using ThirdParty.Tensorflow.Python.Keras.Protobuf; using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Training;



namespace Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow.Keras.Saving.SavedModel;


@@ -108,5 +104,59 @@ public partial class KerasSavedModelUtils
return metadata; return metadata;
} }


}
public static bool should_skip_serialization(object layer)
{
return false;
}

/// <summary>
/// Returns extra trackable objects to attach to the serialized layer.
/// </summary>
/// <param name="layer"></param>
/// <param name="serialization_cache"></param>
/// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs.

// TODO: change the inherits of `Variable` and revise the implmentation.
var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));

Dictionary<string, Trackable> res = new();
res["variables"] = variables;
res["trainable_variables"] = trainable_variables;
res["non_trainable_variables"] = non_trainable_variables;
res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable()));

return res;
}

/// <summary>
/// Returns dict of wrapped layer call function and losses in tf.functions.
/// </summary>
/// <param name="layer"></param>
/// <param name="serialization_cache"></param>
/// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with type `RevivedLayer` and `Sequential`.

// skip the process because of lack of APIs of `Layer`.

return new Dictionary<string, Trackable>();
}
}

+ 0
- 66
src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs View File

@@ -1,66 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training;

namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static bool should_skip_serialization(object layer)
{
return false;
}

/// <summary>
/// Returns extra trackable objects to attach to the serialized layer.
/// </summary>
/// <param name="layer"></param>
/// <param name="serialization_cache"></param>
/// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs.

// TODO: change the inherits of `Variable` and revise the implmentation.
var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));

Dictionary<string, Trackable> res = new();
res["variables"] = variables;
res["trainable_variables"] = trainable_variables;
res["non_trainable_variables"] = non_trainable_variables;
res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable()));

return res;
}

/// <summary>
/// Returns dict of wrapped layer call function and losses in tf.functions.
/// </summary>
/// <param name="layer"></param>
/// <param name="serialization_cache"></param>
/// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with type `RevivedLayer` and `Sequential`.

// skip the process because of lack of APIs of `Layer`.

return new Dictionary<string, Trackable>();
}
}

+ 1
- 2
src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs View File

@@ -34,5 +34,4 @@ public abstract class SavedModelSaver
return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value))
.ToDictionary(x => x.Key, x => x.Value); .ToDictionary(x => x.Key, x => x.Value);
} }
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -162,4 +162,4 @@ public class InputLayerSavedModelSaver: SavedModelSaver
return JsonConvert.SerializeObject(info); return JsonConvert.SerializeObject(info);
} }
} }
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs View File

@@ -44,4 +44,4 @@ public class SaveOptionsContext: IDisposable
{ {
KerasSavedModelUtils.ShouldHaveTraces = _old_value; KerasSavedModelUtils.ShouldHaveTraces = _old_value;
} }
}
}

+ 4
- 4
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs View File

@@ -73,7 +73,7 @@ public class SequentialModelTest
{ {
TrainDir = "mnist", TrainDir = "mnist",
OneHot = false, OneHot = false,
ValidationSize = 10000,
ValidationSize = 50000,
}).Result; }).Result;


model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
@@ -119,13 +119,13 @@ public class SequentialModelTest
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" });


var num_epochs = 1; var num_epochs = 1;
var batch_size = 16;
var batch_size = 8;


var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);


model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);


model.save("./pb_elex_sequential", save_format: "tf");
model.save("./pb_alex_sequential", save_format: "tf");


// The saved model can be test with the following python code: // The saved model can be test with the following python code:
#region alexnet_python_code #region alexnet_python_code
@@ -136,7 +136,7 @@ public class SequentialModelTest
// return -a // return -a


//if __name__ == '__main__': //if __name__ == '__main__':
// model = tf.keras.models.load_model("./pb_elex_sequential")
// model = tf.keras.models.load_model("./pb_alex_sequential")
// model.summary() // model.summary()


// num_classes = 5 // num_classes = 5


Loading…
Cancel
Save