From 7954bb5fcfc4d56d7d5f35877d790953f2a6b2b2 Mon Sep 17 00:00:00 2001 From: Simone Margaritelli Date: Fri, 25 Oct 2019 12:15:56 +0200 Subject: [PATCH] misc: small fix or general refactoring i did not bother commenting --- pwnagotchi/ai/__init__.py | 75 +++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py index 7390193..4dbdc9b 100644 --- a/pwnagotchi/ai/__init__.py +++ b/pwnagotchi/ai/__init__.py @@ -15,41 +15,46 @@ def load(config, agent, epoch, from_disk=True): logging.info("ai disabled") return False - logging.info("[ai] bootstrapping dependencies ...") + try: + logging.info("[ai] bootstrapping dependencies ...") - start = time.time() - from stable_baselines import A2C - logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) - - start = time.time() - from stable_baselines.common.policies import MlpLstmPolicy - logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) - - start = time.time() - from stable_baselines.common.vec_env import DummyVecEnv - logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) - - start = time.time() - import pwnagotchi.ai.gym as wrappers - logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start)) - - env = wrappers.Environment(agent, epoch) - env = DummyVecEnv([lambda: env]) - - logging.info("[ai] creating model ...") - - start = time.time() - a2c = A2C(MlpLstmPolicy, env, **config['params']) - logging.debug("[ai] A2C crated in %.2fs" % (time.time() - start)) - - if from_disk and os.path.exists(config['path']): - logging.info("[ai] loading %s ..." % config['path']) start = time.time() - a2c.load(config['path'], env) - logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) - else: - logging.info("[ai] model created:") - for key, value in config['params'].items(): - logging.info(" %s: %s" % (key, value)) + from stable_baselines import A2C + logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) - return a2c + start = time.time() + from stable_baselines.common.policies import MlpLstmPolicy + logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) + + start = time.time() + from stable_baselines.common.vec_env import DummyVecEnv + logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) + + start = time.time() + import pwnagotchi.ai.gym as wrappers + logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start)) + + env = wrappers.Environment(agent, epoch) + env = DummyVecEnv([lambda: env]) + + logging.info("[ai] creating model ...") + + start = time.time() + a2c = A2C(MlpLstmPolicy, env, **config['params']) + logging.debug("[ai] A2C crated in %.2fs" % (time.time() - start)) + + if from_disk and os.path.exists(config['path']): + logging.info("[ai] loading %s ..." % config['path']) + start = time.time() + a2c.load(config['path'], env) + logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) + else: + logging.info("[ai] model created:") + for key, value in config['params'].items(): + logging.info(" %s: %s" % (key, value)) + + return a2c + except Exception as e: + logging.exception("error while starting AI") + + return False