diff --git a/aiserver.py b/aiserver.py index 6614b56e..c1db53e3 100644 --- a/aiserver.py +++ b/aiserver.py @@ -923,7 +923,12 @@ def load_lua_scripts(): # Decode tokens into a string using current tokenizer #==================================================================# def lua_decode(tokens): + tokens = list(tokens.values()) assert type(tokens) is list + if("tokenizer" not in globals()): + from transformers import GPT2TokenizerFast + global tokenizer + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") return tokenizer.decode(tokens) #==================================================================# @@ -931,6 +936,11 @@ def lua_decode(tokens): #==================================================================# def lua_encode(string): assert type(string) is str + if("tokenizer" not in globals()): + thinking = False + from transformers import GPT2TokenizerFast + global tokenizer + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") return tokenizer.encode(string, max_length=int(4e9), truncation=True) #==================================================================# diff --git a/bridge.lua b/bridge.lua index 31423f85..3d589115 100644 --- a/bridge.lua +++ b/bridge.lua @@ -666,10 +666,8 @@ return function(_python, _bridged) return end local encoded = {} - local i = 1 - for token in _python.iter(bridged.encode(str)) do - encoded[i] = token - i = i + 1 + for i, token in _python.enumerate(bridged.encode(str)) do + encoded[i+1] = math.tointeger(token) end return encoded end @@ -681,17 +679,20 @@ return function(_python, _bridged) error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok)) return end - if type(tok) ~= "number" then + if type(tok) == "number" then tok = {tok} end + local _tok = {} + local _v for k, v in ipairs(tok) do - tok[k] = math.tointeger(v) - if tok[k] == nil then + _v = math.tointeger(v) + if _v == nil then error "`decode` got a table with one or more non-integer values" return end + _tok[k] = _v end - return bridged.decode(_python.builtins.list(tok)) + return bridged.decode(_tok) end ---@return nil