联邦学习 (FL) 不再是研究的好奇心,而是对棘手限制的实际回应:最有价值的数据通常是最不可动的数据。监管边界、数据主权规则和组织风险承受能力通常会阻止集中式聚合。与此同时,纯粹的数据引力使得即使是允许的大规模传输也会变得缓慢、昂贵且脆弱。

最新版本的 NVIDIA FLARE 通过联邦计算运行时解决了这一问题,该运行时可将训练逻辑迁移至数据所在位置,而原始数据无需移动。在高风险环境中,集中式数据聚合通常不可行或不现实,因此现代联邦学习平台必须将数据隔离合规性以及 隐私增强技术 作为首要考量。

历来减缓采用的不是 FL 的概念,而是开发者体验。如果从“我的本地脚本训练”到“我的工作在联邦站点中运行”的路径需要深度重构、新的类层次结构或易碎配置,许多项目会在试点后停止。

FLARE API 演进正是为了实现这一目标:将工作分为两个具体步骤,清晰地映射出团队构建和交付 ML 系统的实际方式,从而消除重构用度:

  • 第 1 步(客户端 API): 只需添加约 5 到 6 行代码,即可将现有的本地训练脚本转换为联邦学习客户端,且无需修改训练循环结构。
  • 第 2 步(作业方案):选择 FL 工作流并将其与客户端训练脚本绑定,然后仅通过更换执行环境,在仿真、PoC 和生产环境中运行同一作业。

“无数据复制”作为系统要求

在受监管或高灵敏度设置中,“只需集中数据集”的可能性越来越大。实用的联邦计算平台需要支持:

  • 无数据复制: 数据保持本地,只有模型更新 (或同等信号) 会移动。
  • 合规性态势: 支持主权和审计要求的部署和治理控制。
  • 隐私增强技术: 多层防御 (示例包括同态加密、差分隐私和机密计算) 。

Figure shows a before-and-after comparison of centralized versus federated computing. On the left (“before”), three separate data silos send their data into one centralized database where a model is trained. On the right (“after”), data remains in separate, locked databases at multiple sites while a shared model is coordinated across them, with arrows indicating that only model updates are exchanged rather than copying raw data. The middle shows data silos across different industries, such as finance, healthcare, and the public sector.

图 1. 联邦计算可确保数据就位,通过模型更新实现协作,同时支持合规性和隐私保护。

正在进行重构的悬崖:为什么 FL 项目会停滞

领航员完成试飞后,参赛队伍通常会击中两个悬崖之一:

  • 代码瓶颈: 将正在进行的 PyTorch/TensorFlow/Lightning 训练转换为联邦学习(FL)可能需要大规模重构,包括引入新的抽象、消息传递逻辑以及特定框架的适配代码。
  • 生命周期悬念: 即使在仿真工作时,也会通过作业重新定义、重新配置和特定于环境的分支来重写转向 PoC 和生产触发器的操作。

FLARE 通过将工作流程标准化为两个步骤,使这两个悬崖变得扁平化:

  1. 使脚本联合 (客户端 API)
  2. 将其作为可移植作业 (作业配方) 执行

预期体验是明确地将这些内容结合起来,以便您可以快速从零开始执行可操作的联合作业。

第 1 步:将本地训练脚本转换为联邦客户端 (客户端 API)

适用对象:使用现有训练代码且希望尽可能缩小差距的从业者和 ML 工程师。

思维模型有意简单化:

  1. 初始化客户端运行时
  2. 作业运行时的循环
  3. 接收当前的全局模型
  4. 本地训练 (您的代码)
  5. 将更新后的权重和指标发送回

FLARE 的客户端 API 旨在尽可能减少代码更改,并避免强制您进行繁重的“执行程序/ 学习者”继承 – 使用 FLModel 结构或简单的数据交换与运行时通信。

示例 1a:将 PyTorch 转换为 FLARE

以下是您可以应用于许多脚本的具体模式。关键接触点是:flare.init()flare.receive()、加载模型权重,以及 flare.send() 与更新的权重和指标。

左侧显示本地训练代码,右侧显示联邦版本,并突出显示:import、flare.init()receive()send()

train.py

# train.py

import torch
import torchvision
import torchvision.transforms as transforms

from model import Net

batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
   [
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ]
)

train_dataset = torchvision.datasets.CIFAR10(
   root="/tmp/data/cifar10", transform=transform, download=True, train=True
)

trainloader = torch.utils.data.DataLoader(
   train_dataset, batch_size=batch_size, shuffle=True
)

model.to(device)

for epoch in range(epochs):
   running_loss = 0.0

   for i, batch in enumerate(trainloader):
       images, labels = batch[0].to(device), batch[1].to(device)

       optimizer.zero_grad()

       predictions = model(images)
       cost = loss(predictions, labels)
       cost.backward()
       optimizer.step()

       running_loss += cost.cpu().detach().numpy() / batch_size

       if i % 3000 == 2999:
           print(
               f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
           )
           running_loss = 0.0

   print(
       f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
   )

print("Finished Training")

torch.save(model.state_dict(), "./cifar_net.pth")
client.py

# client.py

# 1. Import client API
import nvflare.client as flare
import torch
import torchvision
import torchvision.transforms as transforms

from model import Net

batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
   [
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ]
)

train_dataset = torchvision.datasets.CIFAR10(
   root="/tmp/data/cifar10", transform=transform, download=True, train=True
)

trainloader = torch.utils.data.DataLoader(
   train_dataset, batch_size=batch_size, shuffle=True
)

# 2. Initialize FLARE
flare.init()

