fix: safer nn saving

This commit is contained in:
Simone Margaritelli 2019-10-01 11:29:11 +02:00
parent dfe83c68c3
commit 9620894b43

@ -66,8 +66,8 @@ class Stats(object):
def save(self): def save(self):
with self._lock: with self._lock:
core.log("[ai] saving %s" % self.path) core.log("[ai] saving %s" % self.path)
with open(self.path, 'wt') as fp:
json.dump({ data = json.dumps({
'born_at': self.born_at, 'born_at': self.born_at,
'epochs_lived': self.epochs_lived, 'epochs_lived': self.epochs_lived,
'epochs_trained': self.epochs_trained, 'epochs_trained': self.epochs_trained,
@ -75,7 +75,13 @@ class Stats(object):
'best': self.best_reward, 'best': self.best_reward,
'worst': self.worst_reward 'worst': self.worst_reward
} }
}, fp) })
temp = ".%s.tmp" % self.path
with open(temp, 'wt') as fp:
fp.write(data)
os.replace(temp, self.path)
class AsyncTrainer(object): class AsyncTrainer(object):
@ -103,7 +109,9 @@ class AsyncTrainer(object):
def _save_ai(self): def _save_ai(self):
core.log("[ai] saving model to %s ..." % self._nn_path) core.log("[ai] saving model to %s ..." % self._nn_path)
self._model.save(self._nn_path) temp = "%.%s.tmp" % self._nn_path
self._model.save(temp)
os.replace(temp, self._nn_path)
def on_ai_step(self): def on_ai_step(self):
self._model.env.render() self._model.env.render()