Browse Source

add code_generation files

master^2
shuaigezhu 3 years ago
parent
commit
028551cd62
4 changed files with 32 additions and 11 deletions
  1. +2
    -1
      modelscope/models/nlp/__init__.py
  2. +2
    -3
      modelscope/models/nlp/codegeex/codegeex_for_code_generation.py
  3. +10
    -3
      modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py
  4. +18
    -4
      modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py

+ 2
- 1
modelscope/models/nlp/__init__.py View File

@@ -109,7 +109,8 @@ else:
'sentence_embedding': ['SentenceEmbedding'], 'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'], 'T5': ['T5ForConditionalGeneration'],
'mglm': ['MGLMForTextSummarization'], 'mglm': ['MGLMForTextSummarization'],
'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'codegeex':
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'gpt_neo': ['GPTNeoModel'], 'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'], 'bloom': ['BloomModel'],
} }


+ 2
- 3
modelscope/models/nlp/codegeex/codegeex_for_code_generation.py View File

@@ -65,7 +65,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
bad_ids = None bad_ids = None
lang = input['language'] lang = input['language']
prompt = input['prompt'] prompt = input['prompt']
prompt = f"# language: {lang}\n{prompt}"
prompt = f'# language: {lang}\n{prompt}'
logger = get_logger() logger = get_logger()
tokenizer = self.tokenizer tokenizer = self.tokenizer
model = self.model model = self.model
@@ -83,8 +83,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
topk=1, topk=1,
topp=0.9, topp=0.9,
temperature=0.9, temperature=0.9,
greedy=True
)
greedy=True)
is_finished = [False for _ in range(micro_batch_size)] is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream): for i, generated in enumerate(token_stream):
generated_tokens = generated[0] generated_tokens = generated[0]


+ 10
- 3
modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py View File

@@ -21,7 +21,7 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
*args, *args,
**kwargs): **kwargs):
model = CodeGeeXForCodeGeneration(model) if isinstance(model, model = CodeGeeXForCodeGeneration(model) if isinstance(model,
str) else model
str) else model
self.model = model self.model = model
self.model.eval() self.model.eval()
self.model.half() self.model.half()
@@ -38,8 +38,15 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
for para in ['prompt', 'language']: for para in ['prompt', 'language']:
if para not in inputs: if para not in inputs:
raise Exception('Please check your input format.') raise Exception('Please check your input format.')
if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]: # noqa
raise Exception(
'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa


return self.model(inputs) return self.model(inputs)




+ 18
- 4
modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py View File

@@ -38,11 +38,25 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
for para in ['prompt', 'source language', 'target language']: for para in ['prompt', 'source language', 'target language']:
if para not in inputs: if para not in inputs:
raise Exception('please check your input format.') raise Exception('please check your input format.')
if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['source language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]:
raise Exception(
'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa


if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['target language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]:
raise Exception(
'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa


return self.model(inputs) return self.model(inputs)




Loading…
Cancel
Save