diff --git a/aiserver.py b/aiserver.py index 0e98a53e..5e7e9ef6 100644 --- a/aiserver.py +++ b/aiserver.py @@ -53,6 +53,7 @@ class vars: url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL apikey = "" savedir = getcwd()+"\stories\\newstory.json" + hascuda = False #==================================================================# # Startup @@ -107,8 +108,16 @@ if(vars.model != "InferKit"): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.HEADER, colors.ENDC)) from transformers import pipeline, GPT2Tokenizer + import torch + + # Is CUDA available? If so, use GPU, otherwise fall back to CPU + vars.hascuda = torch.cuda.is_available() - generator = pipeline('text-generation', model=vars.model, device=0) + if(vars.hascuda): + generator = pipeline('text-generation', model=vars.model, device=0) + else: + generator = pipeline('text-generation', model=vars.model) + tokenizer = GPT2Tokenizer.from_pretrained(vars.model) print("{0}OK! {1} pipeline created!{2}".format(colors.OKGREEN, vars.model, colors.ENDC)) else: