Fix the Lua tokenizer API

This commit is contained in:
Gnome Ann 2021-12-11 21:24:34 -05:00
parent 67974947b2
commit 8e6a62259e
2 changed files with 19 additions and 8 deletions

View File

@ -923,7 +923,12 @@ def load_lua_scripts():
# Decode tokens into a string using current tokenizer
#==================================================================#
def lua_decode(tokens):
tokens = list(tokens.values())
assert type(tokens) is list
if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return tokenizer.decode(tokens)
#==================================================================#
@ -931,6 +936,11 @@ def lua_decode(tokens):
#==================================================================#
def lua_encode(string):
assert type(string) is str
if("tokenizer" not in globals()):
thinking = False
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================#

View File

@ -666,10 +666,8 @@ return function(_python, _bridged)
return
end
local encoded = {}
local i = 1
for token in _python.iter(bridged.encode(str)) do
encoded[i] = token
i = i + 1
for i, token in _python.enumerate(bridged.encode(str)) do
encoded[i+1] = math.tointeger(token)
end
return encoded
end
@ -681,17 +679,20 @@ return function(_python, _bridged)
error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok))
return
end
if type(tok) ~= "number" then
if type(tok) == "number" then
tok = {tok}
end
local _tok = {}
local _v
for k, v in ipairs(tok) do
tok[k] = math.tointeger(v)
if tok[k] == nil then
_v = math.tointeger(v)
if _v == nil then
error "`decode` got a table with one or more non-integer values"
return
end
_tok[k] = _v
end
return bridged.decode(_python.builtins.list(tok))
return bridged.decode(_tok)
end
---@return nil