Implementing Federated Learning in Android

Deep Learning has gained intense attention in recent times. As the field expands, more and more research is done to expand its horizon. But, Deep Learning algorithms are notorious for requiring large amounts of data to perform well..

However, advances in technology have caused a rise in privacy concerns with companies using unethical means to collect data for their algorithms. This leads to people not willing to share data for improvisation.

Also, researchers and organisations are not able to share data. Crowdsourcing of data becomes difficult and leads to increased costs.

Federated Learning aims to solve the growing privacy concerns and allows application of crowdsourcing techniques to collect data.

What is Federated Learning?

The process of Federated Learning consists of the following cycle:
  1. The model on the user devices are trained locally on the user data (On-Device).

  2. The updated weights of the models from the user devices are sent to the server where they are averaged before being stored.

  3. These averaged weights are then used to update the global model which is pushed to the user devices.

This achieves the following aims:
  1. Sensitive user data (images, audio, etc.) stay on the device and are not uploaded to the server hence preserving the user’s privacy.
  2. The number of client devices is easy to increase, resulting in ease of data collection.

Implementing Federated Learning in Android

This tutorial will guide you through the process of implementing Federated Learning with Android Devices as the client. The tutorial will be divided into 5 parts:

  1. Creating a graph and checkpoint for the model

  2. Inference and Training on the Android Device

  3. Extracting weights from the on-device model

  4. Uploading Weights to the Server

  5. Performing Federated Averaging on the server

A. Creating a graph and checkpoint for the model

To run inference and training on Android, we'll need the model's metagraph and optionally, a checkpoint file.

For this tutorial, we'll use a simple Linear Regressor:

Let's generate the graph for the model:

              #Python 3

              1. import tensorflow as tf

              2. x = tf.placeholder(tf.float32, name='input')
              3. y_ = tf.placeholder(tf.float32, name='target')

              4. W = tf.Variable(5., name='W')
              5. b = tf.Variable(3., name='b')

              6. y = x * W + b
              7. y = tf.identity(y, name='output')

              8. loss = tf.reduce_mean(tf.square(y - y_))
              9. optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
              10. train_op = optimizer.minimize(loss, name='train')

              11. init = tf.global_variables_initializer()

              # Creating a tf.train.Saver adds operations to the graph to save and
              # restore variables from checkpoints.

              12. saver_def = tf.train.Saver().as_saver_def

              13. with open('graph.pb', 'wb') as f:
              14.   f.write(tf.get_default_graph().as_graph_def())
            

In line 14, tf.get_default_graph().as_graph_def() returns the Graphdef for the graph.

To generate the checkpoint files, use tf.train.Saver():

              saver = tf.train.Saver()
              //Training
              saver.save(sess, your_path + "/checkpoint_name.ckpt")
            

B. Inference and Training on Android

(i) Loading the model in the Android Application

1. We use TensorFlow's Java API to enable on - device training and inference. First, import the TensorFlow Package for Inference and Java API in your Android application by adding the following in your app’s build.gradle:

                    dependencies {
                    .
                    .
                    .
                    implementation 'org.tensorflow:tensorflow-android:1.13.1'
                    }
                    //Replace 1.13.1 with the latest version of the package
              

2. Then, import the graph in the application. To do this:

Create a variable of class org.tensorflow.Graph:

Graph graph = new Graph();

Place the .pb file generated before in the assets folder and import it as a byte[] array. Let the array's name be graphdef.
Now, load the graph from the graphdef:

graph.importGraphDef(graphdef);

3. To load the checkpoint, place the checkpoint files in the device and create a Tensor to the path of the checkpoint prefix:

checkpointPrefix = org.tensorflow.Tensors.create(“Path to checkpoint.ckpt”);

4. Now, load the checkpoint by running the restore checkpoint op in the graph:

sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();

5. Alternatively, initialize the graph by calling the init op:

sess.runner().addTarget("init").run();

(ii) Performing Inference using the model

First, create an input tensor:

Tensor x_train = Tensor.create(features);

Then perform inference by:

Tensor op_tensor = sess.runner().feed("input",input).fetch("output").run().get(0).expect(Float.class);
              

Copy this output to a float array using:

op_tensor.copyTo(output);
              

(iii) Training the model

