MASKER: Masked Keyword Regularization for Reliable Text Classification
DOI:
https://doi.org/10.1609/aaai.v35i15.17601Keywords:
Text Classification & Sentiment AnalysisAbstract
Pre-trained language models have achieved state-of-the-art accuracies on various text classification tasks, e.g., sentiment analysis, natural language inference, and semantic textual similarity. However, the reliability of the fine-tuned text classifiers is an often underlooked performance criterion. For instance, one may desire a model that can detect out-of-distribution (OOD) samples (drawn far from training distribution) or be robust against domain shifts. We claim that one central obstacle to the reliability is the over-reliance of the model on a limited number of keywords, instead of looking at the whole context. In particular, we find that (a) OOD samples often contain in-distribution keywords, while (b) cross-domain samples may not always contain keywords; over-relying on the keywords can be problematic for both cases. In light of this observation, we propose a simple yet effective fine-tuning method, coined masked keyword regularization (MASKER), that facilitates context-based prediction. MASKER regularizes the model to reconstruct the keywords from the rest of the words and make low-confidence predictions without enough context. When applied to various pre-trained language models (e.g., BERT, RoBERTa, and ALBERT), we demonstrate that MASKER improves OOD detection and cross-domain generalization without degrading classification accuracy. Code is available at https://github.com/alinlab/MASKER.Downloads
Published
2021-05-18
How to Cite
Moon, S. J., Mo, S., Lee, K., Lee, J., & Shin, J. (2021). MASKER: Masked Keyword Regularization for Reliable Text Classification. Proceedings of the AAAI Conference on Artificial Intelligence, 35(15), 13578-13586. https://doi.org/10.1609/aaai.v35i15.17601
Issue
Section
AAAI Technical Track on Speech and Natural Language Processing II