diff --git a/aiserver.py b/aiserver.py index ec69fbc2..53fb5b87 100644 --- a/aiserver.py +++ b/aiserver.py @@ -458,9 +458,9 @@ def move_model_to_devices(model): gc.collect() GPTNeoModel.forward = breakmodel.new_forward_neo if("GPTJModel" in globals()): - GPTJModel.forward = breakmodel.new_forward_neo + GPTJModel.forward = breakmodel.new_forward_neo # type: ignore if("XGLMModel" in globals()): - XGLMModel.forward = breakmodel.new_forward_xglm + XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore generator = model.generate if(hasattr(model, "transformer")): breakmodel.move_hidden_layers(model.transformer)