iterations=$($env_python-c"from math import floor; print($batches_per_epoch * $global_epochs) if $comm_rounds_per_global_epoch >= $batches_per_epoch else print($global_epochs * $comm_rounds_per_global_epoch)")
echo iterations: $iterations
# calculating the number of batches each user/proc uses per communication step (The actual number may be a float, which we round down)
batches_per_comm_round=$($env_python-c"from math import floor; x = floor($batches_per_epoch / $comm_rounds_per_global_epoch); print(1 if x==0 else x)")
# since the batches per communication round were rounded down we need to change the number of iterations to reflect that
new_iterations=$($env_python-c"from math import floor; tmp = floor($batches_per_epoch / $comm_rounds_per_global_epoch); x = 1 if tmp == 0 else tmp; y = floor((($batches_per_epoch / $comm_rounds_per_global_epoch)/x)*$iterations); print($iterations if y<$iterations else y)")
echo batches per communication round: $batches_per_comm_round