Model: Port rest of models over

Generation's still broke but it's a start
This commit is contained in:
somebody
2023-02-25 16:05:56 -06:00
parent f8c4158ebc
commit 6b4905de30
3 changed files with 1987 additions and 1760 deletions

View File

@@ -8,6 +8,7 @@ from urllib.error import HTTPError
import requests
import requests.adapters
import time
import breakmodel
from transformers import __version__ as transformers_version
from transformers import PreTrainedModel
import packaging.version
@@ -637,6 +638,7 @@ def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[s
class UIProgressBarFile(object):
"""Write TQDM progress to the UI."""
def write(self, bar):
bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "")
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
@@ -649,4 +651,32 @@ class UIProgressBarFile(object):
pass
def flush(self):
pass
pass
def get_auxilary_device():
"""Get device auxilary tensors like inputs should be stored on."""
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
if koboldai_vars.hascuda and koboldai_vars.usegpu:
return koboldai_vars.gpu_device
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
return breakmodel.primary_device
return "cpu"
#==================================================================#
# Strips submitted text from the text returned by the AI
#==================================================================#
def getnewcontent(txt, tokenizer):
# If the submitted context was blank, then everything is new
if(koboldai_vars.lastctx == ""):
return txt
# Tokenize the last context and the generated content
ctxtokens = tokenizer.encode(encodenewlines(koboldai_vars.lastctx), max_length=int(2e9), truncation=True)
txttokens = tokenizer.encode(encodenewlines(txt), max_length=int(2e9), truncation=True)
dif = (len(txttokens) - len(ctxtokens)) * -1
# Remove the context from the returned text
newtokens = txttokens[dif:]
return decodenewlines(tokenizer.decode(newtokens))