Merge branch 'united' into mkultra

This commit is contained in:
vfbd 2022-09-28 14:30:34 -04:00
commit 6758d5b538
19 changed files with 1243 additions and 458 deletions

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,6 @@
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPbwW79K9/RkYH9i9rkYFyj",
"include_colab_link": true
},
"kernelspec": {
@ -68,9 +67,9 @@
"#@title <b><-- Select your model below and then click this to start KoboldAI</b>\n",
"#@markdown You can find a description of the models below along with instructions on how to start KoboldAI.\n",
"\n",
"Model = \"Nerys 2.7B\" #@param [\"Nerys 2.7B\", \"Janeway 2.7B\", \"Picard 2.7B\", \"AID 2.7B\", \"Horni LN 2.7B\", \"Horni 2.7B\", \"Shinen 2.7B\", \"Neo 2.7B\"] {allow-input: true}\n",
"Model = \"Nerys 2.7B\" #@param [\"Nerys 2.7B\", \"AID 2.7B\", \"Erebus 2.7B\", \"Janeway 2.7B\", \"Picard 2.7B\", \"Horni LN 2.7B\", \"Horni 2.7B\", \"Shinen 2.7B\", \"Neo 2.7B\"] {allow-input: true}\n",
"Version = \"Official\" #@param [\"Official\", \"United\"] {allow-input: true}\n",
"Provider = \"Cloudflare\" #@param [\"Localtunnel\", \"Cloudflare\"]\n",
"Provider = \"Localtunnel\" #@param [\"Localtunnel\", \"Cloudflare\"]\n",
"\n",
"!nvidia-smi\n",
"from google.colab import drive\n",
@ -80,11 +79,15 @@
" Model = \"KoboldAI/fairseq-dense-2.7B-Nerys\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Erebus 2.7B\":\n",
" Model = \"KoboldAI/OPT-2.7B-Erebus\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Janeway 2.7B\":\n",
" Model = \"KoboldAI/GPT-Neo-2.7B-Janeway\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Picard 2.7B\":\n",
"elif Model == \"Picard 2.7B\":\n",
" Model = \"KoboldAI/GPT-Neo-2.7B-Picard\"\n",
" path = \"\"\n",
" download = \"\"\n",
@ -156,7 +159,7 @@
"| Adventure | These models are excellent for people willing to play KoboldAI like a Text Adventure game and are meant to be used with Adventure mode enabled. Even if you wish to use it as a Novel style model you should always have Adventure mode on and set it to story. These models typically have a strong bias towards the use of the word You and without Adventure mode enabled break the story flow and write actions on your behalf. |\n",
"| Generic | Generic models are not trained towards anything specific, typically used as a basis for other tasks and models. They can do everything the other models can do, but require much more handholding to work properly. Generic models are an ideal basis for tasks that we have no specific model for, or for experiencing a softprompt in its raw form. |\n",
"\n",
"---\n",
"---\n",
"# How to start KoboldAI in 7 simple steps\n",
"Using KoboldAI on Google Colab is easy! Simply follow these steps to get started:\n",
"1. Mobile phone? Tap the play button below next to \"<--- Tap this if you play on mobile\" to reveal an audio player, play the silent audio to keep the tab alive so Google will not shut you down when your using KoboldAI. If no audio player is revealed your phone browser does not support Google Colab in the mobile view, go to your browser menu and enable Desktop mode before you continue.\n",
@ -174,4 +177,4 @@
}
}
]
}
}

View File

