|
- """
- Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
- Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
- """
-
- import importlib
- import jittor as jt
- from jittor import init
- from jittor import nn
-
-
- def find_model_using_name(model_name):
- # Given the option --model [modelname],
- # the file "models/modelname_model.py"
- # will be imported.
- model_filename = "models." + model_name + "_model"
- modellib = importlib.import_module(model_filename)
-
- # In the file, the class called ModelNameModel() will
- # be instantiated. It has to be a subclass of torch.nn.Module,
- # and it is case-insensitive.
- model = None
- target_model_name = model_name.replace('_', '') + 'model'
- for name, cls in modellib.__dict__.items():
- if name.lower() == target_model_name.lower() and issubclass(cls, nn.Module):
- model = cls
-
- if model is None:
- print("In %s.py, there should be a subclass of nn.Module with class name that matches %s in lowercase." % (
- model_filename, target_model_name))
- exit(0)
-
- return model
-
-
- def get_option_setter(model_name):
- model_class = find_model_using_name(model_name)
- return model_class.modify_commandline_options
-
-
- def create_model(opt):
- model = find_model_using_name(opt.model)
- instance = model(opt)
- print("model [%s] was created" % (type(instance).__name__))
-
- return instance
|