Source code for paddlespeech.s2t.modules.crf

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn

from paddlespeech.s2t.utils.log import Log

logger = Log(__name__).getlog()

__all__ = ['CRF']


[docs]class CRF(nn.Layer): """ Linear-chain Conditional Random Field (CRF). Args: nb_labels (int): number of labels in your tagset, including special symbols. bos_tag_id (int): integer representing the beginning of sentence symbol in your tagset. eos_tag_id (int): integer representing the end of sentence symbol in your tagset. pad_tag_id (int, optional): integer representing the pad symbol in your tagset. If None, the model will treat the PAD as a normal tag. Otherwise, the model will apply constraints for PAD transitions. batch_first (bool): Whether the first dimension represents the batch dimension. """ def __init__(self, nb_labels: int, bos_tag_id: int, eos_tag_id: int, pad_tag_id: int=None, batch_first: bool=True): super().__init__() self.nb_labels = nb_labels self.BOS_TAG_ID = bos_tag_id self.EOS_TAG_ID = eos_tag_id self.PAD_TAG_ID = pad_tag_id self.batch_first = batch_first # initialize transitions from a random uniform distribution between -0.1 and 0.1 self.transitions = self.create_parameter( [self.nb_labels, self.nb_labels], default_initializer=nn.initializer.Uniform(-0.1, 0.1)) self.init_weights()
[docs] def init_weights(self): # enforce contraints (rows=from, columns=to) with a big negative number # so exp(-10000) will tend to zero # no transitions allowed to the beginning of sentence self.transitions[:, self.BOS_TAG_ID] = -10000.0 # no transition alloed from the end of sentence self.transitions[self.EOS_TAG_ID, :] = -10000.0 if self.PAD_TAG_ID is not None: # no transitions from padding self.transitions[self.PAD_TAG_ID, :] = -10000.0 # no transitions to padding self.transitions[:, self.PAD_TAG_ID] = -10000.0 # except if the end of sentence is reached # or we are already in a pad position self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0 self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
[docs] def forward(self, emissions: paddle.Tensor, tags: paddle.Tensor, mask: paddle.Tensor=None) -> paddle.Tensor: """Compute the negative log-likelihood. See `log_likelihood` method.""" nll = -self.log_likelihood(emissions, tags, mask=mask) return nll
[docs] def log_likelihood(self, emissions, tags, mask=None): """Compute the probability of a sequence of tags given a sequence of emissions scores. Args: emissions (paddle.Tensor): Sequence of emissions for each label. Shape of (batch_size, seq_len, nb_labels) if batch_first is True, (seq_len, batch_size, nb_labels) otherwise. tags (paddle.LongTensor): Sequence of labels. Shape of (batch_size, seq_len) if batch_first is True, (seq_len, batch_size) otherwise. mask (paddle.FloatTensor, optional): Tensor representing valid positions. If None, all positions are considered valid. Shape of (batch_size, seq_len) if batch_first is True, (seq_len, batch_size) otherwise. Returns: paddle.Tensor: sum of the log-likelihoods for each sequence in the batch. Shape of () """ # fix tensors order by setting batch as the first dimension if not self.batch_first: emissions = emissions.transpose(0, 1) tags = tags.transpose(0, 1) if mask is None: mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) scores = self._compute_scores(emissions, tags, mask=mask) partition = self._compute_log_partition(emissions, mask=mask) return paddle.sum(scores - partition)
[docs] def decode(self, emissions, mask=None): """Find the most probable sequence of labels given the emissions using the Viterbi algorithm. Args: emissions (paddle.Tensor): Sequence of emissions for each label. Shape (batch_size, seq_len, nb_labels) if batch_first is True, (seq_len, batch_size, nb_labels) otherwise. mask (paddle.FloatTensor, optional): Tensor representing valid positions. If None, all positions are considered valid. Shape (batch_size, seq_len) if batch_first is True, (seq_len, batch_size) otherwise. Returns: paddle.Tensor: the viterbi score for the for each batch. Shape of (batch_size,) list of lists: the best viterbi sequence of labels for each batch. [B, T] """ # fix tensors order by setting batch as the first dimension if not self.batch_first: emissions = emissions.transpose(0, 1) tags = tags.transpose(0, 1) if mask is None: mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) scores, sequences = self._viterbi_decode(emissions, mask) return scores, sequences
def _compute_scores(self, emissions, tags, mask): """Compute the scores for a given batch of emissions with their tags. Args: emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) tags (Paddle.LongTensor): (batch_size, seq_len) mask (Paddle.FloatTensor): (batch_size, seq_len) Returns: paddle.Tensor: Scores for each batch. Shape of (batch_size,) """ batch_size, seq_length = tags.shape scores = paddle.zeros([batch_size]) # save first and last tags to be used later first_tags = tags[:, 0] last_valid_idx = mask.int().sum(1) - 1 # TODO(Hui Zhang): not support fancy index. # last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze() batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype) gather_last_valid_idx = paddle.stack( [batch_idx, last_valid_idx], axis=-1) last_tags = tags.gather_nd(gather_last_valid_idx) # add the transition from BOS to the first tags for each batch # t_scores = self.transitions[self.BOS_TAG_ID, first_tags] t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags) # add the [unary] emission scores for the first tags for each batch # for all batches, the first word, see the correspondent emissions # for the first tags (which is a list of ids): # emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]] # e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze() gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1) e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx) # the scores for a word is just the sum of both scores scores += e_scores + t_scores # now lets do this for each remaining word for i in range(1, seq_length): # we could: iterate over batches, check if we reached a mask symbol # and stop the iteration, but vecotrizing is faster due to gpu, # so instead we perform an element-wise multiplication is_valid = mask[:, i] previous_tags = tags[:, i - 1] current_tags = tags[:, i] # calculate emission and transition scores as we did before # e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() gather_current_tags_idx = paddle.stack( [batch_idx, current_tags], axis=-1) e_scores = emissions[:, i].gather_nd(gather_current_tags_idx) # t_scores = self.transitions[previous_tags, current_tags] gather_transitions_idx = paddle.stack( [previous_tags, current_tags], axis=-1) t_scores = self.transitions.gather_nd(gather_transitions_idx) # apply the mask e_scores = e_scores * is_valid t_scores = t_scores * is_valid scores += e_scores + t_scores # add the transition from the end tag to the EOS tag for each batch # scores += self.transitions[last_tags, self.EOS_TAG_ID] scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID] return scores def _compute_log_partition(self, emissions, mask): """Compute the partition function in log-space using the forward-algorithm. Args: emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) mask (Paddle.FloatTensor): (batch_size, seq_len) Returns: paddle.Tensor: the partition scores for each batch. Shape of (batch_size,) """ batch_size, seq_length, nb_labels = emissions.shape # in the first iteration, BOS will have all the scores alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( 0) + emissions[:, 0] for i in range(1, seq_length): # (bs, nb_labels) -> (bs, 1, nb_labels) e_scores = emissions[:, i].unsqueeze(1) # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) t_scores = self.transitions.unsqueeze(0) # (bs, nb_labels) -> (bs, nb_labels, 1) a_scores = alphas.unsqueeze(2) scores = e_scores + t_scores + a_scores new_alphas = paddle.logsumexp(scores, axis=1) # set alphas if the mask is valid, otherwise keep the current values is_valid = mask[:, i].unsqueeze(-1) alphas = is_valid * new_alphas + (1 - is_valid) * alphas # add the scores for the final transition last_transition = self.transitions[:, self.EOS_TAG_ID] end_scores = alphas + last_transition.unsqueeze(0) # return a *log* of sums of exps return paddle.logsumexp(end_scores, axis=1) def _viterbi_decode(self, emissions, mask): """Compute the viterbi algorithm to find the most probable sequence of labels given a sequence of emissions. Args: emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) mask (Paddle.FloatTensor): (batch_size, seq_len) Returns: paddle.Tensor: the viterbi score for the for each batch. Shape of (batch_size,) list of lists of ints: the best viterbi sequence of labels for each batch """ batch_size, seq_length, nb_labels = emissions.shape # in the first iteration, BOS will have all the scores and then, the max alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( 0) + emissions[:, 0] backpointers = [] for i in range(1, seq_length): # (bs, nb_labels) -> (bs, 1, nb_labels) e_scores = emissions[:, i].unsqueeze(1) # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) t_scores = self.transitions.unsqueeze(0) # (bs, nb_labels) -> (bs, nb_labels, 1) a_scores = alphas.unsqueeze(2) # combine current scores with previous alphas scores = e_scores + t_scores + a_scores # so far is exactly like the forward algorithm, # but now, instead of calculating the logsumexp, # we will find the highest score and the tag associated with it # max_scores, max_score_tags = paddle.max(scores, axis=1) max_scores = paddle.max(scores, axis=1) max_score_tags = paddle.argmax(scores, axis=1) # set alphas if the mask is valid, otherwise keep the current values is_valid = mask[:, i].unsqueeze(-1) alphas = is_valid * max_scores + (1 - is_valid) * alphas # add the max_score_tags for our list of backpointers # max_scores has shape (batch_size, nb_labels) so we transpose it to # be compatible with our previous loopy version of viterbi backpointers.append(max_score_tags.t()) # add the scores for the final transition last_transition = self.transitions[:, self.EOS_TAG_ID] end_scores = alphas + last_transition.unsqueeze(0) # get the final most probable score and the final most probable tag # max_final_scores, max_final_tags = paddle.max(end_scores, axis=1) max_final_scores = paddle.max(end_scores, axis=1) max_final_tags = paddle.argmax(end_scores, axis=1) # find the best sequence of labels for each sample in the batch best_sequences = [] emission_lengths = mask.int().sum(axis=1) for i in range(batch_size): # recover the original sentence length for the i-th sample in the batch sample_length = emission_lengths[i].item() # recover the max tag for the last timestep sample_final_tag = max_final_tags[i].item() # limit the backpointers until the last but one # since the last corresponds to the sample_final_tag sample_backpointers = backpointers[:sample_length - 1] # follow the backpointers to build the sequence of labels sample_path = self._find_best_path(i, sample_final_tag, sample_backpointers) # add this path to the list of best sequences best_sequences.append(sample_path) return max_final_scores, best_sequences def _find_best_path(self, sample_id, best_tag, backpointers): """Auxiliary function to find the best path sequence for a specific sample. Args: sample_id (int): sample index in the range [0, batch_size) best_tag (int): tag which maximizes the final score backpointers (list of lists of tensors): list of pointers with shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i represents the length of the ith sample in the batch Returns: list of ints: a list of tag indexes representing the bast path """ # add the final best_tag to our best path best_path = [best_tag] # traverse the backpointers in backwards for backpointers_t in reversed(backpointers): # recover the best_tag at this timestep best_tag = backpointers_t[best_tag][sample_id].item() # append to the beginning of the list so we don't need to reverse it later best_path.insert(0, best_tag) return best_path