seq3代码

2019-06-09  本文已影响0人  VanJordan
python models/sent_lm.py --config model_configs/camera/lm_prior.yaml 

sent_lm

dataset 数据加载

        args_str = ''.join(args_to_str(args))
        key = hashlib.md5(args_str.encode()).hexdigest()
        cache_file = os.path.join(cache_dir, key)
def disk_memoize(func):
    cache_dir = os.path.join(BASE_DIR, "_cache")
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    @functools.wraps(func)
    def wrapper_decorator(*args, **kwargs):
        # check fn arguments
        args_str = ''.join(args_to_str(args))
        key = hashlib.md5(args_str.encode()).hexdigest()
        cache_file = os.path.join(cache_dir, key)

        if os.path.exists(cache_file):
            print(f"Loading {cache_file} from cache!")
            with open(cache_file, 'rb') as f:
                return pickle.load(f)
        else:
            print(f"No cache file for {cache_file}...")
            data = func(*args, **kwargs)

            with open(cache_file, 'wb') as pickle_file:
                pickle.dump(data, pickle_file)

            return data

    return wrapper_decorator

# @disk_memoize
def read_corpus(file, tokenize):
    _vocab = Vocab()

    _data = []
    for line in iterate_data(file):
        tokens = tokenize(line)
        _vocab.read_sequence(tokens)
        _data.append(tokens)

    return _vocab, _data
from subprocess import check_output
def wc(filename):
    return int(check_output(["wc", "-l", filename]).split()[0])
def iterate_data(data):
    if isinstance(data, str):
        assert os.path.exists(data), f"path `{data}` does not exist!"
        with open(data, "r") as f:
            for line in tqdm(f, total=wc(data), desc=f"Reading {data}..."):
                if len(line.strip()) > 0:
                    yield line

    elif isinstance(data, collections.Iterable):
        for x in data:
            yield x
from tabulate import tabulate
return tabulate([[x[1] for x in props]], headers=[x[0] for x in props])
from tabulate import tabulate

props = [('ni','7889'),('taa','890988'),('you','safdkk')]
print(tabulate([[x[1] for x in props]], headers=[x[0] for x in props])) 
---------------------------------------输入----------------------------------
  ni     taa  you
----  ------  ------
7889  890988  safdkk

model 构建模型

>>> a = torch.randn(1, 2, 3, 4, 5)
>>> torch.numel(a)

model中的RNNModule

class RNNModule(nn.Module, RecurrentHelper):
    def __init__(self, input_size,
                 rnn_size,
                 num_layers=1,
                 bidirectional=False,
                 dropout=0.,
                 pack=True, last=False, countdown=False):
        """
        A simple RNN Encoder, which produces a fixed vector representation
        for a variable length sequence of feature vectors, using the output
        at the last timestep of the RNN.
        Args:
            input_size (int): the size of the input features
            rnn_size (int):
            num_layers (int):
            bidirectional (bool):
            dropout (float):
        """
        super(RNNModule, self).__init__()

        self.pack = pack
        self.last = last
        self.countdown = countdown

        if self.countdown:
            self.Wt = nn.Parameter(torch.rand(1))
            input_size += 1

        self.rnn = nn.LSTM(input_size=input_size,
                           hidden_size=rnn_size,
                           num_layers=num_layers,
                           bidirectional=bidirectional,
                           batch_first=True)

        # the dropout "layer" for the output of the RNN
        self.dropout = nn.Dropout(dropout)

        # define output feature size
        self.feature_size = rnn_size

        # double if bidirectional
        if bidirectional:
            self.feature_size *= 2

    @staticmethod
    def reorder_hidden(hidden, order):
        if isinstance(hidden, tuple):
            hidden = hidden[0][:, order, :], hidden[1][:, order, :]
        else:
            hidden = hidden[:, order, :]

        return hidden

    def forward(self, x, hidden=None, lengths=None):

        batch, max_length, feat_size = x.size()

        if lengths is not None and self.pack:

            ###############################################
            # sorting
            ###############################################
            lenghts_sorted, sorted_i = lengths.sort(descending=True)
            _, reverse_i = sorted_i.sort()

            x = x[sorted_i]

            if hidden is not None:
                hidden = self.reorder_hidden(hidden, sorted_i)

            ###############################################
            # forward
            ###############################################

            if self.countdown:
                ticks = length_countdown(lenghts_sorted).float() * self.Wt
                x = torch.cat([x, ticks.unsqueeze(-1)], -1)

            packed = pack_padded_sequence(x, lenghts_sorted, batch_first=True)

            self.rnn.flatten_parameters()
            out_packed, hidden = self.rnn(packed, hidden)

            out_unpacked, _lengths = pad_packed_sequence(out_packed,
                                                         batch_first=True,
                                                         total_length=max_length)

            out_unpacked = self.dropout(out_unpacked)

            ###############################################
            # un-sorting
            ###############################################
            outputs = out_unpacked[reverse_i]
            hidden = self.reorder_hidden(hidden, reverse_i)

        else:
            # todo: make hidden return the true last states
            self.rnn.flatten_parameters()
            outputs, hidden = self.rnn(x, hidden)
            outputs = self.dropout(outputs)

        if self.last:
            return outputs, hidden, self.last_timestep(outputs, lengths,
                                                       self.rnn.bidirectional)

        return outputs, hidden

