其他分享
首页 > 其他分享> > Tensorflow Federated(TFF)框架整理(下)

Tensorflow Federated(TFF)框架整理(下)

作者:互联网

之前提到的方法,完全没有提供任何的反向传播/优化过程,都是tff.templates.IterativeProcess帮我们处理好的,我们每次传入当前state和训练集就可以得到新的statemetrics。为了更好的定制我们自己的优化方法,我们需要自己编写tff.template.IterativeProcess方法,重写initializenext方法,并且自己设定优化过程。

数据类型

Federated Core 提供了以下几种类型:

以下类型解决 TFF 计算的分布系统方面的问题:

  • 布局类型。除了 2 个文字形式的 tff.SERVERtff.CLIENTS(可将其视为这种类型的常量)外,这种类型还没有在公共 API 中公开。它仅供内部使用,但是,将在以后的公共 API 版本中引入。该类型的紧凑表示法为 placement。布局表示扮演特定角色的系统参与者的集合。最初的版本是为了解决客户端-服务器计算的问题,其中有 2 组参与者:客户端和服务器(可将后者视为单一实例组)。但是,在更复杂的架构中,可能还有其他角色,如多层系统中的中间聚合器。这种聚合器可能执行不同类型的聚合,或者使用不同类型的数据压缩/解压缩,而不是服务器或客户端使用的类型。定义布局概念的主要目的是作为定义联合类型的基础。
  • 联合类型 (tff.FederatedType)。联合类型的值是由特定布局(如 tff.SERVERtff.CLIENTS)定义的一组系统参与者托管的值。联合类型通过布局值(因此,它是一种依赖类型), 成员组成要素(每个参与者在本地托管的内容类型),以及指定所有参与者是否在本地托管同一项目的附加部分 all_equal 进行定义。对于包含 T 类型项目(成员组成)的值的联合类型,如果每个项目由组(布局)G 托管,则其紧凑表示法为 T@G{T}@G,分别设置或不设置 all_equal 位。{int32}@CLIENTS 表示包含一组可能不同的整数;{<X=float32,Y=float32>*}@CLIENTS 表示一个联合数据集;<weights=float32[10,5],bias=float32[5]>@SERVER 表示服务器上的权重和偏差张量的命名元组。我们省略了花括号,这表示已设置 all_equal 位。
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)  # '{float32}@CLIENTS'

函数

Federated Core 的语言是一种 λ 演算,它提供了当前在公共 API 中公开的以下编程抽象:

# tensor computation is constricted in tff.federated_computation
# should be completion by the following way:
@tff.tf_computation(tff.SequenceType(tf.int32))
def add_up_integers(x):
  return x.reduce(np.int32(0), lambda x, y: x + y)

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)
# '({float32}@CLIENTS -> float32@SERVER)'
# the biggest differnce between tf.computation and tff.federated_computation is the placement
@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

Work flow

一个典型的FL研究code由三种主要逻辑

# data preparation
import nest_asyncio
nest_asyncio.apply()

import tensorflow as tf
import tensorflow_federated as tff

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]
# model preparation
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

To build our own Federated Learning algorithm, there are four main components:

  1. A server-to-clients broadcast step
  2. A local client update step
  3. A client-to-server upload step
  4. A server update step

Meanwhile, we should rewrite initialize and next functions.

Method_1

Local training

本地训练是不需要tff参与的

# step 2 local training
# return client model weights
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
    client_weights = model.trainable_variables
    
    # clone server_weights, which is exactly state meaning in the previous code.
    tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
	
    # optimization
    for batch in dataset:
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
    
        grads = tape.gradient(outputs.loss, client_weights)
        grad_and_vars = zip(grads, client_weights)

        client_optimizer.apply_gradients(grad_and_vars)  # update
    
    return client_weights

输入的参数有modeldatasetserver_weightsclient_optimizer,为什么参数这么多呢?是因为tf.function不涉及任何数据placement的信息,而关于placement的部分全交给tff去处理。

Server update

跟客户端的更新一样,服务器端的更新也是不需要tff参与的

# step4
@tf.function
def server_update(model, mean_client_weights):
    model_weights = model.trainable_variables
    tf.nest.map_structure(lambda x,y: x.assign(y), model_weights, mean_client_weights)
    return model_weights

