misc: added debug logs for AI loading times
This commit is contained in:
parent
c4ae3c15bd
commit
06d8cc63fb
@ -1,14 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709
|
# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
|
||||||
import warnings
|
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning
|
# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
def load(config, agent, epoch, from_disk=True):
|
def load(config, agent, epoch, from_disk=True):
|
||||||
config = config['ai']
|
config = config['ai']
|
||||||
@ -18,25 +17,39 @@ def load(config, agent, epoch, from_disk=True):
|
|||||||
|
|
||||||
logging.info("[ai] bootstrapping dependencies ...")
|
logging.info("[ai] bootstrapping dependencies ...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
from stable_baselines import A2C
|
from stable_baselines import A2C
|
||||||
from stable_baselines.common.policies import MlpLstmPolicy
|
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||||
from stable_baselines.common.vec_env import DummyVecEnv
|
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
from stable_baselines.common.policies import MlpLstmPolicy
|
||||||
|
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
from stable_baselines.common.vec_env import DummyVecEnv
|
||||||
|
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
import pwnagotchi.ai.gym as wrappers
|
import pwnagotchi.ai.gym as wrappers
|
||||||
|
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
env = wrappers.Environment(agent, epoch)
|
env = wrappers.Environment(agent, epoch)
|
||||||
env = DummyVecEnv([lambda: env])
|
env = DummyVecEnv([lambda: env])
|
||||||
|
|
||||||
logging.info("[ai] bootstrapping model ...")
|
logging.info("[ai] creating model ...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
a2c = A2C(MlpLstmPolicy, env, **config['params'])
|
a2c = A2C(MlpLstmPolicy, env, **config['params'])
|
||||||
|
logging.debug("[ai] A2C crated in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
if from_disk and os.path.exists(config['path']):
|
if from_disk and os.path.exists(config['path']):
|
||||||
logging.info("[ai] loading %s ..." % config['path'])
|
logging.info("[ai] loading %s ..." % config['path'])
|
||||||
|
start = time.time()
|
||||||
a2c.load(config['path'], env)
|
a2c.load(config['path'], env)
|
||||||
|
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
||||||
else:
|
else:
|
||||||
logging.info("[ai] model created:")
|
logging.info("[ai] model created:")
|
||||||
for key, value in config['params'].items():
|
for key, value in config['params'].items():
|
||||||
logging.info(" %s: %s" % (key, value))
|
logging.info(" %s: %s" % (key, value))
|
||||||
|
|
||||||
return a2c
|
return a2c
|
||||||
|
Loading…
x
Reference in New Issue
Block a user