SeqReader

loss_function = nn.CrossEntropyLoss(ignore_index=0)
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(parameters,
                       lr=config["lr"], weight_decay=config["weight_decay"])

trainer

best_loss = None
for epoch in range(config["epochs"]):
    train_loss = trainer.train_epoch()
    val_loss = trainer.eval_epoch()

    if config["scheduler"] == "plateau":
        scheduler.step(val_loss)

    elif config["scheduler"] == "cosine":
        scheduler.step()
    elif config["scheduler"] == "step":
        scheduler.step()

    exp.update_metric("lr", optimizer.param_groups[0]['lr'])

    exp.update_metric("ep_loss", train_loss, "TRAIN")
    exp.update_metric("ep_loss", val_loss, "VAL")
    exp.update_metric("ep_ppl", math.exp(train_loss), "TRAIN")
    exp.update_metric("ep_ppl", math.exp(val_loss), "VAL")

    print()
    epoch_log = exp.log_metrics(["ep_loss", "ep_ppl"])
    print(epoch_log)
    exp.update_value("epoch", epoch_log)

    # Save the model if the validation loss is the best we've seen so far.
    if not best_loss or val_loss < best_loss:
        best_loss = val_loss
        trainer.checkpoint()

    print("\n" * 2)

    exp.save()

class LMTrainer(Trainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _seq_loss(self, predictions, labels):
        _labels = labels.contiguous().view(-1)

        _logits = predictions[0]
        _logits = _logits.contiguous().view(-1, _logits.size(-1))
        loss = self.criterion(_logits, _labels)

        return loss

    def _process_batch(self, inputs, labels, lengths):
        predictions = self.model(inputs, None, lengths)

        loss = self._seq_loss(predictions, labels)
        del predictions
        predictions = None

        return loss, predictions

    def get_state(self):
        if self.train_loader.dataset.subword:
            _vocab = self.train_loader.dataset.subword_path
        else:
            _vocab = self.train_loader.dataset.vocab

        state = {
            "config": self.config,
            "epoch": self.epoch,
            "step": self.step,
            "model": self.model.state_dict(),
            "model_class": self.model.__class__.__name__,
            "optimizers": [x.state_dict() for x in self.optimizers],
            "vocab": _vocab,
        }

        return state

seq3

AEDataset

Model for seq3

def length_countdown(lengths):
    batch_size = lengths.size(0)
    max_length = max(lengths)
    desired_lengths = lengths - 1

    _range = torch.arange(0, -max_length, -1, device=lengths.device)
    _range = _range.repeat(batch_size, 1)
    _countdown = _range + desired_lengths.unsqueeze(-1)

    return _countdown

lengths = torch.LongTensor([1,2,3])
print(length_countdown(lengths))
---------------------------------------输出------------------------------------------
tensor([[ 0, -1, -2],
        [ 1,  0, -1],
        [ 2,  1,  0]])

train seq3.py

trainer seq3

topic loss

lm loss

_process_batch

batch = list(map(lambda x: x.to(self.device), batch))

关于前项正常的sample,后项使用概率近似的进行可微的反向传递

上一篇 下一篇

猜你喜欢

热点阅读