TFF 全称 tensorflow_federated,为谷歌的联邦学习框架。
在TFF官网的Building Your Own Federated Learning Algorithm界面中,介绍了如何尽可能多的利用现有的TensorFlow代码,构建一个TFF的模型。
本文的行文结构
本文共分为3个章节,其中第1章介绍了TFF的框架,然后给出了客户端和服务器的模型参数更新函数。第2章到主要介绍Federated core的内容。第3章主要把前2章的内容串起来,构建自己的TFF框架。
1.TFF框架的构成
1.1 TFF框架可以分成4个步骤:
- 服务器向客户端(server-to-client)传递初始模型参数
- 客户端更新模型参数
- 客户端向服务器(client-to-server)传递参数
- 服务器更新参数
1.2 这4个步骤又可以根据是否使用纯TensorFlow的代码分为两类:
第一类:全部使用TensorFlow代码构建
包括第2和第4步,客户端更新模型参数,服务器更新参数。
第二类:使用Federated Core代码构建
包括第1和第3步,server-to-client和client-to-server
下面就详细介绍纯TensorFlow环节的要点。需要Federated Core构建的放在后面第5部分。
1.2.1 客户端更新参数
可以分成两步:
(1)从服务器模型获取客户端模型参数,此处的服务器模型是经过第一步传递过来的模型
(2)客户端模型在客户端数据集上训练和更新参数
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs training (using the server model weights) on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Use the client_optimizer to update the local model.
for batch in dataset:
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data
outputs = model.forward_pass(batch)
# Compute the corresponding gradient
grads = tape.gradient(outputs.loss, client_weights)
grads_and_vars = zip(grads, client_weights)
# Apply the gradient using a client optimizer.
client_optimizer.apply_gradients(grads_and_vars)
return client_weights
1.2.2 服务器更新参数
在服务器上更新参数主要涉及更新策略。此处采用了较为简单的”vanilla” 联合平均算法,直接取各个客户端模型参数的平均值作为服务器的模型参数。这里的模型参数只包括可训练的参数。
@tf.function
def server_update(model, mean_client_weights):
"""Updates the server model weights as the average of the client model weights."""
model_weights = model.trainable_variables
# Assign the mean client weights to the server model.
tf.nest.map_structure(lambda x, y: x.assign(y),
model_weights, mean_client_weights)
return model_weights
2.关于Federated Core (FC)
FC包含底层和顶层两个维度的API。具体而言,FC是服务于tff.learning
API的底层的接口。然而,FC又是一个顶层的开发环境,它提供了一种更加紧凑的程序逻辑,把TensorFlow的代码和分布式通信操作(包括分布式求和和广播)结合起来。
FC的目标是允许开发者明确地控制系统中的分布式通信(例如点对点的网络消息交换),而不需要了解实施的细节。
TFF的设计之初就是为了数据的隐秘性,所以FC允许用户明确的控制数据应该在哪一个层面,防止数据泄露。
2.1 联邦数据(Federated data)
TFF中的一个关键概念是“联合数据”(Federated data),它是指分布式系统中一组设备上托管的数据项的集合(例如,客户端数据集或服务器模型权重)。跨所有设备的整个值集合表示为单个联合值。
例如,假设存在客户端设备,每个客户端设备都有一个表示传感器温度的浮点。这些浮点可以通过下面的式子表示为联合浮点:
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)
联合类型由其成员成分的变量类型“T”(例如“tf.float32”)和设备组“G”指定。通常,“G”是tff.CLIENTS
或tff.SERVER
。这种联合类型表示为{T}@G’,如下所示:
str(federated_float_on_clients)
# '{float32}@CLIENTS'
2.2 联邦计算(Federated computations)
TFF接受联合值作为输入,并且把联合值作为输出。例如,假设您想要平均客户端传感器上的温度。您可以定义以下内容(使用我们的联合浮点):
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
return tff.federated_mean(client_temperatures)
对于联邦计算,作者用了一句话来定义:
It is a specification of a distributed system in an internal platform-independent glue language.
它是一个分布式系统的规范,采用独立于内部平台的“粘合语言”
此处的tff.federated_computation
接受联合类型{float32}@CLIENTS
的参数,并返回联合类型{float 32}@SERVER
的值。联邦计算也可以从服务器到客户机、从客户机到客户端或从服务器到服务器。联邦计算也可以像普通函数一样组成,只要它们的类型签名匹配即可。
get_average_temperature([68.5, 70.3, 69.8])
# 相当于(68.5+70.3+69.8)/3
2.3 关于非渴望计算和TensorFlow
- TFF操作的是联合数值
- 每一个联合值都有一个联合类型
(Federated type)
,包括类型(type)
和分配(placement)
。 - 联合数值可以使用联合计算来传递,必须使用
tff.federated_computation
加上联合类型去修饰。 - TensorFlow code必须包含在
tff.tf_computation
的修饰块里面,然后可以将这些块合并到federated_computation
中
3.关于构建自己的联邦学习算法
我们定义了 initialize_fn
and next_fn
来完成联邦学习的步骤。
在1.2中介绍了服务器参数更新server_update
和客户端参数client_update
更新,都是由TensorFlow代码构成,
但是为了能够实现联邦计算,要把initialize_fn
and next_fn
变成一个tff.federated_computation
.
3.1 TFF blocks
3.1.1 创建初始化计算
使用model_fn来创建一个我们的模型,然后使用tff.tf_computation
来把TF的代码分开。
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
然后我们可以通过tff.federated_value
来把服务器初始化参数传递到联邦计算中:
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
3.1.2 创建next_fn函数
服务器和客户端更新代码可以用于编写实际算法。首先,需要把 client_update
转变成tff.tf_computation
,接收一个客户端的数据集和服务器的参数,并且输出一个更新后的客户端参数tensor。
需要对函数添加相应的变量类型修饰。幸运的是,服务器的权重类型可以直接通过模型导出。
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
让我们看一看数据集类型签名,假设采用的是Mnist数据集中的数据,里面的样本是是28*28像素的图片,可以展开成784,然后标签是1。
也可以通过server_init
函数提取权重类型:
model_weights_type = server_init.type_signature.result
然后通过str直接打印模型的结构:
str(model_weights_type)
# '<float32[784,10],float32[10]>'
现在,我们知道了tf_dataset_type
和model_weights_type
,然后我们可以为client_update
创建 tff.tf_computation
了:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
model = model_fn()
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
return client_update(model, tf_dataset, server_weights, client_optimizer)
为 server update
创建 tff.tf_computation
的方式是类似的:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
model = model_fn()
return server_update(model, mean_client_weights)
最后最重要的是:我们需要创建一个tff.federated_computation
把他们都放在一起。这个函数将会接收2个联邦数值(以及分配情况),一个是服务器相应的权重(分配给tff.SERVER
),另一个是对应的客户端的数据集(分配给tff.CLIENTS
)。
这两个变量的类型上面都已经定义了,也就是model_weights_type
和tf_dataset_type
。
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
3.1.3 构建federated_computation
至此,TFF所需要的组件都已经构建好了,下面就开始把各个组件放到一起。
按照联邦学习的4个步骤,构建next_fn
:
- 服务器参数传递
- 更新客户端参数
- 根据客户端参数计算服务器参数
- 根据服务器参数更新客户端参数
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = tff.federated_broadcast(server_weights)
# Each client computes their updated weights.
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# The server averages these updates.
mean_client_weights = tff.federated_mean(client_weights)
# The server updates its model.
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights
3.1.4 tff.templates.IterativeProcess
为了完成我们的算法,还需要把initialize_fn和next_fn传给tff.templates.IterativeProcess
。
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
可以通过str查看federated_algorithm
的类型:
str(federated_algorithm.initialize.type_signature)
#'( -> <float32[784,10],float32[10]>@SERVER)'
str(federated_algorithm.next.type_signature)
# '(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'
->的左边是传入的参数结构,右边是输出的参数结构。可以清楚的看到,next_fn的参数传入的是服务器的参数,客户端的数据集,输出的是更新的服务器的参数。
3.2 评估算法
终于来到了机动人心的评估算法编写环节。
首先需要构建一个集中的评估数据集,需要对其做训练集一样的预处理:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)
然后,我们需要写一个函数接收一个服务器的状态server_state
,并且使用keras在测试数据集上进行评估。
def evaluate(server_state):
keras_model = create_keras_model()
keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
keras_model.set_weights(server_state)
keras_model.evaluate(central_emnist_test)
然后,让我们初始化算法并且在测试集上进行评估。
server_state = federated_algorithm.initialize()
evaluate(server_state)
# 2042/2042 [==============================] - 8s 3ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027
然后我们进行联合训练15轮再次评估:
for round in range(15):
server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
# 2042/2042 [==============================] - 5s 2ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980
可以看到loss小幅度降低了。
3.3 构建自己的算法
经过上面的步骤,我们已经可以导入emnist数据集,基于keras构建模型,然后编写服务器更新函数和客户端更新函数,将其转换到TFF框架下,然后在TFF测试数据集进行训练的评估。
那么如果我们需要构建自己的模型,然后在自己的数据集上进行训练只需要把其中纯TensorFlow构建的模型部分进行修改就可以了。
今天的文章[联邦学习TFF]构建自己的联邦学习模型分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/85946.html