Hey Ovidiu!
So it seems what you are looking for is something more dynamic, beyond compilation.
For some more context, the way Shardy is being integrated in the XLA compilation pipeline, it's supposed to be a drop-in replacement to GSPMD, dealing with compiling static graphs. We will not be restructuring any other parts of PJRT/XLA/etc as part of this integration. Also note that neither GSPMD nor Shardy solves anything with regards to dynamic distributed execution.
In order to handle the case of starting with N devices then dynamically moving towards 2*N devices, this would be something either someone in Python or some sort of runtime would need to handle - I'm not aware of any runtime that you can use that does this, but I may be wrong!
You can maybe get around doing this in Python in a few ways, but this totally depends on how the sharding strategies are set up. Let's keep it simple and say you are just doing data parallelism.
If you start with N devices, let's say you increase by k devices (where N % k == 0). Then what you can do is `jax.jit` a program on a mesh with k devices. Then run more and more of those programs asynchronously.
First with N/k amount, then N/k + 1, up until N/k + N/k -> 2N/k.
Another option is you pre-compile various programs, and substitute them in. Let's say we have N devices, and are doing just data parallelism again. You compile a program for N devices, then another for N+k, N+2k, ...N+N devices.
You then shard your samples across whatever amount of devices, and look up the right program you need in this number_of_devices -> executable Python dictionary to run it.
If we are considering model/tensor parallelism+data parallelism then the first option may not work, but the second one still should!
Hopefully that clarifies it. Let me know if you have any more Qs!
- Bart