diff --git a/modeling/inference_models/hf.py b/modeling/inference_models/hf.py index 3f98f381..61f030b1 100644 --- a/modeling/inference_models/hf.py +++ b/modeling/inference_models/hf.py @@ -1,7 +1,6 @@ import os from typing import Optional from transformers import AutoConfig -import torch import utils import koboldai_settings @@ -41,16 +40,24 @@ class HFInferenceModel(InferenceModel): original_decode = type(self.tokenizer.tokenizer).decode def decode_wrapper(self, token_ids, *args, **kwargs): first = None - dim0 = False + # Note, the code below that wraps single-value token_ids in a list + # is to work around this wonky behavior: + # >>> t.decode(13) + # '<0x0A>' + # >>> t.decode([13]) + # '\n' + # Not doing this causes token streaming to receive <0x0A> characters + # instead of newlines. if isinstance(token_ids, int): first = token_ids - dim0 = True - elif isinstance(token_ids, torch.Tensor): + token_ids = [first] + elif hasattr(token_ids, 'dim'): # Check for e.g. torch.Tensor # Tensors don't support the Python standard of 'empty is False' - # and the special case of dimension 0 tensors also needs to be handled separately. + # and the special case of dimension 0 tensors also needs to be + # handled separately. if token_ids.dim() == 0: first = int(token_ids.item()) - dim0 = True + token_ids = [first] elif len(token_ids) > 0: first = int(token_ids[0]) elif token_ids: @@ -58,14 +65,6 @@ class HFInferenceModel(InferenceModel): result = original_decode(self, token_ids, *args, **kwargs) if first is not None and first in has_prefix_space: result = " " + result - if dim0: - # Work around this wonky behavior: - # >>> t.decode(13) - # '<0x0A>' - # >>> t.decode([13]) - # '\n' - # Not doing this causes token streaming to receive <0x0A> characters instead of newlines. - result = result.replace('<0x0A>', '\n') return result # GenericTokenizer overrides __setattr__ so we need to use object.__setattr__ to bypass it object.__setattr__(self.tokenizer, 'decode', decode_wrapper.__get__(self.tokenizer))