#! https://zhuanlan.zhihu.com/p/517570917
8002:构建损失函数的程序语言
静态站:http://catsmile.info/8002-program-tools.html
目标
梳理jax相关的扩展库
基于Python的框架
python
jax
flax
haiku
tensorflow
sonnet
pytorch
keras
Tensorflow | PyTorch | Jax | |
Developed by | |||
Flexible | No | Yes | Yes |
Graph-Creation | Static/Dynamic | Dynamic | Static |
Target Audience | Researchers, Developers | Researchers, Developers | Researchers |
Low/High-level API | High Level | Both | Both |
Development Stage | Mature( v2.4.1 ) | Mature( v1.8.0 ) | Developing( v0.1.55 ) |
ref:https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax
python-jax
python3.7 -m pip install install dm-haiku jax jaxlib keras tensorflow # tensorflow_datasets
评价: TBC
函数式编程,signature为王
py-jax-vae 137 lines https://github.com/google/jax/blob/main/examples/mnist_vae.py
py-jax-haiku
评价:更加面向函数,传承了jax设计精神.更加细致更接近底层
核心组件
haiku.transform(f, *, apply_rng=True) -> haiku.Transformed
def transform(f, *, apply_rng=True) -> Transformed:
"""Transforms a function using Haiku modules into a pair of pure functions.
For a function ``out = f(*a, **k)`` this function returns a pair of two pure
functions that call ``f(*a, **k)`` explicitly collecting and injecting
parameter values::
params = init(rng, *a, **k)
out = apply(params, rng, *a, **k)
"""
class Transformed(NamedTuple):
"""Holds a pair of pure functions.
Attributes:
init: A pure function: ``params = init(rng, *a, **k)``
apply: A pure function: ``out = apply(params, rng, *a, **k)``
"""
# Args: [Optional[PRNGKey], ...]
init: Callable[..., hk.Params]
# Args: [Params, Optional[PRNGKey], ...]
apply: Callable[..., Any]
由于jax函数完全由子函数描述,因此jax的模板模块应当仅仅规定子函数的形式,而不需要规定包括device等元属性.
例子
https://github.com/deepmind/dm-haiku#quickstart
参数初始化model.init
import os,shutil,sys
import haiku as hk
import jax.numpy as jnp
import jax
# import tensorflow_datasets as tfds
import numpy as np
from pprint import pprint
def load_dataset(
split: str,
is_training: bool,
batch_size: int,
):
PKL = __file__+'.npy'
if not os.path.exists(PKL):
from keras.datasets import mnist
v = mnist.load_data()
np.save(PKL+'.temp.npy',v)
shutil.move(PKL+'.temp.npy',PKL)
else:
v = np.load(PKL,allow_pickle=True)
(x_train, y_train), (x_test, y_test) = v
x = x_train
y = y_train
x = x.reshape((-1,28**2))
y = y[:,None]
if is_training:
idx = np.random.permutation(range(len(x)))
x = x[idx]
y = y[idx]
x = jnp.array(x,dtype='float')
y = jnp.array(y,dtype='int8')
def giter(it=(x,y),batch_size=batch_size):
x,y = it
L = len(x)
i=-1
while True:
i+=1
i = i %(L//batch_size)
idx = slice(i*batch_size,(i+1)*batch_size)
tup = x[idx],y[idx]
yield tup
return giter()
def softmax_cross_entropy(logits, labels):
one_hot = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
def loss_fn(images, labels):
mlp = hk.Sequential([
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
logits = mlp(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
#### core routine for training
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
input_dataset = load_dataset("train", is_training=True, batch_size=100)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)
def update_rule(param, update):
return param - 0.001 * update
max_iter = 100
print_interval = 10
i = -1
for images, labels in input_dataset:
i+= 1
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
params = jax.tree_map(update_rule, params, grads)
loss = loss_fn_t.apply(params,images,labels)
if i%print_interval==0:
print(f'[B{i}]loss={loss:.3f}')
if i>= max_iter:
break
! python3.7 catsmile/c8002_haiku_example.py
/home/ubuntu/.local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py:1806: UserWarning: Explicitly requested dtype float requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax_internal._check_user_dtype_supported(dtype, "array")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[B0]loss=52.627
[B10]loss=4.963
[B20]loss=3.043
[B30]loss=2.578
[B40]loss=2.498
[B50]loss=2.518
[B60]loss=2.454
[B70]loss=2.335
[B80]loss=2.359
[B90]loss=2.335
[B100]loss=2.341
py-jax-haiku-vae ,211 lines https://github.com/deepmind/dm-haiku/blob/main/examples/vae.py
py-jax-flax
评价:更加面向对象,提供适合大型工程的灵活性.处理device,flexible states之类的接口性质的杂事.
核心组件flax.linen.Module
@dataclass_transform()
class Module:
"""Base class for all neural network modules. Layers and models should subclass this class.
All Flax Modules are Python 3.7
`dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. Since
dataclasses take over ``__init__``, you should instead override :meth:`setup`,
which is automatically called to initialize the module.
Modules can contain submodules, and in this way can be nested in a tree
structure. Submodels can be assigned as regular attributes inside the
:meth:`setup` method.
You can define arbitrary "forward pass" methods on your Module subclass.
While no methods are special-cased, ``__call__`` is a popular choice because
it allows you to use module instances as if they are functions::
from flax import linen as nn
class Module(nn.Module):
features: Tuple[int] = (16, 4)
def setup(self):
self.dense1 = Dense(self.features[0])
self.dense2 = Dense(self.features[1])
def __call__(self, x):
return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules
definitions are co-located with their usage, you can use the
:meth:`compact` wrapper.
"""
例子: py-jax-flax-vae, 211 lines https://github.com/google/flax/blob/main/examples/vae/train.py