Fix exllama model unload

This commit is contained in:
0cc4m
2023-06-05 18:43:57 +02:00
parent b35f61e987
commit 94520d5c80

View File

@@ -9,6 +9,8 @@ import os
import glob
from pathlib import Path
import re
import warnings
import gc
import utils
from logger import logger
@@ -26,8 +28,6 @@ from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from transformers import LlamaTokenizer
from exllama.generator import ExLlamaGenerator
import traceback
model_backend_name = "ExLlama"
@@ -60,8 +60,10 @@ class model_backend(InferenceModel):
self.model = None
self.tokenizer = None
self.cache = None
self.generator = None
self.model_name = None
self.model_name = ""
self.path = None
def is_valid(self, model_name, model_path, menu_path):
@@ -84,7 +86,7 @@ class model_backend(InferenceModel):
def _load(self, save_model: bool, initial_load: bool) -> None:
self.model = self._get_model(self.get_local_model_path(), {})
self.tokenizer = self._get_tokenizer(os.path.join(self.get_local_model_path(), "tokenizer.model"))
self.tokenizer = self._get_tokenizer(self.get_local_model_path()))
self.cache = ExLlamaCache(self.model)
@@ -174,6 +176,33 @@ class model_backend(InferenceModel):
return result
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
def unload(self):
self.model_config = None
self.model = None
self.tokenizer = None
self.cache = None
self.generator = None
self.model_name = ""
self.path = None
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
for tensor in gc.get_objects():
try:
if torch.is_tensor(tensor):
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except:
pass
gc.collect()
try:
with torch.no_grad():
torch.cuda.empty_cache()
except:
pass
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
@@ -184,6 +213,9 @@ class model_backend(InferenceModel):
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
if seed:
torch.manual_seed(seed)
if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else: