|
|
|
@@ -203,8 +203,9 @@ def get_policy_tokens(prompt_num_for_policy): |
|
|
|
|
|
|
|
# all special tokens definition |
|
|
|
def get_special_tokens(other_tokens): |
|
|
|
special_tokens = ['<go_r>', '<go_b>', '<go_a>', '<go_d>', |
|
|
|
'<eos_u>', '<eos_r>', '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', |
|
|
|
'<sos_u>', '<sos_r>', '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'] \ |
|
|
|
+ db_tokens + other_tokens |
|
|
|
special_tokens = [ |
|
|
|
'<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', |
|
|
|
'<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', |
|
|
|
'<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>' |
|
|
|
] + db_tokens + other_tokens |
|
|
|
return special_tokens |