AI Для Всех
12.8K subscribers
1.17K photos
152 videos
10 files
1.37K links
Канал, в котором мы говорим про искусственный интеллект простыми словами

Главный редактор и по рекламе: @crimeacs

Иногда пишут в канал: @GingerSpacetail, @innovationitsme
Download Telegram
JAX - что это такое и с чем его едят?

В последнее время много новинок от Google и DeepMind выходит на JAX, вместо привычного PyTorch или TF.

JAX - это новая библиотека в мире машинного обучения (ML), которая обещает сделать программирование ML более интуитивным, структурированным и чистым.

Основная и единственная цель JAX - выполнение числовых операций в высокопроизводительной форме. Это означает, что синтаксис практически идентичен Numpy.

Одним из главных преимуществ JAX является то, что мы можем запускать одну и ту же программу без каких-либо изменений на аппаратных ускорителях, таких как GPU и TPU.

Другой важный момент - это скорость. JAX быстрее. Намного быстрее. Например перемножение двух матриц (1000,1000) в NumPy занимает ~50ms, а в JAX ~1.5ms (на GPU).

В библиотеку встроен автоград. JAX способен дифференцировать всевозможные функции python и NumPy, включая циклы, ветвления, рекурсии и многое другое.

Факторы, делающие JAX таким быстрым:
* ускоренная линейная алгебра (Accelerated Linear Algebra или XLA).
* Just in time compilation (jit) - способ выполнения компьютерного кода, который предполагает компиляцию программы - во время выполнения - а не перед выполнением.
* Репликация вычислений между устройствами с помощью pmap - еще одно преобразование, которое позволяет нам реплицировать вычисления на несколько ядер или устройств и выполнять их параллельно (p в pmap означает parallel).
И ещё много различных трюков и улучшений.

Ещё одной особенностью JAX (и возможно даже более важной чем скорость) является Pseudo-Random number generator. В отличие от NumPy или PyTorch, в JAX состояния случайности должны быть поданы пользователем в качестве аргумента (что делает JAX по умолчанию намного более воспроизводимым).

Ещё больше деталей и примеров кода
Официальный GitHub

#gpu #code #jax
Введение в JAX (рекомендовано Szegedy)

Этот туториал (colab) знакомит с важными концепциями JAX (autograd, pytree, JIT и др.), реализуя при этом простой алгоритм градиентного спуска.

#jax #basics #tutorial