DLRM代码理解

2021-06-07  本文已影响0人  CPinging

在DLRM中有对训练集做处理的函数,我们对训练序列做了研究,

    def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
        # WARNING: notice that we are processing the batch at once. We implicitly
        # assume that the data is laid out such that:
        # 1. each embedding is indexed with a group of sparse indices,
        #   corresponding to a single lookup
        # 2. for each embedding the lookups are further organized into a batch
        # 3. for a list of embedding tables there is a list of batched lookups

        ly = []
        for k, sparse_index_group_batch in enumerate(lS_i):
            sparse_offset_group_batch = lS_o[k]

            # embedding lookup
            # We are using EmbeddingBag, which implicitly uses sum operator.
            # The embeddings are represented as tall matrices, with sum
            # happening vertically across 0 axis, resulting in a row vector
            # E = emb_l[k]

            if v_W_l[k] is not None:
                per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
            else:
                per_sample_weights = None

            if:
                ....
            else:
                E = emb_l[k]
                V = E(
                    sparse_index_group_batch,
                    sparse_offset_group_batch,
                    per_sample_weights=per_sample_weights,
                )
            
                ly.append(V)

重点是这个地方,其中E是所有打包好的Embedding:

image.png

其中第一维为这个Embedding table中包括的vector的数量,第二维64为vector的维度(有64个float)。

sparse_index_group_batch以及sparse_offset_group_batch为训练时需要的index以及offset,Embedding table会根据index找具体的vector。

offset需要注意,offset = torch.LongTensor([0,1,4]).to(0)代表三个样本,第一个样本是0 ~ 1,第二个是1 ~ 4,第三个是4(网上解释的都不够清楚,所以我这里通过代码实际跑了一下测出来是这个结果) 。且左闭右开[0,1)这种形式取整数(已经根据代码进行过验证)。

详细解释一下流程:

首先在apply_emb函数中每次循环会取出当前第k个Emb table:E = emb_l[k],其中k是当前所在轮数。

对于index数组与offset数组:

image.png

我们能看到,第一个tensor是index,有五个元素,代表我要取的当前table中的vector的编号(共5个)。

而后面的offset就代表我取出来的这5个数组哪些要进行reduce操作(加和等)。

例如我如果取offset为[0,3],则代表0,1,2相加进行reduce,3,4进行reduce。所以最终出来的数字个数就是offset的size。

IS_I以及IS_O生成的位置

在dlrm_data_pytorch.py中的collate_wrapper_criteo_offset()函数里:

def collate_wrapper_criteo_offset(list_of_tuples):
    # where each tuple is (X_int, X_cat, y)
    transposed_data = list(zip(*list_of_tuples))
    X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
    X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
    T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)

    batchSize = X_cat.shape[0]
    featureCnt = X_cat.shape[1]
    lS_i = [X_cat[:, i] for i in range(featureCnt)]
    lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
    return X_int, torch.stack(lS_o), torch.stack(lS_i), T

在这里生成访问序列,首先将传入的数据解析为X_cat,当bs=2时,X_cat为:

tensor([[    0,    17, 36684, 11838,     1,     0,   145,     9,     0,  1176,
            24, 34569,    24,     5,    24, 15109,     0,    19,    14,     3,
         32351,     0,     1,  4159,    32,  5050],
        [    3,    12, 33818, 19987,     0,     5,  1426,     1,     0,  8616,
           729, 31879,   658,     1,    50, 26833,     1,    12,    89,     0,
         29850,     0,     1,  1637,     3,  1246]])

其中每一个tensor有26个数字,代表26个Embedding table。每一个数字代表其中每个table需要访问的vector。(比如0代表访问第一个table的0号vector)

下面将访问序列打包,IS_i为:

[tensor([0, 3]), tensor([17, 12]), tensor([36684, 33818]), tensor([11838, 19987]), tensor([1, 0]), tensor([0, 5]), tensor([ 145, 1426]), tensor([9, 1]), tensor([0, 0]), tensor([1176, 8616]), tensor([ 24, 729]), tensor([34569, 31879]), tensor([ 24, 658]), tensor([5, 1]), tensor([24, 50]), tensor([15109, 26833]), tensor([0, 1]), tensor([19, 12]), tensor([14, 89]), tensor([3, 0]), tensor([32351, 29850]), tensor([0, 0]), tensor([1, 1]), tensor([4159, 1637]), tensor([32,  3]), tensor([5050, 1246])]

这里bs为2,所以[tensor([0, 3])代表访问第一个table的0,3个vactor。

这里我们要再次理解一下数据集的含义,这里每一个table都是用户的一个特征(所在城市、年龄等),所以每一个用户也就是每个table拥有一个数值,所以当bs=2时,这里的tensor[0,3]代表对两个用户进行训练,其中第一个用户的第一个table取值是0号vector,第二个用户第一个table取值是3号vector。

上一篇下一篇

猜你喜欢

热点阅读