연합학습(Federated Learning)

Introduction of Federated Learning

재호아빠 2022. 11. 4. 17:55

1. 연합학습(Federated Learning)

연합학습(FL)은 많은 디바이스(이하 클라이언트)가 각자 보유하고 있는 데이터를 바탕으로 모델을 학습하여, 데이터를 직접 공유하지 않으면서, 중앙 서버 또는 마스터 노드의 관리(Orchestration) 하에 전체 모델(Global model)을 학습하는 기계학습 기법이다. 연합학습은 중앙 집중형 데이터 수집을 최소화하는 것을 원칙으로 구현하며, 전통적인 중앙 집중식 기계 학습으로 인해 발생하는 많은 개인 정보 보호 위험과 학습에 필요한 비용을 줄일 수 있다[1].

연합학습의 기본 메커니즘은 1) 중앙 서버가 Client들에게 모델을 보내주는 것부터 시작한다. 물론 이 때, 학습에 참여 가능한 client를 선별을 먼저 진행한다. 연합학습에서는 Client와 계속 통신이 가능한지, 학습에 필요한 시간을 충족하는 지 등 Communication Cost도 고려해야한다[2]. 2) 학습에 참여하는 각 클라이언트들은 보유하고있는 데이터를 바탕으로 서버로부터 받은 Global 모델에 독립적으로 학습을 진행한다. 3) 학습이 완료되면, 학습한 모델을 Server로 전송한다(경우에 따라서는 학습에 사용한 데이터의 크기도 함께 보낸다). 4) Server는 Client들로부터 받은 모델을 특정 전략에 따라 취합한다(e.g. FedAvg). 5) 과정 1부터 4까지를 반복하고, 이를 Global Epoch 또는 Global Iteration이라고 한다.

[그림1] 연합학습 Flow

이와 유사한 분산학습과 가장 차이가 있는 부분은 분산 학습은 I.I.D한 데이터를 컴퓨팅 파워를 병렬화 하는것을 목적으로한다. 반면, 연합학습은 각 클라이언트의 데이터 분포가 Non-i.i.d 하고, 데이터를 직접 공유하지 않는 환경에서 컴퓨팅 자원의 공유 없이 학습을 진행한다[3]. 언급된 연합학습의 Non-i.i.dness로 인하여 기존의 기계학습 방법이 잘 적용되지 않는 문제가 발생하고있다. 또한, 연합학습은 Global model update를 위해 네트워크 통신을 통해 정보를 자주 주고 받아야 한다는 특징이 있으며, 매 학습 시 참여하는 클라이언트가 달라 질 수 있음을 가정한다.

2. Non-Independent and Identically Distributed (Non-I.I.D)

기존의 중앙 집중형 학습 및 분산학습과 연합학습의 가장 큰 차이점은 데이터의 Non-Independent and Identically Distributed(Non-I.I.D)에 있다. Non-I.I.D는 데이터 분포가 독립적이지 않고, 동일하지 못하는 것을 말한다. 현재 기존의 다양한 인공지능 학습방법이 잘 적용되지 않는 문제의 원인으로 Non-I.I.D 이슈가 주요하다. Non-I.I.D 케이스는 5가지 정도로 정리할 수 있는데, 이는 아래 표와 같다.

Client Distribution 설명 예시 연구 방향
Feature Distribution Skew 클라이언트에서 보유하고 있는 X 데이터 분포 자체가 다른 경우 MNIST 데이터에서 어떤 클라이언트는 특정 숫자에 대한 데이터가 없음 (존재는 알고 있으나, 데이터를 본적은 없음) Optimization, Knowledge Distillation, Accumulation Method(e.g. FedAvg, FedProxy)
Label Distribution Skew 클라이언트에서 보유하고 있는 Y 데이터 분포 자체가 다른 경우 Classifier logit의 output vector 크기 자체가 다름 (존재 자체를 모름) Continual Learning
Same Label, 
Different Features
Y가 같더라도, X의 분포가 다름 Home이라는 기준이 지역 마다 다름. 미국은 단독주택을 의미하지만, 한국은 아파트의 비율이 높음 Personalized Federated Learning
Same Features,
Different Label
입력 X 분포가 같지만,
타겟 Y가 다름
같은 영화를 보고 서로 느끼는 바가 다른 경우 Personalized Federated Learning
Quantity Skew 보유하고 있는 데이터의 전체 양이 다르거나, 클래스의 비율이 다름 -  

