From 9620894b43c9758d8a10715f6df10db631f8305d Mon Sep 17 00:00:00 2001 From: Simone Margaritelli <evilsocket@gmail.com> Date: Tue, 1 Oct 2019 11:29:11 +0200 Subject: [PATCH] fix: safer nn saving --- .../pwnagotchi/scripts/pwnagotchi/ai/train.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py index 917815d..9bcefad 100644 --- a/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py +++ b/sdcard/rootfs/root/pwnagotchi/scripts/pwnagotchi/ai/train.py @@ -66,8 +66,8 @@ class Stats(object): def save(self): with self._lock: core.log("[ai] saving %s" % self.path) - with open(self.path, 'wt') as fp: - json.dump({ + + data = json.dumps({ 'born_at': self.born_at, 'epochs_lived': self.epochs_lived, 'epochs_trained': self.epochs_trained, @@ -75,7 +75,13 @@ class Stats(object): 'best': self.best_reward, 'worst': self.worst_reward } - }, fp) + }) + + temp = ".%s.tmp" % self.path + with open(temp, 'wt') as fp: + fp.write(data) + + os.replace(temp, self.path) class AsyncTrainer(object): @@ -103,7 +109,9 @@ class AsyncTrainer(object): def _save_ai(self): core.log("[ai] saving model to %s ..." % self._nn_path) - self._model.save(self._nn_path) + temp = "%.%s.tmp" % self._nn_path + self._model.save(temp) + os.replace(temp, self._nn_path) def on_ai_step(self): self._model.env.render()