From 9620894b43c9758d8a10715f6df10db631f8305d Mon Sep 17 00:00:00 2001
From: Simone Margaritelli <evilsocket@gmail.com>
Date: Tue, 1 Oct 2019 11:29:11 +0200
Subject: [PATCH] fix: safer nn saving

---
 .../pwnagotchi/scripts/pwnagotchi/ai/train.py    | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py
index 917815d..9bcefad 100644
--- a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py
+++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py
@@ -66,8 +66,8 @@ class Stats(object):
     def save(self):
         with self._lock:
             core.log("[ai] saving %s" % self.path)
-            with open(self.path, 'wt') as fp:
-                json.dump({
+
+            data = json.dumps({
                     'born_at': self.born_at,
                     'epochs_lived': self.epochs_lived,
                     'epochs_trained': self.epochs_trained,
@@ -75,7 +75,13 @@ class Stats(object):
                         'best': self.best_reward,
                         'worst': self.worst_reward
                     }
-                }, fp)
+                })
+
+            temp = ".%s.tmp" % self.path
+            with open(temp, 'wt') as fp:
+                fp.write(data)
+
+            os.replace(temp, self.path)
 
 
 class AsyncTrainer(object):
@@ -103,7 +109,9 @@ class AsyncTrainer(object):
 
     def _save_ai(self):
         core.log("[ai] saving model to %s ..." % self._nn_path)
-        self._model.save(self._nn_path)
+        temp = "%.%s.tmp" % self._nn_path
+        self._model.save(temp)
+        os.replace(temp, self._nn_path)
 
     def on_ai_step(self):
         self._model.env.render()