Fix the Lua tokenizer API
This commit is contained in:
parent
67974947b2
commit
8e6a62259e
10
aiserver.py
10
aiserver.py
|
@ -923,7 +923,12 @@ def load_lua_scripts():
|
||||||
# Decode tokens into a string using current tokenizer
|
# Decode tokens into a string using current tokenizer
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def lua_decode(tokens):
|
def lua_decode(tokens):
|
||||||
|
tokens = list(tokens.values())
|
||||||
assert type(tokens) is list
|
assert type(tokens) is list
|
||||||
|
if("tokenizer" not in globals()):
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
global tokenizer
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
return tokenizer.decode(tokens)
|
return tokenizer.decode(tokens)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -931,6 +936,11 @@ def lua_decode(tokens):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def lua_encode(string):
|
def lua_encode(string):
|
||||||
assert type(string) is str
|
assert type(string) is str
|
||||||
|
if("tokenizer" not in globals()):
|
||||||
|
thinking = False
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
global tokenizer
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
17
bridge.lua
17
bridge.lua
|
@ -666,10 +666,8 @@ return function(_python, _bridged)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local encoded = {}
|
local encoded = {}
|
||||||
local i = 1
|
for i, token in _python.enumerate(bridged.encode(str)) do
|
||||||
for token in _python.iter(bridged.encode(str)) do
|
encoded[i+1] = math.tointeger(token)
|
||||||
encoded[i] = token
|
|
||||||
i = i + 1
|
|
||||||
end
|
end
|
||||||
return encoded
|
return encoded
|
||||||
end
|
end
|
||||||
|
@ -681,17 +679,20 @@ return function(_python, _bridged)
|
||||||
error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok))
|
error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok))
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
if type(tok) ~= "number" then
|
if type(tok) == "number" then
|
||||||
tok = {tok}
|
tok = {tok}
|
||||||
end
|
end
|
||||||
|
local _tok = {}
|
||||||
|
local _v
|
||||||
for k, v in ipairs(tok) do
|
for k, v in ipairs(tok) do
|
||||||
tok[k] = math.tointeger(v)
|
_v = math.tointeger(v)
|
||||||
if tok[k] == nil then
|
if _v == nil then
|
||||||
error "`decode` got a table with one or more non-integer values"
|
error "`decode` got a table with one or more non-integer values"
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
_tok[k] = _v
|
||||||
end
|
end
|
||||||
return bridged.decode(_python.builtins.list(tok))
|
return bridged.decode(_tok)
|
||||||
end
|
end
|
||||||
|
|
||||||
---@return nil
|
---@return nil
|
||||||
|
|
Loading…
Reference in New Issue