mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
88
aiserver.py
88
aiserver.py
@@ -2804,6 +2804,36 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
|
||||
global breakmodel
|
||||
import breakmodel
|
||||
elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]:
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
if koboldai_vars.model == "API":
|
||||
tokenizer_id = requests.get(
|
||||
koboldai_vars.colaburl[:-8] + "/api/v1/model",
|
||||
).json()["result"]
|
||||
else:
|
||||
tokenizer_id = {
|
||||
"Colab": "EleutherAI/gpt-neo-2.7B",
|
||||
"CLUSTER": koboldai_vars.cluster_requested_models[0],
|
||||
"OAI": "gpt2",
|
||||
}[koboldai_vars.model]
|
||||
|
||||
# TODO: This should probably be a bit more robust of a check.
|
||||
koboldai_vars.newlinemode = "n"
|
||||
if "xglm" in tokenizer_id:
|
||||
# Default to </s> newline mode if using XGLM
|
||||
koboldai_vars.newlinemode = "s"
|
||||
if "opt" in tokenizer_id or "bloom" in tokenizer_id:
|
||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||
koboldai_vars.newlinemode = "ns"
|
||||
|
||||
print(tokenizer_id, koboldai_vars.newlinemode)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=koboldai_vars.revision, cache_dir="cache")
|
||||
|
||||
loadsettings()
|
||||
koboldai_vars.colaburl = url or koboldai_vars.colaburl
|
||||
koboldai_vars.usegpu = False
|
||||
koboldai_vars.breakmodel = False
|
||||
elif (not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
if(not koboldai_vars.noai):
|
||||
logger.init("Transformers", status='Starting')
|
||||
@@ -3273,19 +3303,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
"rprange": int(koboldai_vars.rep_pen_range),
|
||||
}
|
||||
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
if(koboldai_vars.model in ("Colab", "API", "CLUSTER")):
|
||||
from transformers import GPT2Tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B", revision=koboldai_vars.revision, cache_dir="cache")
|
||||
loadsettings()
|
||||
koboldai_vars.colaburl = url if url is not None else koboldai_vars.colaburl
|
||||
elif(koboldai_vars.model == "OAI"):
|
||||
from transformers import GPT2Tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=koboldai_vars.revision, cache_dir="cache")
|
||||
loadsettings()
|
||||
koboldai_vars.colaburl = url if url is not None else koboldai_vars.colaburl
|
||||
# Load the TPU backend if requested
|
||||
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
if (koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
global tpu_mtj_backend
|
||||
import tpu_mtj_backend
|
||||
|
||||
@@ -5837,7 +5856,7 @@ def cluster_raw_generate(
|
||||
'rep_pen_range': gen_settings.rep_pen_range,
|
||||
'temperature': gen_settings.temp,
|
||||
'top_p': gen_settings.top_p,
|
||||
'top_k': gen_settings.top_k,
|
||||
'top_k': int(gen_settings.top_k),
|
||||
'top_a': gen_settings.top_a,
|
||||
'tfs': gen_settings.tfs,
|
||||
'typical': gen_settings.typical,
|
||||
@@ -5872,6 +5891,8 @@ def cluster_raw_generate(
|
||||
elif not req.ok:
|
||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||
logger.error(errmsg)
|
||||
logger.error(f"HTTP {req.status_code}!!!")
|
||||
logger.error(req.text)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
try:
|
||||
@@ -9699,16 +9720,43 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]:
|
||||
}
|
||||
}
|
||||
|
||||
cluster_headers = {'apikey': koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",}
|
||||
cluster_headers = {"apikey": koboldai_vars.sh_apikey or "0000000000"}
|
||||
id_req = requests.post("https://stablehorde.net/api/v2/generate/async", json=final_submit_dict, headers=cluster_headers)
|
||||
|
||||
if not id_req.ok:
|
||||
logger.error(f"HTTP {id_req.status_code}, expected OK-ish")
|
||||
logger.error(id_req.text)
|
||||
logger.error(f"Response headers: {id_req.headers}")
|
||||
raise HordeException("Image seeding failed. See console for more details.")
|
||||
|
||||
logger.debug(final_submit_dict)
|
||||
submit_req = requests.post('https://stablehorde.net/api/v2/generate/sync', json=final_submit_dict, headers=cluster_headers)
|
||||
image_id = id_req.json()["id"]
|
||||
|
||||
if not submit_req.ok:
|
||||
logger.error(submit_req.text)
|
||||
return
|
||||
while True:
|
||||
poll_req = requests.get(f"https://stablehorde.net/api/v2/generate/check/{image_id}")
|
||||
if not poll_req.ok:
|
||||
logger.error(f"HTTP {poll_req.status_code}, expected OK-ish")
|
||||
logger.error(poll_req.text)
|
||||
logger.error(f"Response headers: {poll_req.headers}")
|
||||
raise HordeException("Image polling failed. See console for more details.")
|
||||
poll_j = poll_req.json()
|
||||
|
||||
if poll_j["finished"] > 0:
|
||||
break
|
||||
|
||||
# This should always exist but if it doesn't 2 seems like a safe bet.
|
||||
sleepy_time = int(poll_req.headers.get("retry-after", 2))
|
||||
time.sleep(sleepy_time)
|
||||
|
||||
# Done generating, we can now fetch it.
|
||||
|
||||
gen_req = requests.get(f"https://stablehorde.net/api/v2/generate/status/{image_id}")
|
||||
if not gen_req.ok:
|
||||
logger.error(f"HTTP {gen_req.status_code}, expected OK-ish")
|
||||
logger.error(gen_req.text)
|
||||
logger.error(f"Response headers: {gen_req.headers}")
|
||||
raise HordeException("Image fetching failed. See console for more details.")
|
||||
results = gen_req.json()
|
||||
|
||||
results = submit_req.json()
|
||||
if len(results["generations"]) > 1:
|
||||
logger.warning(f"Got too many generations, discarding extras. Got {len(results['generations'])}, expected 1.")
|
||||
|
||||
|
Reference in New Issue
Block a user