mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Added status bar for downloading models
This commit is contained in:
		
							
								
								
									
										52
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -1305,9 +1305,60 @@ def patch_causallm(model): | |||||||
|     Embedding._koboldai_patch_causallm_model = model |     Embedding._koboldai_patch_causallm_model = model | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
|  | def patch_transformers_download(): | ||||||
|  |     global transformers | ||||||
|  |     import copy, requests, tqdm, time | ||||||
|  |     class Send_to_socketio(object): | ||||||
|  |         def write(self, bar): | ||||||
|  |             bar = bar.replace("\r", "") | ||||||
|  |             try: | ||||||
|  |                 emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) | ||||||
|  |                 eventlet.sleep(seconds=0) | ||||||
|  |             except: | ||||||
|  |                 pass | ||||||
|  |     def http_get( | ||||||
|  |         url: str, | ||||||
|  |         temp_file: transformers.utils.hub.BinaryIO, | ||||||
|  |         proxies=None, | ||||||
|  |         resume_size=0, | ||||||
|  |         headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None, | ||||||
|  |         file_name: transformers.utils.hub.Optional[str] = None, | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Download remote file. Do not gobble up errors. | ||||||
|  |         """ | ||||||
|  |         headers = copy.deepcopy(headers) | ||||||
|  |         if resume_size > 0: | ||||||
|  |             headers["Range"] = f"bytes={resume_size}-" | ||||||
|  |         r = requests.get(url, stream=True, proxies=proxies, headers=headers) | ||||||
|  |         transformers.utils.hub._raise_for_status(r) | ||||||
|  |         content_length = r.headers.get("Content-Length") | ||||||
|  |         total = resume_size + int(content_length) if content_length is not None else None | ||||||
|  |         # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` | ||||||
|  |         # and can be set using `utils.logging.enable/disable_progress_bar()` | ||||||
|  |         progress = tqdm.tqdm( | ||||||
|  |             unit="B", | ||||||
|  |             unit_scale=True, | ||||||
|  |             unit_divisor=1024, | ||||||
|  |             total=total, | ||||||
|  |             initial=resume_size, | ||||||
|  |             desc=f"Downloading {file_name}" if file_name is not None else "Downloading", | ||||||
|  |             file=Send_to_socketio(), | ||||||
|  |         ) | ||||||
|  |         for chunk in r.iter_content(chunk_size=1024): | ||||||
|  |             if chunk:  # filter out keep-alive new chunks | ||||||
|  |                 progress.update(len(chunk)) | ||||||
|  |                 temp_file.write(chunk) | ||||||
|  |         progress.close() | ||||||
|  |  | ||||||
|  |     transformers.utils.hub.http_get = http_get | ||||||
|  |      | ||||||
|  |  | ||||||
| def patch_transformers(): | def patch_transformers(): | ||||||
|     global transformers |     global transformers | ||||||
|  |      | ||||||
|  |     patch_transformers_download() | ||||||
|  |      | ||||||
|     old_from_pretrained = PreTrainedModel.from_pretrained.__func__ |     old_from_pretrained = PreTrainedModel.from_pretrained.__func__ | ||||||
|     @classmethod |     @classmethod | ||||||
|     def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |     def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||||||
| @@ -6377,6 +6428,7 @@ if __name__ == "__main__": | |||||||
|                 vars.flaskwebgui = True |                 vars.flaskwebgui = True | ||||||
|                 FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run() |                 FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run() | ||||||
|             except: |             except: | ||||||
|  |                 pass | ||||||
|                 import webbrowser |                 import webbrowser | ||||||
|                 webbrowser.open_new('http://localhost:{0}'.format(port)) |                 webbrowser.open_new('http://localhost:{0}'.format(port)) | ||||||
|                 print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}" |                 print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}" | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								utils.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								utils.py
									
									
									
									
									
								
							| @@ -172,6 +172,16 @@ def num_layers(config): | |||||||
| #==================================================================# | #==================================================================# | ||||||
| #  Downloads huggingface checkpoints using aria2c if possible | #  Downloads huggingface checkpoints using aria2c if possible | ||||||
| #==================================================================# | #==================================================================# | ||||||
|  | from flask_socketio import emit | ||||||
|  | class Send_to_socketio(object): | ||||||
|  |     def write(self, bar): | ||||||
|  |         print("should be emitting: ", bar, end="") | ||||||
|  |         time.sleep(0.01) | ||||||
|  |         try: | ||||||
|  |             emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) | ||||||
|  |         except: | ||||||
|  |             pass | ||||||
|  |              | ||||||
| def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): | def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): | ||||||
|     import transformers |     import transformers | ||||||
|     import transformers.modeling_utils |     import transformers.modeling_utils | ||||||
| @@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d | |||||||
|                     done = True |                     done = True | ||||||
|                     break |                     break | ||||||
|                 if bar is None: |                 if bar is None: | ||||||
|                     bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) |                     bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) | ||||||
|                 visited = set() |                 visited = set() | ||||||
|                 for x in r: |                 for x in r: | ||||||
|                     filename = x["files"][0]["path"] |                     filename = x["files"][0]["path"] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user