Remove torch dependency and more generic dim0 workaround

Remove torch dependency from hf.py
Make workaround for dimension zero values of token_ids
more generic to handle every token, not just newlines.
This commit is contained in:
Llama
2023-05-03 09:48:16 -07:00
parent 3768848548
commit 35d344b951

View File

@@ -1,7 +1,6 @@
import os
from typing import Optional
from transformers import AutoConfig
import torch
import utils
import koboldai_settings
@@ -41,16 +40,24 @@ class HFInferenceModel(InferenceModel):
original_decode = type(self.tokenizer.tokenizer).decode
def decode_wrapper(self, token_ids, *args, **kwargs):
first = None
dim0 = False
# Note, the code below that wraps single-value token_ids in a list
# is to work around this wonky behavior:
# >>> t.decode(13)
# '<0x0A>'
# >>> t.decode([13])
# '\n'
# Not doing this causes token streaming to receive <0x0A> characters
# instead of newlines.
if isinstance(token_ids, int):
first = token_ids
dim0 = True
elif isinstance(token_ids, torch.Tensor):
token_ids = [first]
elif hasattr(token_ids, 'dim'): # Check for e.g. torch.Tensor
# Tensors don't support the Python standard of 'empty is False'
# and the special case of dimension 0 tensors also needs to be handled separately.
# and the special case of dimension 0 tensors also needs to be
# handled separately.
if token_ids.dim() == 0:
first = int(token_ids.item())
dim0 = True
token_ids = [first]
elif len(token_ids) > 0:
first = int(token_ids[0])
elif token_ids:
@@ -58,14 +65,6 @@ class HFInferenceModel(InferenceModel):
result = original_decode(self, token_ids, *args, **kwargs)
if first is not None and first in has_prefix_space:
result = " " + result
if dim0:
# Work around this wonky behavior:
# >>> t.decode(13)
# '<0x0A>'
# >>> t.decode([13])
# '\n'
# Not doing this causes token streaming to receive <0x0A> characters instead of newlines.
result = result.replace('<0x0A>', '\n')
return result
# GenericTokenizer overrides __setattr__ so we need to use object.__setattr__ to bypass it
object.__setattr__(self.tokenizer, 'decode', decode_wrapper.__get__(self.tokenizer))