mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Add seed parameter to raw_generate
Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work.
This commit is contained in:
@@ -246,9 +246,6 @@ class InferenceModel:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if utils.koboldai_vars.is_model_torch():
|
if utils.koboldai_vars.is_model_torch():
|
||||||
# Torch stuff
|
# Torch stuff
|
||||||
if utils.koboldai_vars.full_determinism:
|
|
||||||
torch.manual_seed(utils.koboldai_vars.seed)
|
|
||||||
|
|
||||||
if utils.koboldai_vars.sp is not None:
|
if utils.koboldai_vars.sp is not None:
|
||||||
assert self.capabilties.embedding_manipulation
|
assert self.capabilties.embedding_manipulation
|
||||||
soft_tokens = torch.arange(
|
soft_tokens = torch.arange(
|
||||||
@@ -256,9 +253,6 @@ class InferenceModel:
|
|||||||
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
|
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
|
||||||
)
|
)
|
||||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||||
elif utils.koboldai_vars.use_colab_tpu:
|
|
||||||
if utils.koboldai_vars.full_determinism:
|
|
||||||
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"core_generate: Model Setup (SP, etc) time {}s".format(
|
"core_generate: Model Setup (SP, etc) time {}s".format(
|
||||||
@@ -329,6 +323,9 @@ class InferenceModel:
|
|||||||
not utils.koboldai_vars.nogenmod
|
not utils.koboldai_vars.nogenmod
|
||||||
and utils.koboldai_vars.has_genmod
|
and utils.koboldai_vars.has_genmod
|
||||||
),
|
),
|
||||||
|
seed=utils.koboldai_vars.seed
|
||||||
|
if utils.koboldai_vars.full_determinism
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"core_generate: run raw_generate pass {} {}s".format(
|
"core_generate: run raw_generate pass {} {}s".format(
|
||||||
@@ -481,6 +478,7 @@ class InferenceModel:
|
|||||||
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
||||||
single_line (bool, optional): Generate one line only. Defaults to False.
|
single_line (bool, optional): Generate one line only. Defaults to False.
|
||||||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||||
|
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GenerationResult: The model's output
|
GenerationResult: The model's output
|
||||||
@@ -501,6 +499,7 @@ class InferenceModel:
|
|||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
found_entries: set = (),
|
found_entries: set = (),
|
||||||
tpu_dynamic_inference: bool = False,
|
tpu_dynamic_inference: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||||
@@ -562,6 +561,7 @@ class InferenceModel:
|
|||||||
gen_settings=gen_settings,
|
gen_settings=gen_settings,
|
||||||
single_line=single_line,
|
single_line=single_line,
|
||||||
tpu_dynamic_inference=tpu_dynamic_inference,
|
tpu_dynamic_inference=tpu_dynamic_inference,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_end = round(time.time() - time_start, 2)
|
time_end = round(time.time() - time_start, 2)
|
||||||
|
@@ -5,7 +5,7 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from logger import logger
|
from logger import logger
|
||||||
@@ -40,8 +40,15 @@ class APIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
from logger import logger
|
||||||
from modeling.inference_model import (
|
from modeling.inference_model import (
|
||||||
GenerationResult,
|
GenerationResult,
|
||||||
GenerationSettings,
|
GenerationSettings,
|
||||||
@@ -36,8 +36,14 @@ class BasicAPIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if seed is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
@@ -4,7 +4,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from eventlet import tpool
|
from eventlet import tpool
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import koboldai_settings
|
import koboldai_settings
|
||||||
@@ -258,6 +258,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
soft_tokens = self.get_soft_tokens()
|
soft_tokens = self.get_soft_tokens()
|
||||||
@@ -265,6 +266,9 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||||
logger.info(f"dynamic_inference={dynamic_inference}")
|
logger.info(f"dynamic_inference={dynamic_inference}")
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
tpu_mtj_backend.set_rng_seed(seed)
|
||||||
|
|
||||||
if not dynamic_inference:
|
if not dynamic_inference:
|
||||||
genout = tpool.execute(
|
genout = tpool.execute(
|
||||||
tpu_mtj_backend.infer_static,
|
tpu_mtj_backend.infer_static,
|
||||||
|
@@ -10,7 +10,7 @@ import itertools
|
|||||||
import traceback
|
import traceback
|
||||||
import contextlib
|
import contextlib
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Embedding
|
from torch.nn import Embedding
|
||||||
@@ -457,6 +457,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
if not isinstance(prompt_tokens, torch.Tensor):
|
if not isinstance(prompt_tokens, torch.Tensor):
|
||||||
@@ -469,6 +470,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
|
|
||||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
print("seeding", seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
genout = self.model.generate(
|
genout = self.model.generate(
|
||||||
|
@@ -42,8 +42,14 @@ class HordeInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
|
if seed is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
from logger import logger
|
||||||
from modeling.inference_model import (
|
from modeling.inference_model import (
|
||||||
GenerationResult,
|
GenerationResult,
|
||||||
GenerationSettings,
|
GenerationSettings,
|
||||||
@@ -29,9 +30,14 @@ class OpenAIAPIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
seed: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
# Taken mainly from oairequest()
|
|
||||||
|
if seed is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Seed is unsupported on the OpenAIAPIInferenceModel. Seed will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -10,6 +10,7 @@ model: InferenceModel
|
|||||||
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
|
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
|
||||||
TEST_PROMPT = "Once upon a time I found myself"
|
TEST_PROMPT = "Once upon a time I found myself"
|
||||||
TEST_GEN_TOKEN_COUNT = 20
|
TEST_GEN_TOKEN_COUNT = 20
|
||||||
|
TEST_SEED = 1337
|
||||||
|
|
||||||
|
|
||||||
def test_generic_hf_torch_load() -> None:
|
def test_generic_hf_torch_load() -> None:
|
||||||
@@ -29,16 +30,12 @@ def test_generic_hf_torch_low_mem_load() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_model_gen() -> None:
|
def test_model_gen() -> None:
|
||||||
# This should probably be supported in the model interface!
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
koboldai_vars.full_determinism = True
|
|
||||||
koboldai_vars.seed = 1337
|
|
||||||
|
|
||||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
|
|
||||||
print(x.decoded)
|
print(x.decoded)
|
||||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||||
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
|
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
|
||||||
|
|
||||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
|
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
|
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
x.encoded[0], y.encoded[0]
|
x.encoded[0], y.encoded[0]
|
||||||
|
Reference in New Issue
Block a user