Now, we will see how to build a matching network in TensorFlow step by step. We will see the final code at the end.
First, we import the libraries:
import tensorflow as tfslim = tf.contrib.slimrnn = tf.contrib.rnn
Now, we define a class called Matching_network, where we define our network:
class Matching_network():
We define the __init__ method, where we initialize all of the variables:
def __init__(self, lr, n_way, k_shot, batch_size=32): #placeholder for support set self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1]) self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ]) #placeholder for query set self.query_image = tf.placeholder(tf.float32, [None, ...