diff --git a/modeling/inference_models/exllama/class.py b/modeling/inference_models/exllama/class.py index 0160ed4b..db1728cf 100644 --- a/modeling/inference_models/exllama/class.py +++ b/modeling/inference_models/exllama/class.py @@ -9,6 +9,8 @@ import os import glob from pathlib import Path import re +import warnings +import gc import utils from logger import logger @@ -26,8 +28,6 @@ from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig from transformers import LlamaTokenizer from exllama.generator import ExLlamaGenerator -import traceback - model_backend_name = "ExLlama" @@ -60,8 +60,10 @@ class model_backend(InferenceModel): self.model = None self.tokenizer = None + self.cache = None + self.generator = None - self.model_name = None + self.model_name = "" self.path = None def is_valid(self, model_name, model_path, menu_path): @@ -84,7 +86,7 @@ class model_backend(InferenceModel): def _load(self, save_model: bool, initial_load: bool) -> None: self.model = self._get_model(self.get_local_model_path(), {}) - self.tokenizer = self._get_tokenizer(os.path.join(self.get_local_model_path(), "tokenizer.model")) + self.tokenizer = self._get_tokenizer(self.get_local_model_path())) self.cache = ExLlamaCache(self.model) @@ -174,6 +176,33 @@ class model_backend(InferenceModel): return result object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer)) + def unload(self): + self.model_config = None + + self.model = None + self.tokenizer = None + self.cache = None + self.generator = None + + self.model_name = "" + self.path = None + + with torch.no_grad(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated") + for tensor in gc.get_objects(): + try: + if torch.is_tensor(tensor): + tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype)) + except: + pass + gc.collect() + try: + with torch.no_grad(): + torch.cuda.empty_cache() + except: + pass + def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor], @@ -184,6 +213,9 @@ class model_backend(InferenceModel): seed: Optional[int] = None, **kwargs, ) -> GenerationResult: + if seed: + torch.manual_seed(seed) + if not isinstance(prompt_tokens, torch.Tensor): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] else: