def _get_pairwise_mask(labels, ids):"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.A triplet (i, j, k) is valid if:- i, j, k are distinct- labels[i] == labels[j] and labels[i] != labels[k]- id[i] == id[j] == id[k]Args:labels: tf.int32 `Tensor` with shape [batch_size]"""# Check that i, j are distinctlabels = tf.reshape(labels, shape=(-1,))ids = tf.reshape(ids, shape=(-1,))# check id i,j,k are equalid_equal = tf.equal(tf.expand_dims(ids, 0), tf.expand_dims(ids, 1))# Check if labels[i] != labels[k]# check if labels[i] > labels[j]label_less = tf.less(tf.expand_dims(labels, 1), tf.expand_dims(labels, 0))# Combine the two masksmask = tf.logical_and(id_equal, label_less)return mask
- 构建一个triplet(a,p,n)需要满足三个条件,如上所示。
- i、j、k are distinct:都代表商品,他们是三个不同的商品。
- labels[i] == labels[j] and labels[i] != labels[k]:i和j同标签,i和k不同标签。
- id[i] == id[j] == id[k]:来自统一个用户。
- label_less:这里用less是为了得到label不同的情况,由于label只有0和1,避免另两个label不同的位置计算两次,所以只取label小于另一个label的情况。
- 这个方式返回的mask,只满足了条件2和条件3.