Now, we implement a method that generates a child network using the Controller:
def generate_child_network(self, child_network_architecture): with self.graph.as_default(): return self.sess.run(self.cnn_dna_output, {self.child_network_architectures: child_network_architecture})
Once we generate our child network, we call the train_child_network function to train it. This function takes child_dna and child_id and returns the validation accuracy that the child network achieves. First, we instantiate a new tf.Graph() and a new tf.Session() so that the child network is separated from the Controller's graph:
def train_child_network(self, cnn_dna, child_id): """ Trains a child network and returns reward, ...