mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix exllama model unload
This commit is contained in:
@@ -9,6 +9,8 @@ import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import re
|
||||
import warnings
|
||||
import gc
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
@@ -26,8 +28,6 @@ from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
|
||||
from transformers import LlamaTokenizer
|
||||
from exllama.generator import ExLlamaGenerator
|
||||
|
||||
import traceback
|
||||
|
||||
model_backend_name = "ExLlama"
|
||||
|
||||
|
||||
@@ -60,8 +60,10 @@ class model_backend(InferenceModel):
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.cache = None
|
||||
self.generator = None
|
||||
|
||||
self.model_name = None
|
||||
self.model_name = ""
|
||||
self.path = None
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
@@ -84,7 +86,7 @@ class model_backend(InferenceModel):
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.model = self._get_model(self.get_local_model_path(), {})
|
||||
self.tokenizer = self._get_tokenizer(os.path.join(self.get_local_model_path(), "tokenizer.model"))
|
||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path()))
|
||||
|
||||
self.cache = ExLlamaCache(self.model)
|
||||
|
||||
@@ -174,6 +176,33 @@ class model_backend(InferenceModel):
|
||||
return result
|
||||
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
|
||||
|
||||
def unload(self):
|
||||
self.model_config = None
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.cache = None
|
||||
self.generator = None
|
||||
|
||||
self.model_name = ""
|
||||
self.path = None
|
||||
|
||||
with torch.no_grad():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
|
||||
for tensor in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(tensor):
|
||||
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
|
||||
except:
|
||||
pass
|
||||
gc.collect()
|
||||
try:
|
||||
with torch.no_grad():
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
def _raw_generate(
|
||||
self,
|
||||
prompt_tokens: Union[List[int], torch.Tensor],
|
||||
@@ -184,6 +213,9 @@ class model_backend(InferenceModel):
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
if seed:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user