Next, a function is created to perform the inference for the test data. The model was stored as a checkpoint in the preceding step, and it is used here for inference. The placeholders for the input data are defined, and a saver object is also defined, as follows:
def inference(test_x1, max_sent_len, batch_size=1024): with tf.name_scope('Placeholders'): x_pls1 = tf.placeholder(tf.int32, shape=[None, max_sent_len]) keep_prob = tf.placeholder(tf.float32) # Dropout predict = model(x_pls1, keep_prob) saver = tf.train.Saver() ckpt_path = tf.train.latest_checkpoint('.')
Next, a session is created and the model is restored:
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, ckpt_path) print("Model ...