From 26abbf51d65e5c987474b6fe65adf1587c64e303 Mon Sep 17 00:00:00 2001
From: Simone Margaritelli <evilsocket@gmail.com>
Date: Mon, 30 Sep 2019 21:22:01 +0200
Subject: [PATCH] more fixes

---
 sdcard/rootfs/root/pwnagotchi/scripts/main.py |  2 +-
 .../pwnagotchi/scripts/pwnagotchi/agent.py    |  2 +-
 .../pwnagotchi/scripts/pwnagotchi/ai/train.py | 43 ++++++++++---------
 .../scripts/pwnagotchi/mesh/advertise.py      | 39 -----------------
 .../scripts/pwnagotchi/mesh/utils.py          | 42 ++++++++++++++++++
 .../scripts/pwnagotchi/ui/display.py          | 12 +++---
 6 files changed, 72 insertions(+), 68 deletions(-)
 create mode 100644 sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/utils.py

diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/main.py b/sdcard/rootfs/root/pwnagotchi/scripts/main.py
index de44012..dcb6c21 100755
--- a/sdcard/rootfs/root/pwnagotchi/scripts/main.py
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/main.py
@@ -24,7 +24,7 @@ args = parser.parse_args()
 
 if args.do_clear:
     print("clearing the display ...")
-    from pwnagotchi.ui.waveshare import EPD
+    from pwnagotchi.ui.waveshare.v2.waveshare import EPD
 
     epd = EPD()
     epd.init(epd.FULL_UPDATE)
diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/agent.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/agent.py
index f8f1c41..1fa5112 100644
--- a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/agent.py
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/agent.py
@@ -9,7 +9,7 @@ import _thread
 import core
 
 from bettercap.client import Client
-from pwnagotchi.mesh.advertise import AsyncAdvertiser
+from pwnagotchi.mesh.utils import AsyncAdvertiser
 from pwnagotchi.ai.train import AsyncTrainer
 
 RECOVERY_DATA_FILE = '/root/.pwnagotchi-recovery'
diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py
index 97e9a6d..917815d 100644
--- a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py
@@ -142,28 +142,29 @@ class AsyncTrainer(object):
     def _ai_worker(self):
         self._model = ai.load(self._config, self, self._epoch)
 
-        self.on_ai_ready()
+        if self._model:
+            self.on_ai_ready()
 
-        epochs_per_episode = self._config['ai']['epochs_per_episode']
+            epochs_per_episode = self._config['ai']['epochs_per_episode']
 
-        obs = None
-        while True:
-            self._model.env.render()
-            # enter in training mode?
-            if random.random() > self._config['ai']['laziness']:
-                core.log("[ai] learning for %d epochs ..." % epochs_per_episode)
-                try:
-                    self.set_training(True, epochs_per_episode)
-                    self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step)
-                except Exception as e:
-                    core.log("[ai] error while training: %s" % e)
-                finally:
-                    self.set_training(False)
+            obs = None
+            while True:
+                self._model.env.render()
+                # enter in training mode?
+                if random.random() > self._config['ai']['laziness']:
+                    core.log("[ai] learning for %d epochs ..." % epochs_per_episode)
+                    try:
+                        self.set_training(True, epochs_per_episode)
+                        self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step)
+                    except Exception as e:
+                        core.log("[ai] error while training: %s" % e)
+                    finally:
+                        self.set_training(False)
+                        obs = self._model.env.reset()
+                # init the first time
+                elif obs is None:
                     obs = self._model.env.reset()
-            # init the first time
-            elif obs is None:
-                obs = self._model.env.reset()
 
-            # run the inference
-            action, _ = self._model.predict(obs)
-            obs, _, _, _ = self._model.env.step(action)
+                # run the inference
+                action, _ = self._model.predict(obs)
+                obs, _, _, _ = self._model.env.step(action)
diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/advertise.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/advertise.py
index c5313eb..9d5c788 100644
--- a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/advertise.py
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/advertise.py
@@ -5,10 +5,8 @@ import threading
 from scapy.all import Dot11, Dot11FCS, Dot11Elt, RadioTap, sendp, sniff
 
 import core
-import pwnagotchi
 import pwnagotchi.ui.faces as faces
 
-from pwnagotchi.mesh import get_identity
 import pwnagotchi.mesh.wifi as wifi
 from pwnagotchi.mesh import new_session_id
 from pwnagotchi.mesh.peer import Peer
@@ -181,40 +179,3 @@ class Advertiser(object):
 
                 for ident in stale:
                     del self._peers[ident]
