Google JAX

JAX
Utvecklare Google
Förhandsgranska release
v0.3.13 / 16 maj 2022 ; 9 månader sedan ( 2022-05-16 )
Förvar github .com /google /jax
Skrivet i Python , C++
Operativ system Linux , macOS , Windows
Plattform Python , NumPy
Storlek 9,0 MB
Typ Maskininlärning
Licens Apache 2.0
Hemsida jax .readthedocs .io /sv /senaste /  Edit this on Wikidata

Google JAX är ett ramverk för maskininlärning för att transformera numeriska funktioner. Det beskrivs som att sammanföra en modifierad version av autograd (automatisk erhållande av gradientfunktionen genom differentiering av en funktion) och TensorFlows XLA (Accelerated Linear Algebra) . Den är utformad för att följa strukturen och arbetsflödet för NumPy så nära som möjligt och fungerar med olika befintliga ramverk som TensorFlow och PyTorch . De primära funktionerna i JAX är:

  1. grad: automatisk differentiering
  2. jit: sammanställning
  3. vmap: autovektorisering
  4. pmap: SPMD-programmering

grad

Koden nedan visar gradfunktionens automatiska differentiering.


   
   


   
         


  


     
 # importer  från  jax  import  grad  import  jax.numpy  som  jnp  # definierar logistikfunktionen  def  logistic  (  x  ):  returnera  jnp  .  exp  (  x  )  /  (  jnp  .  exp  (  x  )  +  1  )  # få gradientfunktionen för logistikfunktionen  grad_logistic  =  grad  (  logistic  )  # utvärdera gradienten för logistikfunktionen vid x = 1  grad_log_out  =  grad_logistic  (  1.0  )  print  (  grad_log_out  ) 

Den sista raden ska skriva utː

0,19661194

jit

Koden nedan visar jit -funktionens optimering genom fusion.


   
   


 
         


   


  



 # importer  från  jax  import  jit  import  jax.numpy  som  jnp  # definiera kubfunktionen  def  cube  (  x  ):  returnera  x  *  x  *  x  # generera data  x  =  jnp  .  ones  ((  10000  ,  10000  ))  # skapa jit-versionen av kubfunktionen  jit_cube  =  jit  (  cube  )  # tillämpa kub- och jit_cube-funktionerna på samma data för  hastighetsjämförelsekub  (  x  )  jit_cube  (  x  ) 

Beräkningstiden för jit_cube (rad nr 17 ) bör vara märkbart kortare än för kub (rad nr 16 ). Öka värdena på rad nr. 7 , kommer att öka skillnaden.

vmap

Koden nedan visar vmap -funktionens vektorisering.


   
   
   


  
       
      
      
      
         0  0
      # importer  från  functools  import  partiell  från  jax  import  vmap  import  jax.numpy  som  jnp  # definiera funktion  def  grads  (  self  ,  inputs  ) :  in_grad_partial  =  partial  (  self  .  _net_grads  ,  self  .  _net_params  )  grad_vmap  =  jax  .  vmap  (  in_grad_partial  )  rich_grads  =  grad_vmap  (  ingångar  )  flat_grads  =  np  .  asarray  (  self  .  _flatten_batch  (  rich_grads  ))  hävda  flat_grads  .  ndim  ==  2  och  flat_grads  .  form  [  ]  ==  ingångar  .  form  [  ]  returnera  flat_grads 

GIF:en till höger om detta avsnitt illustrerar begreppet vektoriserad addition.

Illustrationsvideo av vektoriserad addition

pmap

Koden nedan visar pmap -funktionens parallellisering för matrismultiplikation.


    
   


  0 
      


     


  
 # importera pmap och slumpmässigt från JAX; import JAX NumPy   från  jax  import  pmap  ,  slumpmässig  import  jax.numpy  som  jnp  # generera 2 slumpmässiga matriser med dimensionerna 5000 x 6000, en per enhet  random_keys  =  random  .  split  (  random  .  PRNGKey  (  ),  2  )  matriser  =  pmap  (  lambda  key  :  random  .  normal  (  key  ,  (  5000  ,  6000  )))(  random_keys  )  # utan dataöverföring, parallellt, utför en lokal matrismultiplikation på varje CPU/ GPU-  utgångar  =  pmap  (  lambda  x  :  jnp  .  dot  (  x  ,  x  .  T  ))(  matriser  )  # utan dataöverföring, parallellt, erhåll medelvärdet för båda matriserna på varje CPU/GPU separat  betyder  =  pmap  (  jnp  .  mean  ) (  utgångar  )  print  (  betyder  ) 

Den sista raden ska skriva ut värdenaː

  [  1,1566595  1,1805978  ] 

Bibliotek som använder Jax

Flera pythonbibliotek använder Jax som backend, inklusive:

Se även

externa länkar