diff --git a/aiserver.py b/aiserver.py index 7d9d8a17..e79c0272 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5053,6 +5053,8 @@ def tpu_raw_generate( max_length: int, batch_count: int, ): + + prompt_tokens = prompt_tokens[0] # Mostly lifted from apiactionsubmit_tpumtjgenerate soft_tokens = tpumtjgetsofttokens() __debug("we are generating with", prompt_tokens, "batch", batch_count, "soft tokens", soft_tokens)