Reducing Transformer Key-Value Cache Size with Cross-Layer Attention
KV Cache поистине прекрасная идея, которая уже успела повлиять на нашу область. однако сколько же по памяти занимает это чудо?
сколько же нужно элемент на один токен для каждого слоя? эмбеддинг размерность х количество голов как для К так и для V ⇒ получаем 2 * num_heads * embedding_dim на *каждый* слой.
но уже относительно давно известны методы по группировке модулей (Grouped-Query Attention, GQA), которые обрабатывают запросы. и каждый из модулей внутри этой группы разделяет общие модули но ключам/значениям ⇒ получаем снижение 2 * num_groups * embedding_dim на *каждый* слой
авторы из MIT решили пойти дальше и уже снизить потребление по памяти не внутри одного слоя, а объединив модули между слоями. потому и называется статься cross-layer KV Cache
что же придумали? да все просто - давайте группировать слои так, чтобы между ними были KV значения только из одного слоя внутри этой группы: и группы составлять через каждые 2, 3… N слоев. просто? - просто. сокращает память и работает! ( в силу тех ресурсов которые были у ресерчеров ) + оставляет возможность для совместимости с другими модификациями как GQA & не ставит никаких преград для параллелизации
правда есть вопросы по тому поводу, что на разных слоях происходят проекции по разным семантическим пространствам, что может не очень хорошо сказываться, когда запросы одной “природы”, а ключи/значения на инференсе другой “природы”
энивей, на скейле моделек 1В и 3В видится заметное снижение по памяти с приемлемым снижением качества (смотря какая задача). но я бы спекулировал, что на моделях большей размерности из-за гетерогенности Q vs KV перформанс будет заметно хуже
👀LINK
KV Cache поистине прекрасная идея, которая уже успела повлиять на нашу область. однако сколько же по памяти занимает это чудо?
сколько же нужно элемент на один токен для каждого слоя? эмбеддинг размерность х количество голов как для К так и для V ⇒ получаем 2 * num_heads * embedding_dim на *каждый* слой.
но уже относительно давно известны методы по группировке модулей (Grouped-Query Attention, GQA), которые обрабатывают запросы. и каждый из модулей внутри этой группы разделяет общие модули но ключам/значениям ⇒ получаем снижение 2 * num_groups * embedding_dim на *каждый* слой
авторы из MIT решили пойти дальше и уже снизить потребление по памяти не внутри одного слоя, а объединив модули между слоями. потому и называется статься cross-layer KV Cache
что же придумали? да все просто - давайте группировать слои так, чтобы между ними были KV значения только из одного слоя внутри этой группы: и группы составлять через каждые 2, 3… N слоев. просто? - просто. сокращает память и работает! ( в силу тех ресурсов которые были у ресерчеров ) + оставляет возможность для совместимости с другими модификациями как GQA & не ставит никаких преград для параллелизации
правда есть вопросы по тому поводу, что на разных слоях происходят проекции по разным семантическим пространствам, что может не очень хорошо сказываться, когда запросы одной “природы”, а ключи/значения на инференсе другой “природы”
энивей, на скейле моделек 1В и 3В видится заметное снижение по памяти с приемлемым снижением качества (смотря какая задача). но я бы спекулировал, что на моделях большей размерности из-за гетерогенности Q vs KV перформанс будет заметно хуже
👀LINK
Surgical Robot Transformer (SRT): Imitation Learning for Surgical Tasks
заимплементить трансформер под хирургического робота так, чтобы он мог зашивать раны и не только? да!
авторы из стенфорда решили такое сделать, и получилось очень даже круто. при том они не используют никакие данные из кинематики, а только картиночные инпуты (которые даже сильно даунсемплят до 224х224х3, но результат все равно очень даже крутой). и это довольно нетривиально для такого рода работ
реализуют на основе имитейшн лернинга и Action-Chunking Transformer, который группирует небольшой чанк действий в одну группу и третирует их как один юнит. сделано это для того, чтобы нивелировать момент накопления ошибок (который довольно часто случается во время инференса при сетапе имитейшн лернинга)
еще есть залипательные видосы с демонстрацией работы трансформера
ждем теперь, когда такое будет возможно делать в более-менее сносном темпе
👀LINK
заимплементить трансформер под хирургического робота так, чтобы он мог зашивать раны и не только? да!
авторы из стенфорда решили такое сделать, и получилось очень даже круто. при том они не используют никакие данные из кинематики, а только картиночные инпуты (которые даже сильно даунсемплят до 224х224х3, но результат все равно очень даже крутой). и это довольно нетривиально для такого рода работ
реализуют на основе имитейшн лернинга и Action-Chunking Transformer, который группирует небольшой чанк действий в одну группу и третирует их как один юнит. сделано это для того, чтобы нивелировать момент накопления ошибок (который довольно часто случается во время инференса при сетапе имитейшн лернинга)
еще есть залипательные видосы с демонстрацией работы трансформера
ждем теперь, когда такое будет возможно делать в более-менее сносном темпе
👀LINK
Efficient World Models with Context-Aware Tokenization
какое-то время назад понятие “модель мира” начало всплывать в контексте ЛЛМ, в то время как свои корни оно имеет из рл. так вот с ней в рл тоже часто не всегда все понятно
что же такое модель мира в сетапе рл? есть ответвление, как model-based алгоритмы, которые помимо основной модели полиси (которая притворяет действия по данным ею состояниям) учат еще и модель среды, или же динамику среды, которая ловит паттерны по истории и предиктит, а что же произойдет на следующем таймстепе. зачем это нужно? при достаточно качественной модели динамики, можно получать данные при помощи нее, не действуя в среде ⇒ экономия времени, ресурсов, да и безопасность повышается, если ошибки на реальной задаче очень как критичны
так же такую штуку называют learning in imagination. так вот авторы берут один из сота методов, iris, который работает через дискретный автоэнкодер + каузальный трансформер для предикта динамики по заданному контексту. и улучшают его! а иначе никак
а что можно улучшить? ирис перед этапом энкодинга разбивает историю на отдельные токены, которые затем подаются в трансформер. а такая отдельность не всегда оптимальна, поскольку, как говорят авторы и что интуитивно понятно, независимое разбиение на токены не всегда оптимально, когда не берется контекст прошлого, ибо выгоднее порой обращать внимание не на само состояние, сколько на *разницу,* которая успела произойти за это время
вот авторы такое и имплементят, просто вставляя на этапе энкодинга и декодинга кусок предыдущей истории. дальше над этим оперирует трансформер, который является моделью мира для актор-критика. в остальном сетап довольно классичен. и происходит норм такой буст по результатам
правда проевалили пока только на атари и крафтере, а что в том же майнкрафте происходит непонятно (где очень крут DreamerV3) - там наверняка надо брать горизонт побольше из-за специфики движений и состояний стива, что уже будет вызывать трудности у метода имхо.
👀LINK
UPD: надо еще учитывать что майнкрафт сильно упрощен на дримере, так что все может быть ок с дельта-ирис
какое-то время назад понятие “модель мира” начало всплывать в контексте ЛЛМ, в то время как свои корни оно имеет из рл. так вот с ней в рл тоже часто не всегда все понятно
что же такое модель мира в сетапе рл? есть ответвление, как model-based алгоритмы, которые помимо основной модели полиси (которая притворяет действия по данным ею состояниям) учат еще и модель среды, или же динамику среды, которая ловит паттерны по истории и предиктит, а что же произойдет на следующем таймстепе. зачем это нужно? при достаточно качественной модели динамики, можно получать данные при помощи нее, не действуя в среде ⇒ экономия времени, ресурсов, да и безопасность повышается, если ошибки на реальной задаче очень как критичны
так же такую штуку называют learning in imagination. так вот авторы берут один из сота методов, iris, который работает через дискретный автоэнкодер + каузальный трансформер для предикта динамики по заданному контексту. и улучшают его! а иначе никак
а что можно улучшить? ирис перед этапом энкодинга разбивает историю на отдельные токены, которые затем подаются в трансформер. а такая отдельность не всегда оптимальна, поскольку, как говорят авторы и что интуитивно понятно, независимое разбиение на токены не всегда оптимально, когда не берется контекст прошлого, ибо выгоднее порой обращать внимание не на само состояние, сколько на *разницу,* которая успела произойти за это время
вот авторы такое и имплементят, просто вставляя на этапе энкодинга и декодинга кусок предыдущей истории. дальше над этим оперирует трансформер, который является моделью мира для актор-критика. в остальном сетап довольно классичен. и происходит норм такой буст по результатам
правда проевалили пока только на атари и крафтере, а что в том же майнкрафте происходит непонятно (где очень крут DreamerV3) - там наверняка надо брать горизонт побольше из-за специфики движений и состояний стива, что уже будет вызывать трудности у метода имхо.
👀LINK
UPD: надо еще учитывать что майнкрафт сильно упрощен на дримере, так что все может быть ок с дельта-ирис
🔥4❤1 1 1
Diffusion for World Modeling: Visual Details Matter in Atari
в продолжение темы про модели мира
в основном используют дискретные автоэнкодеры, ибо с дискретными латентами снижается проблема накопительной ошибки при процессинге данных. но в более сложных задачах дискретизация может быть чревата слишком сильной потерей инфрмации, отчего все плоховато. в принципе можно тогда увеличивать количество дискретных эмбеддингов, но и повышаются требования на компьют. а какая есть альтернатива - диффузия🤑🤑🤑
и в качестве решения используют score-based diffusion, а именно EDM (есть так же сравнение результатов с обычным DDPM), где таргет адаптивно миксует signal-to-noise ratio в соотношении с нойз шедулингом. в качестве чистого таргета выступает фиксированная последовательность из предыдущих интеракций. а интуитивно такие трюки с адаптивным таргетом нужны для того, чтобы аутпуты модели оставались вариативными (в силу сложности задачи) когда шума мало.
правда авторы так же указывают, что для предикта реварда и флага терминации используются отдельные модели → диффузия не оч хороша в моделировании таких скаляров вместе с динамикой по состояниям, либо авторы не смогли нормально это прикрутить
по mean human normalized score обыгрываем все методы, на одну сотую только проигрывает по interquantile mean методу выше
👀LINK
в продолжение темы про модели мира
в основном используют дискретные автоэнкодеры, ибо с дискретными латентами снижается проблема накопительной ошибки при процессинге данных. но в более сложных задачах дискретизация может быть чревата слишком сильной потерей инфрмации, отчего все плоховато. в принципе можно тогда увеличивать количество дискретных эмбеддингов, но и повышаются требования на компьют. а какая есть альтернатива - диффузия🤑🤑🤑
и в качестве решения используют score-based diffusion, а именно EDM (есть так же сравнение результатов с обычным DDPM), где таргет адаптивно миксует signal-to-noise ratio в соотношении с нойз шедулингом. в качестве чистого таргета выступает фиксированная последовательность из предыдущих интеракций. а интуитивно такие трюки с адаптивным таргетом нужны для того, чтобы аутпуты модели оставались вариативными (в силу сложности задачи) когда шума мало.
правда авторы так же указывают, что для предикта реварда и флага терминации используются отдельные модели → диффузия не оч хороша в моделировании таких скаляров вместе с динамикой по состояниям, либо авторы не смогли нормально это прикрутить
по mean human normalized score обыгрываем все методы, на одну сотую только проигрывает по interquantile mean методу выше
👀LINK
👍5❤1🔥1 1
SLiC-HF: Sequence Likelihood Calibration with Human Feedback
есть очень классная статья с простой идей. SLiC помогает улучшить генерации ЛЛМ посредством калибровки своих же аутпутов с таргет последовательностями. и авторы этой статьи заметили, что такое легко можно переложить и на алаймент
в качестве задачи взяли суммаризацию на реддите
что же они сделали? добавили калибровочный лосс на перевес лайклихуда позитивного семлпа над негативным с некоторым марджином + регуляризацию на повышение лайклихуда таргетов из СФТ датасета
представили 2 способа калибровки (выбора позитивного/негативного семпла) - с обучением ранкинг/ревард модели или напрямую через преференции, данные в датасете. но как вы можете понять, последнее составляет основной контрибьюшн
правда не очень понятно почему решили сравниться с тем, чтобы продолжать файнтюн моделей на правильных ответах без преференций. да, там есть разные варианты подбора таргета, которые использовали ранкинг модели, но по этой работе теперь мы понимаем, что это не оч + было бы неплохо в принципе сделать рлхф-ппо эксперименты, а не просто предоставлять сравнительную таблицу (чтобы подкреплять свои доводы эмпирикой)
👀LINK
есть очень классная статья с простой идей. SLiC помогает улучшить генерации ЛЛМ посредством калибровки своих же аутпутов с таргет последовательностями. и авторы этой статьи заметили, что такое легко можно переложить и на алаймент
в качестве задачи взяли суммаризацию на реддите
что же они сделали? добавили калибровочный лосс на перевес лайклихуда позитивного семлпа над негативным с некоторым марджином + регуляризацию на повышение лайклихуда таргетов из СФТ датасета
представили 2 способа калибровки (выбора позитивного/негативного семпла) - с обучением ранкинг/ревард модели или напрямую через преференции, данные в датасете. но как вы можете понять, последнее составляет основной контрибьюшн
правда не очень понятно почему решили сравниться с тем, чтобы продолжать файнтюн моделей на правильных ответах без преференций. да, там есть разные варианты подбора таргета, которые использовали ранкинг модели, но по этой работе теперь мы понимаем, что это не оч + было бы неплохо в принципе сделать рлхф-ппо эксперименты, а не просто предоставлять сравнительную таблицу (чтобы подкреплять свои доводы эмпирикой)
👀LINK
👍3 1 1
A General Theoretical Paradigm to Understand Learning from Human Preferences
как-то пропустили от дипмаинда папиру насчет теоретической формалзации рлхф + дпо
на чем строятся эти основные фреймворки в алайменте?
1) из датасета пар преференций можно составить поточечную ревард функцию (Bradley-Terry model )
2) эту ревард функцию можно аппроксимировать какой-то нейронкой так, что на ООД семплах она все равно будет выдавать адекватные результаты
у рлхф 2 предположения, в то время как дпо отказывается от второго, но сохраняет 1
и авторы обобщают эти две концепции, называя более общий метод ΨPO (хз почему именно так назвали) с введением через неубывающую нелинейную функцию
но это еще не все - они пробуют ввести линейный маппинг в эту обобщенную концепцию и получают IPO. а что значит линейный маппинг, когда вроде вот ета h функция все равно нелинейная?
я не понял до конца 😐. но следствия такие, что обучение происходит напрямую на преференциях, при том отличие от дпо в том, что отсутствует предположение 1 чтобы не произошел оверфит попы (хотя тут все равно остаются вопросы оптимален ли практический семплинг чтобы предотвратить оверфиттинг в реализации)
на бандитах работает неплохо, но в рл зачастую это еще ничего не значит. ждем что будет на большом скейле и на реальной задаче языкового моделирования
👀LINK
как-то пропустили от дипмаинда папиру насчет теоретической формалзации рлхф + дпо
на чем строятся эти основные фреймворки в алайменте?
1) из датасета пар преференций можно составить поточечную ревард функцию (Bradley-Terry model )
2) эту ревард функцию можно аппроксимировать какой-то нейронкой так, что на ООД семплах она все равно будет выдавать адекватные результаты
у рлхф 2 предположения, в то время как дпо отказывается от второго, но сохраняет 1
и авторы обобщают эти две концепции, называя более общий метод ΨPO (хз почему именно так назвали) с введением через неубывающую нелинейную функцию
но это еще не все - они пробуют ввести линейный маппинг в эту обобщенную концепцию и получают IPO. а что значит линейный маппинг, когда вроде вот ета h функция все равно нелинейная?
я не понял до конца 😐. но следствия такие, что обучение происходит напрямую на преференциях, при том отличие от дпо в том, что отсутствует предположение 1 чтобы не произошел оверфит попы (хотя тут все равно остаются вопросы оптимален ли практический семплинг чтобы предотвратить оверфиттинг в реализации)
на бандитах работает неплохо, но в рл зачастую это еще ничего не значит. ждем что будет на большом скейле и на реальной задаче языкового моделирования
👀LINK