From 7b84adc914219afb5eb4173ff80068c31c5b4cdd Mon Sep 17 00:00:00 2001 From: "jiaqi.sjq" Date: Wed, 26 Oct 2022 19:15:43 +0800 Subject: [PATCH] [to #42322933]Fix remove files in local model not take effect to remote repo after push_model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10533214 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10533214 --- modelscope/hub/api.py | 8 ++++++++ tests/hub/test_hub_upload.py | 15 +++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 00254f16..eacde64a 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -266,6 +266,14 @@ class HubApi: logger.info('Create new branch %s' % revision) git_wrapper.new_branch(tmp_dir, revision) git_wrapper.checkout(tmp_dir, revision) + files_in_repo = os.listdir(tmp_dir) + for f in files_in_repo: + if f[0] != '.': + src = os.path.join(tmp_dir, f) + if os.path.isfile(src): + os.remove(src) + else: + shutil.rmtree(src, ignore_errors=True) for f in files_to_save: if f[0] != '.': src = os.path.join(model_dir, f) diff --git a/tests/hub/test_hub_upload.py b/tests/hub/test_hub_upload.py index e1f61467..835aa62b 100644 --- a/tests/hub/test_hub_upload.py +++ b/tests/hub/test_hub_upload.py @@ -7,7 +7,7 @@ import uuid from modelscope.hub.api import HubApi from modelscope.hub.constants import Licenses, ModelVisibility -from modelscope.hub.errors import HTTPError, NotLoginException +from modelscope.hub.errors import GitError, HTTPError, NotLoginException from modelscope.hub.repository import Repository from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger @@ -97,6 +97,17 @@ class HubUploadTest(unittest.TestCase): revision='new_revision/version1') assert os.path.exists(os.path.join(add4_path, 'add4.py')) shutil.rmtree(self.repo_path, ignore_errors=True) + assert os.path.exists(os.path.join(self.finetune_path, 'add3.py')) + os.remove(os.path.join(self.finetune_path, 'add3.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert not os.path.exists(os.path.join(self.repo_path, 'add3.py')) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_upload_non_exists_repo(self): @@ -133,7 +144,7 @@ class HubUploadTest(unittest.TestCase): def test_upload_invalid_repo(self): logger.info('test upload to invalid repo!') self.api.login(TEST_ACCESS_TOKEN1) - with self.assertRaises(HTTPError): + with self.assertRaises((HTTPError, GitError)): self.api.push_model( model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), model_dir=self.finetune_path,