more fixes

This commit is contained in:
Simone Margaritelli 2019-09-30 21:22:01 +02:00
parent 52ab525e9f
commit 26abbf51d6
6 changed files with 72 additions and 68 deletions

View File

@ -24,7 +24,7 @@ args = parser.parse_args()
if args.do_clear:
print("clearing the display ...")
from pwnagotchi.ui.waveshare import EPD
from pwnagotchi.ui.waveshare.v2.waveshare import EPD
epd = EPD()
epd.init(epd.FULL_UPDATE)

View File

@ -9,7 +9,7 @@ import _thread
import core
from bettercap.client import Client
from pwnagotchi.mesh.advertise import AsyncAdvertiser
from pwnagotchi.mesh.utils import AsyncAdvertiser
from pwnagotchi.ai.train import AsyncTrainer
RECOVERY_DATA_FILE = '/root/.pwnagotchi-recovery'

View File

@ -142,28 +142,29 @@ class AsyncTrainer(object):
def _ai_worker(self):
self._model = ai.load(self._config, self, self._epoch)
self.on_ai_ready()
if self._model:
self.on_ai_ready()
epochs_per_episode = self._config['ai']['epochs_per_episode']
epochs_per_episode = self._config['ai']['epochs_per_episode']
obs = None
while True:
self._model.env.render()
# enter in training mode?
if random.random() > self._config['ai']['laziness']:
core.log("[ai] learning for %d epochs ..." % epochs_per_episode)
try:
self.set_training(True, epochs_per_episode)
self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step)
except Exception as e:
core.log("[ai] error while training: %s" % e)
finally:
self.set_training(False)
obs = None
while True:
self._model.env.render()
# enter in training mode?
if random.random() > self._config['ai']['laziness']:
core.log("[ai] learning for %d epochs ..." % epochs_per_episode)
try:
self.set_training(True, epochs_per_episode)
self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step)
except Exception as e:
core.log("[ai] error while training: %s" % e)
finally:
self.set_training(False)
obs = self._model.env.reset()
# init the first time
elif obs is None:
obs = self._model.env.reset()
# init the first time
elif obs is None:
obs = self._model.env.reset()
# run the inference
action, _ = self._model.predict(obs)
obs, _, _, _ = self._model.env.step(action)
# run the inference
action, _ = self._model.predict(obs)
obs, _, _, _ = self._model.env.step(action)

View File

@ -5,10 +5,8 @@ import threading
from scapy.all import Dot11, Dot11FCS, Dot11Elt, RadioTap, sendp, sniff
import core
import pwnagotchi
import pwnagotchi.ui.faces as faces
from pwnagotchi.mesh import get_identity
import pwnagotchi.mesh.wifi as wifi
from pwnagotchi.mesh import new_session_id
from pwnagotchi.mesh.peer import Peer
@ -181,40 +179,3 @@ class Advertiser(object):
for ident in stale:
del self._peers[ident]
class AsyncAdvertiser(object):
def __init__(self, config, view):
self._config = config
self._view = view
self._public_key, self._identity = get_identity(config)
self._advertiser = None
def start_advertising(self):
_thread.start_new_thread(self._adv_worker, ())
def _adv_worker(self):
# this will take some time due to scapy being slow to be imported ...
from pwnagotchi.mesh.advertise import Advertiser
self._advertiser = Advertiser(
self._config['main']['iface'],
pwnagotchi.name(),
pwnagotchi.version,
self._identity,
period=0.3,
data=self._config['personality'])
self._advertiser.on_peer(self._on_new_unit, self._on_lost_unit)
if self._config['personality']['advertise']:
self._advertiser.start()
self._view.on_state_change('face', self._advertiser.on_face_change)
else:
core.log("advertising is disabled")
def _on_new_unit(self, peer):
self._view.on_new_peer(peer)
def _on_lost_unit(self, peer):
self._view.on_lost_peer(peer)

View File

@ -0,0 +1,42 @@
import _thread
import core
import pwnagotchi
from pwnagotchi.mesh import get_identity
class AsyncAdvertiser(object):
def __init__(self, config, view):
self._config = config
self._view = view
self._public_key, self._identity = get_identity(config)
self._advertiser = None
def start_advertising(self):
_thread.start_new_thread(self._adv_worker, ())
def _adv_worker(self):
# this will take some time due to scapy being slow to be imported ...
from pwnagotchi.mesh.advertise import Advertiser
self._advertiser = Advertiser(
self._config['main']['iface'],
pwnagotchi.name(),
pwnagotchi.version,
self._identity,
period=0.3,
data=self._config['personality'])
self._advertiser.on_peer(self._on_new_unit, self._on_lost_unit)
if self._config['personality']['advertise']:
self._advertiser.start()
self._view.on_state_change('face', self._advertiser.on_face_change)
else:
core.log("advertising is disabled")
def _on_new_unit(self, peer):
self._view.on_new_peer(peer)
def _on_lost_unit(self, peer):
self._view.on_lost_peer(peer)

View File

@ -111,12 +111,14 @@ class Display(View):
def _init_display(self):
if self._is_inky():
core.log("initializing inky display")
from inky import InkyPHAT
self._display = InkyPHAT(self._display_color)
self._display.set_border(InkyPHAT.BLACK)
self._render_cb = self._inky_render
elif self._is_papirus():
core.log("initializing papirus display")
from pwnagotchi.ui.papirus.epd import EPD
os.environ['EPD_SIZE'] = '2.0'
self._display = EPD()
@ -124,8 +126,8 @@ class Display(View):
self._render_cb = self._papirus_render
elif self._is_waveshare1():
core.log("initializing waveshare v1 display")
from pwnagotchi.ui.waveshare.v1.epd2in13 import EPD
# core.log("display module started")
self._display = EPD()
self._display.init(self._display.lut_full_update)
self._display.Clear(0xFF)
@ -133,8 +135,8 @@ class Display(View):
self._render_cb = self._waveshare_render
elif self._is_waveshare2():
core.log("initializing waveshare v2 display")
from pwnagotchi.ui.waveshare.v2.waveshare import EPD
# core.log("display module started")
self._display = EPD()
self._display.init(self._display.FULL_UPDATE)
self._display.Clear(WHITE)
@ -146,8 +148,6 @@ class Display(View):
self.on_render(self._on_view_rendered)
core.log("display type '%s' initialized (color:%s)" % (self._display_type, self._display_color))
def image(self):
img = None
if self.canvas is not None:
@ -189,9 +189,9 @@ class Display(View):
def _waveshare_render(self):
buf = self._display.getbuffer(self.canvas)
if self._is_waveshare1:
if self._is_waveshare1():
self._display.display(buf)
elif self._is_waveshare2:
elif self._is_waveshare2():
self._display.displayPartial(buf)
def _on_view_rendered(self, img):