Modeling: Add seed parameter to raw_generate

Yahooo, decoupling from koboldai_vars. This makes the generation test
pass in `test_generation.py`, and makes full determinism outside of
core_generate work.
This commit is contained in:
somebody
2023-03-12 21:49:10 -05:00
parent 38c4edac40
commit cd8ccf0a5e
8 changed files with 51 additions and 20 deletions

View File

@@ -246,9 +246,6 @@ class InferenceModel:
start_time = time.time() start_time = time.time()
if utils.koboldai_vars.is_model_torch(): if utils.koboldai_vars.is_model_torch():
# Torch stuff # Torch stuff
if utils.koboldai_vars.full_determinism:
torch.manual_seed(utils.koboldai_vars.seed)
if utils.koboldai_vars.sp is not None: if utils.koboldai_vars.sp is not None:
assert self.capabilties.embedding_manipulation assert self.capabilties.embedding_manipulation
soft_tokens = torch.arange( soft_tokens = torch.arange(
@@ -256,9 +253,6 @@ class InferenceModel:
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0], self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
) )
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
elif utils.koboldai_vars.use_colab_tpu:
if utils.koboldai_vars.full_determinism:
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
logger.debug( logger.debug(
"core_generate: Model Setup (SP, etc) time {}s".format( "core_generate: Model Setup (SP, etc) time {}s".format(
@@ -329,6 +323,9 @@ class InferenceModel:
not utils.koboldai_vars.nogenmod not utils.koboldai_vars.nogenmod
and utils.koboldai_vars.has_genmod and utils.koboldai_vars.has_genmod
), ),
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
) )
logger.debug( logger.debug(
"core_generate: run raw_generate pass {} {}s".format( "core_generate: run raw_generate pass {} {}s".format(
@@ -481,6 +478,7 @@ class InferenceModel:
gen_settings (GenerationSettings): State to pass in single-generation setting overrides gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False. single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1. batch_count (int, optional): How big of a batch to generate. Defaults to 1.
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
Returns: Returns:
GenerationResult: The model's output GenerationResult: The model's output
@@ -501,6 +499,7 @@ class InferenceModel:
single_line: bool = False, single_line: bool = False,
found_entries: set = (), found_entries: set = (),
tpu_dynamic_inference: bool = False, tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -562,6 +561,7 @@ class InferenceModel:
gen_settings=gen_settings, gen_settings=gen_settings,
single_line=single_line, single_line=single_line,
tpu_dynamic_inference=tpu_dynamic_inference, tpu_dynamic_inference=tpu_dynamic_inference,
seed=seed,
) )
time_end = round(time.time() - time_start, 2) time_end = round(time.time() - time_start, 2)

View File

@@ -5,7 +5,7 @@ import json
import torch import torch
import requests import requests
import numpy as np import numpy as np
from typing import List, Union from typing import List, Optional, Union
import utils import utils
from logger import logger from logger import logger
@@ -40,8 +40,15 @@ class APIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
seed: Optional[int] = None,
**kwargs, **kwargs,
): ):
if seed is not None:
logger.warning(
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import torch import torch
import requests import requests
import numpy as np import numpy as np
from typing import List, Union from typing import List, Optional, Union
import utils import utils
from logger import logger
from modeling.inference_model import ( from modeling.inference_model import (
GenerationResult, GenerationResult,
GenerationSettings, GenerationSettings,
@@ -36,8 +36,14 @@ class BasicAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
seed: Optional[int] = None,
**kwargs, **kwargs,
): ):
if seed is not None:
logger.warning(
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content

View File

@@ -4,7 +4,7 @@ import os
import torch import torch
import numpy as np import numpy as np
from eventlet import tpool from eventlet import tpool
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
import utils import utils
import koboldai_settings import koboldai_settings
@@ -258,6 +258,7 @@ class HFMTJInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
seed: Optional[int] = None,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
soft_tokens = self.get_soft_tokens() soft_tokens = self.get_soft_tokens()
@@ -265,6 +266,9 @@ class HFMTJInferenceModel(HFInferenceModel):
dynamic_inference = kwargs.get("tpu_dynamic_inference", False) dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
logger.info(f"dynamic_inference={dynamic_inference}") logger.info(f"dynamic_inference={dynamic_inference}")
if seed is not None:
tpu_mtj_backend.set_rng_seed(seed)
if not dynamic_inference: if not dynamic_inference:
genout = tpool.execute( genout = tpool.execute(
tpu_mtj_backend.infer_static, tpu_mtj_backend.infer_static,

View File

@@ -10,7 +10,7 @@ import itertools
import traceback import traceback
import contextlib import contextlib
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
import torch import torch
from torch.nn import Embedding from torch.nn import Embedding
@@ -457,6 +457,7 @@ class HFTorchInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
seed: Optional[int] = None,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
if not isinstance(prompt_tokens, torch.Tensor): if not isinstance(prompt_tokens, torch.Tensor):
@@ -469,6 +470,10 @@ class HFTorchInferenceModel(HFInferenceModel):
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else [] additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
if seed is not None:
print("seeding", seed)
torch.manual_seed(seed)
with torch.no_grad(): with torch.no_grad():
start_time = time.time() start_time = time.time()
genout = self.model.generate( genout = self.model.generate(

View File

@@ -42,8 +42,14 @@ class HordeInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
seed: Optional[int] = None,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
if seed is not None:
logger.warning(
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content

View File

@@ -1,9 +1,10 @@
import torch import torch
import requests import requests
import numpy as np import numpy as np
from typing import List, Union from typing import List, Optional, Union
import utils import utils
from logger import logger
from modeling.inference_model import ( from modeling.inference_model import (
GenerationResult, GenerationResult,
GenerationSettings, GenerationSettings,
@@ -29,9 +30,14 @@ class OpenAIAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
seed: Optional[int] = None,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
# Taken mainly from oairequest()
if seed is not None:
logger.warning(
"Seed is unsupported on the OpenAIAPIInferenceModel. Seed will be ignored."
)
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -10,6 +10,7 @@ model: InferenceModel
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m" TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
TEST_PROMPT = "Once upon a time I found myself" TEST_PROMPT = "Once upon a time I found myself"
TEST_GEN_TOKEN_COUNT = 20 TEST_GEN_TOKEN_COUNT = 20
TEST_SEED = 1337
def test_generic_hf_torch_load() -> None: def test_generic_hf_torch_load() -> None:
@@ -29,16 +30,12 @@ def test_generic_hf_torch_low_mem_load() -> None:
def test_model_gen() -> None: def test_model_gen() -> None:
# This should probably be supported in the model interface! x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
koboldai_vars.full_determinism = True
koboldai_vars.seed = 1337
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
print(x.decoded) print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)" assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)" assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert torch.equal( assert torch.equal(
x.encoded[0], y.encoded[0] x.encoded[0], y.encoded[0]