diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 00254f16..eacde64a 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -266,6 +266,14 @@ class HubApi: logger.info('Create new branch %s' % revision) git_wrapper.new_branch(tmp_dir, revision) git_wrapper.checkout(tmp_dir, revision) + files_in_repo = os.listdir(tmp_dir) + for f in files_in_repo: + if f[0] != '.': + src = os.path.join(tmp_dir, f) + if os.path.isfile(src): + os.remove(src) + else: + shutil.rmtree(src, ignore_errors=True) for f in files_to_save: if f[0] != '.': src = os.path.join(model_dir, f) diff --git a/tests/hub/test_hub_upload.py b/tests/hub/test_hub_upload.py index e1f61467..835aa62b 100644 --- a/tests/hub/test_hub_upload.py +++ b/tests/hub/test_hub_upload.py @@ -7,7 +7,7 @@ import uuid from modelscope.hub.api import HubApi from modelscope.hub.constants import Licenses, ModelVisibility -from modelscope.hub.errors import HTTPError, NotLoginException +from modelscope.hub.errors import GitError, HTTPError, NotLoginException from modelscope.hub.repository import Repository from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger @@ -97,6 +97,17 @@ class HubUploadTest(unittest.TestCase): revision='new_revision/version1') assert os.path.exists(os.path.join(add4_path, 'add4.py')) shutil.rmtree(self.repo_path, ignore_errors=True) + assert os.path.exists(os.path.join(self.finetune_path, 'add3.py')) + os.remove(os.path.join(self.finetune_path, 'add3.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert not os.path.exists(os.path.join(self.repo_path, 'add3.py')) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_upload_non_exists_repo(self): @@ -133,7 +144,7 @@ class HubUploadTest(unittest.TestCase): def test_upload_invalid_repo(self): logger.info('test upload to invalid repo!') self.api.login(TEST_ACCESS_TOKEN1) - with self.assertRaises(HTTPError): + with self.assertRaises((HTTPError, GitError)): self.api.push_model( model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), model_dir=self.finetune_path,