mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' of https://github.com/henk717/KoboldAI into fixing-time
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
import os, sys
|
||||
from typing import Optional
|
||||
from transformers import AutoConfig
|
||||
try:
|
||||
from hf_bleeding_edge import AutoConfig
|
||||
except ImportError:
|
||||
from transformers import AutoConfig
|
||||
|
||||
import warnings
|
||||
import utils
|
||||
import json
|
||||
@@ -335,6 +339,11 @@ class HFInferenceModel(InferenceModel):
|
||||
if any(c in str(k) for c in "[]")
|
||||
]
|
||||
|
||||
try:
|
||||
self.badwordsids.remove([self.tokenizer.pad_token_id])
|
||||
except:
|
||||
pass
|
||||
|
||||
if utils.koboldai_vars.newlinemode == "n":
|
||||
self.badwordsids.append([self.tokenizer.eos_token_id])
|
||||
|
||||
@@ -387,7 +396,17 @@ class HFInferenceModel(InferenceModel):
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
|
||||
self.model_type = self.model_config.model_type
|
||||
|
||||
if "gptq_bits" in dir(self.model_config):
|
||||
self.gptq_model = True
|
||||
self.gptq_bits = self.model_config.gptq_bits
|
||||
self.gptq_groupsize = self.model_config.gptq_groupsize if getattr(self.model_config, "gptq_groupsize", False) else -1
|
||||
self.gptq_version = self.model_config.gptq_version if getattr(self.model_config, "gptq_version", False) else 1
|
||||
self.gptq_file = None
|
||||
else:
|
||||
self.gptq_model = False
|
||||
except ValueError:
|
||||
self.model_type = {
|
||||
"NeoCustom": "gpt_neo",
|
||||
@@ -398,4 +417,4 @@ class HFInferenceModel(InferenceModel):
|
||||
logger.warning(
|
||||
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||
)
|
||||
self.model_type = "gpt_neo"
|
||||
self.model_type = "gpt_neo"
|
||||
|
Reference in New Issue
Block a user