You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hub_operation.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import tempfile
  4. import unittest
  5. import uuid
  6. from modelscope.hub.api import HubApi, ModelScopeConfig
  7. from modelscope.hub.constants import Licenses, ModelVisibility
  8. from modelscope.hub.file_download import model_file_download
  9. from modelscope.hub.repository import Repository
  10. from modelscope.hub.snapshot_download import snapshot_download
  11. USER_NAME = 'maasadmin'
  12. PASSWORD = '12345678'
  13. model_chinese_name = '达摩卡通化模型'
  14. model_org = 'unittest'
  15. DEFAULT_GIT_PATH = 'git'
  16. download_model_file_name = 'test.bin'
  17. class HubOperationTest(unittest.TestCase):
  18. def setUp(self):
  19. self.old_cwd = os.getcwd()
  20. self.api = HubApi()
  21. # note this is temporary before official account management is ready
  22. self.api.login(USER_NAME, PASSWORD)
  23. self.model_name = uuid.uuid4().hex
  24. self.model_id = '%s/%s' % (model_org, self.model_name)
  25. self.api.create_model(
  26. model_id=self.model_id,
  27. chinese_name=model_chinese_name,
  28. visibility=ModelVisibility.PUBLIC,
  29. license=Licenses.APACHE_V2)
  30. temporary_dir = tempfile.mkdtemp()
  31. self.model_dir = os.path.join(temporary_dir, self.model_name)
  32. repo = Repository(self.model_dir, clone_from=self.model_id)
  33. os.chdir(self.model_dir)
  34. os.system("echo 'testtest'>%s"
  35. % os.path.join(self.model_dir, 'test.bin'))
  36. repo.push('add model', all_files=True)
  37. def tearDown(self):
  38. os.chdir(self.old_cwd)
  39. self.api.delete_model(model_id=self.model_id)
  40. def test_model_repo_creation(self):
  41. # change to proper model names before use
  42. try:
  43. info = self.api.get_model(model_id=self.model_id)
  44. assert info['Name'] == self.model_name
  45. except KeyError as ke:
  46. if ke.args[0] == 'name':
  47. print(f'model {self.model_name} already exists, ignore')
  48. else:
  49. raise
  50. def test_download_single_file(self):
  51. downloaded_file = model_file_download(
  52. model_id=self.model_id, file_path=download_model_file_name)
  53. assert os.path.exists(downloaded_file)
  54. mdtime1 = os.path.getmtime(downloaded_file)
  55. # download again
  56. downloaded_file = model_file_download(
  57. model_id=self.model_id, file_path=download_model_file_name)
  58. mdtime2 = os.path.getmtime(downloaded_file)
  59. assert mdtime1 == mdtime2
  60. def test_snapshot_download(self):
  61. snapshot_path = snapshot_download(model_id=self.model_id)
  62. downloaded_file_path = os.path.join(snapshot_path,
  63. download_model_file_name)
  64. assert os.path.exists(downloaded_file_path)
  65. mdtime1 = os.path.getmtime(downloaded_file_path)
  66. # download again
  67. snapshot_path = snapshot_download(model_id=self.model_id)
  68. mdtime2 = os.path.getmtime(downloaded_file_path)
  69. assert mdtime1 == mdtime2
  70. if __name__ == '__main__':
  71. unittest.main()