misc: small fix or general refactoring i did not bother commenting
This commit is contained in:
parent
06d8cc63fb
commit
7954bb5fcf
@ -15,41 +15,46 @@ def load(config, agent, epoch, from_disk=True):
|
||||
logging.info("ai disabled")
|
||||
return False
|
||||
|
||||
logging.info("[ai] bootstrapping dependencies ...")
|
||||
try:
|
||||
logging.info("[ai] bootstrapping dependencies ...")
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines import A2C
|
||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.policies import MlpLstmPolicy
|
||||
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.vec_env import DummyVecEnv
|
||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
import pwnagotchi.ai.gym as wrappers
|
||||
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||
|
||||
env = wrappers.Environment(agent, epoch)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
logging.info("[ai] creating model ...")
|
||||
|
||||
start = time.time()
|
||||
a2c = A2C(MlpLstmPolicy, env, **config['params'])
|
||||
logging.debug("[ai] A2C crated in %.2fs" % (time.time() - start))
|
||||
|
||||
if from_disk and os.path.exists(config['path']):
|
||||
logging.info("[ai] loading %s ..." % config['path'])
|
||||
start = time.time()
|
||||
a2c.load(config['path'], env)
|
||||
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
||||
else:
|
||||
logging.info("[ai] model created:")
|
||||
for key, value in config['params'].items():
|
||||
logging.info(" %s: %s" % (key, value))
|
||||
from stable_baselines import A2C
|
||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||
|
||||
return a2c
|
||||
start = time.time()
|
||||
from stable_baselines.common.policies import MlpLstmPolicy
|
||||
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.vec_env import DummyVecEnv
|
||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
import pwnagotchi.ai.gym as wrappers
|
||||
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||
|
||||
env = wrappers.Environment(agent, epoch)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
logging.info("[ai] creating model ...")
|
||||
|
||||
start = time.time()
|
||||
a2c = A2C(MlpLstmPolicy, env, **config['params'])
|
||||
logging.debug("[ai] A2C crated in %.2fs" % (time.time() - start))
|
||||
|
||||
if from_disk and os.path.exists(config['path']):
|
||||
logging.info("[ai] loading %s ..." % config['path'])
|
||||
start = time.time()
|
||||
a2c.load(config['path'], env)
|
||||
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
||||
else:
|
||||
logging.info("[ai] model created:")
|
||||
for key, value in config['params'].items():
|
||||
logging.info(" %s: %s" % (key, value))
|
||||
|
||||
return a2c
|
||||
except Exception as e:
|
||||
logging.exception("error while starting AI")
|
||||
|
||||
return False
|
||||
|
Loading…
x
Reference in New Issue
Block a user