其他分享
首页 > 其他分享> > flax的学习01 基本用法

flax的学习01 基本用法

作者:互联网

安装jax jaxlib

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

安装flax

pip install flax
pip install --upgrade git+https://github.com/google/flax.git #但是这个我没有成功

文档

文档地址https://flax.readthedocs.io/en/latest/index.html
flax莫的参数和初始化,看两个模型中的代码

class TokenLearnerModule(nn.Module):
  """TokenLearner module.

  This is the module used for the experiments in the paper.

  Attributes:
    num_tokens: Number of tokens.
  """
  num_tokens: int
  use_sum_pooling: bool = True

  @nn.compact
  def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
    """Applies learnable tokenization to the 2D inputs.

    Args:
      inputs: Inputs of shape `[bs, h, w, c]` or `[bs, hw, c]`.

    Returns:
      Output of shape `[bs, n_token, c]`.
    """
    if inputs.ndim == 3:
      n, hw, c = inputs.shape
      h = int(math.sqrt(hw))
      inputs = jnp.reshape(inputs, [n, h, h, c])#保证形状时这个样子的

      if h * h != hw:
        raise ValueError('Only square inputs supported.')

    feature_shape = inputs.shape

    selected = inputs
    selected = nn.LayerNorm()(selected)

    for _ in range(3):#这里就是向前传报了
      selected = nn.Conv(
          self.num_tokens,
          kernel_size=(3, 3),
          strides=(1, 1),
          padding='SAME',
          use_bias=False)(selected)  # Shape: [bs, h, w, n_token].

      selected = nn.gelu(selected)

    selected = nn.Conv(
        self.num_tokens,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding='SAME',
        use_bias=False)(selected)  # Shape: [bs, h, w, n_token].

    selected = jnp.reshape(
        selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
                  ])  # Shape: [bs, h*w, n_token].
    selected = jnp.transpose(selected, [0, 2, 1])  # Shape: [bs, n_token, h*w].
    selected = nn.sigmoid(selected)[..., None]  # Shape: [bs, n_token, h*w, 1].

    feat = inputs
    feat = jnp.reshape(
        feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
              ])[:, None, ...]  # Shape: [bs, 1, h*w, c].

    if self.use_sum_pooling:
      inputs = jnp.sum(feat * selected, axis=2)
    else:
      inputs = jnp.mean(feat * selected, axis=2)

    return inputs
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

这里包含了几个jax的知识点,但是jax个人也不是很熟悉,所以再去查找jax的文档https://jax.readthedocs.io/en/latest/index.html
找到random的相关内容,PRNGKey是seudorandom number generators keys 方法的缩写,把他认为是生成两个随机的数就可以了,需要注意的是他一次返回两个数值,比如key=random.PRNGKey(0)的返回值是[0,0],这个东西是一个随机的key,好像在jax中没有像常见的那种随机数就是直接给个数,jax的随机数都是要提供一个key的,这个key就是用这个方法所生成的,此时就可以用random.uniform(key)来得到一个服从均匀分布的数字。
同样的所有的随机数都需要这样的一个key,但是不需要重复的进行调用random.PRNGKey,可以使用jax.random.split(key,num=2)来吧这个随机键(暂且那么叫)拆分成更多的子健,每一个子健都可以像原来的那样使用,需要的子健的数量在num参数中给出,此时接受数据的方法就和元组类似k1,k2,k3 = jax.random.split(key,num=3)

参数

参数需要进行初始化,对于习惯了pytorch中的再init中先写模型的定义再向前传播这个无疑是很让人看不懂的,在文档中已经写明了,Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.
具体的参数矩阵的形状是交给模型去自动推理的,自己不需要计算,需要提供一个假输入(假输出),模型会自动推算模型的各个矩阵的形状

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input这里就是假定的输入
params = model.init(key2, x) # Initialization call  自动计算参数的大小
jax.tree_map(lambda x: x.shape, params) # Checking output shapes  和python原生的map类似这里的作用主要是查看形状

model.init_with_output就是用输出去计算参数的形状的

向前传播

向前传播也和torch有很大的不同,model.apply(params, x)是jax的向前传播语句

向后传播

对于样本 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\),目标是找的最优的参数\(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\)使得输出在最小二乘法的损失下有最小值。

准备数据

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

使用jax的向前传播

# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

梯度下降

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

标签:inputs,selected,jax,random,flax,用法,shape,01,params
来源: https://www.cnblogs.com/honosayaka/p/16309418.html