JAX - что это такое и с чем его едят?
В последнее время много новинок от Google и DeepMind выходит на JAX, вместо привычного PyTorch или TF.
JAX - это новая библиотека в мире машинного обучения (ML), которая обещает сделать программирование ML более интуитивным, структурированным и чистым.
Основная и единственная цель JAX - выполнение числовых операций в высокопроизводительной форме. Это означает, что синтаксис практически идентичен Numpy.
Одним из главных преимуществ JAX является то, что мы можем запускать одну и ту же программу без каких-либо изменений на аппаратных ускорителях, таких как GPU и TPU.
Другой важный момент - это скорость. JAX быстрее. Намного быстрее. Например перемножение двух матриц (1000,1000) в NumPy занимает ~50ms, а в JAX ~1.5ms (на GPU).
В библиотеку встроен автоград. JAX способен дифференцировать всевозможные функции python и NumPy, включая циклы, ветвления, рекурсии и многое другое.
Факторы, делающие JAX таким быстрым:
* ускоренная линейная алгебра (Accelerated Linear Algebra или XLA).
* Just in time compilation (jit) - способ выполнения компьютерного кода, который предполагает компиляцию программы - во время выполнения - а не перед выполнением.
* Репликация вычислений между устройствами с помощью pmap - еще одно преобразование, которое позволяет нам реплицировать вычисления на несколько ядер или устройств и выполнять их параллельно (p в pmap означает parallel).
И ещё много различных трюков и улучшений.
Ещё одной особенностью JAX (и возможно даже более важной чем скорость) является Pseudo-Random number generator. В отличие от NumPy или PyTorch, в JAX состояния случайности должны быть поданы пользователем в качестве аргумента (что делает JAX по умолчанию намного более воспроизводимым).
Ещё больше деталей и примеров кода
Официальный GitHub
#gpu #code #jax
В последнее время много новинок от Google и DeepMind выходит на JAX, вместо привычного PyTorch или TF.
JAX - это новая библиотека в мире машинного обучения (ML), которая обещает сделать программирование ML более интуитивным, структурированным и чистым.
Основная и единственная цель JAX - выполнение числовых операций в высокопроизводительной форме. Это означает, что синтаксис практически идентичен Numpy.
Одним из главных преимуществ JAX является то, что мы можем запускать одну и ту же программу без каких-либо изменений на аппаратных ускорителях, таких как GPU и TPU.
Другой важный момент - это скорость. JAX быстрее. Намного быстрее. Например перемножение двух матриц (1000,1000) в NumPy занимает ~50ms, а в JAX ~1.5ms (на GPU).
В библиотеку встроен автоград. JAX способен дифференцировать всевозможные функции python и NumPy, включая циклы, ветвления, рекурсии и многое другое.
Факторы, делающие JAX таким быстрым:
* ускоренная линейная алгебра (Accelerated Linear Algebra или XLA).
* Just in time compilation (jit) - способ выполнения компьютерного кода, который предполагает компиляцию программы - во время выполнения - а не перед выполнением.
* Репликация вычислений между устройствами с помощью pmap - еще одно преобразование, которое позволяет нам реплицировать вычисления на несколько ядер или устройств и выполнять их параллельно (p в pmap означает parallel).
И ещё много различных трюков и улучшений.
Ещё одной особенностью JAX (и возможно даже более важной чем скорость) является Pseudo-Random number generator. В отличие от NumPy или PyTorch, в JAX состояния случайности должны быть поданы пользователем в качестве аргумента (что делает JAX по умолчанию намного более воспроизводимым).
Ещё больше деталей и примеров кода
Официальный GitHub
#gpu #code #jax
OpenXLA Project