190 lines
6.1 KiB
Python
190 lines
6.1 KiB
Python
import _thread
|
|
import threading
|
|
import time
|
|
import random
|
|
import os
|
|
import json
|
|
import logging
|
|
|
|
import pwnagotchi.plugins as plugins
|
|
import pwnagotchi.ai as ai
|
|
|
|
|
|
class Stats(object):
|
|
def __init__(self, path, events_receiver):
|
|
self._lock = threading.Lock()
|
|
self._receiver = events_receiver
|
|
|
|
self.path = path
|
|
self.born_at = time.time()
|
|
# total epochs lived (trained + just eval)
|
|
self.epochs_lived = 0
|
|
# total training epochs
|
|
self.epochs_trained = 0
|
|
|
|
self.worst_reward = 0.0
|
|
self.best_reward = 0.0
|
|
|
|
self.load()
|
|
|
|
def on_epoch(self, data, training):
|
|
best_r = False
|
|
worst_r = False
|
|
with self._lock:
|
|
reward = data['reward']
|
|
if reward < self.worst_reward:
|
|
self.worst_reward = reward
|
|
worst_r = True
|
|
|
|
elif reward > self.best_reward:
|
|
best_r = True
|
|
self.best_reward = reward
|
|
|
|
self.epochs_lived += 1
|
|
if training:
|
|
self.epochs_trained += 1
|
|
|
|
self.save()
|
|
|
|
if best_r:
|
|
self._receiver.on_ai_best_reward(reward)
|
|
elif worst_r:
|
|
self._receiver.on_ai_worst_reward(reward)
|
|
|
|
def load(self):
|
|
with self._lock:
|
|
if os.path.exists(self.path) and os.path.getsize(self.path) > 0:
|
|
logging.info("[ai] loading %s" % self.path)
|
|
with open(self.path, 'rt') as fp:
|
|
obj = json.load(fp)
|
|
|
|
self.born_at = obj['born_at']
|
|
self.epochs_lived, self.epochs_trained = obj['epochs_lived'], obj['epochs_trained']
|
|
self.best_reward, self.worst_reward = obj['rewards']['best'], obj['rewards']['worst']
|
|
|
|
def save(self):
|
|
with self._lock:
|
|
logging.info("[ai] saving %s" % self.path)
|
|
|
|
data = json.dumps({
|
|
'born_at': self.born_at,
|
|
'epochs_lived': self.epochs_lived,
|
|
'epochs_trained': self.epochs_trained,
|
|
'rewards': {
|
|
'best': self.best_reward,
|
|
'worst': self.worst_reward
|
|
}
|
|
})
|
|
|
|
temp = "%s.tmp" % self.path
|
|
with open(temp, 'wt') as fp:
|
|
fp.write(data)
|
|
|
|
os.replace(temp, self.path)
|
|
|
|
|
|
class AsyncTrainer(object):
|
|
def __init__(self, config):
|
|
self._config = config
|
|
self._model = None
|
|
self._is_training = False
|
|
self._training_epochs = 0
|
|
self._nn_path = self._config['ai']['path']
|
|
self._stats = Stats("%s.json" % os.path.splitext(self._nn_path)[0], self)
|
|
|
|
def set_training(self, training, for_epochs=0):
|
|
self._is_training = training
|
|
self._training_epochs = for_epochs
|
|
|
|
if training:
|
|
plugins.on('ai_training_start', self, for_epochs)
|
|
else:
|
|
plugins.on('ai_training_end', self)
|
|
|
|
def is_training(self):
|
|
return self._is_training
|
|
|
|
def training_epochs(self):
|
|
return self._training_epochs
|
|
|
|
def start_ai(self):
|
|
_thread.start_new_thread(self._ai_worker, ())
|
|
|
|
def _save_ai(self):
|
|
logging.info("[ai] saving model to %s ..." % self._nn_path)
|
|
temp = "%s.tmp" % self._nn_path
|
|
self._model.save(temp)
|
|
os.replace(temp, self._nn_path)
|
|
|
|
def on_ai_step(self):
|
|
self._model.env.render()
|
|
|
|
if self._is_training:
|
|
self._save_ai()
|
|
|
|
self._stats.on_epoch(self._epoch.data(), self._is_training)
|
|
|
|
def on_ai_training_step(self, _locals, _globals):
|
|
self._model.env.render()
|
|
plugins.on('ai_training_step', self, _locals, _globals)
|
|
|
|
def on_ai_policy(self, new_params):
|
|
plugins.on('ai_policy', self, new_params)
|
|
logging.info("[ai] setting new policy:")
|
|
for name, value in new_params.items():
|
|
if name in self._config['personality']:
|
|
curr_value = self._config['personality'][name]
|
|
if curr_value != value:
|
|
logging.info("[ai] ! %s: %s -> %s" % (name, curr_value, value))
|
|
self._config['personality'][name] = value
|
|
else:
|
|
logging.error("[ai] param %s not in personality configuration!" % name)
|
|
|
|
self.run('set wifi.ap.ttl %d' % self._config['personality']['ap_ttl'])
|
|
self.run('set wifi.sta.ttl %d' % self._config['personality']['sta_ttl'])
|
|
self.run('set wifi.rssi.min %d' % self._config['personality']['min_rssi'])
|
|
|
|
def on_ai_ready(self):
|
|
self._view.on_ai_ready()
|
|
plugins.on('ai_ready', self)
|
|
|
|
def on_ai_best_reward(self, r):
|
|
logging.info("[ai] best reward so far: %s" % r)
|
|
self._view.on_motivated(r)
|
|
plugins.on('ai_best_reward', self, r)
|
|
|
|
def on_ai_worst_reward(self, r):
|
|
logging.info("[ai] worst reward so far: %s" % r)
|
|
self._view.on_demotivated(r)
|
|
plugins.on('ai_worst_reward', self, r)
|
|
|
|
def _ai_worker(self):
|
|
self._model = ai.load(self._config, self, self._epoch)
|
|
|
|
if self._model:
|
|
self.on_ai_ready()
|
|
|
|
epochs_per_episode = self._config['ai']['epochs_per_episode']
|
|
|
|
obs = None
|
|
while True:
|
|
self._model.env.render()
|
|
# enter in training mode?
|
|
if random.random() > self._config['ai']['laziness']:
|
|
logging.info("[ai] learning for %d epochs ..." % epochs_per_episode)
|
|
try:
|
|
self.set_training(True, epochs_per_episode)
|
|
self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step)
|
|
except Exception as e:
|
|
logging.exception("[ai] error while training")
|
|
finally:
|
|
self.set_training(False)
|
|
obs = self._model.env.reset()
|
|
# init the first time
|
|
elif obs is None:
|
|
obs = self._model.env.reset()
|
|
|
|
# run the inference
|
|
action, _ = self._model.predict(obs)
|
|
obs, _, _, _ = self._model.env.step(action)
|