mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Add seed parameter to raw_generate
Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work.
This commit is contained in:
@@ -10,7 +10,7 @@ import itertools
|
||||
import traceback
|
||||
import contextlib
|
||||
from tqdm.auto import tqdm
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Embedding
|
||||
@@ -457,6 +457,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
@@ -469,6 +470,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||
|
||||
if seed is not None:
|
||||
print("seeding", seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
genout = self.model.generate(
|
||||
|
Reference in New Issue
Block a user