@ -66,7 +66,7 @@
"#@title <b><-- Select your model below and then click this to start KoboldAI</b>\n",
"#@markdown You can find a description of the models below along with instructions on how to start KoboldAI.\n",
"\n",
"Model = \"Nerys 13B V2\" #@param [\"Nerys 13B V2\", \"Janeway 13B\", \"Shinen 13B\", \"Skein 20B\", \"Skein 6B\", \"Janeway 6B\", \"Adventure 6B\", \"Shinen 6B\", \"Lit 6B\", \"NeoX 20B\", \"OPT 13B\", \"Fairseq Dense 13B\", \"GPT-J-6B\"] {allow-input: true}\n",
"Model = \"Nerys 13B V2\" #@param [\"Nerys 13B V2\", \"Erebus 13B\", \"Janeway 13B\", \"Shinen 13B\", \"Skein 20B\", \"Erebus 20B\", \"Skein 6B\", \"Janeway 6B\", \"Adventure 6B\", \"Shinen 6B\", \"Lit V2 6B\", \"Lit 6B\", \"NeoX 20B\", \"OPT 13B\", \"Fairseq Dense 13B\", \"GPT-J-6B\"] {allow-input: true}\n",
"Version = \"Official\" #@param [\"Official\", \"United\"] {allow-input: true}\n",
"Provider = \"Cloudflare\" #@param [\"Localtunnel\", \"Cloudflare\"]\n",
"\n",
@ -86,13 +86,21 @@
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Nerys 13B V2\":\n",
" Model = \"KoboldAI/fairseq-dense-13B-Nerys-v2\"\n",
" Model = \"KoboldAI/OPT-13B-Nerys-v2\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Erebus 13B\":\n",
" Model = \"KoboldAI/OPT-13B-Erebus\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Shinen 13B\":\n",
" Model = \"KoboldAI/fairseq-dense-13B-Shinen\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Erebus 20B\":\n",
" Model = \"KoboldAI/GPT-NeoX-20B-Erebus\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Skein 20B\":\n",
" Model = \"KoboldAI/GPT-NeoX-20B-Skein\"\n",
" path = \"\"\n",
@ -113,6 +121,10 @@
" Model = \"KoboldAI/GPT-J-6B-Adventure\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Lit V2 6B\":\n",
" Model = \"hakurei/litv2-6B-rev3\"\n",
" path = \"\"\n",
" download = \"\"\n",
"elif Model == \"Lit 6B\":\n",
" Model = \"hakurei/lit-6B\"\n",
" path = \"\"\n",
@ -200,7 +212,7 @@
"source": [
"#@title <b>Model Cleaner</b>\n",
"#@markdown Out of space? Run this to remove all cached models (Google Drive models are not effected).\n",
"!rm /content/KoboldAI-Client/cache/*\n"
"!rm -rf /content/KoboldAI-Client/cache/*\n"
],
"metadata": {
"cellView": "form",

0
docker-standalone/docker-helper.sh Normal file → Executable file
View File

View File

@ -18,6 +18,7 @@ dependencies:
- git=2.35.1
- marshmallow>=3.13
- apispec-webframeworks
- loguru
- pip:
- git+https://github.com/finetuneanon/transformers@gpt-neo-localattention3-rp-b
- flask-cloudflared

View File

@ -19,6 +19,7 @@ dependencies:
- protobuf
- marshmallow>=3.13
- apispec-webframeworks
- loguru
- pip:
- flask-cloudflared
- flask-ngrok

View File

@ -14,6 +14,7 @@ dependencies:
- git=2.35.1
- marshmallow>=3.13
- apispec-webframeworks
- loguru
- pip:
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
- torch

View File

@ -16,9 +16,10 @@ dependencies:
- protobuf
- marshmallow>=3.13
- apispec-webframeworks
- loguru
- pip:
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
- torch==1.10.*
- --extra-index-url https://download.pytorch.org/whl/rocm5.1.1
- torch
- torchvision
- flask-cloudflared
- flask-ngrok

View File

@ -3,6 +3,7 @@ from typing import Tuple, Union, Optional
import os
import json
import zipfile
from logger import logger
#==================================================================#
# Generic Method for prompting for file path
@ -149,16 +150,16 @@ def getspfiles(model_dimension: int):
continue
z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
if z == 1:
print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.")
logger.warning(f"Softprompt {file} is malformed or not a soft prompt ZIP file.")
continue
if z == 2:
print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.")
logger.warning(f"Softprompt {file} tensor.npy has unsupported dtype '{dtype.name}'.")
continue
if z == 3:
print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
logger.debug(f"Softprompt {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
continue
if z == 4:
print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
logger.warning(f"Softprompt {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
continue
assert isinstance(z, zipfile.ZipFile)
try:

View File

@ -241,17 +241,6 @@ gensettingstf = [
"default": 0,
"tooltip": "Causes generation to be fully deterministic -- the model will always output the same thing as long as your story, settings and RNG seed are the same. If this is off, only the sequence of outputs that the model makes will be deterministic."
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Debug",
"id": "debug",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Show debug info"
},
{
"uitype": "toggle",
"unit": "bool",
@ -285,6 +274,17 @@ gensettingstf = [
"default": 0,
"tooltip": "Shows token usage when typing in relevant text boxes. <b>May lag slower devices.</b>"
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Debug",
"id": "debug",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Show debug info"
},
]
gensettingsik =[{

99
logger.py Normal file
View File

@ -0,0 +1,99 @@
import sys
from functools import partialmethod
from loguru import logger
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
MESSAGE_LEVELS = ["MESSAGE"]
# By default we're at error level or higher
verbosity = 20
quiet = 0
def set_logger_verbosity(count):
global verbosity
# The count comes reversed. So count = 0 means minimum verbosity
# While count 5 means maximum verbosity
# So the more count we have, the lowe we drop the versbosity maximum
verbosity = 20 - (count * 10)
def quiesce_logger(count):
global quiet
# The bigger the count, the more silent we want our logger
quiet = count * 10
def is_stdout_log(record):
if record["level"].name not in STDOUT_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def is_init_log(record):
if record["level"].name not in INIT_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def is_msg_log(record):
if record["level"].name not in MESSAGE_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def is_stderr_log(record):
if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def test_logger():
logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
logger.prompt("This is a prompt message")
logger.debug("Debug Message")
logger.info("Info Message")
logger.warning("Info Warning")
logger.error("Error Message")
logger.critical("Critical Message")
logger.init("This is an init message", status="Starting")
logger.init_ok("This is an init message", status="OK")
logger.init_warn("This is an init message", status="Warning")
logger.init_err("This is an init message", status="Error")
logger.message("This is user message")
sys.exit()
logfmt = "<level>{level: <10}</level> | <green>{name}</green>:<green>{function}</green>:<green>{line}</green> - <level>{message}</level>"
genfmt = "<level>{level: <10}</level> @ <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{message}</level>"
initfmt = "<magenta>INIT </magenta> | <level>{extra[status]: <10}</level> | <magenta>{message}</magenta>"
msgfmt = "<level>{level: <10}</level> | <level>{message}</level>"
logger.level("GENERATION", no=24, color="<cyan>")
logger.level("PROMPT", no=23, color="<yellow>")
logger.level("INIT", no=31, color="<white>")
logger.level("INIT_OK", no=31, color="<green>")
logger.level("INIT_WARN", no=31, color="<yellow>")
logger.level("INIT_ERR", no=31, color="<red>")
# Messages contain important information without which this application might not be able to be used
# As such, they have the highest priority
logger.level("MESSAGE", no=61, color="<green>")
logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION")
logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT")
logger.__class__.init = partialmethod(logger.__class__.log, "INIT")
logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK")
logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN")
logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR")
logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE")
config = {
"handlers": [
{"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
{"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
{"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
{"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
],
}
logger.configure(**config)

View File

@ -50,30 +50,35 @@ Each edition features different models and requires different hardware to run, t
## [TPU Edition Model Descriptions](https://colab.research.google.com/github/KoboldAI/KoboldAI-Client/blob/main/colab/TPU.ipynb)
| Model | Size | Style | Description |
| --- | --- | --- | --- |
| [Nerys](https://huggingface.co/KoboldAI/fairseq-dense-13B-Nerys) by Mr Seeker | 13B | Novel/Adventure | Nerys is a hybrid model based on Pike (A newer Janeway), on top of the Pike dataset you also get some Light Novels, Adventure mode support and a little bit of Shinen thrown in the mix. The end result is a very diverse model that is heavily biased towards SFW novel writing, but one that can go beyond its novel training and make for an excellent adventure model to. Adventure mode is best played from a second person perspective, but can be played in first or third person as well. Novel writing can be done best from the first or third person. |
| [Janeway](https://huggingface.co/KoboldAI/fairseq-dense-13B-Janeway) by Mr Seeker | 13B | Novel | Janeway is a model created from Picard's dataset combined with a brand new collection of ebooks. This model is trained on 20% more content than Picard and has been trained on literature from various genres. Although the model is mainly focussed on SFW, romantic scenes might involve a degree of nudity. |
| [Shinen](https://huggingface.co/KoboldAI/fairseq-dense-13B-Shinen) by Mr Seeker | 13B | NSFW | Shinen is an NSFW model designed to be more explicit. Trained on a variety of stories from the website Sexstories it contains many different kinks. |
| [Skein](https://huggingface.co/KoboldAI/GPT-J-6B-Skein) by VE\_FORBRYDERNE | 6B | Adventure | Skein is best used with Adventure mode enabled, it consists of a 4 times larger adventure dataset than the Adventure model making it excellent for text adventure gaming. On top of that it also consists of light novel training further expanding its knowledge and writing capabilities. It can be used with the You filter bias if you wish to write Novels with it, but dedicated Novel models can perform better for this task. |
| [Adventure](https://huggingface.co/KoboldAI/GPT-J-6B-Adventure) by VE\_FORBRYDERNE | 6B | Adventure | Adventure is a 6B model designed to mimick the behavior of AI Dungeon. It is exclusively for Adventure Mode and can take you on the epic and wackey adventures that AI Dungeon players love. It also features the many tropes of AI Dungeon as it has been trained on very similar data. It must be used in second person (You). |
| [Lit](https://huggingface.co/hakurei/lit-6B) by Haru | 6B | NSFW | Lit is a great NSFW model trained by Haru on both a large set of Literotica stories and high quality novels along with tagging support. Creating a high quality model for your NSFW stories. This model is exclusively a novel model and is best used in third person. |
| Neo(X) by EleutherAI | 20B | Generic | NeoX is the largest EleutherAI model currently available, being a generic model it is not particularly trained towards anything and can do a variety of writing, Q&A and coding tasks. 20B's performance is closely compared to the 13B models and it is worth trying both especially if you have a task that does not involve english writing. Its behavior will be similar to the GPT-J-6B model since they are trained on the same dataset but with more sensitivity towards repetition penalty and with more knowledge. |
| [Fairseq Dense](https://huggingface.co/KoboldAI/fairseq-dense-13B) | 13B | Generic | Trained by Facebook Researchers this model stems from the MOE research project within Fairseq. This particular version has been converted by us for use in KoboldAI. It is known to be on par with the larger 20B model from EleutherAI and considered as better for pop culture and language tasks. Because the model has never seen a new line (enter) it may perform worse on formatting and paragraphing. |
| [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) by EleutherAI | 6B | Generic | This model serves as the basis for most other 6B models (Some being based on Fairseq Dense instead). Being trained on the Pile and not biased towards anything in particular it is suitable for a variety of tasks such as writing, Q&A and coding tasks. You will likely get better result with larger generic models or finetuned models. |
| Model | Style | Description |
| --- | --- | --- |
| [Nerys](https://huggingface.co/KoboldAI/fairseq-dense-13B-Nerys) by Mr Seeker | Novel/Adventure | Nerys is a hybrid model based on Pike (A newer Janeway), on top of the Pike dataset you also get some Light Novels, Adventure mode support and a little bit of Shinen thrown in the mix. The end result is a very diverse model that is heavily biased towards SFW novel writing, but one that can go beyond its novel training and make for an excellent adventure model to. Adventure mode is best played from a second person perspective, but can be played in first or third person as well. Novel writing can be done best from the first or third person. |
| [Erebus](https://huggingface.co/KoboldAI/OPT-13B-Erebus) by Mr Seeker | NSFW | Erebus is our community's flagship NSFW model, being a combination of multiple large datasets that include Literotica, Shinen and erotic novels from Nerys and featuring thourough tagging support it covers the vast majority of erotic writing styles. This model is capable of replacing both the Lit and Shinen models in terms of content and style and has been well received as (one of) the best NSFW models out there. If you wish to use this model for commercial or non research usage we recommend choosing the 20B version as that one is not subject to the restrictive OPT license. |
| [Janeway](https://huggingface.co/KoboldAI/fairseq-dense-13B-Janeway) by Mr Seeker | Novel | Janeway is a model created from Picard's dataset combined with a brand new collection of ebooks. This model is trained on 20% more content than Picard and has been trained on literature from various genres. Although the model is mainly focussed on SFW, romantic scenes might involve a degree of nudity. |
| [Shinen](https://huggingface.co/KoboldAI/fairseq-dense-13B-Shinen) by Mr Seeker | NSFW | Shinen is an NSFW model trained on a variety of stories from the website Sexstories it contains many different kinks. It has been merged into the larger (and better) Erebus model. |
| [Skein](https://huggingface.co/KoboldAI/GPT-J-6B-Skein) by VE\_FORBRYDERNE | Adventure | Skein is best used with Adventure mode enabled, it consists of a 4 times larger adventure dataset than the Adventure model making it excellent for text adventure gaming. On top of that it also consists of light novel training further expanding its knowledge and writing capabilities. It can be used with the You filter bias if you wish to write Novels with it, but dedicated Novel models can perform better for this task. |
| [Adventure](https://huggingface.co/KoboldAI/GPT-J-6B-Adventure) by VE\_FORBRYDERNE | Adventure | Adventure is a 6B model designed to mimick the behavior of AI Dungeon. It is exclusively for Adventure Mode and can take you on the epic and wackey adventures that AI Dungeon players love. It also features the many tropes of AI Dungeon as it has been trained on very similar data. It must be used in second person (You). |
| [Lit](https://huggingface.co/hakurei/lit-6B) ([V2](https://huggingface.co/hakurei/litv2-6B-rev3)) by Haru | NSFW | Lit is a great NSFW model trained by Haru on both a large set of Literotica stories and high quality novels along with tagging support. Creating a high quality model for your NSFW stories. This model is exclusively a novel model and is best used in third person. |
| [OPT](https://huggingface.co/facebook/opt-13b) by Metaseq | Generic | OPT is considered one of the best base models as far as content goes, its behavior has the strengths of both GPT-Neo and Fairseq Dense. Compared to Neo duplicate and unnecessary content has been left out, while additional literature was added in similar to the Fairseq Dense model. The Fairseq Dense model however lacks the broader data that OPT does have. The biggest downfall of OPT is its license, which prohibits any commercial usage, or usage beyond research purposes. |
| [Neo(X)](https://huggingface.co/EleutherAI/gpt-neox-20b) by EleutherAI | Generic | NeoX is the largest EleutherAI model currently available, being a generic model it is not particularly trained towards anything and can do a variety of writing, Q&A and coding tasks. 20B's performance is closely compared to the 13B models and it is worth trying both especially if you have a task that does not involve english writing. Its behavior will be similar to the GPT-J-6B model since they are trained on the same dataset but with more sensitivity towards repetition penalty and with more knowledge. |
| [Fairseq Dense](https://huggingface.co/KoboldAI/fairseq-dense-13B) | Generic | Trained by Facebook Researchers this model stems from the MOE research project within Fairseq. This particular version has been converted by us for use in KoboldAI. It is known to be on par with the larger 20B model from EleutherAI and considered as better for pop culture and language tasks. Because the model has never seen a new line (enter) it may perform worse on formatting and paragraphing. Compared to other models the dataset focuses primarily on literature and contains little else. |
| [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) by EleutherAI | Generic | This model serves as the basis for most other 6B models (Some being based on Fairseq Dense instead). Being trained on the Pile and not biased towards anything in particular it is suitable for a variety of tasks such as writing, Q&A and coding tasks. You will likely get better result with larger generic models or finetuned models. |
## [GPU Edition Model Descriptions](https://colab.research.google.com/github/KoboldAI/KoboldAI-Client/blob/main/colab/GPU.ipynb)
| Model | Size | Style | Description |
| --- | --- | --- | --- |
| [Nerys 2.7B](https://huggingface.co/KoboldAI/fairseq-dense-2.7B-Nerys) by Mr Seeker | 2.7B | Novel/Adventure | Nerys is a hybrid model based on Pike (A newer Janeway), on top of the Pike dataset you also get some Light Novels, Adventure mode support and a little bit of Shinen thrown in the mix. The end result is a very diverse model that is heavily biased towards SFW novel writing, but one that can go beyond its novel training and make for an excellent adventure model to. Adventure mode is best played from a second person perspective, but can be played in first or third person as well. Novel writing can be done best from the first or third person. |
| [Janeway 2.7B](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Janeway) by Mr Seeker | 2.7B | Novel | Janeway is a model created from Picard's dataset combined with a brand new collection of ebooks. This model is trained on 20% more content than Picard and has been trained on literature from various genres. Although the model is mainly focussed on SFW, romantic scenes might involve a degree of nudity. |
| [Picard 2.7B](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Picard) by Mr Seeker | 2.7B | Novel | Picard is a model trained for SFW Novels based on Neo 2.7B. It is focused on Novel style writing without the NSFW bias. While the name suggests a sci-fi model this model is designed for Novels of a variety of genre's. It is meant to be used in KoboldAI's regular mode. |
| [AID 2.7B](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-AID) by melastacho | 2.7B | Adventure | Also know as Adventure 2.7B this is a clone of the AI Dungeon Classic model and is best known for the epic wackey adventures that AI Dungeon Classic players love. |
| [Horni LN 2.7B](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Horni-LN) by finetune | 2.7B | Novel | This model is based on Horni 2.7B and retains its NSFW knowledge, but was then further biased towards SFW novel stories. If you seek a balance between a SFW Novel model and a NSFW model this model should be a good choice. |
| [Horni 2.7B](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Horni) by finetune | 2.7B | NSFW | This model is tuned on Literotica to produce a Novel style model biased towards NSFW content. Can still be used for SFW stories but will have a bias towards NSFW content. It is meant to be used in KoboldAI's regular mode. |
| [Shinen 2.7B ](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Shinen) by Mr Seeker | 2.7B | NSFW | Shinen is an alternative to the Horni model designed to be more explicit. If Horni is to tame for you Shinen might produce better results. While it is a Novel model it is unsuitable for SFW stories due to its heavy NSFW bias. Shinen will not hold back. It is meant to be used in KoboldAI's regular mode. |
| [Neo 2.7B](https://huggingface.co/EleutherAI/gpt-neo-2.7B) by EleutherAI | 2.7B | Generic | This is the base model for all the other 2.7B models, it is best used when you have a use case that we have no other models available for, such as writing blog articles or programming. It can also be a good basis for the experience of some of the softprompts if your softprompt is not about a subject the other models cover. |
| Model | Style | Description |
| --- | --- | --- |
| [Nerys](https://huggingface.co/KoboldAI/fairseq-dense-2.7B-Nerys) by Mr Seeker | Novel/Adventure | Nerys is a hybrid model based on Pike (A newer Janeway), on top of the Pike dataset you also get some Light Novels, Adventure mode support and a little bit of Shinen thrown in the mix. The end result is a very diverse model that is heavily biased towards SFW novel writing, but one that can go beyond its novel training and make for an excellent adventure model to. Adventure mode is best played from a second person perspective, but can be played in first or third person as well. Novel writing can be done best from the first or third person. |
| [Erebus](https://huggingface.co/KoboldAI/OPT-2.7B-Erebus) by Mr Seeker | NSFW | Erebus is our community's flagship NSFW model, being a combination of multiple large datasets that include Literotica, Shinen and erotic novels from Nerys and featuring thourough tagging support it covers the vast majority of erotic writing styles. This model is capable of replacing both the Lit and Shinen models in terms of content and style and has been well received as (one of) the best NSFW models out there. If you wish to use this model for commercial or non research usage we recommend choosing the 20B version as that one is not subject to the restrictive OPT license. |
| [Janeway](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Janeway) by Mr Seeker | Novel | Janeway is a model created from Picard's dataset combined with a brand new collection of ebooks. This model is trained on 20% more content than Picard and has been trained on literature from various genres. Although the model is mainly focussed on SFW, romantic scenes might involve a degree of nudity. |
| [Picard](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Picard) by Mr Seeker | Novel | Picard is a model trained for SFW Novels based on Neo 2.7B. It is focused on Novel style writing without the NSFW bias. While the name suggests a sci-fi model this model is designed for Novels of a variety of genre's. It is meant to be used in KoboldAI's regular mode. |
| [AID](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-AID) by melastacho | Adventure | Also know as Adventure 2.7B this is a clone of the AI Dungeon Classic model and is best known for the epic wackey adventures that AI Dungeon Classic players love. |
| [Horni LN](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Horni-LN) by finetune | Novel | This model is based on Horni 2.7B and retains its NSFW knowledge, but was then further biased towards SFW novel stories. If you seek a balance between a SFW Novel model and a NSFW model this model should be a good choice. |
| [Horni](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Horni) by finetune | NSFW | This model is tuned on Literotica to produce a Novel style model biased towards NSFW content. Can still be used for SFW stories but will have a bias towards NSFW content. It is meant to be used in KoboldAI's regular mode. |
| [Shinen](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Shinen) by Mr Seeker | NSFW | Shinen is an alternative to the Horni model designed to be more explicit. If Horni is to tame for you Shinen might produce better results. While it is a Novel model it is unsuitable for SFW stories due to its heavy NSFW bias. Shinen will not hold back. It is meant to be used in KoboldAI's regular mode. |
| [OPT](https://huggingface.co/facebook/opt-2.7b) by Metaseq | Generic | OPT is considered one of the best base models as far as content goes, its behavior has the strengths of both GPT-Neo and Fairseq Dense. Compared to Neo duplicate and unnecessary content has been left out, while additional literature was added in similar to the Fairseq Dense model. The Fairseq Dense model however lacks the broader data that OPT does have. The biggest downfall of OPT is its license, which prohibits any commercial usage, or usage beyond research purposes. |
| [Fairseq Dense](https://huggingface.co/KoboldAI/fairseq-dense-2.7B) | Generic | Trained by Facebook Researchers this model stems from the MOE research project within Fairseq. This particular version has been converted by us for use in KoboldAI. It is known to be on par with the larger models from EleutherAI and considered as better for pop culture and language tasks. Because the model has never seen a new line (enter) it may perform worse on formatting and paragraphing. Compared to other models the dataset focuses primarily on literature and contains little else. |
| [Neo](https://huggingface.co/EleutherAI/gpt-neo-2.7B) by EleutherAI | Generic | This is the base model for all the other 2.7B models, it is best used when you have a use case that we have no other models available for, such as writing blog articles or programming. It can also be a good basis for the experience of some of the softprompts if your softprompt is not about a subject the other models cover. |
### Styles
@ -192,14 +197,21 @@ Lastly the all the features of our userscript API are documented inside the API
For our TPU versions keep in mind that scripts modifying AI behavior relies on a different way of processing that is slower than if you leave these userscripts disabled even if your script only sporadically uses this modifier. If you want to partially use a script at its full speed than you can enable "No Gen Modifiers" to ensure that the parts that would make the TPU slow are not active.
## API
KoboldAI has a REST API that can be accessed by adding /api to the URL that Kobold provides you (For example http://127.0.0.1:5000/api).
When accessing this link in a browser you will be taken to the interactive documentation.
## Contributors
This project contains work from the following contributors :
* The Gantian - Creator of KoboldAI, has created most features such as the interface, the different AI model / API integrations and in general the largest part of the project.
* VE FORBRYDERNE - Contributed many features such as the Editing overhaul, Adventure Mode, expansions to the world info section, breakmodel integration, scripting support, softpromtps and much more. As well as vastly improving the TPU compatibility and integrating external code into KoboldAI so we could use official versions of Transformers with virtually no downsides.
* VE FORBRYDERNE - Contributed many features such as the Editing overhaul, Adventure Mode, expansions to the world info section, breakmodel integration, scripting support, API, softpromtps and much more. As well as vastly improving the TPU compatibility and integrating external code into KoboldAI so we could use official versions of Transformers with virtually no downsides.
* Henk717 - Contributed the installation scripts, this readme, random story generator, the docker scripts, the foundation for the commandline interface and other smaller changes as well as integrating multiple parts of the code of different forks to unite it all. He also optimized the model loading so that downloaded models get converted to efficient offline models and that in future models are more likely to work out of the box. Not all code Github attributes to Henk717 is by Henk717 as some of it has been integrations of other people's work. We try to clarify this in the contributors list as much as we can.
* Ebolam - Automatic Saving, back/redo, pinning, web loading of models
* one-some, Logits Viewer and Token Streaming
* db0, KoboldAI Horde
* Frogging101 - top\_k / tfs support (Part of this support was later redone by VE to integrate what was originally inside of finetuneanon's transformers)
* UWUplus (Ralf) - Contributed storage systems for community colabs, as well as cleaning up and integrating the website dependencies/code better. He is also the maintainer of flask-cloudflared which we use to generate the cloudflare links.
* Javalar - Initial Performance increases on the story\_refresh
@ -216,4 +228,4 @@ Did we miss your contribution? Feel free to issue a commit adding your name to t
KoboldAI is licensed with a AGPL license, in short this means that it can be used by anyone for any purpose. However, if you decide to make a publicly available instance your users are entitled to a copy of the source code including all modifications that you have made (which needs to be available trough an interface such as a button on your website), you may also not distribute this project in a form that does not contain the source code (Such as compiling / encrypting the code and distributing this version without also distributing the source code that includes the changes that you made. You are allowed to distribute this in a closed form if you also provide a separate archive with the source code.).
umamba.exe is bundled for convenience because we observed that many of our users had trouble with command line download methods, it is not part of our project and does not fall under the AGPL license. It is licensed under the BSD-3-Clause license. Other files with differing licenses will have a reference or embedded version of this license within the file.
umamba.exe is bundled for convenience because we observed that many of our users had trouble with command line download methods, it is not part of our project and does not fall under the AGPL license. It is licensed under the BSD-3-Clause license. Other files with differing licenses will have a reference or embedded version of this license within the file. It has been sourced from https://anaconda.org/conda-forge/micromamba/files and its source code can be found here : https://github.com/mamba-org/mamba/tree/master/micromamba

View File

@ -2,7 +2,7 @@ transformers>=4.20.1
Flask
Flask-SocketIO
requests
torch==1.11
torch >= 1.9, < 1.13
flask-cloudflared
flask-ngrok
eventlet
@ -15,3 +15,4 @@ accelerate
flask-session
marshmallow>=3.13
apispec-webframeworks
loguru

View File

@ -1,12 +1,11 @@
torch >= 1.9, <= 1.11
torch >= 1.9, < 1.13
numpy
tqdm
requests
optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5
jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7
transformers >= 4.19
transformers >= 4.20.1
progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
flask
@ -20,3 +19,4 @@ bleach==4.1.0
flask-session
marshmallow>=3.13
apispec-webframeworks
loguru

View File

@ -107,6 +107,9 @@ var modelname = null;
var model = "";
var ignore_stream = false;
//timer for loading CLUSTER models
var online_model_timmer;
// This is true iff [we're in macOS and the browser is Safari] or [we're in iOS]
var using_webkit_patch = true;
@ -1003,7 +1006,7 @@ function hideSaveAsPopup() {
}
function sendSaveAsRequest() {
socket.send({'cmd': 'saveasrequest', 'data': {"name": saveasinput.val(), "pins": savepins.val()}});
socket.send({'cmd': 'saveasrequest', 'data': {"name": saveasinput.val(), "pins": savepins.prop('checked')}});
}
function showLoadModelPopup() {
@ -1643,26 +1646,29 @@ function chunkOnBeforeInput(event) {
if(buildChunkSetFromNodeArray(getSelectedNodes()).size === 0) {
var s = rangy.getSelection();
var r = s.getRangeAt(0);
var rand = Math.random();
if(document.queryCommandSupported && document.execCommand && document.queryCommandSupported('insertHTML')) {
document.execCommand('insertHTML', false, '<span id="_EDITOR_SENTINEL_">|</span>');
document.execCommand('insertHTML', false, '<span id="_EDITOR_SENTINEL_' + rand + '_">|</span>');
} else {
var t = document.createTextNode('|');
var b = document.createElement('span');
b.id = "_EDITOR_SENTINEL_";
b.id = "_EDITOR_SENTINEL_" + rand + "_";
b.insertNode(t);
r.insertNode(b);
}
var sentinel = document.getElementById("_EDITOR_SENTINEL_");
if(sentinel.nextSibling && sentinel.nextSibling.tagName === "CHUNK") {
r.selectNodeContents(sentinel.nextSibling);
r.collapse(true);
} else if(sentinel.previousSibling && sentinel.previousSibling.tagName === "CHUNK") {
r.selectNodeContents(sentinel.previousSibling);
r.collapse(false);
}
s.removeAllRanges();
s.addRange(r);
sentinel.parentNode.removeChild(sentinel);
setTimeout(function() {
var sentinel = document.getElementById("_EDITOR_SENTINEL_" + rand + "_");
if(sentinel.nextSibling && sentinel.nextSibling.tagName === "CHUNK") {
r.selectNodeContents(sentinel.nextSibling);
r.collapse(true);
} else if(sentinel.previousSibling && sentinel.previousSibling.tagName === "CHUNK") {
r.selectNodeContents(sentinel.previousSibling);
r.collapse(false);
}
s.removeAllRanges();
s.addRange(r);
sentinel.parentNode.removeChild(sentinel);
}, 1);
}
}
@ -2708,6 +2714,9 @@ $(document).ready(function(){
} else if(msg.cmd == "updateoutputstreaming") {
// Update toggle state
$("#setoutputstreaming").prop('checked', msg.data).change();
} else if(msg.cmd == "updateshowbudget") {
// Update toggle state
$("#setshowbudget").prop('checked', msg.data).change();
} else if(msg.cmd == "updateshowprobs") {
$("#setshowprobs").prop('checked', msg.data).change();
@ -2847,17 +2856,17 @@ $(document).ready(function(){
chat_name.val(msg.data);
} else if(msg.cmd == "setlabelnumseq") {
// Update setting label with value from server
$("#setnumseqcur").html(msg.data);
$("#setnumseqcur").val(msg.data);
} else if(msg.cmd == "updatenumseq") {
// Send current max tokens value to input
$("#setnumseqcur").html(msg.data);
$("#setnumseqcur").val(msg.data);
$("#setnumseq").val(parseInt(msg.data)).trigger("change");
} else if(msg.cmd == "setlabelwidepth") {
// Update setting label with value from server
$("#setwidepthcur").html(msg.data);
$("#setwidepthcur").val(msg.data);
} else if(msg.cmd == "updatewidepth") {
// Send current max tokens value to input
$("#setwidepthcur").html(msg.data);
$("#setwidepthcur").val(msg.data);
$("#setwidepth").val(parseInt(msg.data)).trigger("change");
} else if(msg.cmd == "updateuseprompt") {
// Update toggle state
@ -2912,20 +2921,48 @@ $(document).ready(function(){
$("#oaimodel").addClass("hidden")
buildLoadModelList(msg.data, msg.menu, msg.breadcrumbs, msg.showdelete);
} else if(msg.cmd == 'selected_model_info') {
console.log(msg);
enableButtons([load_model_accept]);
$("#oaimodel").addClass("hidden")
$("#oaimodel")[0].options[0].selected = true;
if (msg.key) {
$("#modelkey").removeClass("hidden");
$("#modelkey")[0].value = msg.key_value;
if (msg.models_on_url) {
$("#modelkey")[0].oninput = function() {clearTimeout(online_model_timmer);
online_model_timmer = setTimeout(function() {
socket.send({'cmd': 'Cluster_Key_Update', 'key': document.getElementById("modelkey").value,
'url': document.getElementById("modelurl").value});
}, 1000);
}
$("#modelkey")[0].onblur = function () {socket.send({'cmd': 'Cluster_Key_Update', 'key': this.value, 'url': document.getElementById("modelurl").value});};
$("#modelurl")[0].onblur = function () {socket.send({'cmd': 'Cluster_Key_Update', 'key': document.getElementById("modelkey").value, 'url': this.value});};
} else {
$("#modelkey")[0].onblur = function () {socket.send({'cmd': 'OAI_Key_Update', 'key': $('#modelkey')[0].value});};
$("#modelurl")[0].onblur = null;
}
//if we're in the API list, disable to load button until the model is selected (after the API Key is entered)
disableButtons([load_model_accept]);
} else {
$("#modelkey").addClass("hidden");
}
console.log(msg.multi_online_models);
if (msg.multi_online_models) {
$("#oaimodel")[0].setAttribute("multiple", "");
$("#oaimodel")[0].options[0].textContent = "All"
} else {
$("#oaimodel")[0].removeAttribute("multiple");
$("#oaimodel")[0].options[0].textContent = "Select Model(s)"
}
if (msg.url) {
$("#modelurl").removeClass("hidden");
if (msg.default_url != null) {
document.getElementById("modelurl").value = msg.default_url;
}
} else {
$("#modelurl").addClass("hidden");
}
@ -3286,7 +3323,11 @@ $(document).ready(function(){
}
}
var disk_layers = $("#disk_layers").length > 0 ? $("#disk_layers")[0].value : 0;
message = {'cmd': 'load_model', 'use_gpu': $('#use_gpu')[0].checked, 'key': $('#modelkey')[0].value, 'gpu_layers': gpu_layers.slice(0, -1), 'disk_layers': disk_layers, 'url': $('#modelurl')[0].value, 'online_model': $('#oaimodel')[0].value};
models = getSelectedOptions(document.getElementById('oaimodel'));
if (models.length == 1) {
models = models[0];
}
message = {'cmd': 'load_model', 'use_gpu': $('#use_gpu')[0].checked, 'key': $('#modelkey')[0].value, 'gpu_layers': gpu_layers.slice(0, -1), 'disk_layers': disk_layers, 'url': $('#modelurl')[0].value, 'online_model': models};
socket.send(message);
loadmodelcontent.html("");
hideLoadModelPopup();
@ -3732,3 +3773,27 @@ function upload_file(file_box) {
}
}
function getSelectedOptions(element) {
// validate element
if(!element || !element.options)
return []; //or null?
// return HTML5 implementation of selectedOptions instead.
if (element.selectedOptions) {
selectedOptions = element.selectedOptions;
} else {
// you are here because your browser doesn't have the HTML5 selectedOptions
var opts = element.options;
var selectedOptions = [];
for(var i = 0; i < opts.length; i++) {
if(opts[i].selected) {
selectedOptions.push(opts[i]);
}
}
}
output = []
for (item of selectedOptions) {
output.push(item.value);
}
return output;
}

View File

@ -1,20 +0,0 @@
async function getRequest(url='') {
const response = await fetch(url, {
method: 'GET',
cache: 'no-cache'
})
}
document.addEventListener('DOMContentLoaded', function() {
let url = document.location
let route = "/flaskwebgui-keep-server-alive";
let interval_request = 3 * 1000; //sec
function keep_alive_server(){
getRequest(url + route);
}
setInterval(keep_alive_server, interval_request);
})

View File

@ -18,11 +18,8 @@
<script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.18.1e"></script>
<script src="static/application.js?ver=1.18.1f"></script>
<script src="static/favicon.js"></script>
{% if flaskwebgui %}
<script src="static/flask_web_gui.js"></script>
{% endif %}
</head>
<body>
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
@ -295,12 +292,12 @@
<div id="loadmodellistcontent" style="overflow: auto; height: 300px;">
</div>
<div class="popupfooter">
<input class="form-control hidden" type="text" placeholder="key" id="modelkey" onblur="socket.send({'cmd': 'OAI_Key_Update', 'key': $('#modelkey')[0].value});">
<input class="form-control hidden" type="text" placeholder="Enter the URL of the server (For example a trycloudflare link)" id="modelurl" onchange="check_enable_model_load()">
<input class="form-control hidden" type="text" placeholder="key" id="modelkey" onblur="socket.send({'cmd': 'OAI_Key_Update', 'key': $('#modelkey')[0].value});">
<input class="form-control hidden" type="text" placeholder="Model Path or Hugging Face Name" id="custommodelname" menu="" onblur="socket.send({'cmd': 'selectmodel', 'data': $(this).attr('menu'), 'path_modelname': $('#custommodelname')[0].value});">
</div>
<div class="popupfooter">
<select class="form-control hidden" id="oaimodel"><option value="">Select OAI Model</option></select>
<select class="form-control hidden" id="oaimodel"><option value="">Select Model(s)</option></select>
</div>
<div class="popupfooter hidden" id=modellayers>
<div class='settingitem' style="width:100%">

View File

@ -30,7 +30,7 @@ SOFTWARE.
import utils
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
import progressbar
import time
import os
@ -45,9 +45,8 @@ from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from transformers import AutoTokenizer, GPT2Tokenizer, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
@ -136,6 +135,14 @@ def __batch_xmap(shard_dim=1):
return inner
class _EmptyState(NamedTuple):
pass
class _DummyOptimizer:
def init(*args, **kwargs):
return _EmptyState()
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
@ -533,7 +540,7 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
gen_length,
rpslope,
rprange,
)
),
**sampler_options,
)
# Remember what token was picked
@ -1054,7 +1061,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
"pe_rotary_dims": 64,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer_class": "GPT2Tokenizer",
"tokenizer": "gpt2",
}
params = kwargs
@ -1072,7 +1079,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
"pe_rotary_dims": 24,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer_class": "GPT2Tokenizer",
"tokenizer": "gpt2",
}
@ -1167,7 +1174,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]
params["optimizer"] = optax.scale(0)
params["optimizer"] = _DummyOptimizer()
mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
@ -1343,49 +1350,46 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
print("\n", flush=True)
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
try:
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
try:
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
try:
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
except Exception as e:

397
utils.py
View File

@ -4,6 +4,7 @@ import shutil
import json
import subprocess
import tempfile
from urllib.error import HTTPError
import requests
import requests.adapters
import time
@ -13,6 +14,10 @@ import packaging.version
from tqdm.auto import tqdm
import os
import itertools
import hashlib
import huggingface_hub
import packaging.version
from pathlib import Path
from typing import List, Optional
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
@ -176,91 +181,28 @@ def num_layers(config):
# Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
from flask_socketio import emit
class Send_to_socketio(object):
def write(self, bar):
time.sleep(0.01)
try:
print(bar)
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True)
except:
pass
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
class Send_to_socketio(object):
def write(self, bar):
bar = bar.replace("\r", "").replace("\n", "")
if bar != "":
try:
print('\r' + bar, end='')
try:
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True)
except:
pass
eventlet.sleep(seconds=0)
except:
pass
def flush(self):
pass
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return
if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
return
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
aria2_port = 6799 if vars is None else vars.aria2_port
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
def is_cached(url):
try:
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
lengths = {}
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None
@ -270,7 +212,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)
f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
while p.poll() is None:
r = s.post(f"http://localhost:{aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
if not r:
@ -306,6 +248,291 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
code = p.wait()
if not done and code:
raise OSError(f"aria2 exited with exit code {code}")
def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model"))
os.makedirs(storage_folder, exist_ok=True)
def is_cached(filename):
try:
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True)
except ValueError:
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
if not force_download:
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
if not urls:
return
blob_paths = []
# This section is a modified version of hf_hub_download from huggingface_hub
# See https://github.com/huggingface/huggingface_hub/blob/main/LICENSE for license
for u, n in zip(urls, filenames):
relative_filename = os.path.join(*n.split("/"))
if not local_files_only:
try:
r = huggingface_hub.file_download._request_wrapper(
method="HEAD",
url=u,
headers=headers,
allow_redirects=False,
follow_relative_redirects=True,
proxies=proxies,
timeout=10,
)
try:
r.raise_for_status()
except HTTPError as e:
error_code = r.headers.get("X-Error-Code")
if error_code != "EntryNotFound":
raise RuntimeError(f"HEAD {u} failed with error code {r.status_code}")
commit_hash = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT)
if commit_hash is not None:
no_exist_file_path = (
Path(storage_folder)
/ ".no_exist"
/ commit_hash
/ relative_filename
)
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
no_exist_file_path.touch()
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(
storage_folder, _revision, commit_hash
)
raise
commit_hash = r.headers[huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT]
if commit_hash is None:
raise OSError(
"Distant resource does not seem to be on huggingface.co (missing"
" commit header)."
)
etag = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get(
"ETag"
)
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to"
" reliably ensure reproducibility."
)
etag = huggingface_hub.file_download._normalize_etag(etag)
# In case of a redirect, save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
# Useful for lfs blobs that are stored on a CDN.
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
if (
"lfs.huggingface.co" in url_to_download
or "lfs-staging.huggingface.co" in url_to_download
):
# Remove authorization header when downloading a LFS blob
headers.pop("authorization", None)
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
# Actually raise for those subclasses of ConnectionError
raise
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
huggingface_hub.file_download.OfflineModeIsEnabled,
):
# Otherwise, our Internet connection is down.
# etag is None
pass
if etag is None:
# In those cases, we cannot force download.
if force_download:
raise ValueError(
"We have no connection or you passed local_files_only, so"
" force_download is not an accepted option."
)
if huggingface_hub.file_download.REGEX_COMMIT_HASH.match(_revision):
commit_hash = _revision
else:
ref_path = os.path.join(storage_folder, "refs", _revision)
with open(ref_path) as f:
commit_hash = f.read()
pointer_path = os.path.join(
storage_folder, "snapshots", commit_hash, relative_filename
)
if os.path.exists(pointer_path):
return pointer_path
# If we couldn't find an appropriate file on disk,
# raise an error.
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise huggingface_hub.file_download.LocalEntryNotFoundError(
"Cannot find the requested files in the disk cache and"
" outgoing traffic has been disabled. To enable hf.co look-ups"
" and downloads online, set 'local_files_only' to False."
)
else:
raise huggingface_hub.file_download.LocalEntryNotFoundError(
"Connection error, and we cannot find the requested files in"
" the disk cache. Please try again or make sure your Internet"
" connection is on."
)
# From now on, etag and commit_hash are not None.
blob_path = os.path.join(storage_folder, "blobs", etag)
pointer_path = os.path.join(
storage_folder, "snapshots", commit_hash, relative_filename
)
os.makedirs(os.path.dirname(blob_path), exist_ok=True)
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
# if passed revision is not identical to commit_hash
# then revision has to be a branch name or tag name.
# In that case store a ref.
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(storage_folder, _revision, commit_hash)
if os.path.exists(pointer_path) and not force_download:
return pointer_path
if os.path.exists(blob_path) and not force_download:
# we have the blob already, but not the pointer
huggingface_hub.file_download.logger.info("creating pointer to %s from %s", blob_path, pointer_path)
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
return pointer_path
# Some Windows versions do not allow for paths longer than 255 characters.
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
blob_path = "\\\\?\\" + os.path.abspath(blob_path)
blob_paths.append(blob_path)
filenames = blob_paths
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
for n in filenames:
prefix, suffix = n.rsplit(os.sep, 1)
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(prefix, "kai-tempfile." + suffix)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit(os.sep, 1)]).encode()
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
for u, n in zip(urls, filenames):
prefix, suffix = n.rsplit(os.sep, 1)
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return
if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
return
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.22.0.dev0"):
return _transformers22_aria2_hook(pretrained_model_name_or_path, force_download=force_download, cache_dir=cache_dir, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, **kwargs)
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
def is_cached(url):
try:
huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True)
except ValueError:
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)]
for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
_download_with_aria2(aria2_config, total_length, directory=_cache_dir, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
@ -325,10 +552,10 @@ def get_num_shards(filename):
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers.modeling_utils
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
#==================================================================#
@ -379,4 +606,4 @@ def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[s
else:
recurse(c[1], head=name + ".")
recurse(model)
return missing_names
return missing_names