Browse Source

add code_generation files

master^2
shuaigezhu 3 years ago
parent
commit
028551cd62
4 changed files with 32 additions and 11 deletions
  1. +2
    -1
      modelscope/models/nlp/__init__.py
  2. +2
    -3
      modelscope/models/nlp/codegeex/codegeex_for_code_generation.py
  3. +10
    -3
      modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py
  4. +18
    -4
      modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py

+ 2
- 1
modelscope/models/nlp/__init__.py View File

@@ -109,7 +109,8 @@ else:
'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'],
'mglm': ['MGLMForTextSummarization'],
'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'codegeex':
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
}


+ 2
- 3
modelscope/models/nlp/codegeex/codegeex_for_code_generation.py View File

@@ -65,7 +65,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
bad_ids = None
lang = input['language']
prompt = input['prompt']
prompt = f"# language: {lang}\n{prompt}"
prompt = f'# language: {lang}\n{prompt}'
logger = get_logger()
tokenizer = self.tokenizer
model = self.model
@@ -83,8 +83,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
topk=1,
topp=0.9,
temperature=0.9,
greedy=True
)
greedy=True)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]


+ 10
- 3
modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py View File

@@ -21,7 +21,7 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
*args,
**kwargs):
model = CodeGeeXForCodeGeneration(model) if isinstance(model,
str) else model
str) else model
self.model = model
self.model.eval()
self.model.half()
@@ -38,8 +38,15 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
for para in ['prompt', 'language']:
if para not in inputs:
raise Exception('Please check your input format.')
if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]: # noqa
raise Exception(
'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa

return self.model(inputs)



+ 18
- 4
modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py View File

@@ -38,11 +38,25 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
for para in ['prompt', 'source language', 'target language']:
if para not in inputs:
raise Exception('please check your input format.')
if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['source language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]:
raise Exception(
'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa

if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['target language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]:
raise Exception(
'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa

return self.model(inputs)



Loading…
Cancel
Save