First, create the tensors for the input and the labels:

Tensor x_train = Tensor.create(features);
Tensor y_train = Tensor.create(label);

Then, use the ‘train_op’ graph operation defined in the graph to train the graph:

sess.runner().feed("input", x_train).feed("target", y_train).addTarget("train_op").run();

C. Extracting weights from the on-device model

Extract the weights of the model by:

ArrayList <Tensor<?>> w1 = (ArrayList <Tensor<?>>) sess.runner().fetch("W").run();
ArrayList <Tensor<?>> b1 = (ArrayList <Tensor<?>>) sess.runner().fetch("b").run();

Now, save these weights to send them to the server

D. Uploading the Weights to the Server

To enable Federated Averaging, we'll need a server which can receive the weights. We host a server on Heroku and use Python for server functions.

We built the following function to upload weights to the server:

@app.route("/upload", methods = ['POST'])
def upload():
if flask.request.method == "POST":
  print("Uploading File")
  if flask.request.files["file"]:
    weights = flask.request.files["file"].read()
    weights_stream = io.BytesIO(weights)
    bucket = storage.bucket()
    #Uploading Files to Firebase
    print("Saving at Server")
    with open("delta.bin", "wb") as f:
      f.write(weights_stream.read())
    print("Starting upload to Firebase")
    with open("delta.bin", "rb") as upload:
      byte_w = upload.read()
      #Preprocessing data before upload. File to be sent to Firebase is named "Weights.bin"
    with open("Weights.bin", "wb") as f:
      pickle.dump(weights, f)
    with open("Weights.bin", "rb") as f:
      blob = bucket.blob('weight1')
      blob.upload_from_file(f)
      print("File Successfully Uploaded to Firebase")
      return "File Uploaded\n"
  else:
    print("File not found")
  

The following function uploads weights from the Android Application to the server:

    String lineEnd = "\r\n";
    String twoHyphens = "--";
    String boundary = "*****";
    private void uploadWeight(HttpURLConnection conn) {
        try {
            conn.setRequestMethod("POST");
            conn.setRequestProperty("Connection", "Keep-Alive");
            conn.setRequestProperty("ENCTYPE",
                    "multipart/form-data");
            conn.setRequestProperty("Content-Type",
                    "multipart/form-data;boundary=" + boundary);
            conn.setRequestProperty("file", "weights");

            DataOutputStream dos = new DataOutputStream(conn.getOutputStream());
            dos.writeBytes(twoHyphens + boundary + lineEnd);
            dos.writeBytes("Content-Disposition: form-data; name=\"file\";filename=\"" + "weights.bin" + "\"" + lineEnd);
            dos.writeBytes(lineEnd);

            //Write File
            dos.write(readArrayFromDevice());

            dos.writeBytes(lineEnd);
            dos.writeBytes(twoHyphens + boundary + twoHyphens + lineEnd);
            int serverResponseCode = conn.getResponseCode();
            serverResponseMessage = conn.getResponseMessage();
            Log.i("Response Message: ", serverResponseMessage);
            Log.i("Response Code: ", String.valueOf(serverResponseCode));
            dos.flush();
            dos.close();
        } catch (ProtocolException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            conn.disconnect();
        }
    }

where readArrayFromDevice() reads the weights stored on the device

E. Performing Averaging on the Server

Now, let's average the weights. To get started, take out the weights (W, b) from the uploaded weights.

Let the weights be (W1, W2, W3, ..., Wn) and (b1, b2, b3, ..., bn).

Then, the new averaged weights are:

Wavg = (W1 + W2 + W3 + ... + Wn) / n

and

bavg = (b1 + b2 + b3 + ... + bn) / n

Put these in a list

w_ls = [Wavg, bavg]

Now, recreate the model on the server:

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1, input_shape=(features_shape, ))
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.01), loss=tf.keras.losses.mean_squared_error)

Set the list of weights as the model's weights and generate the checkpoint file for it:

model.set_weights(w_ls)
saver = tf.train.Saver()
sess = tf.keras.backend.get_session()
save_path = saver.save(sess, "model.ckpt")

Send these files to the Android application and store them in the folder which has the checkpoint prefix. These weights will be loaded when the application runs.

Hence, the process of Federated Learning is Successfully completed.