Client Drift

연합학습은 매 Global round 마다 다른 클라이언트들과 통신을 통해 모델을 학습한다. 여기서 통신을 한다는 것은 많은 비용이 소요된다. 네트워크 통신을 위한 delay 뿐만 아니라, 클라이언트가 학습이 완료될때까지 기다려야하는 비용까지 포함한다. 물론, 학습이 너무 오래 걸리는 경우 또는 통신이 끊어진 경우 등을 시스템 개발 이슈로 해결 가능하지만, 연구 목적에서는 이러한 조치가 적절하지는 않다. 따라서 최대한 communication cost가 적게 드는 방법을 고려해야한다.

그 중 하나로 local iteration을 충분히 늘리는 방법으로 연구가 진행되고있다. 데이터가 I.I.D 한 상황에서는 이 방법이 효과적이다. 그러나, Non-I.I.D 한 상황에서 Client Drift 문제가 발생한다. Client Drift는 Local iteration을 많이 진행하면서, Client의 가중치가 Global 모델이 이상적으로 학습해야하는 가중치 방향과 다르게 학습되는 현상을 말한다[4].

[그림2] Client Drift Example

[그림2]를 통해 이 현상에 대해 더 직관적으로 잘 이해 할 수 있다. IID 셋팅에서는 Client마다 데이터 분포가 비슷하기 때문에 모델의 초기값이 같고, SGD와 같이 학습하는 step 사이즈가 같다면 Global 모델과 client 모델의 가중치가 크게 벗어나지 않는다. 그러나 Non-IID 셋팅처럼 각 클라이언트별 데이터 분포(데이터의 양, 또는 데이터 분포, 클래스 분포 등)가 다르다면, local step을 거칠 수록 Global모델과 client 모델의 가중치가 크게 다를 수 있고, 이로 인해 평균(FedAvg)을 했을 때 Global 모델이 central 환경에서 학습했을 때와 크게 다를 수 있다. 지금까지 연구들은 이 문제를 해결하기 위해 Global model과 client weight가 divergence 하지 않게 regularize하는 방법으로 진행하고있다[5-8]. 다만, 극단적인 Non-IID 환경의 환경에서는 아직까지 한계가 많다. 특히 정말 필요한 정보를 놓치게 되어 너무 General하게 학습하면서 전체 모델의 성능이 떨어지는 현상이 발생하고있다.

Hypothesis Conflict

Hypothesis는 인공지능에서 어떤 정보가 사상(mapping)되는 공간을 의미한다. 즉, 기계학습과 딥러닝에서는 모델을 의미하는데, 연합학습에서 Non-I.I.D 데이터로 학습을 진행하면 서로 다른 두 가설 공간이 충돌하는 것을 말한다[9]. 아래 [그림3]의 예시를 통하면 더 직관적으로 이해 가능하다. 2차원 평면위(hypothesis space)에 데이터가 (a)의 좌측 처럼 구분되어있다고 가정한다. 이 때, 전체 데이터를 관측이 가능하면, classifier는 (a)의 오른쪽 그림과 같이 결정경계(dicision boundary)를 학습 가능 할 것이다. 그러나 Non-IID한 연합학습 환경에서는 (b)와 같이 데이터의 일부만 관측 가능 할 가능성이 높다. 즉, Client 본인의 데이터에서 최적의 결정경계를 얻는다 하더라도, 다른 client들의 모델과 평균 할 경우, 두 모델의 결정경계가 상충(conflict)하여 제대로된 결정경계를 찾을 수 없게 된다. 이 문제 또한 극단적인 Non-IID 환경일수록 발생 가능성이 높아진다.

[그림3] Hypothesis conflict between two sites in a distributed setting

 

3. FedAvg의 문제

FedAvg는 클라이언트로부터 받은 가중치 또는 모델을 평균하는 방법이다[10]. FedAvg를 정의하고 있는 원 논문[11]에서는 클라이언트가 보유하고 있는 데이터 크기(\(n_k \over n\))의 비율대로 가중치(weight; \(w\))를 반영하는 것으로 정의한다.

 

$$ w_{t+1} \leftarrow \sum^K_{k=1} {n_k \over n} w^k_{t+1} $$

 

