Update aiserver.py

This commit is contained in:
catboxanon
2023-03-09 22:36:45 -05:00
committed by GitHub
parent f761c4dafa
commit 8c9ed55406

View File

@@ -87,6 +87,38 @@ from io import BytesIO
global tpu_mtj_backend
from transformers.models.llama.tokenization_llama import LLaMATokenizer
from repos.gptq.gptq import *
from repos.gptq.modelutils import *
from repos.gptq.quant import *
def load_quant(model, checkpoint, wbits):
from transformers import LLaMAConfig, LLaMAForCausalLM
config = LLaMAConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LLaMAForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, wbits)
print('Loading model ...')
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
if lupa.LUA_VERSION[:2] != (5, 4):
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
@@ -2886,7 +2918,10 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
@functools.lru_cache(maxsize=None)
def get_original_key(key):
try:
return max((original_key for original_key in utils.module_names if original_key.endswith(key)), key=len)
except ValueError:
return key
for key, value in model_dict.items():
original_key = get_original_key(key)
@@ -3083,22 +3118,24 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if(koboldai_vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {}
if(os.path.isdir(koboldai_vars.custmodpth)):
tokenizer = LLaMATokenizer.from_pretrained(koboldai_vars.custmodpth)
# try:
# tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", use_fast=False)
# except Exception as e:
# try:
# tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache")
# except Exception as e:
# try:
# tokenizer = GPT2Tokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache")
# except Exception as e:
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=koboldai_vars.revision, cache_dir="cache")
try:
tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache")
except Exception as e:
try:
tokenizer = GPT2Tokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=koboldai_vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", **lowmem)
# model = AutoModelForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", **lowmem)
model = load_quant(koboldai_vars.custmodpth, os.environ['LLAMA_30B_4BIT'], 4)
except Exception as e:
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", **lowmem)
# model = GPTNeoForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(koboldai_vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(koboldai_vars.model.replace('/', '_')), revision=koboldai_vars.revision, cache_dir="cache", use_fast=False)