Термин "прунинг" пришел к нам из нейробиологи. Это процесс сокращения числа синапсов или нейронов для повышения эффективности нашего мозга или, другими словами, удаление избыточных связей. Прунинг нейросетей построен на той же идее: мы хотим сжать сеть за счет удаления части параметров предобученной модели, тем самым уменьшив расход памяти и вычислительной сложности.
Методы прунинга делятся на структурированные (Structured), неструктурированные (Unstructured) и полу-структурированные (Semi-Structured). Структурированный прунинг удаляет целые конструкции: например, головы внимания или слои. Однако этот подход обязывает прикладывать большие усилия для последующего дообучения, так как после удаления целых блоков архитектуры модель естественным образом сильно теряет в точности. Неструктурированный прунинг обнуляет только отдельные веса. При этом модель не сильно теряет в производительности, но есть нюанс. Этот метод очень непредсказумый, и получается, что на выходе мы имеем нерегулярные шаблоны разреженности (это называется sparse patterns). Это в свою очередь ведет к тому, что метод сложно оптимизировать с точки зрения ускорения вычислений на железе, потому что произвольно разреженные матрицы требуют сложных операций индексации и хранения.
Semi-Structured старается взять лучшее из двух методов выше так, чтобы прунинг получился и эффективным с точки зрения компьюта, и перформансу исходной модели не сильно навредил. Такой прунинг основывается на "аппаратно-дружественных" шаблонах разреживания, таких как разреженность N:M (из M значений оставляем только N ненулевых). Тем самым такой подход гармонизирует аппаратную бережливость Structured и гибкость Unstructured. Однако и тут не обходится без определенных сложностей:
- Во-первых, для такого прунинга нам все равно нужно уметь оценивать, какие веса (или головы внимания, или слои) для модели наименее важны, и удалять именно их. Для этого используется так называемая калибровочная выборка. Проблема заключается в том, что такая выборка обычно слишком мала для того, чтобы полноценно оценить большую модель. Более того, масштабирумость тут работает плохо: это значит, что увеличение этого калибровочного сета не факт, что приведет к лучшим результатам.
- Во-вторых, критерии, по которым мы отбираем кандидатов на удаление во время калибровки (это обычно информация о градиентах), тоже не слишком надежные. Эксперименты показывают, что есть огромный разрыв между ожидаемым эффектом от удаления части сети и реальным, и возникает он именно из-за неправильно выбранных критериев оценки.
Для решения этих проблем авторы из Nvidia на днях предложили MaskLLM — обучаемый метод Semi-Structured прунинга. MaskLLM автоматизирует выбор маски, по которой модель будет прунится, и тем самым почти полностью нивелирует перечисленные проблемы. Выбор маски – это как раз тот самый выбор паттерна N:M, по которому будут обнуляться веса. В классическом прунинге этот выбор осуществляется из конечного дискретного множества, поэтому получается, что этот процесс недифференцируемый, и оптимизировать его с помощью обратного распространения ошибки нельзя. Но ученые из Nvidia придумали, как это обойти.
Вместо того чтобы выбирать маски "жестко", MaskLLM рассматривает выбор маски как вероятностный процесс. Это означает, что задача превращается в присвоение каждой маске некоторой вероятности, а такое присвоение уже можно обучать градиентным спуском. Для этого авторы используют Gumbel Softmax – сглаживание, которое как раз и параметризует "случайность" выбора в независимую переменную. А чтобы еще больше ускорить процесс и сделать его масштабируемым, авторы предлагают использовать априорные маски. Это заранее вычисленные маски, которые могут служить теплым стартом для дальнейшего обучения.
Разработчики тестировании метод на нескольких LLM, включая LLaMA-2 7B, LLaMA-2 13B и Nemotron-4 15B. Результаты тестов – ниже. Как видите, показывает себя метод очень неплохо. Например, по сравнению со SparseGPT, который переоформит перплексию 10,42 на прунинге LLaMA-2 7B, MaskLLM улучшает PPL до 6,72 без какого-либо дообучения самой LLM.
Но основной вклад работы не в метриках (хотя они впечатляющие), а именно в том, что ученые предложили метод обучаемого прунинга. К тому же, MaskLLM бережет железо, хорошо подходит для использования крупных калибровочных выборок и хорошо масштабируется на другие домены.
Полный текст статьи можно найти здесь.