TFF snippet

现在就需要tff进行不同placement数据的整合,以及重写tff.templates.IterativeProcess的两个方法了。

# initialize method
@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)  # A federated value with the given placement placement, and the member constituent value equal at all locations.
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)  # inpute specification
model_weights_type = server_init.type_signature.result  # output specification
# there are multiple sources data and should use tff.tf_computation decoration
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
    model = model_fn()
    client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    return client_update(model, tf_dataset, server_weights, client_optimizer)

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)


# rewrite next function
# state is server_weights.
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
    # step1. broadcast
    server_weights_at_client = tff.federated_broadcast(server_weights)
	
    # step2. local update
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))
	
    # step3. uploading
    mean_client_weights = tff.federated_mean(client_weights)
	
    # step4. server update
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

server_state = federated_algorithm.initialize()
evaluate(server_state)
for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)

Method_2

In the second method, an optimizer from tff.leraning.optimizers will supersede the previous one, which has initialize(<Tensorspec>) and next functions.

TF snippet

@tf.function
def client_update(model, dataset, server_weights, optimizer):
    client_weights = model.trainable_weights
    tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
    
    trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
    optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
    
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            output = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, client_weights)
        optimizer_state, update_weights = client_optimizer.next(
            optimizer_state, client_weights, grads)
        tf.nest.map_structure(lambda a, b: a.assign(b), client_weights, update_weights)
    return tf.nest.map_structure(tf.subtract, client_weights, server_weights)  # return the cumulative gradient

# contanier, collecting server weights and server optimizer state.
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
    trainable_weights = attr.ib()
    optimizer_state = attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
    negative_weights_delta = tf.nest.map_structure(
        lambda w: -1.0 * w, mean_model_delta)
    new_optimizer_state, updated_weights = server_optimizer.next(
        server_state.optimizer_state, server_state.trainable_weights, negative_weights_delta)
    return tff.structure.update_struct(
        server_state, 
        trainable_weights = updated_weights, 
        optimizer_state = new_optimizer_state)

TFF snippet

server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)

@tff.tf_computation
def server_init():
    model = model_fn()
    trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
    optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
    return ServerState(
        trainable_weights=model.trainable_variables,
        optimizer_state=optimizer_state)

@tff.tff_computation
def server_init_tff():
    return tff.federated_value(server_init(), tff.SERVER)

server_state_type = server_init.type_signature.result
trainable_weights_type = server_state_type.trainable_weights

@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
    return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
    model = model_fn()
    return client_update(model, dataset, server_weights, client_optimizer)

federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
    
    model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
    
    mean_model_delta = tff.federated_mean(model_deltas)
    
    server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
    return server_state

fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)

Summary

The process of customizing our own tff.template.IterativeProcess class:

  1. Firstly, regardless of placement constraint, you should complete the Tensorflow code to fulfill the client update and server update function. Usually, the input parameters for the client update function should include model, dataset, server_weights and optimizer and the output should be the cumulative grads or the new client model trainable variables. The input of the server update is rather simple, the current server state and the new aggregated changes and its output is the new server state. According to your definition, the serve state can be the model trainable variables or contains other items. Both of this two functions are decorated by tf.function;
  2. Secondly, server_update_fn, client_update_fn and server_init_fn should be completed and all of them are decorated by tff.tf_computation. The decoration shows that the input parameters should be placed in the same position. In the server_init_fn, the output should be a new state. In the client_update_fn, the input parameters are dataset and server_weights(Note, server_weights are the duplication and placed in the tff.CLIENTS by the tff.federated_broadcast function) and it will call the previous client update function. In the server_update_fn, the input parameters are server_state and the cumulative changes(Note, cumulative changes are aggregated by the tff.federated_mean function and placed in tff.SERVER) and call the previous server update function;
  3. Thirdly, server_init_tff and next_fn will be created and both of them are decorated by tff.federated_computation to solve the placement issues. In the server_init_tff function, it will place the value, output of the server_init function, to the tff.SERVER by the tff.federated_value function. In the next_fn, four steps in the workflow will be completed.

标签:TFF,tff,Federated,server,client,weights,tf,Tensorflow,model
来源: https://www.cnblogs.com/DemonHunter/p/15591392.html