Browse Source

[to #42322933] use gpu when available

ofa/caption 增加feature, 如果有gpu默认使用gpu
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9228113
master
yichang.zyc yingda.chen 3 years ago
parent
commit
12cc394a95
1 changed files with 13 additions and 7 deletions
  1. +13
    -7
      modelscope/models/multi_modal/image_captioning_model.py

+ 13
- 7
modelscope/models/multi_modal/image_captioning_model.py View File

@@ -1,6 +1,7 @@
import os.path as osp import os.path as osp
from typing import Any, Dict from typing import Any, Dict


import torch.cuda
from PIL import Image from PIL import Image


from modelscope.metainfo import Models from modelscope.metainfo import Models
@@ -26,9 +27,13 @@ class OfaForImageCaptioning(Model):
self.eval_caption = eval_caption self.eval_caption = eval_caption


tasks.register_task('caption', CaptionTask) tasks.register_task('caption', CaptionTask)
use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False
use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\
else False
overrides = { overrides = {
'bpe_dir': bpe_dir, 'bpe_dir': bpe_dir,
'eval_cider': False, 'eval_cider': False,
@@ -39,13 +44,11 @@ class OfaForImageCaptioning(Model):
} }
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(local_model), arg_overrides=overrides) utils.split_paths(local_model), arg_overrides=overrides)

# Move models to GPU # Move models to GPU
for model in models: for model in models:
model.eval() model.eval()
if use_cuda:
model.cuda()
if use_fp16:
model.to(self._device)
if self.use_fp16:
model.half() model.half()
model.prepare_for_inference_(cfg) model.prepare_for_inference_(cfg)
self.models = models self.models = models
@@ -68,6 +71,9 @@ class OfaForImageCaptioning(Model):
self.task = task self.task = task


def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
import fairseq.utils
if torch.cuda.is_available():
input = fairseq.utils.move_to_cuda(input, device=self._device)
results, _ = self.eval_caption(self.task, self.generator, self.models, results, _ = self.eval_caption(self.task, self.generator, self.models,
input) input)
return { return {


Loading…
Cancel
Save