mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Remove torch dependency and more generic dim0 workaround
Remove torch dependency from hf.py Make workaround for dimension zero values of token_ids more generic to handle every token, not just newlines.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
|
||||
import utils
|
||||
import koboldai_settings
|
||||
@@ -41,16 +40,24 @@ class HFInferenceModel(InferenceModel):
|
||||
original_decode = type(self.tokenizer.tokenizer).decode
|
||||
def decode_wrapper(self, token_ids, *args, **kwargs):
|
||||
first = None
|
||||
dim0 = False
|
||||
# Note, the code below that wraps single-value token_ids in a list
|
||||
# is to work around this wonky behavior:
|
||||
# >>> t.decode(13)
|
||||
# '<0x0A>'
|
||||
# >>> t.decode([13])
|
||||
# '\n'
|
||||
# Not doing this causes token streaming to receive <0x0A> characters
|
||||
# instead of newlines.
|
||||
if isinstance(token_ids, int):
|
||||
first = token_ids
|
||||
dim0 = True
|
||||
elif isinstance(token_ids, torch.Tensor):
|
||||
token_ids = [first]
|
||||
elif hasattr(token_ids, 'dim'): # Check for e.g. torch.Tensor
|
||||
# Tensors don't support the Python standard of 'empty is False'
|
||||
# and the special case of dimension 0 tensors also needs to be handled separately.
|
||||
# and the special case of dimension 0 tensors also needs to be
|
||||
# handled separately.
|
||||
if token_ids.dim() == 0:
|
||||
first = int(token_ids.item())
|
||||
dim0 = True
|
||||
token_ids = [first]
|
||||
elif len(token_ids) > 0:
|
||||
first = int(token_ids[0])
|
||||
elif token_ids:
|
||||
@@ -58,14 +65,6 @@ class HFInferenceModel(InferenceModel):
|
||||
result = original_decode(self, token_ids, *args, **kwargs)
|
||||
if first is not None and first in has_prefix_space:
|
||||
result = " " + result
|
||||
if dim0:
|
||||
# Work around this wonky behavior:
|
||||
# >>> t.decode(13)
|
||||
# '<0x0A>'
|
||||
# >>> t.decode([13])
|
||||
# '\n'
|
||||
# Not doing this causes token streaming to receive <0x0A> characters instead of newlines.
|
||||
result = result.replace('<0x0A>', '\n')
|
||||
return result
|
||||
# GenericTokenizer overrides __setattr__ so we need to use object.__setattr__ to bypass it
|
||||
object.__setattr__(self.tokenizer, 'decode', decode_wrapper.__get__(self.tokenizer))
|
||||
|
Reference in New Issue
Block a user