-
-
-class AsyncAdvertiser(object):
-    def __init__(self, config, view):
-        self._config = config
-        self._view = view
-        self._public_key, self._identity = get_identity(config)
-        self._advertiser = None
-
-    def start_advertising(self):
-        _thread.start_new_thread(self._adv_worker, ())
-
-    def _adv_worker(self):
-        # this will take some time due to scapy being slow to be imported ...
-        from pwnagotchi.mesh.advertise import Advertiser
-
-        self._advertiser = Advertiser(
-            self._config['main']['iface'],
-            pwnagotchi.name(),
-            pwnagotchi.version,
-            self._identity,
-            period=0.3,
-            data=self._config['personality'])
-
-        self._advertiser.on_peer(self._on_new_unit, self._on_lost_unit)
-
-        if self._config['personality']['advertise']:
-            self._advertiser.start()
-            self._view.on_state_change('face', self._advertiser.on_face_change)
-        else:
-            core.log("advertising is disabled")
-
-    def _on_new_unit(self, peer):
-        self._view.on_new_peer(peer)
-
-    def _on_lost_unit(self, peer):
-        self._view.on_lost_peer(peer)
diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/utils.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/utils.py
new file mode 100644
index 0000000..c99cd57
--- /dev/null
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/mesh/utils.py
@@ -0,0 +1,42 @@
+import _thread
+
+import core
+import pwnagotchi
+from pwnagotchi.mesh import get_identity
+
+
+class AsyncAdvertiser(object):
+    def __init__(self, config, view):
+        self._config = config
+        self._view = view
+        self._public_key, self._identity = get_identity(config)
+        self._advertiser = None
+
+    def start_advertising(self):
+        _thread.start_new_thread(self._adv_worker, ())
+
+    def _adv_worker(self):
+        # this will take some time due to scapy being slow to be imported ...
+        from pwnagotchi.mesh.advertise import Advertiser
+
+        self._advertiser = Advertiser(
+            self._config['main']['iface'],
+            pwnagotchi.name(),
+            pwnagotchi.version,
+            self._identity,
+            period=0.3,
+            data=self._config['personality'])
+
+        self._advertiser.on_peer(self._on_new_unit, self._on_lost_unit)
+
+        if self._config['personality']['advertise']:
+            self._advertiser.start()
+            self._view.on_state_change('face', self._advertiser.on_face_change)
+        else:
+            core.log("advertising is disabled")
+
+    def _on_new_unit(self, peer):
+        self._view.on_new_peer(peer)
+
+    def _on_lost_unit(self, peer):
+        self._view.on_lost_peer(peer)
diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ui/display.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ui/display.py
index c9d72c4..4c693a6 100644
--- a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ui/display.py
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ui/display.py
@@ -111,12 +111,14 @@ class Display(View):
 
     def _init_display(self):
         if self._is_inky():
+            core.log("initializing inky display")
             from inky import InkyPHAT
             self._display = InkyPHAT(self._display_color)
             self._display.set_border(InkyPHAT.BLACK)
             self._render_cb = self._inky_render
 
         elif self._is_papirus():
+            core.log("initializing papirus display")
             from pwnagotchi.ui.papirus.epd import EPD
             os.environ['EPD_SIZE'] = '2.0'
             self._display = EPD()
@@ -124,8 +126,8 @@ class Display(View):
             self._render_cb = self._papirus_render
 
         elif self._is_waveshare1():
+            core.log("initializing waveshare v1 display")
             from pwnagotchi.ui.waveshare.v1.epd2in13 import EPD
-            # core.log("display module started")
             self._display = EPD()
             self._display.init(self._display.lut_full_update)
             self._display.Clear(0xFF)
@@ -133,8 +135,8 @@ class Display(View):
             self._render_cb = self._waveshare_render
 
         elif self._is_waveshare2():
+            core.log("initializing waveshare v2 display")
             from pwnagotchi.ui.waveshare.v2.waveshare import EPD
-            # core.log("display module started")
             self._display = EPD()
             self._display.init(self._display.FULL_UPDATE)
             self._display.Clear(WHITE)
@@ -146,8 +148,6 @@ class Display(View):
 
         self.on_render(self._on_view_rendered)
 
-        core.log("display type '%s' initialized (color:%s)" % (self._display_type, self._display_color))
-
     def image(self):
         img = None
         if self.canvas is not None:
@@ -189,9 +189,9 @@ class Display(View):
 
     def _waveshare_render(self):
         buf = self._display.getbuffer(self.canvas)
-        if self._is_waveshare1:
+        if self._is_waveshare1():
             self._display.display(buf)
-        elif self._is_waveshare2:
+        elif self._is_waveshare2():
             self._display.displayPartial(buf)
 
     def _on_view_rendered(self, img):