diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py index 4b095042..c20d020c 100644 --- a/modelscope/pipelines/cv/ocr_recognition_pipeline.py +++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py @@ -91,7 +91,8 @@ class OCRRecognitionPipeline(Pipeline): data.append(mask) data = torch.FloatTensor(data).view( - len(data), 1, IMG_HEIGHT, IMG_WIDTH).cuda() / 255. + len(data), 1, IMG_HEIGHT, IMG_WIDTH) / 255. + data = data.to(self.device) result = {'img': data}