Modeling: Add seed parameter to raw_generate

Yahooo, decoupling from koboldai_vars. This makes the generation test
pass in `test_generation.py`, and makes full determinism outside of
core_generate work.
This commit is contained in:
somebody
2023-03-12 21:49:10 -05:00
parent 38c4edac40
commit cd8ccf0a5e
8 changed files with 51 additions and 20 deletions

View File

@@ -246,9 +246,6 @@ class InferenceModel:
start_time = time.time()
if utils.koboldai_vars.is_model_torch():
# Torch stuff
if utils.koboldai_vars.full_determinism:
torch.manual_seed(utils.koboldai_vars.seed)
if utils.koboldai_vars.sp is not None:
assert self.capabilties.embedding_manipulation
soft_tokens = torch.arange(
@@ -256,9 +253,6 @@ class InferenceModel:
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
elif utils.koboldai_vars.use_colab_tpu:
if utils.koboldai_vars.full_determinism:
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
logger.debug(
"core_generate: Model Setup (SP, etc) time {}s".format(
@@ -329,6 +323,9 @@ class InferenceModel:
not utils.koboldai_vars.nogenmod
and utils.koboldai_vars.has_genmod
),
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
@@ -481,6 +478,7 @@ class InferenceModel:
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
Returns:
GenerationResult: The model's output
@@ -501,6 +499,7 @@ class InferenceModel:
single_line: bool = False,
found_entries: set = (),
tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -562,6 +561,7 @@ class InferenceModel:
gen_settings=gen_settings,
single_line=single_line,
tpu_dynamic_inference=tpu_dynamic_inference,
seed=seed,
)
time_end = round(time.time() - time_start, 2)

View File

@@ -5,7 +5,7 @@ import json
import torch
import requests
import numpy as np
from typing import List, Union
from typing import List, Optional, Union
import utils
from logger import logger
@@ -40,8 +40,15 @@ class APIInferenceModel(InferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
):
if seed is not None:
logger.warning(
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import torch
import requests
import numpy as np
from typing import List, Union
from typing import List, Optional, Union
import utils
from logger import logger
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
@@ -36,8 +36,14 @@ class BasicAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
):
if seed is not None:
logger.warning(
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content

View File

@@ -4,7 +4,7 @@ import os
import torch
import numpy as np
from eventlet import tpool
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
import utils
import koboldai_settings
@@ -258,6 +258,7 @@ class HFMTJInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
soft_tokens = self.get_soft_tokens()
@@ -265,6 +266,9 @@ class HFMTJInferenceModel(HFInferenceModel):
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
logger.info(f"dynamic_inference={dynamic_inference}")
if seed is not None:
tpu_mtj_backend.set_rng_seed(seed)
if not dynamic_inference:
genout = tpool.execute(
tpu_mtj_backend.infer_static,

View File

@@ -10,7 +10,7 @@ import itertools
import traceback
import contextlib
from tqdm.auto import tqdm
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
import torch
from torch.nn import Embedding
@@ -457,6 +457,7 @@ class HFTorchInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
if not isinstance(prompt_tokens, torch.Tensor):
@@ -469,6 +470,10 @@ class HFTorchInferenceModel(HFInferenceModel):
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
if seed is not None:
print("seeding", seed)
torch.manual_seed(seed)
with torch.no_grad():
start_time = time.time()
genout = self.model.generate(

View File

@@ -42,8 +42,14 @@ class HordeInferenceModel(InferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
if seed is not None:
logger.warning(
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content

View File

@@ -1,9 +1,10 @@
import torch
import requests
import numpy as np
from typing import List, Union
from typing import List, Optional, Union
import utils
from logger import logger
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
@@ -29,9 +30,14 @@ class OpenAIAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
# Taken mainly from oairequest()
if seed is not None:
logger.warning(
"Seed is unsupported on the OpenAIAPIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -10,6 +10,7 @@ model: InferenceModel
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
TEST_PROMPT = "Once upon a time I found myself"
TEST_GEN_TOKEN_COUNT = 20
TEST_SEED = 1337
def test_generic_hf_torch_load() -> None:
@@ -29,19 +30,15 @@ def test_generic_hf_torch_low_mem_load() -> None:
def test_model_gen() -> None:
# This should probably be supported in the model interface!
koboldai_vars.full_determinism = True
koboldai_vars.seed = 1337
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert torch.equal(
x.encoded[0], y.encoded[0]
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
print(x)
print(x)