mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Add seed parameter to raw_generate
Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work.
This commit is contained in:
@@ -246,9 +246,6 @@ class InferenceModel:
|
||||
start_time = time.time()
|
||||
if utils.koboldai_vars.is_model_torch():
|
||||
# Torch stuff
|
||||
if utils.koboldai_vars.full_determinism:
|
||||
torch.manual_seed(utils.koboldai_vars.seed)
|
||||
|
||||
if utils.koboldai_vars.sp is not None:
|
||||
assert self.capabilties.embedding_manipulation
|
||||
soft_tokens = torch.arange(
|
||||
@@ -256,9 +253,6 @@ class InferenceModel:
|
||||
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
elif utils.koboldai_vars.use_colab_tpu:
|
||||
if utils.koboldai_vars.full_determinism:
|
||||
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
|
||||
|
||||
logger.debug(
|
||||
"core_generate: Model Setup (SP, etc) time {}s".format(
|
||||
@@ -329,6 +323,9 @@ class InferenceModel:
|
||||
not utils.koboldai_vars.nogenmod
|
||||
and utils.koboldai_vars.has_genmod
|
||||
),
|
||||
seed=utils.koboldai_vars.seed
|
||||
if utils.koboldai_vars.full_determinism
|
||||
else None,
|
||||
)
|
||||
logger.debug(
|
||||
"core_generate: run raw_generate pass {} {}s".format(
|
||||
@@ -481,6 +478,7 @@ class InferenceModel:
|
||||
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
||||
single_line (bool, optional): Generate one line only. Defaults to False.
|
||||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
|
||||
|
||||
Returns:
|
||||
GenerationResult: The model's output
|
||||
@@ -501,6 +499,7 @@ class InferenceModel:
|
||||
single_line: bool = False,
|
||||
found_entries: set = (),
|
||||
tpu_dynamic_inference: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||
@@ -562,6 +561,7 @@ class InferenceModel:
|
||||
gen_settings=gen_settings,
|
||||
single_line=single_line,
|
||||
tpu_dynamic_inference=tpu_dynamic_inference,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
time_end = round(time.time() - time_start, 2)
|
||||
|
@@ -5,7 +5,7 @@ import json
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
@@ -40,8 +40,15 @@ class APIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if seed is not None:
|
||||
logger.warning(
|
||||
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
||||
)
|
||||
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import utils
|
||||
|
||||
from logger import logger
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
@@ -36,8 +36,14 @@ class BasicAPIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if seed is not None:
|
||||
logger.warning(
|
||||
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
||||
)
|
||||
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
|
@@ -4,7 +4,7 @@ import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from eventlet import tpool
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import utils
|
||||
import koboldai_settings
|
||||
@@ -258,6 +258,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
@@ -265,6 +266,9 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||
logger.info(f"dynamic_inference={dynamic_inference}")
|
||||
|
||||
if seed is not None:
|
||||
tpu_mtj_backend.set_rng_seed(seed)
|
||||
|
||||
if not dynamic_inference:
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_static,
|
||||
|
@@ -10,7 +10,7 @@ import itertools
|
||||
import traceback
|
||||
import contextlib
|
||||
from tqdm.auto import tqdm
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Embedding
|
||||
@@ -457,6 +457,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
@@ -469,6 +470,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||
|
||||
if seed is not None:
|
||||
print("seeding", seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
genout = self.model.generate(
|
||||
|
@@ -42,8 +42,14 @@ class HordeInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
if seed is not None:
|
||||
logger.warning(
|
||||
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
||||
)
|
||||
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
|
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
@@ -29,9 +30,14 @@ class OpenAIAPIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
# Taken mainly from oairequest()
|
||||
|
||||
if seed is not None:
|
||||
logger.warning(
|
||||
"Seed is unsupported on the OpenAIAPIInferenceModel. Seed will be ignored."
|
||||
)
|
||||
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -10,6 +10,7 @@ model: InferenceModel
|
||||
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
|
||||
TEST_PROMPT = "Once upon a time I found myself"
|
||||
TEST_GEN_TOKEN_COUNT = 20
|
||||
TEST_SEED = 1337
|
||||
|
||||
|
||||
def test_generic_hf_torch_load() -> None:
|
||||
@@ -29,19 +30,15 @@ def test_generic_hf_torch_low_mem_load() -> None:
|
||||
|
||||
|
||||
def test_model_gen() -> None:
|
||||
# This should probably be supported in the model interface!
|
||||
koboldai_vars.full_determinism = True
|
||||
koboldai_vars.seed = 1337
|
||||
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
print(x.decoded)
|
||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
|
||||
|
||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
|
||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
|
||||
assert torch.equal(
|
||||
x.encoded[0], y.encoded[0]
|
||||
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
|
||||
|
||||
print(x)
|
||||
print(x)
|
||||
|
Reference in New Issue
Block a user