From cd8ccf0a5eebe999f1595fa9857a2f415465a0ae Mon Sep 17 00:00:00 2001 From: somebody Date: Sun, 12 Mar 2023 21:49:10 -0500 Subject: [PATCH] Modeling: Add seed parameter to raw_generate Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work. --- modeling/inference_model.py | 12 ++++++------ modeling/inference_models/api.py | 9 ++++++++- modeling/inference_models/basic_api.py | 10 ++++++++-- modeling/inference_models/hf_mtj.py | 6 +++++- modeling/inference_models/hf_torch.py | 7 ++++++- modeling/inference_models/horde.py | 6 ++++++ modeling/inference_models/openai.py | 10 ++++++++-- modeling/test_generation.py | 11 ++++------- 8 files changed, 51 insertions(+), 20 deletions(-) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index c0b20a7c..72fa3314 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -246,9 +246,6 @@ class InferenceModel: start_time = time.time() if utils.koboldai_vars.is_model_torch(): # Torch stuff - if utils.koboldai_vars.full_determinism: - torch.manual_seed(utils.koboldai_vars.seed) - if utils.koboldai_vars.sp is not None: assert self.capabilties.embedding_manipulation soft_tokens = torch.arange( @@ -256,9 +253,6 @@ class InferenceModel: self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0], ) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) - elif utils.koboldai_vars.use_colab_tpu: - if utils.koboldai_vars.full_determinism: - tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed) logger.debug( "core_generate: Model Setup (SP, etc) time {}s".format( @@ -329,6 +323,9 @@ class InferenceModel: not utils.koboldai_vars.nogenmod and utils.koboldai_vars.has_genmod ), + seed=utils.koboldai_vars.seed + if utils.koboldai_vars.full_determinism + else None, ) logger.debug( "core_generate: run raw_generate pass {} {}s".format( @@ -481,6 +478,7 @@ class InferenceModel: gen_settings (GenerationSettings): State to pass in single-generation setting overrides single_line (bool, optional): Generate one line only. Defaults to False. batch_count (int, optional): How big of a batch to generate. Defaults to 1. + seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None. Returns: GenerationResult: The model's output @@ -501,6 +499,7 @@ class InferenceModel: single_line: bool = False, found_entries: set = (), tpu_dynamic_inference: bool = False, + seed: Optional[int] = None, **kwargs, ) -> GenerationResult: """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. @@ -562,6 +561,7 @@ class InferenceModel: gen_settings=gen_settings, single_line=single_line, tpu_dynamic_inference=tpu_dynamic_inference, + seed=seed, ) time_end = round(time.time() - time_start, 2) diff --git a/modeling/inference_models/api.py b/modeling/inference_models/api.py index 7f1f4ea8..f36e7205 100644 --- a/modeling/inference_models/api.py +++ b/modeling/inference_models/api.py @@ -5,7 +5,7 @@ import json import torch import requests import numpy as np -from typing import List, Union +from typing import List, Optional, Union import utils from logger import logger @@ -40,8 +40,15 @@ class APIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + seed: Optional[int] = None, **kwargs, ): + + if seed is not None: + logger.warning( + "Seed is unsupported on the APIInferenceModel. Seed will be ignored." + ) + decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) # Store context in memory to use it for comparison with generated content diff --git a/modeling/inference_models/basic_api.py b/modeling/inference_models/basic_api.py index 9e6a6713..c96eb42c 100644 --- a/modeling/inference_models/basic_api.py +++ b/modeling/inference_models/basic_api.py @@ -3,10 +3,10 @@ from __future__ import annotations import torch import requests import numpy as np -from typing import List, Union +from typing import List, Optional, Union import utils - +from logger import logger from modeling.inference_model import ( GenerationResult, GenerationSettings, @@ -36,8 +36,14 @@ class BasicAPIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + seed: Optional[int] = None, **kwargs, ): + if seed is not None: + logger.warning( + "Seed is unsupported on the APIInferenceModel. Seed will be ignored." + ) + decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) # Store context in memory to use it for comparison with generated content diff --git a/modeling/inference_models/hf_mtj.py b/modeling/inference_models/hf_mtj.py index 4a5d52e6..f8993f56 100644 --- a/modeling/inference_models/hf_mtj.py +++ b/modeling/inference_models/hf_mtj.py @@ -4,7 +4,7 @@ import os import torch import numpy as np from eventlet import tpool -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import utils import koboldai_settings @@ -258,6 +258,7 @@ class HFMTJInferenceModel(HFInferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + seed: Optional[int] = None, **kwargs, ) -> GenerationResult: soft_tokens = self.get_soft_tokens() @@ -265,6 +266,9 @@ class HFMTJInferenceModel(HFInferenceModel): dynamic_inference = kwargs.get("tpu_dynamic_inference", False) logger.info(f"dynamic_inference={dynamic_inference}") + if seed is not None: + tpu_mtj_backend.set_rng_seed(seed) + if not dynamic_inference: genout = tpool.execute( tpu_mtj_backend.infer_static, diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index c6ee2027..3d4b7029 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -10,7 +10,7 @@ import itertools import traceback import contextlib from tqdm.auto import tqdm -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch from torch.nn import Embedding @@ -457,6 +457,7 @@ class HFTorchInferenceModel(HFInferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + seed: Optional[int] = None, **kwargs, ) -> GenerationResult: if not isinstance(prompt_tokens, torch.Tensor): @@ -469,6 +470,10 @@ class HFTorchInferenceModel(HFInferenceModel): additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else [] + if seed is not None: + print("seeding", seed) + torch.manual_seed(seed) + with torch.no_grad(): start_time = time.time() genout = self.model.generate( diff --git a/modeling/inference_models/horde.py b/modeling/inference_models/horde.py index 9bdc62b2..3e54df07 100644 --- a/modeling/inference_models/horde.py +++ b/modeling/inference_models/horde.py @@ -42,8 +42,14 @@ class HordeInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + seed: Optional[int] = None, **kwargs, ) -> GenerationResult: + if seed is not None: + logger.warning( + "Seed is unsupported on the APIInferenceModel. Seed will be ignored." + ) + decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) # Store context in memory to use it for comparison with generated content diff --git a/modeling/inference_models/openai.py b/modeling/inference_models/openai.py index c6f07e0e..1441ae2f 100644 --- a/modeling/inference_models/openai.py +++ b/modeling/inference_models/openai.py @@ -1,9 +1,10 @@ import torch import requests import numpy as np -from typing import List, Union +from typing import List, Optional, Union import utils +from logger import logger from modeling.inference_model import ( GenerationResult, GenerationSettings, @@ -29,9 +30,14 @@ class OpenAIAPIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + seed: Optional[int] = None, **kwargs, ) -> GenerationResult: - # Taken mainly from oairequest() + + if seed is not None: + logger.warning( + "Seed is unsupported on the OpenAIAPIInferenceModel. Seed will be ignored." + ) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/test_generation.py b/modeling/test_generation.py index 0d9173b5..a3efa7bc 100644 --- a/modeling/test_generation.py +++ b/modeling/test_generation.py @@ -10,6 +10,7 @@ model: InferenceModel TEST_MODEL_HF_ID = "EleutherAI/pythia-70m" TEST_PROMPT = "Once upon a time I found myself" TEST_GEN_TOKEN_COUNT = 20 +TEST_SEED = 1337 def test_generic_hf_torch_load() -> None: @@ -29,19 +30,15 @@ def test_generic_hf_torch_low_mem_load() -> None: def test_model_gen() -> None: - # This should probably be supported in the model interface! - koboldai_vars.full_determinism = True - koboldai_vars.seed = 1337 - - x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) + x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) print(x.decoded) assert len(x.encoded) == 1, "Bad output shape (too many batches!)" assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)" - y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) + y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) assert torch.equal( x.encoded[0], y.encoded[0] ), f"Faulty full determinism! {x.decoded} vs {y.decoded}" - print(x) \ No newline at end of file + print(x)