Google JAX
Utvecklare | |
---|---|
Förhandsgranska release | v0.3.13 / 16 maj 2022
|
Förvar | |
Skrivet i | Python , C++ |
Operativ system | Linux , macOS , Windows |
Plattform | Python , NumPy |
Storlek | 9,0 MB |
Typ | Maskininlärning |
Licens | Apache 2.0 |
Hemsida |
|
Google JAX är ett ramverk för maskininlärning för att transformera numeriska funktioner. Det beskrivs som att sammanföra en modifierad version av autograd (automatisk erhållande av gradientfunktionen genom differentiering av en funktion) och TensorFlows XLA (Accelerated Linear Algebra) . Den är utformad för att följa strukturen och arbetsflödet för NumPy så nära som möjligt och fungerar med olika befintliga ramverk som TensorFlow och PyTorch . De primära funktionerna i JAX är:
- grad: automatisk differentiering
- jit: sammanställning
- vmap: autovektorisering
- pmap: SPMD-programmering
grad
Koden nedan visar gradfunktionens automatiska differentiering.
# importer från jax import grad import jax.numpy som jnp # definierar logistikfunktionen def logistic ( x ): returnera jnp . exp ( x ) / ( jnp . exp ( x ) + 1 ) # få gradientfunktionen för logistikfunktionen grad_logistic = grad ( logistic ) # utvärdera gradienten för logistikfunktionen vid x = 1 grad_log_out = grad_logistic ( 1.0 ) print ( grad_log_out )
Den sista raden ska skriva utː
0,19661194
jit
Koden nedan visar jit -funktionens optimering genom fusion.
# importer från jax import jit import jax.numpy som jnp # definiera kubfunktionen def cube ( x ): returnera x * x * x # generera data x = jnp . ones (( 10000 , 10000 )) # skapa jit-versionen av kubfunktionen jit_cube = jit ( cube ) # tillämpa kub- och jit_cube-funktionerna på samma data för hastighetsjämförelsekub ( x ) jit_cube ( x )
Beräkningstiden för jit_cube (rad nr 17 ) bör vara märkbart kortare än för kub (rad nr 16 ). Öka värdena på rad nr. 7 , kommer att öka skillnaden.
vmap
Koden nedan visar vmap -funktionens vektorisering.
0 0
# importer från functools import partiell från jax import vmap import jax.numpy som jnp # definiera funktion def grads ( self , inputs ) : in_grad_partial = partial ( self . _net_grads , self . _net_params ) grad_vmap = jax . vmap ( in_grad_partial ) rich_grads = grad_vmap ( ingångar ) flat_grads = np . asarray ( self . _flatten_batch ( rich_grads )) hävda flat_grads . ndim == 2 och flat_grads . form [ ] == ingångar . form [ ] returnera flat_grads
GIF:en till höger om detta avsnitt illustrerar begreppet vektoriserad addition.
pmap
Koden nedan visar pmap -funktionens parallellisering för matrismultiplikation.
0
# importera pmap och slumpmässigt från JAX; import JAX NumPy från jax import pmap , slumpmässig import jax.numpy som jnp # generera 2 slumpmässiga matriser med dimensionerna 5000 x 6000, en per enhet random_keys = random . split ( random . PRNGKey ( ), 2 ) matriser = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( random_keys ) # utan dataöverföring, parallellt, utför en lokal matrismultiplikation på varje CPU/ GPU- utgångar = pmap ( lambda x : jnp . dot ( x , x . T ))( matriser ) # utan dataöverföring, parallellt, erhåll medelvärdet för båda matriserna på varje CPU/GPU separat betyder = pmap ( jnp . mean ) ( utgångar ) print ( betyder )
Den sista raden ska skriva ut värdenaː
[ 1,1566595 1,1805978 ]
Bibliotek som använder Jax
Flera pythonbibliotek använder Jax som backend, inklusive:
- Flax, ett neuralt nätverksbibliotek på hög nivå som ursprungligen utvecklades av Google Brain .
- Haiku, ett objektorienterat bibliotek för neurala nätverk utvecklat av DeepMind .
- Equinox, ett bibliotek som kretsar kring idén att representera parametriserade funktioner (inklusive neurala nätverk) som PyTrees. Den skapades av Patrick Kidger.
- Optax, ett bibliotek för gradientbearbetning och optimering utvecklat av DeepMind .
- RLax, ett bibliotek för utveckling av förstärkningsinlärningsmedel utvecklat av DeepMind .
Se även
- NumPy
- TensorFlow
- PyTorch
- CUDA
- Automatisk differentiering
- Just-in-time sammanställning
- Vektorisering
- Automatisk parallellisering
externa länkar
- Dokumentationː
jax .readthedocs .io - Colab ( Jupyter /iPython) snabbstartsguideː
colab .research .google .com /github /google /jax /blob /main /docs /notebooks /quickstart .ipynb -
TensorFlows XLAː
www .tensorflow .org /xla (Accelerated Linear Algebra) - på YouTube
- Originalpapperː
mlsys .org /Conferences /doc /2018 /146 .pdf