mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: WIP horde and API tests
This commit is contained in:
@@ -23,10 +23,12 @@ class APIException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class APIInferenceModel(InferenceModel):
|
class APIInferenceModel(InferenceModel):
|
||||||
|
def __init__(self, base_url: str = "http://localhost:5000") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
tokenizer_id = requests.get(
|
tokenizer_id = requests.get(f"{self.base_url}/api/v1/model").json()["result"]
|
||||||
utils.koboldai_vars.colaburl[:-8] + "/api/v1/model",
|
|
||||||
).json()["result"]
|
|
||||||
|
|
||||||
self.tokenizer = self._get_tokenizer(tokenizer_id)
|
self.tokenizer = self._get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
@@ -73,13 +75,10 @@ class APIInferenceModel(InferenceModel):
|
|||||||
|
|
||||||
# Create request
|
# Create request
|
||||||
while True:
|
while True:
|
||||||
req = requests.post(
|
req = requests.post(f"{self.base_url}/api/v1/generate", json=reqdata)
|
||||||
utils.koboldai_vars.colaburl[:-8] + "/api/v1/generate",
|
|
||||||
json=reqdata,
|
if req.status_code == 503:
|
||||||
)
|
# Server is currently generating something else so poll until it's our turn
|
||||||
if (
|
|
||||||
req.status_code == 503
|
|
||||||
): # Server is currently generating something else so poll until it's our turn
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@@ -471,7 +471,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
print("seeding", seed)
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@@ -4,7 +4,7 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from logger import logger
|
from logger import logger
|
||||||
@@ -87,7 +87,7 @@ class HordeInferenceModel(InferenceModel):
|
|||||||
try:
|
try:
|
||||||
# Create request
|
# Create request
|
||||||
req = requests.post(
|
req = requests.post(
|
||||||
utils.koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async",
|
f"{utils.koboldai_vars.horde_url}/api/v2/generate/text/async",
|
||||||
json=cluster_metadata,
|
json=cluster_metadata,
|
||||||
headers=cluster_headers,
|
headers=cluster_headers,
|
||||||
)
|
)
|
||||||
@@ -102,8 +102,8 @@ class HordeInferenceModel(InferenceModel):
|
|||||||
raise HordeException(errmsg)
|
raise HordeException(errmsg)
|
||||||
elif not req.ok:
|
elif not req.ok:
|
||||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||||
|
logger.error(req.url)
|
||||||
logger.error(errmsg)
|
logger.error(errmsg)
|
||||||
logger.error(f"HTTP {req.status_code}!!!")
|
|
||||||
logger.error(req.text)
|
logger.error(req.text)
|
||||||
raise HordeException(errmsg)
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ class HordeInferenceModel(InferenceModel):
|
|||||||
while not finished:
|
while not finished:
|
||||||
try:
|
try:
|
||||||
req = requests.get(
|
req = requests.get(
|
||||||
f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}",
|
f"{utils.koboldai_vars.horde_url}/api/v2/generate/text/status/{request_id}",
|
||||||
headers=cluster_agent_headers,
|
headers=cluster_agent_headers,
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
|
@@ -1,8 +1,12 @@
|
|||||||
# We have to go through aiserver to initalize koboldai_vars :(
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# We have to go through aiserver to initalize koboldai_vars :(
|
||||||
from aiserver import GenericHFTorchInferenceModel
|
from aiserver import GenericHFTorchInferenceModel
|
||||||
from aiserver import koboldai_vars
|
from aiserver import koboldai_vars
|
||||||
|
|
||||||
from modeling.inference_model import InferenceModel
|
from modeling.inference_model import InferenceModel
|
||||||
|
from modeling.inference_models.api import APIInferenceModel
|
||||||
|
from modeling.inference_models.horde import HordeInferenceModel
|
||||||
|
|
||||||
model: InferenceModel
|
model: InferenceModel
|
||||||
|
|
||||||
@@ -12,6 +16,7 @@ TEST_PROMPT = "Once upon a time I found myself"
|
|||||||
TEST_GEN_TOKEN_COUNT = 20
|
TEST_GEN_TOKEN_COUNT = 20
|
||||||
TEST_SEED = 1337
|
TEST_SEED = 1337
|
||||||
|
|
||||||
|
# HF Torch
|
||||||
|
|
||||||
def test_generic_hf_torch_load() -> None:
|
def test_generic_hf_torch_load() -> None:
|
||||||
global model
|
global model
|
||||||
@@ -29,11 +34,11 @@ def test_generic_hf_torch_low_mem_load() -> None:
|
|||||||
GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load()
|
GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load()
|
||||||
|
|
||||||
|
|
||||||
def test_model_gen() -> None:
|
def test_torch_inference() -> None:
|
||||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
print(x.decoded)
|
print(x.decoded)
|
||||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||||
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
|
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||||
|
|
||||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
|
|
||||||
@@ -42,3 +47,28 @@ def test_model_gen() -> None:
|
|||||||
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
|
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
|
||||||
|
|
||||||
print(x)
|
print(x)
|
||||||
|
|
||||||
|
# Horde
|
||||||
|
def test_horde_load() -> None:
|
||||||
|
global model
|
||||||
|
# TODO: Make this a property and sync it with kaivars
|
||||||
|
koboldai_vars.cluster_requested_models = []
|
||||||
|
model = HordeInferenceModel()
|
||||||
|
model.load()
|
||||||
|
|
||||||
|
def test_horde_inference() -> None:
|
||||||
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
|
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||||
|
print(x)
|
||||||
|
|
||||||
|
# API
|
||||||
|
def test_api_load() -> None:
|
||||||
|
global model
|
||||||
|
model = APIInferenceModel()
|
||||||
|
model.load()
|
||||||
|
|
||||||
|
def test_api_inference() -> None:
|
||||||
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
|
# NOTE: Below test is flakey due to Horde worker-defined constraints
|
||||||
|
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||||
|
print(x)
|
Reference in New Issue
Block a user