Fix the Lua tokenizer API

This commit is contained in:
Gnome Ann 2021-12-11 21:24:34 -05:00
parent 67974947b2
commit 8e6a62259e
2 changed files with 19 additions and 8 deletions

View File

@ -923,7 +923,12 @@ def load_lua_scripts():
# Decode tokens into a string using current tokenizer # Decode tokens into a string using current tokenizer
#==================================================================# #==================================================================#
def lua_decode(tokens): def lua_decode(tokens):
tokens = list(tokens.values())
assert type(tokens) is list assert type(tokens) is list
if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return tokenizer.decode(tokens) return tokenizer.decode(tokens)
#==================================================================# #==================================================================#
@ -931,6 +936,11 @@ def lua_decode(tokens):
#==================================================================# #==================================================================#
def lua_encode(string): def lua_encode(string):
assert type(string) is str assert type(string) is str
if("tokenizer" not in globals()):
thinking = False
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return tokenizer.encode(string, max_length=int(4e9), truncation=True) return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================# #==================================================================#

View File

@ -666,10 +666,8 @@ return function(_python, _bridged)
return return
end end
local encoded = {} local encoded = {}
local i = 1 for i, token in _python.enumerate(bridged.encode(str)) do
for token in _python.iter(bridged.encode(str)) do encoded[i+1] = math.tointeger(token)
encoded[i] = token
i = i + 1
end end
return encoded return encoded
end end
@ -681,17 +679,20 @@ return function(_python, _bridged)
error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok)) error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok))
return return
end end
if type(tok) ~= "number" then if type(tok) == "number" then
tok = {tok} tok = {tok}
end end
local _tok = {}
local _v
for k, v in ipairs(tok) do for k, v in ipairs(tok) do
tok[k] = math.tointeger(v) _v = math.tointeger(v)
if tok[k] == nil then if _v == nil then
error "`decode` got a table with one or more non-integer values" error "`decode` got a table with one or more non-integer values"
return return
end end
_tok[k] = _v
end end
return bridged.decode(_python.builtins.list(tok)) return bridged.decode(_tok)
end end
---@return nil ---@return nil