Source code for paddlespeech.text.exps.ernie_linear.test

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse

import numpy as np
import paddle
import pandas as pd
import yaml
from paddle import nn
from paddle.io import DataLoader
from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_fscore_support
from yacs.config import CfgNode

from paddlespeech.t2s.utils import str2bool
from paddlespeech.text.models.ernie_linear import ErnieLinear
from paddlespeech.text.models.ernie_linear import PuncDataset
from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer

DefinedClassifier = {
    'ErnieLinear': ErnieLinear,
}

DefinedLoss = {
    "ce": nn.CrossEntropyLoss,
}

DefinedDataset = {
    'Punc': PuncDataset,
    'Ernie': PuncDatasetFromErnieTokenizer,
}


[docs]def evaluation(y_pred, y_test): precision, recall, f1, _ = precision_recall_fscore_support( y_test, y_pred, average=None, labels=[1, 2, 3]) overall = precision_recall_fscore_support( y_test, y_pred, average='macro', labels=[1, 2, 3]) result = pd.DataFrame( np.array([precision, recall, f1]), columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:], index=['Precision', 'Recall', 'F1']) result['OVERALL'] = overall[:3] return result
[docs]def test(args): with open(args.config) as f: config = CfgNode(yaml.safe_load(f)) print("========Args========") print(yaml.safe_dump(vars(args))) print("========Config========") print(config) test_dataset = DefinedDataset[config["dataset_type"]]( train_path=config["test_path"], **config["data_params"]) test_loader = DataLoader( test_dataset, batch_size=config.batch_size, shuffle=False, drop_last=False) model = DefinedClassifier[config["model_type"]](**config["model"]) state_dict = paddle.load(args.checkpoint) model.set_state_dict(state_dict["main_params"]) model.eval() punc_list = [] for i in range(len(test_loader.dataset.id2punc)): punc_list.append(test_loader.dataset.id2punc[i]) test_total_label = [] test_total_predict = [] for i, batch in enumerate(test_loader): input, label = batch label = paddle.reshape(label, shape=[-1]) y, logit = model(input) pred = paddle.argmax(logit, axis=1) test_total_label.extend(label.numpy().tolist()) test_total_predict.extend(pred.numpy().tolist()) t = classification_report( test_total_label, test_total_predict, target_names=punc_list) print(t) if args.print_eval: t2 = evaluation(test_total_label, test_total_predict) print('=========================================================') print(t2)
[docs]def main(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser(description="Test a ErnieLinear model.") parser.add_argument("--config", type=str, help="ErnieLinear config file.") parser.add_argument("--checkpoint", type=str, help="snapshot to load.") parser.add_argument("--print_eval", type=str2bool, default=True) parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") args = parser.parse_args() if args.ngpu == 0: paddle.set_device("cpu") elif args.ngpu > 0: paddle.set_device("gpu") else: print("ngpu should >= 0 !") test(args)
if __name__ == "__main__": main()