Collaborative mode is used for actual distributed federated learning across multiple devices or organizations.
Step 1: Server Setup
On the server machine, create a script run_server.py:
Run the server:
Step 2: Client Setup
On each client machine, create a script run_client.py:
Run the client on each participating machine:
5.3 Example: Image Classification with MNIST
Here's a complete example using the MNIST dataset for image classification, adapted for lung cancer detection, and to be used when running FedLearn in Collaborative Mode (5.2):
This example demonstrates how to use Cifer's FedLearn for a typical image classification task using the MNIST dataset. It includes data loading, model definition, federated learning execution, and final model evaluation and saving.
from cifer import *
from imutils import paths
import numpy as np
import os
import cv2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.optimizers import SGD
from keras.utils import to_categorical
import random
def load_mnist_byCid(cid):
data = []
labels = []
path = f"data/standalone"
img_paths = list(paths.list_images(path))
for imgpath in img_paths:
# Load the image in grayscale mode
img_grayscale = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE)
# Resize the image to 256x256; this can be adjusted based on user requirements
img_resized = cv2.resize(img_grayscale, (256, 256))
# Convert to 3D array and normalize
img = np.expand_dims(img_resized, axis=-1) / 255.0
# Extract the label from the image name according to the folder
label = imgpath.split(os.path.sep)[-2]
# Append the data to the list
data.append(img)
labels.append(label)
return np.array(data), np.array(labels)
def define_model(input_shape, num_classes):
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=input_shape))
model.add(MaxPool2D((2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(num_classes, activation='softmax'))
opt = SGD(learning_rate=0.01, momentum=0.9)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
return model
def setWeightSingleList(weights):
weights_flat = [w.flatten() for w in weights]
weights = np.concatenate(weights_flat).tolist()
return weights
def reshapeWeight(server_weight, client_weight):
reshape_weight = []
for layer_weights in client_weight:
n_weights = np.prod(layer_weights.shape)
reshape_weight.append(np.array(server_weight[:n_weights]).reshape(layer_weights.shape))
server_weight = server_weight[n_weights:]
return reshape_weight
def createRandomClientList(clients_dictionary, n_round_clients):
keys = list(clients_dictionary.keys())
return random.sample(keys, n_round_clients)
def train_model(cid, num_classes):
data, labels = load_mnist_byCid(cid)
labels = to_categorical(labels, num_classes=num_classes)
input_shape = (256, 256, 1) # Image size
model = define_model(input_shape, num_classes)
# Split data into training and testing sets
num_samples = len(data)
split = int(0.8 * num_samples)
x_train, x_test = data[:split], data[split:]
y_train, y_test = labels[:split], labels[split:]
# Train the model
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, batch_size=32)
# Evaluate the model
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Loss: {loss}, Test Accuracy: {accuracy}")
# Call the train_model function
if __name__ == "__main__":
client_id = 0 # Specify the client to train
num_classes = 10 # Number of classes in MNISTT
train_model(client_id, num_classes)
bash
python run_fedlearn_local.py
python
import json
import cifer.fedlearn as fedlearn
import cifer.server as server
def load_config(config_path='config.json'):
"""
Load the configuration file.
"""
try:
with open(config_path, 'r') as config_file:
config = json.load(config_file)
return config
except FileNotFoundError:
print(f"Configuration file {config_path} not found.")
except json.JSONDecodeError:
print("Error decoding JSON from the configuration file.")
except Exception as e:
print(f"An error occurred: {e}")
def main():
config = load_config()
if config:
# Initialize FedLearn object
fl = fedlearn.FedLearn(config)
# Start the FedLearn server
fl.start_server()
if __name__ == "__main__":
main()