다만 논문에 따라서 FedAvg를 약간 변형하여 적용하는 경우도 있다. 평균하는데 있어서 데이터 크기가 아닌 클라이언트 수로 나누거나, 가중치의 평균이 아닌 gradient의 평균을 구하는 방법도 존재한다. Weight의 평균을 구하는 것 대신 Gradient의 평균을 구하는 것은 동일하다는 것이 수학적으로 증명되었으나[11], 데이터 크기 대신에 클라이언트 수로 평균하는 것이 동일하다는 것은 증명되지 않았다(적어도 저는 못찾았습니다).

FedAvg는 연합학습에서 가장 보편적으로 사용되고 있는 aggregation method지만, 위에서 언급했던 Client drift와 Hypothesis conflict 문제를 야기하는 주요한 원인이다. 따라서, 현재 연합학습 연구에서 client drifit와 hypothesis conflict를 회피하려는 연구가 굉장히 많이 등장하고 있다. 모델을 평균하는 과정에 있어서 regularization 방법을 적용하여 client drift를 방지하는 연구[5-8]와 학습 과정에서 모델의 가중치를 보정하여 평균을 진행 했을때 Hypothesis가 충돌하지 않게 방지하는 연구[12-14]가 주로 진행되고있다.

4. 정리

본 글에서는 연합학습에 대해 간단한 정의를 내리고, 현재 연합학습에서 발생하고 있는 다양한 문제 중 Non-IID 환경에 따라 발생하는 문제에 대해 정리했다. 물론 이 글에서 언급했던것 외에 더 많은 문제가 존재하고, 이를 해결하기 위한 다양한 방법들이 제시되고 있으나 현재 필자가 연구하는 방향에 국한하여 자료조사를 수행하였고 그에 대한 정리 자료 정도로 이해해주면 좋을 것 같다. 특히나 연합학습에서 FedAvg를 함으로써 발생하는 문제가 너무나도 다양하고, 이를 해결 가능한 뚜렷한 방법이 등장하고 있지 않으며 모델이나 데이터 종류에 따라서 기존에 제시되었던 방법들도 잘 활용하기 어려운 점 등 문제가 무궁무진하다. 이를 잘 정의하고 해결하려고 하는 문제를 뚜렷하게 정의하여 제약사항을 잘 고려해야 할 듯하다. 본 글에서 사실과 잘못된 정보가 있거나, 추가 설명이 필요한 경우에는 댓글로 남겨주시기 바란다.

References

[1] Kairouz, Peter, et al. "Advances and open problems in federated learning." Foundations and Trends® in Machine Learning 14.1–2 (2021): 1-210.

[2] Yang, Qiang, et al. "Federated machine learning: Concept and applications." ACM Transactions on Intelligent Systems and Technology (TIST) 10.2 (2019): 1-19.

[3] Konečný, Jakub, Brendan McMahan, and Daniel Ramage. "Federated optimization: Distributed optimization beyond the datacenter." arXiv preprint arXiv:1511.03575 (2015).

[4] Zhao, Yue, et al. "Federated learning with non-iid data." arXiv preprint arXiv:1806.00582 (2018).

[5] Li, Tian, et al. "Federated optimization in heterogeneous networks." Proceedings of Machine Learning and Systems 2 (2020): 429-450.

[6] Karimireddy, Sai Praneeth, et al. "Scaffold: Stochastic controlled averaging for federated learning." International Conference on Machine Learning. PMLR, 2020.

[7] Chen, Mingzhe, et al. "Communication-efficient federated learning." Proceedings of the National Academy of Sciences118.17 (2021): e2024789118.

[8] Lin, Tao, et al. "Ensemble distillation for robust model fusion in federated learning." Advances in Neural Information Processing Systems 33 (2020): 2351-2363.

[9] Ozdayi, Mustafa Safa, Murat Kantarcioglu, and Rishabh Iyer. "Improving accuracy of federated learning in non-iid settings." arXiv preprint arXiv:2010.15582 (2020).

[10] Konečný, Jakub, et al. "Federated optimization: Distributed machine learning for on-device intelligence." arXiv preprint arXiv:1610.02527 (2016).

[11] McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial intelligence and statistics. PMLR, 2017.

[12] Li, Qinbin, Bingsheng He, and Dawn Song. "Model-contrastive federated learning." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.

[13] Yoon, Jaehong, et al. "Federated continual learning with adaptive parameter communication." (2020).

[14] Tenison, Irene, et al. "Gradient Masked Averaging for Federated Learning." arXiv preprint arXiv:2201.11986 (2022).