|
|
|
@@ -32,7 +32,9 @@ class ActionRecognitionPipeline(Pipeline): |
|
|
|
config_path = osp.join(self.model, ModelFile.CONFIGURATION) |
|
|
|
logger.info(f'loading config from {config_path}') |
|
|
|
self.cfg = Config.from_file(config_path) |
|
|
|
self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() |
|
|
|
self.device = torch.device( |
|
|
|
'cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) |
|
|
|
self.infer_model.eval() |
|
|
|
self.infer_model.load_state_dict(torch.load(model_path)['model_state']) |
|
|
|
self.label_mapping = self.cfg.label_mapping |
|
|
|
@@ -40,7 +42,7 @@ class ActionRecognitionPipeline(Pipeline): |
|
|
|
|
|
|
|
def preprocess(self, input: Input) -> Dict[str, Any]: |
|
|
|
if isinstance(input, str): |
|
|
|
video_input_data = ReadVideoData(self.cfg, input).cuda() |
|
|
|
video_input_data = ReadVideoData(self.cfg, input).to(self.device) |
|
|
|
else: |
|
|
|
raise TypeError(f'input should be a str,' |
|
|
|
f' but got {type(input)}') |
|
|
|
|