Huggingface GPT-J Support
Finetune's fork has unofficial support which we supported, but this is not compatible with models designed for the official version. In this update we let models decide which transformers backend to use, and fall back to Neo if they don't choose any. We also add the 6B to the menu and for the time being switch to the github version of transformers to be ahead of the waiting time. (Hopefully we can switch back to the conda version before merging upstream).
This commit is contained in:
parent
72669e0489
commit
7d35f825c6
16
aiserver.py
16
aiserver.py
|
@ -44,10 +44,11 @@ class colors:
|
|||
|
||||
# AI models
|
||||
modellist = [
|
||||
["Custom Neo (GPT-Neo / Converted GPT-J)", "NeoCustom", ""],
|
||||
["Custom GPT-2 (eg CloverEdition)", "GPT2Custom", ""],
|
||||
["GPT Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "8GB"],
|
||||
["GPT Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "16GB"],
|
||||
["Load a model from its directory", "NeoCustom", ""],
|
||||
["Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom", ""],
|
||||
["GPT-Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "8GB"],
|
||||
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "16GB"],
|
||||
["GPT-J 6B (HF GIT Required)", "EleutherAI/gpt-j-6B", "24GB"],
|
||||
["GPT-2", "gpt2", "1GB"],
|
||||
["GPT-2 Med", "gpt2-medium", "2GB"],
|
||||
["GPT-2 Large", "gpt2-large", "4GB"],
|
||||
|
@ -401,7 +402,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
|
||||
# If custom GPT Neo model was chosen
|
||||
if(vars.model == "NeoCustom"):
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth)
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
js = json.load(model_config)
|
||||
if("architectures" in js):
|
||||
model = vars.custmodpth
|
||||
else:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
if(vars.hascuda):
|
||||
|
|
|
@ -11,7 +11,8 @@ dependencies:
|
|||
- python=3.8.*
|
||||
- cudatoolkit=11.1
|
||||
- tensorflow-gpu
|
||||
- transformers
|
||||
- pip
|
||||
- git
|
||||
- pip:
|
||||
- flask-cloudflared
|
||||
- flask-cloudflared
|
||||
- git+https://github.com/huggingface/transformers#transformer
|
Loading…
Reference in New Issue