TRAM: Bridging Trust Regions and Sharpness Aware Minimization

Part of International Conference on Representation Learning 2024 (ICLR 2024) Conference

Bibtex Paper

Authors

Tom Sherborne, Naomi Saphra, Pradeep Dasigi, Hao Peng

Abstract

Sharpness-aware minimization (SAM) reports improving domain generalization byreducing the loss surface curvature in the parameter space. However,generalization during fine-tuning is often more dependent on thetransferability of representations in the function space. Trust-regionmethods (TR) target this goal by regularizing representation curvature to reducecatastrophic forgetting of pre-trained task-agnostic information while adoptingtask-specific skills. We consider unifying these strategies for low curvature inboth parameter space and function space to improve out-of-domain (OOD)generalization. We propose Trust Region Aware Minimization (TRAM), aSAM algorithm fine-tuning for low parameter sharpness and smooth, informativerepresentations preserving pre-trained structure. TRAM uses a trust region boundto inform the SAM adversarial neighborhood, introducing an awareness of functioncurvature within optimization for flatter minima. We empirically validate TRAMin vision (cross-dataset adaptation) and text (OOD language modeling, zero-shotcross-lingual transfer) tasks where robust domain transfer and representationgenerality are critical. TRAM outperforms SAM- and TR-based optimization acrossall tasks, notably surpassing competing methods for hard transfer betweenanticorrelated domains. TRAM establishes a novel standard infine-tuning for domain-generalizable models with minimal additional computationover previous sharpness-aware methods.