diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py index 5493342..a031db4 100644 --- a/pwnagotchi/ai/__init__.py +++ b/pwnagotchi/ai/__init__.py @@ -45,8 +45,17 @@ def load(config, agent, epoch, from_disk=True): if from_disk and os.path.exists(config['path']): logging.info("[ai] loading %s ..." % config['path']) start = time.time() - a2c.load(config['path'], env) - logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) + try: + a2c.load(config['path'], env) + except AssertionError as as_err: + from fnmatch import fnmatch + # Sometimes the model breaks... + if not fnmatch(str(as_err), '* same * space as the model *'): + raise as_err + else: + logging.warning("[ai] Model could not be loaded. Using new model.") + else: + logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) else: logging.info("[ai] model created:") for key, value in config['params'].items():