mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Merge pull request #136 from VE-FORBRYDERNE/opt
Fix base OPT-125M and finetuned OPT models in Colab TPU instances
This commit is contained in:
		| @@ -1225,13 +1225,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | |||||||
|                 if utils.num_shards is not None: |                 if utils.num_shards is not None: | ||||||
|                     utils.current_shard += 1 |                     utils.current_shard += 1 | ||||||
|                 for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): |                 for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): | ||||||
|  |                     model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None) | ||||||
|  |  | ||||||
|                     # Some model weights are used by transformers but not by MTJ. |                     # Some model weights are used by transformers but not by MTJ. | ||||||
|                     # We have to materialize these weights anyways because |                     # We have to materialize these weights anyways because | ||||||
|                     # transformers will throw a tantrum otherwise.  To attain |                     # transformers will throw a tantrum otherwise.  To attain | ||||||
|                     # the least possible memory usage, we create them as meta |                     # the least possible memory usage, we create them as meta | ||||||
|                     # tensors, which don't take up any actual CPU or TPU memory. |                     # tensors, which don't take up any actual CPU or TPU memory. | ||||||
|                     if key not in model_spec: |                     if model_spec_key is None: | ||||||
|                         model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") |                         model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") | ||||||
|                         utils.bar.update(1) |                         utils.bar.update(1) | ||||||
|                         continue |                         continue | ||||||
| @@ -1246,7 +1247,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | |||||||
|                     if current_offset != model_dict[key].seek_offset: |                     if current_offset != model_dict[key].seek_offset: | ||||||
|                         f.read(model_dict[key].seek_offset - current_offset) |                         f.read(model_dict[key].seek_offset - current_offset) | ||||||
|                         current_offset = model_dict[key].seek_offset |                         current_offset = model_dict[key].seek_offset | ||||||
|                     spec = model_spec[key] |                     spec = model_spec[model_spec_key] | ||||||
|                     transforms = set(spec.get("transforms", ())) |                     transforms = set(spec.get("transforms", ())) | ||||||
|                     if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): |                     if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): | ||||||
|                         error = f"Duplicate key {repr(key)}" |                         error = f"Duplicate key {repr(key)}" | ||||||
|   | |||||||
| @@ -183,8 +183,8 @@ function userscript.genmod() | |||||||
|             max_overlap[i] = 0 |             max_overlap[i] = 0 | ||||||
|             local s = {} |             local s = {} | ||||||
|             local z = {[0] = 0} |             local z = {[0] = 0} | ||||||
|             local l = 1 |             local l = 0 | ||||||
|             local r = 1 |             local r = 0 | ||||||
|             local n_s = math.min(n_tokens, bias_entry.n_tokens) |             local n_s = math.min(n_tokens, bias_entry.n_tokens) | ||||||
|             local j = 0 |             local j = 0 | ||||||
|             for k = 1, n_s do |             for k = 1, n_s do | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user