mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Remove torch dependency and more generic dim0 workaround
Remove torch dependency from hf.py Make workaround for dimension zero values of token_ids more generic to handle every token, not just newlines.
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
import torch
|
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import koboldai_settings
|
import koboldai_settings
|
||||||
@@ -41,16 +40,24 @@ class HFInferenceModel(InferenceModel):
|
|||||||
original_decode = type(self.tokenizer.tokenizer).decode
|
original_decode = type(self.tokenizer.tokenizer).decode
|
||||||
def decode_wrapper(self, token_ids, *args, **kwargs):
|
def decode_wrapper(self, token_ids, *args, **kwargs):
|
||||||
first = None
|
first = None
|
||||||
dim0 = False
|
# Note, the code below that wraps single-value token_ids in a list
|
||||||
|
# is to work around this wonky behavior:
|
||||||
|
# >>> t.decode(13)
|
||||||
|
# '<0x0A>'
|
||||||
|
# >>> t.decode([13])
|
||||||
|
# '\n'
|
||||||
|
# Not doing this causes token streaming to receive <0x0A> characters
|
||||||
|
# instead of newlines.
|
||||||
if isinstance(token_ids, int):
|
if isinstance(token_ids, int):
|
||||||
first = token_ids
|
first = token_ids
|
||||||
dim0 = True
|
token_ids = [first]
|
||||||
elif isinstance(token_ids, torch.Tensor):
|
elif hasattr(token_ids, 'dim'): # Check for e.g. torch.Tensor
|
||||||
# Tensors don't support the Python standard of 'empty is False'
|
# Tensors don't support the Python standard of 'empty is False'
|
||||||
# and the special case of dimension 0 tensors also needs to be handled separately.
|
# and the special case of dimension 0 tensors also needs to be
|
||||||
|
# handled separately.
|
||||||
if token_ids.dim() == 0:
|
if token_ids.dim() == 0:
|
||||||
first = int(token_ids.item())
|
first = int(token_ids.item())
|
||||||
dim0 = True
|
token_ids = [first]
|
||||||
elif len(token_ids) > 0:
|
elif len(token_ids) > 0:
|
||||||
first = int(token_ids[0])
|
first = int(token_ids[0])
|
||||||
elif token_ids:
|
elif token_ids:
|
||||||
@@ -58,14 +65,6 @@ class HFInferenceModel(InferenceModel):
|
|||||||
result = original_decode(self, token_ids, *args, **kwargs)
|
result = original_decode(self, token_ids, *args, **kwargs)
|
||||||
if first is not None and first in has_prefix_space:
|
if first is not None and first in has_prefix_space:
|
||||||
result = " " + result
|
result = " " + result
|
||||||
if dim0:
|
|
||||||
# Work around this wonky behavior:
|
|
||||||
# >>> t.decode(13)
|
|
||||||
# '<0x0A>'
|
|
||||||
# >>> t.decode([13])
|
|
||||||
# '\n'
|
|
||||||
# Not doing this causes token streaming to receive <0x0A> characters instead of newlines.
|
|
||||||
result = result.replace('<0x0A>', '\n')
|
|
||||||
return result
|
return result
|
||||||
# GenericTokenizer overrides __setattr__ so we need to use object.__setattr__ to bypass it
|
# GenericTokenizer overrides __setattr__ so we need to use object.__setattr__ to bypass it
|
||||||
object.__setattr__(self.tokenizer, 'decode', decode_wrapper.__get__(self.tokenizer))
|
object.__setattr__(self.tokenizer, 'decode', decode_wrapper.__get__(self.tokenizer))
|
||||||
|
Reference in New Issue
Block a user