mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Add seed parameter to raw_generate
Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work.
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from eventlet import tpool
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import utils
|
||||
import koboldai_settings
|
||||
@@ -258,6 +258,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
@@ -265,6 +266,9 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||
logger.info(f"dynamic_inference={dynamic_inference}")
|
||||
|
||||
if seed is not None:
|
||||
tpu_mtj_backend.set_rng_seed(seed)
|
||||
|
||||
if not dynamic_inference:
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_static,
|
||||
|
Reference in New Issue
Block a user