# At each round while FLARE is running
while flare.is_running():
   # 3. Receive the global model
   input_model = flare.receive()

   # 4. Load global model
   model.load_state_dict(input_model.params)
   model.to(device)

   for epoch in range(epochs):
       running_loss = 0.0

       for i, batch in enumerate(trainloader):
           images, labels = batch[0].to(device), batch[1].to(device)

           optimizer.zero_grad()

           predictions = model(images)
           cost = loss(predictions, labels)
           cost.backward()
           optimizer.step()

           running_loss += cost.cpu().detach().numpy() / batch_size

           if i % 3000 == 2999:
               print(
                   f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
               )
               running_loss = 0.0

       print(
           f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
       )

   print("Finished Training")

   torch.save(model.state_dict(), "./cifar_net.pth")

   # 5. Send back the updated model
   output_model = flare.FLModel(
       params=model.cpu().state_dict(),
       meta={"NUM_STEPS_CURRENT_ROUND": len(trainloader) * epochs},
   )
   flare.send(output_model)

示例 1b:PyTorch Lightning 客户端 Lightning 集成保持不变

Lightning 集成保持相同的目标(接收全局模型、训练、发送更新),但以符合 Lightning 风格的方式呈现:导入 Lightning 客户端适配器并对 Trainer 进行修补。

典型流程为:导入、修补、(可选)验证,然后正常训练。

# lightning_client.py

import pytorch_lightning as pl

from pytorch_lightning import Trainer

import nvflare.client.lightning as flare  # Lightning Client API 

from model import LitNet

from data import CIFAR10DataModule

def main():

   model = LitNet()

   dm = CIFAR10DataModule()

   trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1)

   # Patch trainer to participate in FL

   flare.patch(trainer)

   while flare.is_running():

       # Optional: validate current global model (useful for server-side selection flows)

       trainer.validate(model, datamodule=dm)

       # Train starting from received global model (handled internally after patch)

       trainer.fit(model, datamodule=dm)

if __name__ == "__main__":

   main()

要点:Lightning 用户不必进入自定义联邦消息传递,他们会保留 Trainer 的抽象,并且仍然正确地参与 FL 轮次。

第 2 步:在任何位置打包并执行联合作业 (作业方案)

目标受众:希望代码优先的工作定义在各种环境中保持稳定的数据科学家和应用团队。

第 1 步之后,您将拥有一个联邦客户端脚本。第 2 步使其成为联合作业,您可以重复运行,并在整个生命周期中保持整洁。

作业方案旨在将基于 JSON 的作业配置替换为基于 Python 的作业定义:

  • 代码优先: 在 Python 中定义完整的 FL 作业,而非复杂的配置文件
  • 只需编写一次,即可在任意位置运行:相同的方法可在模拟器、PoC 或生产环境中运行
  • 加快部署速度: 从实验转向部署,无需更改代码结构

示例 2a:在仿真中执行 FedAvg recipe

关键关联在于,您的 recipe 会引用您在步骤 1 中创建的客户端训练脚本 (例如 train_script="client.py") ,然后在环境中执行。

# job.py

from nvflare.app_common.workflows.job import FedAvgRecipe

from nvflare.job_config import SimEnv  # exact import path can vary by NVFlare version

from model import SimpleNetwork

def main():

   n_clients = 3

   num_rounds = 5

   batch_size = 32

   recipe = FedAvgRecipe(

       name="hello-pt",

       min_clients=n_clients,

       num_rounds=num_rounds,

       model=SimpleNetwork(),

       train_script="client.py",  # <-- Step A script

       train_args=f"--batch_size {batch_size} --epochs 1",

   )

   env = SimEnv(num_clients=n_clients, num_threads=n_clients)

   recipe.execute(env=env)

if __name__ == "__main__":

   main()

这是在实践中的“一次性写入”理念:一旦 recipe 正确引用了您的客户端脚本,其余部分就会成为执行问题。

示例 2b:通过环境交换从仿真过渡到现实世界。

工作方案通过交换执行环境,将渐进式工作流程规范化:

  1. SimEnv (模拟): 易于开发、快速调试
  2. PocEnv (概念验证): 本地运行时、多进程、真实测试
  3. ProdEnv (生产): 在安全、可扩展的基础设施上进行分布式部署

Alt text: Figure shows a three-stage JobRecipe pipeline flowing into three execution environments. A box labeled “JobRecipe” at the top splits into three arrows pointing to side-by-side panels: SimEnv (Simulation) for easy development and rapid debugging, PocEnv (Proof-of-Concept) for realistic multi-process testing in a local runtime, and ProdEnv (Production) for secure distributed deployment.

图 2. 一个 JobRecipe,多个执行环境:在 SimEnv 中调试,在 PocEnv 中验证,然后在 ProdEnv 中部署,而无需重写作业定义

开始使用

  • 从您已信任的脚本开始。
  • 第 1 步:添加客户端 API 握手 (或修补 Lightning Trainer) 。
  • 第 2 步:将其包装在作业配方中,先在仿真中执行,然后在 PoC 中执行,然后通过交换环境进行生产。

FLARE 新闻报道

FLARE 已在实际场景中逐步部署,包括礼来 TuneLab 的联邦学习平台(基于 Rhino Federated Computing 构建的 NVFlare)到 台湾 MOHW 的国家医疗健康联邦学习计划,以及 Tri-labs(Sandia/LANL/LLNL) 在敏感数据集之间开展的联邦 AI 试点项目。

深入了解

从您已信任的脚本开始。添加最小的 FLARE 客户端握手 (接收+ 训练+ 发送) 。准备就绪后,即可从单节点仿真扩展到多站点部署。

    Logo

    分享最新的 NVIDIA AI Software 资源以及活动/会议信息,精选收录AI相关技术内容,欢迎大家加入社区并参与讨论。

    更多推荐