Fix exllama model unload

This commit is contained in:
0cc4m
2023-06-05 18:43:57 +02:00
parent b35f61e987
commit 94520d5c80

View File

@@ -9,6 +9,8 @@ import os
import glob import glob
from pathlib import Path from pathlib import Path
import re import re
import warnings
import gc
import utils import utils
from logger import logger from logger import logger
@@ -26,8 +28,6 @@ from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from exllama.generator import ExLlamaGenerator from exllama.generator import ExLlamaGenerator
import traceback
model_backend_name = "ExLlama" model_backend_name = "ExLlama"
@@ -60,8 +60,10 @@ class model_backend(InferenceModel):
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.cache = None
self.generator = None
self.model_name = None self.model_name = ""
self.path = None self.path = None
def is_valid(self, model_name, model_path, menu_path): def is_valid(self, model_name, model_path, menu_path):
@@ -84,7 +86,7 @@ class model_backend(InferenceModel):
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
self.model = self._get_model(self.get_local_model_path(), {}) self.model = self._get_model(self.get_local_model_path(), {})
self.tokenizer = self._get_tokenizer(os.path.join(self.get_local_model_path(), "tokenizer.model")) self.tokenizer = self._get_tokenizer(self.get_local_model_path()))
self.cache = ExLlamaCache(self.model) self.cache = ExLlamaCache(self.model)
@@ -174,6 +176,33 @@ class model_backend(InferenceModel):
return result return result
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer)) object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
def unload(self):
self.model_config = None
self.model = None
self.tokenizer = None
self.cache = None
self.generator = None
self.model_name = ""
self.path = None
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
for tensor in gc.get_objects():
try:
if torch.is_tensor(tensor):
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except:
pass
gc.collect()
try:
with torch.no_grad():
torch.cuda.empty_cache()
except:
pass
def _raw_generate( def _raw_generate(
self, self,
prompt_tokens: Union[List[int], torch.Tensor], prompt_tokens: Union[List[int], torch.Tensor],
@@ -184,6 +213,9 @@ class model_backend(InferenceModel):
seed: Optional[int] = None, seed: Optional[int] = None,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
if seed:
torch.manual_seed(seed)
if not isinstance(prompt_tokens, torch.Tensor): if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else: else: