W5XDE contains two main classes: CentralServer
and TrainingNode
.
The CentralServer
class is the central hub for distributed training, handling batch distribution and gradient aggregation.
model
(required): PyTorch model to traindataset
(required): PyTorch dataset for trainingbatch_size
(int): Size of training batches. Default: 16ip
(str): IP address to bind server. Default: "localhost". Use "0.0.0.0" for all interfacesport
(int): Port number. Default: 5555checkpoint_dir
(str): Directory for model checkpoints. Default: "checkpoints"checkpoint_interval
(int): Minutes between checkpoints. Default: 5secure
(bool): Enable encrypted communication. Default: Falsequeue_size
(int): Size of batch queue. Default: 1000
start()
: Starts the server (blocking call)
The TrainingNode
class connects to the CentralServer
to receive batches and send gradients.
model
(required): PyTorch model matching server's modelserver_address
(tuple): Server (ip, port). Default: ('localhost', 5555)secure
(bool): Enable encrypted communication. Default: Falsecollect_metrics
(bool): Enable performance metrics. Default: Falsecompress_gradients
(bool): Enable gradient compression. Default: Falsebatch_gradients
(bool): Batch gradients before sending. Default: True
train(loss_callback=None, network_callback=None)
: Start trainingloss_callback(loss_value: float, batch_id: str)
: Track training lossnetwork_callback(sent_bytes: int, received_bytes: int, comp_time: float, net_time: float, orig_size: int, comp_size: int)
: Track network performance