#! https://zhuanlan.zhihu.com/p/517570917

8002:构建损失函数的程序语言

静态站:http://catsmile.info/8002-program-tools.html

目标

  • 梳理jax相关的扩展库

基于Python的框架

  • python

    • jax

      • flax

      • haiku

    • tensorflow

      • sonnet

    • pytorch

      • keras

TensorflowPyTorchJax
Developed byGoogleFacebookGoogle
FlexibleNoYesYes
Graph-CreationStatic/DynamicDynamicStatic
Target AudienceResearchers,
Developers
Researchers,
Developers
Researchers
Low/High-level APIHigh LevelBothBoth
Development StageMature( v2.4.1 )Mature( v1.8.0 )Developing( v0.1.55 )

ref:https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax

python-jax

python3.7 -m pip install install dm-haiku jax jaxlib keras tensorflow # tensorflow_datasets

评价: TBC

函数式编程,signature为王

py-jax-vae 137 lines https://github.com/google/jax/blob/main/examples/mnist_vae.py

py-jax-haiku

评价:更加面向函数,传承了jax设计精神.更加细致更接近底层

核心组件

haiku.transform(f, *, apply_rng=True) -> haiku.Transformed

def transform(f, *, apply_rng=True) -> Transformed:
  """Transforms a function using Haiku modules into a pair of pure functions.
  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values::
      params = init(rng, *a, **k)
      out = apply(params, rng, *a, **k)
      """
class Transformed(NamedTuple):
  """Holds a pair of pure functions.
  Attributes:
    init: A pure function: ``params = init(rng, *a, **k)``
    apply: A pure function: ``out = apply(params, rng, *a, **k)``
  """

  # Args: [Optional[PRNGKey], ...]
  init: Callable[..., hk.Params]

  # Args: [Params, Optional[PRNGKey], ...]
  apply: Callable[..., Any]

由于jax函数完全由子函数描述,因此jax的模板模块应当仅仅规定子函数的形式,而不需要规定包括device等元属性.

例子

https://github.com/deepmind/dm-haiku#quickstart

参数初始化model.init

import os,shutil,sys
import haiku as hk
import jax.numpy as jnp
import jax
# import tensorflow_datasets as tfds
import numpy as np
from pprint import pprint

def load_dataset(
    split: str,
    is_training: bool,
    batch_size: int,
):
  PKL = __file__+'.npy'
  if not os.path.exists(PKL):
      from keras.datasets import mnist
      v = mnist.load_data()
      np.save(PKL+'.temp.npy',v)
      shutil.move(PKL+'.temp.npy',PKL)
  else:
      v = np.load(PKL,allow_pickle=True)

  (x_train, y_train), (x_test, y_test) = v
  x   = x_train
  y   = y_train
  x   = x.reshape((-1,28**2))
  y   = y[:,None]
  if is_training:
      idx = np.random.permutation(range(len(x)))
      x   = x[idx]
      y   = y[idx]
  x   = jnp.array(x,dtype='float')
  y   = jnp.array(y,dtype='int8')

  def giter(it=(x,y),batch_size=batch_size):
      x,y = it
      L = len(x)
      i=-1
      while True:
          i+=1
          i = i %(L//batch_size)
          idx = slice(i*batch_size,(i+1)*batch_size)
          tup = x[idx],y[idx]
          yield tup
  return giter()


def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

#### core routine for training
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

input_dataset = load_dataset("train", is_training=True, batch_size=100)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)

def update_rule(param, update):
  return param - 0.001 * update

max_iter = 100
print_interval = 10
i = -1
for images, labels in input_dataset:
  i+= 1
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_map(update_rule, params, grads)
  loss = loss_fn_t.apply(params,images,labels)
  if i%print_interval==0:
      print(f'[B{i}]loss={loss:.3f}')
  if i>= max_iter:
      break
! python3.7 catsmile/c8002_haiku_example.py
/home/ubuntu/.local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py:1806: UserWarning: Explicitly requested dtype float requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax_internal._check_user_dtype_supported(dtype, "array")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[B0]loss=29.527
[B10]loss=5.281
[B20]loss=3.177
[B30]loss=2.624
[B40]loss=2.564
[B50]loss=2.472
[B60]loss=2.378
[B70]loss=2.402
[B80]loss=2.389
[B90]loss=2.347
[B100]loss=2.346

py-jax-haiku-vae ,211 lines https://github.com/deepmind/dm-haiku/blob/main/examples/vae.py

py-jax-flax

评价:更加面向对象,提供适合大型工程的灵活性.处理device,flexible states之类的接口性质的杂事.

核心组件flax.linen.Module

@dataclass_transform()
class Module:
  """Base class for all neural network modules. Layers and models should subclass this class.

  All Flax Modules are Python 3.7
  `dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. Since
  dataclasses take over ``__init__``, you should instead override :meth:`setup`,
  which is automatically called to initialize the module.

  Modules can contain submodules, and in this way can be nested in a tree
  structure. Submodels can be assigned as regular attributes inside the
  :meth:`setup` method.

  You can define arbitrary "forward pass" methods on your Module subclass.
  While no methods are special-cased, ``__call__`` is a popular choice because
  it allows you to use module instances as if they are functions::

    from flax import linen as nn

    class Module(nn.Module):
      features: Tuple[int] = (16, 4)

      def setup(self):
        self.dense1 = Dense(self.features[0])
        self.dense2 = Dense(self.features[1])

      def __call__(self, x):
        return self.dense2(nn.relu(self.dense1(x)))

  Optionally, for more concise module implementations where submodules
  definitions are co-located with their usage, you can use the
  :meth:`compact` wrapper.
  """

例子: py-jax-flax-vae, 211 lines https://github.com/google/flax/blob/main/examples/vae/train.py

Ref