1 Introduction
1.1 Background
Increasingly, cell phones and tablets are the primary computing devices for many people. The powerful sensors on these devices (including cameras, microphones, and GPS), coupled with the fact that they are often carried around, means that they have access to an unprecedented number oflarge amount of dataMost of them are essentiallyprivate. Models learned from this data hold the promise of greatly improving usability by supporting smarter applications, but the sensitivity of the data means there are risks and liabilities associated with storing it in a centralized location.
1.2 Contribution of this paper
The main contributions of this paper are
-
The problem of training decentralized data from mobile devices (federated learning) is identified as an important research direction;
-
Select the simple and practical algorithm FedAvg that can be applied to this setup;
-
Extensive empirical evaluation of the proposed methodology.
More specifically, this paper presents the FedAvg algorithm, which combines local stochastic gradient descent (SGD) on each client with a server that performs model averaging. The paper conducts extensive experiments with the algorithm and demonstrates its usefulness fordisequilibriumcap (a poem)Distribution of non-IID datais robust and can reduce by several orders of magnitude the number of communication rounds required to train deep networks on decentralized data.
1.3 Ideal questions for federal learning
- Training on data from real-world mobile devices has significant advantages over training on agent data available in data centers;
- The data is private or the amount of data is large;
- For supervised tasks, labels can be naturally inferred from user interactions.
Examples:
- Image categorization task. Predict which photos are most likely to be viewed or shared multiple times in the future. The photos taken by the user are private, but for local, the user's behavior of deleting, sharing, etc. the photos are the inferred labels.
- Word Prediction. The input method predicts the next word as the user types on the phone. The input information is private and the next word chosen by the user is the inferred label.
1.4 Federated Learning vs. Distributed
- Non-independent and homogeneous distribution: different users use mobile devices differently, so the data are not independently and homogeneously distributed.
- Imbalance: some users will use the service or application more frequently than others, resulting in different amounts of local training data.
- Massively distributed: the number of clients involved in the optimization is expected to be much larger than the average number of instances per client.
- Communication is limited: mobile devices are sometimes offline or on slow and expensive connections.
2 FedAvg
2.1 Loss function
For machine learning problems, for samples\((x_i,y_i)\)The loss of the\(f_i(w)\), then the global loss is defined as:
In the federal learning problem, it is assumed that there\(K\)client, the first\(k\)The dataset of the individual clients is\(P_k\)Data set size\(n_k=|P_k|\). Then for the client\(k\), the loss function for this client data is:
The global loss function is defined as a weighted average of client-side losses:
2.2 Communication costs versus computational costs
For the case where the data is centralized to the center, the communication cost is relatively small and the computational cost is large due to the large amount of data.
Communication cost refers to the cost required to transfer data between the client and the central server. In federated learning, there are bandwidth limitations of mobile devices, while clients are usually willing to participate in optimization only when they have power and WiFi, etc., so the communication cost is large. Whereas, features such as small amount of data on the device and availability of GPU on the phone make the computational cost smaller.
methods in order to minimize communication costs:
- Increase parallelism, using more clients per round (corresponding to the "clients are usually only willing to participate in optimization if they have power, WiFi, etc." restriction).
- Each client performs more complex computations between each communication round rather than performing simple computations like gradient calculations.
2.3 Related work
Previous work has not considered unbalanced and non-independently co-distributed data, as well as the small number of clients.
2.4 FedSGD
Based on the current model\(w_t\)Calculating the gradient\(g_k=\nabla F_k(w_t)\). As:
Then Center Server Aggregationgradientand update the results as:
The above equation is also equivalent to the fact that the client first does a local gradient update, and the central server then performs a gradient update on themouldPerform a weighted average:
2.5 FedAvg
After writing it in the second form above, you can, before doing the averagingseveral timesIterative local updates:
Each client can compute the above equation multiple times to get a local in the first\(t\)rounds of the final model, and finally the central server takes these localmouldPolymerization is performed to obtain\(w^{t+1}\)。
This is the idea behind FedAvg. The algorithm has three main hyperparameters:
- \(C\): Proportion of clients selected at a time
- \(B\): batchsize for local training when\(B=\infty\)i.e., full batch
- \(E\): Number of local training rounds
(coll.) fail (a student)\(B=\infty,E=1\)When FedAvg and FedSGD are equivalent
The number of local updates per round is also defined here:\(u_k=E\frac{n_k}{B}\), from which it can also be calculated that the number of local updates per round of FedSGD is 1.
Complete pseudo-code:
At this point we can briefly compare FedSGD and FedAvg:
arithmetic | local | server |
---|---|---|
FedSGD | Calculate the gradient of the round | Collect the gradient of the local, weight the average and use it as the gradient for the server to descend. |
FedAvg | Multiple gradient descent to get the local model for this round | Collect the local's model, weighted average as the model obtained in this round |
3 Experiments
3.1 Model initialization
aggregation parameter\(\theta\): by\(\theta w+(1-\theta)w^{'}\)The two models are aggregated to obtain the final model.
The figure on the left shows the use of two initial models\(w,w^{'}\)The losses obtained from training different data are shown on the right for both models using the same\(w\)Initializing the training different data, it can be seen that the loss is smaller on the right side and when\(\theta=0.5\)works best. Therefore, in federated learning experiments, each client needs to share the same initialization model.
3.2 Dataset and training task
Selection of datasets of appropriate size for the study ofhyperparameterization。
The first task is MNIST digit recognition using two models:
-
Multilayer Perceptron. 2 hidden layers with 200 units each, activated using ReLU.
199210 parameters: the image is\(28\times 28\)The first layer is 784 when converted to one dimension.\(784*200+Bias 200\)Second level\(200*200+offset 200\)Third floor\(200*10+offset 10\)
-
\(32*5*5\)Convolution +\(2*2\)Maximum Pooling+\(64*5*5\)Convolution +\(2*2\)Maximum pooling + 512 cells fully connected + ReLU + Softmax
Dataset segmentation:
- iid: divide 100 clients and each client receives 600 images.
- Non-iid: first sort the image by number and divide it into 200 fragments of size 300, giving each of the 100 clients 2 fragments, i.e., each client gets a share of the data containing only two numbers.
The split dataset has iid and non-iid, but both are balanced.
The second task is character prediction, using LSTM, which reads a line of characters to predict the next character.
The dataset is the full Shakespeare collection, with one client per speaking character, totaling 1146. For each client, the first 80% of rows are the training set and the last 20% of rows are the test set.
Dataset segmentation:
- iid: divides all text equally between each client.
- Non-iid: only what that role says per client.
learning rateset in\(10^{\frac{1}{3}}\)until (a time)\(10^{\frac{1}{6}}\)Interval.
3.3 Increased parallelism
\(C\)control the amount of parallelism, so first change the\(C\)。
Experiments were recorded on the number of communication rounds required for the MLP to reach 97 test set accuracy and for the CNN to reach 99 test accuracy.
Use small batches when\(C=0.1\)The results are already better at that time. In order to balance the computational efficiency and convergence speed, the experiments afterwards fix the\(C=0.1\)。
3.4 Increasing the amount of computation per client
In the FedAvg algorithm section, we have already pointed out that the number of local updates per round is\(u_k=E\frac{n_k}{B}\). The number of independently and identically distributed updates is set in the experiment to be the expected number of updates, i.e.\(u=E\frac{n}{kB}\)。
First, for both tasks, add\(B\)Both reduce the number of communication rounds.
For the MNIST task, the iid effect is more significant than non-iid. In real life the numbers on our devices will not be regular either, so this situation is an argument for the robustness of the method.
For the Shakespeare dataset, non-iid works well, and this represents the distribution of data we have in real life (the number of words spoken varies greatly from person to person). The speculation is that certain clients have larger datasets, making local training more valuable.
3.5 FedSGD vs FedAvg
As can be seen, FedAvg not only reduces the number of communication rounds, but also improves the test accuracy (the blue solid line is FedSGD). The speculation is that the model on average produces gains similar to dropout regularization.
3.6 Whether the client can be over-optimized
For very large numbers of local iterations, FedAvg may stagnate or diverge. This result suggests that for some models, especially in the later stages of convergence, it may be beneficial to reduce the amount of local computation per round (i.e., decrease E or increase B), as in the case of decaying the learning rate.
3.7 CIFAR experiment
The dataset contains 50,000 training data and 10,000 test data, which are divided equally among 100 clients, each containing 500 training data and 100 test data.
The model used is two convolutional layers + two fully connected layers + one linear transformation layer.
The image will be cropped to\(24*24\)Pre-processing such as left and right inversion, adjusting contrast and brightness.
SGD on a single machine compared to FedSGD and FedAvg on 10 clients:
Existing models, which already have high test accuracies for the CIFAR classification task, only need to reach about 80% here, the reason being that the goal of this paper is to evaluate the FedAvg method, not to improve the CIFAR test accuracy.
Effects of different learning rates:
3.8 Large-scale LSTM experiments
To demonstrate that the methods work on real-world problems, experiments were also conducted on the large-scale task of predicting the next word.
The training dataset consists of 10 million public posts from large social networks. Posts were grouped by author, totaling over 500,000 clients. The paper limits the dataset to a maximum of 5,000 words per client and tests the data on 10,000 authors.
Link to original article:/abs/1602.05629