Merge pull request #359 from one-some/gptj-fix

GPT-J fix
This commit is contained in:
henk717
2023-05-12 08:40:00 +02:00
committed by GitHub
2 changed files with 19 additions and 10 deletions

View File

@@ -332,10 +332,13 @@ class HFTorchInferenceModel(HFInferenceModel):
raise raise
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
logger.debug(traceback.format_exc())
try: try:
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs) return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
except Exception as e: except Exception as e:
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}") logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
logger.debug(traceback.format_exc())
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs) return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
def get_hidden_size(self) -> int: def get_hidden_size(self) -> int:
@@ -462,19 +465,25 @@ class HFTorchInferenceModel(HFInferenceModel):
device_map: Dict[str, Union[str, int]] = {} device_map: Dict[str, Union[str, int]] = {}
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_original_key(key): def get_original_key(key) -> Optional[str]:
return max( key_candidates = [
( original_key
original_key for original_key in utils.module_names
for original_key in utils.module_names if original_key.endswith(key)
if original_key.endswith(key) ]
),
key=len, if not key_candidates:
) logger.debug(f"!!! No key candidates for {key}")
return None
return max(key_candidates, key=len)
for key, value in model_dict.items(): for key, value in model_dict.items():
original_key = get_original_key(key) original_key = get_original_key(key)
if not original_key:
continue
if isinstance(value, lazy_loader.LazyTensor) and not any( if isinstance(value, lazy_loader.LazyTensor) and not any(
original_key.startswith(n) for n in utils.layers_module_names original_key.startswith(n) for n in utils.layers_module_names
): ):

View File

@@ -4006,7 +4006,7 @@ function update_context(data) {
document.getElementById('world_info_'+entry.uid).classList.add("used_in_game"); document.getElementById('world_info_'+entry.uid).classList.add("used_in_game");
} }
break; break;
case 'memory': case 'genre':
genre_length += entry.tokens.length; genre_length += entry.tokens.length;
break; break;
case 'memory': case 'memory':