flax的学习01 基本用法
作者:互联网
安装jax jaxlib
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
安装flax
pip install flax
pip install --upgrade git+https://github.com/google/flax.git #但是这个我没有成功
文档
文档地址https://flax.readthedocs.io/en/latest/index.html
flax莫的参数和初始化,看两个模型中的代码
class TokenLearnerModule(nn.Module):
"""TokenLearner module.
This is the module used for the experiments in the paper.
Attributes:
num_tokens: Number of tokens.
"""
num_tokens: int
use_sum_pooling: bool = True
@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Applies learnable tokenization to the 2D inputs.
Args:
inputs: Inputs of shape `[bs, h, w, c]` or `[bs, hw, c]`.
Returns:
Output of shape `[bs, n_token, c]`.
"""
if inputs.ndim == 3:
n, hw, c = inputs.shape
h = int(math.sqrt(hw))
inputs = jnp.reshape(inputs, [n, h, h, c])#保证形状时这个样子的
if h * h != hw:
raise ValueError('Only square inputs supported.')
feature_shape = inputs.shape
selected = inputs
selected = nn.LayerNorm()(selected)
for _ in range(3):#这里就是向前传报了
selected = nn.Conv(
self.num_tokens,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
use_bias=False)(selected) # Shape: [bs, h, w, n_token].
selected = nn.gelu(selected)
selected = nn.Conv(
self.num_tokens,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
use_bias=False)(selected) # Shape: [bs, h, w, n_token].
selected = jnp.reshape(
selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
]) # Shape: [bs, h*w, n_token].
selected = jnp.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w].
selected = nn.sigmoid(selected)[..., None] # Shape: [bs, n_token, h*w, 1].
feat = inputs
feat = jnp.reshape(
feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
])[:, None, ...] # Shape: [bs, 1, h*w, c].
if self.use_sum_pooling:
inputs = jnp.sum(feat * selected, axis=2)
else:
inputs = jnp.mean(feat * selected, axis=2)
return inputs
from typing import Sequence
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)
这里包含了几个jax的知识点,但是jax个人也不是很熟悉,所以再去查找jax的文档https://jax.readthedocs.io/en/latest/index.html
找到random的相关内容,PRNGKey是seudorandom number generators keys 方法的缩写,把他认为是生成两个随机的数就可以了,需要注意的是他一次返回两个数值,比如key=random.PRNGKey(0)
的返回值是[0,0],这个东西是一个随机的key,好像在jax中没有像常见的那种随机数就是直接给个数,jax的随机数都是要提供一个key的,这个key就是用这个方法所生成的,此时就可以用random.uniform(key)
来得到一个服从均匀分布的数字。
同样的所有的随机数都需要这样的一个key,但是不需要重复的进行调用random.PRNGKey,可以使用jax.random.split(key,num=2)
来吧这个随机键(暂且那么叫)拆分成更多的子健,每一个子健都可以像原来的那样使用,需要的子健的数量在num参数中给出,此时接受数据的方法就和元组类似k1,k2,k3 = jax.random.split(key,num=3)
参数
参数需要进行初始化,对于习惯了pytorch中的再init中先写模型的定义再向前传播这个无疑是很让人看不懂的,在文档中已经写明了,Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.
具体的参数矩阵的形状是交给模型去自动推理的,自己不需要计算,需要提供一个假输入(假输出),模型会自动推算模型的各个矩阵的形状
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input这里就是假定的输入
params = model.init(key2, x) # Initialization call 自动计算参数的大小
jax.tree_map(lambda x: x.shape, params) # Checking output shapes 和python原生的map类似这里的作用主要是查看形状
model.init_with_output
就是用输出去计算参数的形状的
向前传播
向前传播也和torch有很大的不同,model.apply(params, x)
是jax的向前传播语句
向后传播
对于样本 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\),目标是找的最优的参数\(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\)使得输出在最小二乘法的损失下有最小值。
准备数据
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5
# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
使用jax的向前传播
# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
# Define the squared loss for a single pair (x,y)
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y-pred, y-pred) / 2.0
# Vectorize the previous to compute the average of the loss on all samples.
return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
梯度下降
learning_rate = 0.3 # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# Perform one gradient update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
标签:inputs,selected,jax,random,flax,用法,shape,01,params 来源: https://www.cnblogs.com/honosayaka/p/16309418.html