Add NS mode
OPT supports newlines, but it also needs some of the behavior we use in S mode. NS mode is a more limited version of S mode that still handles the </s> token, but instead of replacing it with a new line we replace it empty and newlines are not converted. In future if your Fairseq style model has newline support use NS mode, while if it needs artifically inserted newlines use S mode. This also means that people finetuning fairseq models to include newlines might benefit from testing their models on ns mode.
This commit is contained in:
parent
5c4a087970
commit
8376f12e21
|
@ -144,7 +144,7 @@ optlist = [
|
||||||
["OPT 6.7B", "facebook/opt-6.7b", "16GB"],
|
["OPT 6.7B", "facebook/opt-6.7b", "16GB"],
|
||||||
["OPT 2.7B", "facebook/opt-2.7b", "8GB"],
|
["OPT 2.7B", "facebook/opt-2.7b", "8GB"],
|
||||||
["OPT 1.3B", "facebook/opt-1.3b", "4GB"],
|
["OPT 1.3B", "facebook/opt-1.3b", "4GB"],
|
||||||
["OPT 355M", "facebook/opt-350m", "2GB"],
|
["OPT 350M", "facebook/opt-350m", "2GB"],
|
||||||
["OPT 125M", "facebook/opt-125m", "1GB"],
|
["OPT 125M", "facebook/opt-125m", "1GB"],
|
||||||
["Return to Main Menu", "Return", ""],
|
["Return to Main Menu", "Return", ""],
|
||||||
]
|
]
|
||||||
|
@ -529,6 +529,8 @@ def loadmodelsettings():
|
||||||
js = {}
|
js = {}
|
||||||
if vars.model_type == "xglm" or js.get("compat", "j") == "fairseq_lm":
|
if vars.model_type == "xglm" or js.get("compat", "j") == "fairseq_lm":
|
||||||
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
|
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
|
||||||
|
if vars.model_type == "opt":
|
||||||
|
vars.newlinemode = "ns" # Default to </s> newline mode if using XGLM
|
||||||
vars.modelconfig = js
|
vars.modelconfig = js
|
||||||
if("badwordsids" in js):
|
if("badwordsids" in js):
|
||||||
vars.badwordsids = js["badwordsids"]
|
vars.badwordsids = js["badwordsids"]
|
||||||
|
@ -1345,7 +1347,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
kwargs["logits_warper"] = new_get_logits_warper(
|
kwargs["logits_warper"] = new_get_logits_warper(
|
||||||
beams=1,
|
beams=1,
|
||||||
)
|
)
|
||||||
if(vars.newlinemode == "s"):
|
if(vars.newlinemode == "s") or (vars.newlinemode == "ns"):
|
||||||
kwargs["eos_token_id"] = -1
|
kwargs["eos_token_id"] = -1
|
||||||
kwargs.setdefault("pad_token_id", 2)
|
kwargs.setdefault("pad_token_id", 2)
|
||||||
return new_sample.old_sample(self, *args, **kwargs)
|
return new_sample.old_sample(self, *args, **kwargs)
|
||||||
|
|
2
utils.py
2
utils.py
|
@ -130,6 +130,8 @@ def encodenewlines(txt):
|
||||||
def decodenewlines(txt):
|
def decodenewlines(txt):
|
||||||
if(vars.newlinemode == "s"):
|
if(vars.newlinemode == "s"):
|
||||||
return txt.replace("</s>", '\n')
|
return txt.replace("</s>", '\n')
|
||||||
|
if(vars.newlinemode == "ns"):
|
||||||
|
return txt.replace("</s>", '')
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
Loading…
Reference in New Issue