mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #290 from one-some/ui2-logit-things-are-just-too-invasive-get-them-away-NOW
Add context manager for sandboxing invasive transformers patches
This commit is contained in:
148
aiserver.py
148
aiserver.py
@@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
|
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
|
||||||
|
|
||||||
|
# We only want to use logit manipulations and such on our core text model
|
||||||
|
class use_core_manipulations:
|
||||||
|
# These must be set by wherever they get setup
|
||||||
|
get_logits_processor: callable
|
||||||
|
sample: callable
|
||||||
|
get_stopping_criteria: callable
|
||||||
|
|
||||||
|
# We set these automatically
|
||||||
|
old_get_logits_processor: callable
|
||||||
|
old_sample: callable
|
||||||
|
old_get_stopping_criteria: callable
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
|
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor
|
||||||
|
|
||||||
|
use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||||
|
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample
|
||||||
|
|
||||||
|
use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||||
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor
|
||||||
|
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample
|
||||||
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Variables & Storage
|
# Variables & Storage
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@@ -1910,8 +1938,6 @@ def patch_transformers_download():
|
|||||||
|
|
||||||
def patch_transformers():
|
def patch_transformers():
|
||||||
global transformers
|
global transformers
|
||||||
global old_transfomers_functions
|
|
||||||
old_transfomers_functions = {}
|
|
||||||
|
|
||||||
patch_transformers_download()
|
patch_transformers_download()
|
||||||
|
|
||||||
@@ -1933,7 +1959,6 @@ def patch_transformers():
|
|||||||
PreTrainedModel._kai_patched = True
|
PreTrainedModel._kai_patched = True
|
||||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
|
|
||||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
utils.num_shards = utils.get_num_shards(index_filename)
|
utils.num_shards = utils.get_num_shards(index_filename)
|
||||||
utils.from_pretrained_index_filename = index_filename
|
utils.from_pretrained_index_filename = index_filename
|
||||||
@@ -1961,7 +1986,6 @@ def patch_transformers():
|
|||||||
if max_pos > self.weights.size(0):
|
if max_pos > self.weights.size(0):
|
||||||
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
||||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||||
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
|
|
||||||
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
||||||
|
|
||||||
|
|
||||||
@@ -1981,7 +2005,6 @@ def patch_transformers():
|
|||||||
self.model = OPTModel(config)
|
self.model = OPTModel(config)
|
||||||
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
|
|
||||||
OPTForCausalLM.__init__ = new_init
|
OPTForCausalLM.__init__ = new_init
|
||||||
|
|
||||||
|
|
||||||
@@ -2170,8 +2193,8 @@ def patch_transformers():
|
|||||||
processors.append(PhraseBiasLogitsProcessor())
|
processors.append(PhraseBiasLogitsProcessor())
|
||||||
processors.append(ProbabilityVisualizerLogitsProcessor())
|
processors.append(ProbabilityVisualizerLogitsProcessor())
|
||||||
return processors
|
return processors
|
||||||
|
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
||||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
|
||||||
|
|
||||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||||
def __init__(self, beams: int = 1, **kwargs):
|
def __init__(self, beams: int = 1, **kwargs):
|
||||||
@@ -2204,9 +2227,9 @@ def patch_transformers():
|
|||||||
kwargs["eos_token_id"] = -1
|
kwargs["eos_token_id"] = -1
|
||||||
kwargs.setdefault("pad_token_id", 2)
|
kwargs.setdefault("pad_token_id", 2)
|
||||||
return new_sample.old_sample(self, *args, **kwargs)
|
return new_sample.old_sample(self, *args, **kwargs)
|
||||||
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
|
||||||
transformers.generation_utils.GenerationMixin.sample = new_sample
|
|
||||||
|
|
||||||
|
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||||
|
use_core_manipulations.sample = new_sample
|
||||||
|
|
||||||
# Allow bad words filter to ban <|endoftext|> token
|
# Allow bad words filter to ban <|endoftext|> token
|
||||||
import transformers.generation_logits_process
|
import transformers.generation_logits_process
|
||||||
@@ -2374,7 +2397,7 @@ def patch_transformers():
|
|||||||
|
|
||||||
|
|
||||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||||
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
|
|
||||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||||
global tokenizer
|
global tokenizer
|
||||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||||
@@ -2392,7 +2415,7 @@ def patch_transformers():
|
|||||||
stopping_criteria.insert(0, token_streamer)
|
stopping_criteria.insert(0, token_streamer)
|
||||||
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
|
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria
|
||||||
|
|
||||||
def reset_model_settings():
|
def reset_model_settings():
|
||||||
koboldai_vars.socketio = socketio
|
koboldai_vars.socketio = socketio
|
||||||
@@ -5395,56 +5418,57 @@ def raw_generate(
|
|||||||
result: GenerationResult
|
result: GenerationResult
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
with use_core_manipulations():
|
||||||
batch_encoded = tpu_raw_generate(
|
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||||
prompt_tokens=prompt_tokens,
|
batch_encoded = tpu_raw_generate(
|
||||||
max_new=max_new,
|
prompt_tokens=prompt_tokens,
|
||||||
batch_count=batch_count,
|
max_new=max_new,
|
||||||
gen_settings=gen_settings
|
batch_count=batch_count,
|
||||||
)
|
gen_settings=gen_settings
|
||||||
result = GenerationResult(
|
)
|
||||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
result = GenerationResult(
|
||||||
)
|
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||||
elif koboldai_vars.model in model_functions:
|
)
|
||||||
batch_encoded = model_functions[koboldai_vars.model](
|
elif koboldai_vars.model in model_functions:
|
||||||
prompt_tokens=prompt_tokens,
|
batch_encoded = model_functions[koboldai_vars.model](
|
||||||
max_new=max_new,
|
prompt_tokens=prompt_tokens,
|
||||||
batch_count=batch_count,
|
max_new=max_new,
|
||||||
gen_settings=gen_settings
|
batch_count=batch_count,
|
||||||
)
|
gen_settings=gen_settings
|
||||||
result = GenerationResult(
|
)
|
||||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
result = GenerationResult(
|
||||||
)
|
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||||
elif koboldai_vars.model.startswith("RWKV"):
|
)
|
||||||
batch_encoded = rwkv_raw_generate(
|
elif koboldai_vars.model.startswith("RWKV"):
|
||||||
prompt_tokens=prompt_tokens,
|
batch_encoded = rwkv_raw_generate(
|
||||||
max_new=max_new,
|
prompt_tokens=prompt_tokens,
|
||||||
batch_count=batch_count,
|
max_new=max_new,
|
||||||
gen_settings=gen_settings
|
batch_count=batch_count,
|
||||||
)
|
gen_settings=gen_settings
|
||||||
result = GenerationResult(
|
)
|
||||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
|
result = GenerationResult(
|
||||||
)
|
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
|
||||||
else:
|
)
|
||||||
# Torch HF
|
else:
|
||||||
start_time = time.time()
|
# Torch HF
|
||||||
batch_encoded = torch_raw_generate(
|
start_time = time.time()
|
||||||
prompt_tokens=prompt_tokens,
|
batch_encoded = torch_raw_generate(
|
||||||
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
prompt_tokens=prompt_tokens,
|
||||||
do_streaming=do_streaming,
|
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
||||||
do_dynamic_wi=do_dynamic_wi,
|
do_streaming=do_streaming,
|
||||||
batch_count=batch_count,
|
do_dynamic_wi=do_dynamic_wi,
|
||||||
gen_settings=gen_settings
|
batch_count=batch_count,
|
||||||
)
|
gen_settings=gen_settings
|
||||||
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
|
)
|
||||||
start_time = time.time()
|
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
|
||||||
result = GenerationResult(
|
start_time = time.time()
|
||||||
out_batches=batch_encoded,
|
result = GenerationResult(
|
||||||
prompt=prompt_tokens,
|
out_batches=batch_encoded,
|
||||||
is_whole_generation=False,
|
prompt=prompt_tokens,
|
||||||
output_includes_prompt=True,
|
is_whole_generation=False,
|
||||||
)
|
output_includes_prompt=True,
|
||||||
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
|
)
|
||||||
|
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
|
||||||
|
|
||||||
time_end = round(time.time() - time_start, 2)
|
time_end = round(time.time() - time_start, 2)
|
||||||
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||||||
@@ -9384,14 +9408,10 @@ def summarize(text, max_length=100, min_length=30, unload=True):
|
|||||||
|
|
||||||
#Actual sumarization
|
#Actual sumarization
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
global old_transfomers_functions
|
|
||||||
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
|
|
||||||
#make sure text is less than 1024 tokens, otherwise we'll crash
|
#make sure text is less than 1024 tokens, otherwise we'll crash
|
||||||
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
|
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
|
||||||
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
|
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
|
||||||
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
|
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
|
|
||||||
logger.debug("Time to summarize: {}".format(time.time()-start_time))
|
logger.debug("Time to summarize: {}".format(time.time()-start_time))
|
||||||
#move model back to CPU to save precious vram
|
#move model back to CPU to save precious vram
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
Reference in New Issue
Block a user