mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
Merge pull request #32 from VE-FORBRYDERNE/loader
Move the TPU backend code into this repository
This commit is contained in:
commit
50defbaa04
65
aiserver.py
65
aiserver.py
@ -179,7 +179,7 @@ def getmodelname():
|
|||||||
if(args.configname):
|
if(args.configname):
|
||||||
modelname = args.configname
|
modelname = args.configname
|
||||||
return modelname
|
return modelname
|
||||||
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")):
|
||||||
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
||||||
return modelname
|
return modelname
|
||||||
else:
|
else:
|
||||||
@ -340,7 +340,7 @@ else:
|
|||||||
getModelSelection()
|
getModelSelection()
|
||||||
|
|
||||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
# Test for GPU support
|
# Test for GPU support
|
||||||
import torch
|
import torch
|
||||||
@ -530,7 +530,7 @@ socketio = SocketIO(app)
|
|||||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||||
|
|
||||||
# Start transformers and create pipeline
|
# Start transformers and create pipeline
|
||||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||||
@ -692,6 +692,13 @@ else:
|
|||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
# Load the TPU backend if requested
|
||||||
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
|
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||||
|
import tpu_mtj_backend
|
||||||
|
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||||
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
|
|
||||||
# Set up Flask routes
|
# Set up Flask routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
@ -1357,19 +1364,23 @@ def calcsubmit(txt):
|
|||||||
if(vars.model != "InferKit"):
|
if(vars.model != "InferKit"):
|
||||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
|
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
|
||||||
if(actionlen == 0):
|
if(actionlen == 0):
|
||||||
if(not vars.model in ["Colab", "OAI"]):
|
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(subtxt, min, max)
|
sendtocolab(subtxt, min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(subtxt, min, max)
|
oairequest(subtxt, min, max)
|
||||||
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
else:
|
else:
|
||||||
if(not vars.model in ["Colab", "OAI"]):
|
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(subtxt, min, max)
|
sendtocolab(subtxt, min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(subtxt, min, max)
|
oairequest(subtxt, min, max)
|
||||||
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
|
|
||||||
# For InferKit web API
|
# For InferKit web API
|
||||||
else:
|
else:
|
||||||
@ -1658,7 +1669,49 @@ def sendtocolab(txt, min, max):
|
|||||||
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Send text to TPU mesh transformer backend
|
||||||
|
#==================================================================#
|
||||||
|
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
|
if(found_entries is None):
|
||||||
|
found_entries = set()
|
||||||
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||||
|
|
||||||
|
# Submit input text to generator
|
||||||
|
try:
|
||||||
|
if(vars.sp is not None):
|
||||||
|
raise ValueError("Softprompts are not supported by the TPU backend yet")
|
||||||
|
if(vars.dynamicscan):
|
||||||
|
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||||
|
genout = tpu_mtj_backend.infer(
|
||||||
|
txt,
|
||||||
|
gen_len = maximum-minimum+1,
|
||||||
|
temp=vars.temp,
|
||||||
|
top_p=vars.top_p,
|
||||||
|
top_k=vars.top_k,
|
||||||
|
tfs=vars.tfs,
|
||||||
|
numseqs=vars.numseqs,
|
||||||
|
repetition_penalty=vars.rep_pen,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||||
|
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
|
||||||
|
genout = [{"generated_text": txt} for txt in genout]
|
||||||
|
|
||||||
|
if(len(genout) == 1):
|
||||||
|
genresult(genout[0]["generated_text"])
|
||||||
|
else:
|
||||||
|
genselect(genout)
|
||||||
|
|
||||||
|
set_aibusy(0)
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Replaces returns and newlines with HTML breaks
|
# Replaces returns and newlines with HTML breaks
|
||||||
|
13
requirements_mtj.txt
Normal file
13
requirements_mtj.txt
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
numpy
|
||||||
|
tqdm
|
||||||
|
requests
|
||||||
|
optax
|
||||||
|
dm-haiku >= 0.0.5, <= 0.0.9
|
||||||
|
ray[default]
|
||||||
|
jax == 0.2.12
|
||||||
|
transformers
|
||||||
|
progressbar2
|
||||||
|
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||||
|
flask
|
||||||
|
Flask-SocketIO
|
||||||
|
flask-cloudflared >= 0.0.5
|
334
tpu_mtj_backend.py
Normal file
334
tpu_mtj_backend.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
import multiprocessing
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
import progressbar
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import random
|
||||||
|
import jax
|
||||||
|
from jax.config import config
|
||||||
|
from jax.experimental import maps
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import optax
|
||||||
|
import haiku as hk
|
||||||
|
import transformers
|
||||||
|
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||||
|
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard
|
||||||
|
|
||||||
|
|
||||||
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def show_spinner():
|
||||||
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
bar.update(i)
|
||||||
|
time.sleep(0.1)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
||||||
|
'''
|
||||||
|
This gets called by generate_scan_fn to apply repetition penalty
|
||||||
|
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||||
|
'''
|
||||||
|
# Make a new array with the same length as the tokens array but with
|
||||||
|
# each element replaced by the value at the corresponding index in the
|
||||||
|
# logits array; e.g.
|
||||||
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
|
penalty_logits = jnp.take(logits, tokens)
|
||||||
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
|
# values by repetition_penalty (the academic publication that described
|
||||||
|
# this technique actually just only divided, but that would cause tokens
|
||||||
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
|
penalty_logits = jnp.where(
|
||||||
|
penalty_logits > 0,
|
||||||
|
penalty_logits/repetition_penalty,
|
||||||
|
penalty_logits*repetition_penalty,
|
||||||
|
)
|
||||||
|
# Finally, put those penalized logit values back into their original
|
||||||
|
# positions in the logits array
|
||||||
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
|
def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
|
'''
|
||||||
|
This gets called by generate_scan_fn to apply a series of 4 filters
|
||||||
|
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||||
|
picking one token using the modified logits
|
||||||
|
'''
|
||||||
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
|
# the rest, by setting their logits to negative infinity)
|
||||||
|
def top_k_filter(logits):
|
||||||
|
# After sorting the logits array in descending order,
|
||||||
|
# sorted_indices_to_remove is a 1D array that is True for tokens
|
||||||
|
# in the sorted logits array we want to remove and False for ones
|
||||||
|
# we want to keep, in this case the first top_k elements will be
|
||||||
|
# False and the rest will be True
|
||||||
|
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k
|
||||||
|
# Unsort the logits array back to its original configuration and
|
||||||
|
# remove tokens we need to remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-logits),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
|
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
|
||||||
|
# Top-p (after sorting the remaining tokens again in descending order of
|
||||||
|
# logit, remove the ones that have cumulative softmax probability
|
||||||
|
# greater than p)
|
||||||
|
def top_p_filter(logits):
|
||||||
|
# Sort the logits array in descending order, replace every element
|
||||||
|
# with e (Euler's number) to the power of that element, and divide
|
||||||
|
# each element of the new array by the sum of the elements in the
|
||||||
|
# new array
|
||||||
|
sorted_logits = -jnp.sort(-logits)
|
||||||
|
probabilities = jax.nn.softmax(sorted_logits)
|
||||||
|
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||||
|
# probabilities
|
||||||
|
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||||
|
# We want to remove tokens with cumulative probability higher
|
||||||
|
# than top_p
|
||||||
|
sorted_indices_to_remove = cumulative_probabilities > top_p
|
||||||
|
# Don't ever remove the token with the highest logit, even if
|
||||||
|
# the probability is higher than top_p
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-logits),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
|
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
|
||||||
|
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||||
|
# except it's the "cumulative normalized absolute second finite
|
||||||
|
# differences of the softmax probabilities" instead of just the
|
||||||
|
# cumulative softmax probabilities)
|
||||||
|
def tail_free_filter(logits):
|
||||||
|
# Sort in descending order
|
||||||
|
sorted_logits = -jnp.sort(-logits)
|
||||||
|
# Softmax again
|
||||||
|
probabilities = jax.nn.softmax(sorted_logits)
|
||||||
|
# Calculate the second finite differences of that array (i.e.
|
||||||
|
# calculate the difference array and then calculate the difference
|
||||||
|
# array of the difference array)
|
||||||
|
d2 = jnp.diff(jnp.diff(probabilities))
|
||||||
|
# Get the absolute values of all those second finite differences
|
||||||
|
d2 = jnp.abs(d2)
|
||||||
|
# Normalize (all elements in the array are divided by the sum of the
|
||||||
|
# array's elements)
|
||||||
|
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||||
|
# Get the prefix-sum array
|
||||||
|
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||||
|
# We will remove the tokens with a cumulative normalized absolute
|
||||||
|
# second finite difference larger than the TFS value
|
||||||
|
sorted_indices_to_remove = cumulative_d2 > tfs
|
||||||
|
# Don't remove the token with the highest logit
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
# Since the d2 array has two fewer elements than the logits array,
|
||||||
|
# we'll add two extra Trues to the end
|
||||||
|
sorted_indices_to_remove = jnp.pad(
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
(0, 2),
|
||||||
|
constant_values=True,
|
||||||
|
)
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-logits),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
|
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
||||||
|
# Temperature (just divide the logits by the temperature)
|
||||||
|
def temp_filter(logits):
|
||||||
|
return logits / temp
|
||||||
|
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
|
||||||
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
|
# probability distribution)
|
||||||
|
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis], None
|
||||||
|
|
||||||
|
pad_token_id = 50256
|
||||||
|
|
||||||
|
class PenalizingCausalTransformer(CausalTransformer):
|
||||||
|
def __init__(self, config):
|
||||||
|
# Initialize
|
||||||
|
super().__init__(config)
|
||||||
|
def generate(state, key, ctx, ctx_length, aux, sampler_options):
|
||||||
|
gen_length = self.gen_length
|
||||||
|
# These are the tokens that we don't want the AI to ever write
|
||||||
|
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||||
|
def generate_sample(context, ctx_length, aux):
|
||||||
|
# Give the initial context to the transformer
|
||||||
|
transformer = CausalTransformerShard(config)
|
||||||
|
_, initial_state = transformer.generate_initial(context, ctx_length)
|
||||||
|
# The "generated" array will contain the tokens from the
|
||||||
|
# context as well as the tokens picked by the sampler at
|
||||||
|
# each stage, padded with a bunch of 50256s, so we know
|
||||||
|
# which tokens have to be repetition penalized
|
||||||
|
generated = jnp.pad(context, (0, gen_length), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus gen_length 50256s which will be eventually filled with sampler-chosen tokens
|
||||||
|
generated_index = config["seq"]
|
||||||
|
# Add that information to generate_scan_fn's starting state
|
||||||
|
initial_state = (generated, generated_index) + initial_state
|
||||||
|
# Get repetition penalty from the arguments
|
||||||
|
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||||
|
def generate_scan_fn(carry, sampler_input):
|
||||||
|
# Unpack current generate_scan_fn state
|
||||||
|
generated, generated_index, next_token, decode_state, sample_key = carry
|
||||||
|
# Get the pseudo-random number generator key that will
|
||||||
|
# be used by kobold_sample to randomly pick a token
|
||||||
|
sample_key, new_key = jax.random.split(sample_key)
|
||||||
|
# Give the context to the model and get the logits it
|
||||||
|
# spits out
|
||||||
|
# (a 2D array with 1 row and 50400 columns representing
|
||||||
|
# how strongly it thinks each of the 50257 tokens in its
|
||||||
|
# vocabulary should be appended to the context, followed
|
||||||
|
# by 143 apparently useless columns ???)
|
||||||
|
logits, new_state = transformer.generate_once(next_token, decode_state)
|
||||||
|
# Verify that logits does indeed have that many rows and
|
||||||
|
# columns (if you get an error here, pray for mercy)
|
||||||
|
assert logits.shape == (1, config["n_vocab"])
|
||||||
|
# Flatten it into a 1D array to make it easier to use
|
||||||
|
logits = logits[0]
|
||||||
|
# Apply repetition penalty to all tokens that are
|
||||||
|
# currently inside the "generated" array
|
||||||
|
if repetition_penalty is not None:
|
||||||
|
logits = apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
generated,
|
||||||
|
repetition_penalty
|
||||||
|
)
|
||||||
|
# Remove any tokens in the badwords list by setting
|
||||||
|
# their logits to negative infinity which effectively
|
||||||
|
# makes their probabilities of being chosen zero
|
||||||
|
logits = logits.at[self.badwords].set(-jnp.inf)
|
||||||
|
# Use the sampler (kobold_sample) to pick one token
|
||||||
|
# based on the logits array as a 1D array with 1 element
|
||||||
|
# (higher logit means higher probability of being
|
||||||
|
# picked, non-linearly)
|
||||||
|
next_token, sample_info = kobold_sample(
|
||||||
|
sample_key,
|
||||||
|
logits,
|
||||||
|
sampler_input,
|
||||||
|
**sampler_options,
|
||||||
|
)
|
||||||
|
# Remember what token was picked so we can repetition
|
||||||
|
# penalize it next time
|
||||||
|
generated = generated.at[generated_index].set(next_token[0])
|
||||||
|
generated_index += 1
|
||||||
|
# self.return_logits isn't used in this program, but
|
||||||
|
# for the sake of compatibility...
|
||||||
|
if self.return_logits:
|
||||||
|
output = (next_token, sample_info, logits[jnp.newaxis])
|
||||||
|
else:
|
||||||
|
output = (next_token, sample_info)
|
||||||
|
# Re-pack the current generate_scan_fn's state so we can
|
||||||
|
# get back the same variables the next time
|
||||||
|
new_carry = (generated, generated_index, next_token, new_state, new_key)
|
||||||
|
return new_carry, output
|
||||||
|
# jax.lax.scan is a function that calls generate_scan_fn
|
||||||
|
# gen_length times, each time passing a state object from
|
||||||
|
# its return value (new_carry) back into one of the
|
||||||
|
# function's arguments (carry), and of course gathering the
|
||||||
|
# token it generates each time into the "outputs" array;
|
||||||
|
# we have to use jax.lax.scan instead of a normal loop
|
||||||
|
# because of JAX's JIT-compilation shenanigans
|
||||||
|
final_state, outputs = jax.lax.scan(
|
||||||
|
generate_scan_fn,
|
||||||
|
initial_state,
|
||||||
|
xs=aux,
|
||||||
|
length=gen_length,
|
||||||
|
)
|
||||||
|
return final_state, outputs
|
||||||
|
generate_fn = hk.transform(generate_sample).apply
|
||||||
|
return generate_fn(state["params"], key, ctx, ctx_length, aux)
|
||||||
|
self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'})
|
||||||
|
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False):
|
||||||
|
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||||
|
batch_size = ctx.shape[0]
|
||||||
|
aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
|
||||||
|
self.gen_length = gen_length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.return_logits = return_logits
|
||||||
|
return self.generate_xmap(
|
||||||
|
self.state,
|
||||||
|
jnp.array(key.take(batch_size)),
|
||||||
|
ctx,
|
||||||
|
np.array(ctx_length, dtype=np.uint32),
|
||||||
|
aux,
|
||||||
|
sampler_options
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def infer(context, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, repetition_penalty=1.0, numseqs=1, gen_len=80) -> List[str]:
|
||||||
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
total_batch = numseqs
|
||||||
|
tokens = tokenizer.encode(context, max_length=params["seq"], truncation=True)
|
||||||
|
provided_ctx = len(tokens)
|
||||||
|
pad_amount = seq - provided_ctx
|
||||||
|
padded_tokens = np.pad(np.asarray(tokens, dtype=np.uint32), ((pad_amount, 0),), constant_values=pad_token_id)
|
||||||
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
|
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
|
||||||
|
samples = []
|
||||||
|
batched_generator_params = {
|
||||||
|
"temp": temp * np.ones(total_batch),
|
||||||
|
"top_p": top_p * np.ones(total_batch),
|
||||||
|
"tfs": tfs * np.ones(total_batch),
|
||||||
|
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||||
|
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
||||||
|
}
|
||||||
|
output = network.generate(batched_tokens, length, gen_len, batched_generator_params)
|
||||||
|
decoded_tokens = output[1][0]
|
||||||
|
for o in decoded_tokens[:, :, 0]:
|
||||||
|
samples.append(tokenizer.decode(o))
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
|
||||||
|
global thread_resources_env, seq, tokenizer, network, params
|
||||||
|
|
||||||
|
default_params = {
|
||||||
|
"compat": "j",
|
||||||
|
"layers": 28,
|
||||||
|
"d_model": 4096,
|
||||||
|
"n_heads": 16,
|
||||||
|
"n_vocab": 50400,
|
||||||
|
"n_vocab_padding": 0,
|
||||||
|
"norm": "layernorm",
|
||||||
|
"pe": "rotary",
|
||||||
|
"pe_rotary_dims": 64,
|
||||||
|
"seq": 2048,
|
||||||
|
"cores_per_replica": 8,
|
||||||
|
}
|
||||||
|
params = kwargs
|
||||||
|
for param in default_params:
|
||||||
|
if param not in params:
|
||||||
|
params[param] = default_params[param]
|
||||||
|
|
||||||
|
print("Connecting to your Colab instance's TPU", flush=True)
|
||||||
|
spinner = multiprocessing.Process(target=show_spinner, args=())
|
||||||
|
spinner.start()
|
||||||
|
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
|
||||||
|
url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}'
|
||||||
|
requests.post(url)
|
||||||
|
spinner.terminate()
|
||||||
|
print()
|
||||||
|
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||||
|
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
|
||||||
|
|
||||||
|
cores_per_replica = params["cores_per_replica"]
|
||||||
|
seq = params["seq"]
|
||||||
|
params["optimizer"] = optax.scale(0)
|
||||||
|
mesh_shape = (1, cores_per_replica)
|
||||||
|
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||||
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
|
||||||
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
if not path.endswith("/"):
|
||||||
|
path += "/"
|
||||||
|
|
||||||
|
network = PenalizingCausalTransformer(params)
|
||||||
|
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||||
|
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
Loading…
x
Reference in New Issue
Block a user