From 22fc1b69ee640208d57b6d0298c6d44bc36e9f14 Mon Sep 17 00:00:00 2001 From: ebolam Date: Wed, 14 Sep 2022 18:25:47 -0400 Subject: [PATCH] Transformers fix --- aiserver.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 6ce1577e..a7c24a85 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1424,7 +1424,6 @@ def patch_causallm(model): def patch_transformers_download(): global transformers import copy, requests, tqdm, time - from typing import BinaryIO, Dict, Optional, Tuple, Union class Send_to_socketio(object): def write(self, bar): bar = bar.replace("\r", "").replace("\n", "") @@ -1482,6 +1481,87 @@ def patch_transformers_download(): temp_file.write(chunk) progress.close() + def http_get( + url: str, + temp_file: transformers.utils.hub.BinaryIO, + proxies=None, + resume_size=0, + headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None, + file_name: transformers.utils.hub.Optional[str] = None, + ): + """ + Download remote file. Do not gobble up errors. + """ + headers = copy.deepcopy(headers) + if resume_size > 0: + headers["Range"] = f"bytes={resume_size}-" + r = requests.get(url, stream=True, proxies=proxies, headers=headers) + transformers.utils.hub._raise_for_status(r) + content_length = r.headers.get("Content-Length") + total = resume_size + int(content_length) if content_length is not None else None + # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` + # and can be set using `utils.logging.enable/disable_progress_bar()` + if url[-11:] != 'config.json': + progress = tqdm.tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=total, + initial=resume_size, + desc=f"Downloading {file_name}" if file_name is not None else "Downloading", + file=Send_to_socketio(), + ) + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + if url[-11:] != 'config.json': + progress.update(len(chunk)) + temp_file.write(chunk) + if url[-11:] != 'config.json': + progress.close() + + # def http_get( + # url: str, + # temp_file: BinaryIO, + # *, + # proxies=None, + # resume_size=0, + # headers: Optional[Dict[str, str]] = None, + # timeout=10.0, + # max_retries=0, + # ): + # """ + # Donwload a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. + # """ + # headers = copy.deepcopy(headers) + # if resume_size > 0: + # headers["Range"] = "bytes=%d-" % (resume_size,) + # r = _request_wrapper( + # method="GET", + # url=url, + # stream=True, + # proxies=proxies, + # headers=headers, + # timeout=timeout, + # max_retries=max_retries, + # ) + # hf_raise_for_status(r) + # content_length = r.headers.get("Content-Length") + # total = resume_size + int(content_length) if content_length is not None else None + # progress = tqdm( + # unit="B", + # unit_scale=True, + # total=total, + # initial=resume_size, + # desc="Downloading", + # file=Send_to_socketio(), + # disable=bool(logger.getEffectiveLevel() == logging.NOTSET), + # ) + # for chunk in r.iter_content(chunk_size=1024): + # if chunk: # filter out keep-alive new chunks + # progress.update(len(chunk)) + # temp_file.write(chunk) + # progress.close() + transformers.utils.hub.http_get = http_get