mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Fix the Lua tokenizer API
This commit is contained in:
		
							
								
								
									
										10
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -923,7 +923,12 @@ def load_lua_scripts(): | |||||||
| #  Decode tokens into a string using current tokenizer | #  Decode tokens into a string using current tokenizer | ||||||
| #==================================================================# | #==================================================================# | ||||||
| def lua_decode(tokens): | def lua_decode(tokens): | ||||||
|  |     tokens = list(tokens.values()) | ||||||
|     assert type(tokens) is list |     assert type(tokens) is list | ||||||
|  |     if("tokenizer" not in globals()): | ||||||
|  |         from transformers import GPT2TokenizerFast | ||||||
|  |         global tokenizer | ||||||
|  |         tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | ||||||
|     return tokenizer.decode(tokens) |     return tokenizer.decode(tokens) | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
| @@ -931,6 +936,11 @@ def lua_decode(tokens): | |||||||
| #==================================================================# | #==================================================================# | ||||||
| def lua_encode(string): | def lua_encode(string): | ||||||
|     assert type(string) is str |     assert type(string) is str | ||||||
|  |     if("tokenizer" not in globals()): | ||||||
|  |         thinking = False | ||||||
|  |         from transformers import GPT2TokenizerFast | ||||||
|  |         global tokenizer | ||||||
|  |         tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | ||||||
|     return tokenizer.encode(string, max_length=int(4e9), truncation=True) |     return tokenizer.encode(string, max_length=int(4e9), truncation=True) | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								bridge.lua
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								bridge.lua
									
									
									
									
									
								
							| @@ -666,10 +666,8 @@ return function(_python, _bridged) | |||||||
|             return |             return | ||||||
|         end |         end | ||||||
|         local encoded = {} |         local encoded = {} | ||||||
|         local i = 1 |         for i, token in _python.enumerate(bridged.encode(str)) do | ||||||
|         for token in _python.iter(bridged.encode(str)) do |             encoded[i+1] = math.tointeger(token) | ||||||
|             encoded[i] = token |  | ||||||
|             i = i + 1 |  | ||||||
|         end |         end | ||||||
|         return encoded |         return encoded | ||||||
|     end |     end | ||||||
| @@ -681,17 +679,20 @@ return function(_python, _bridged) | |||||||
|             error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok)) |             error("`decode` takes a number or table of numbers as argument, but got a " .. type(tok)) | ||||||
|             return |             return | ||||||
|         end |         end | ||||||
|         if type(tok) ~= "number" then |         if type(tok) == "number" then | ||||||
|             tok = {tok} |             tok = {tok} | ||||||
|         end |         end | ||||||
|  |         local _tok = {} | ||||||
|  |         local _v | ||||||
|         for k, v in ipairs(tok) do |         for k, v in ipairs(tok) do | ||||||
|             tok[k] = math.tointeger(v) |             _v = math.tointeger(v) | ||||||
|             if tok[k] == nil then |             if _v == nil then | ||||||
|                 error "`decode` got a table with one or more non-integer values" |                 error "`decode` got a table with one or more non-integer values" | ||||||
|                 return |                 return | ||||||
|             end |             end | ||||||
|  |             _tok[k] = _v | ||||||
|         end |         end | ||||||
|         return bridged.decode(_python.builtins.list(tok)) |         return bridged.decode(_tok) | ||||||
|     end |     end | ||||||
|  |  | ||||||
|     ---@return nil |     ---@return nil | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann