Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
_SERVER_IMAGE = flags.DEFINE_string(
"server_image", None, "Full path to the server Docker image"
)
_SIDECAR_IMAGE = flags.DEFINE_string(
"sidecar_image",
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0",
"Full path to the sidecar Docker image",
)
_TPU_TYPE = flags.DEFINE_enum(
"tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type"
)
Expand All @@ -52,6 +57,7 @@
False,
"If true, only print the generated YAML without deploying.",
)
_SIDECAR_DIR = "/tmp/sidecar_dir"


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -191,6 +197,7 @@ def run_deployment(
jobset_name,
gcs_bucket,
server_image,
sidecar_image,
template_file,
dry_run,
deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset,
Expand All @@ -202,6 +209,8 @@ def run_deployment(
context = {
"JOBSET_NAME": jobset_name,
"SERVER_IMAGE": server_image,
"SIDECAR_IMAGE": sidecar_image,
"SIDECAR_DIR": _SIDECAR_DIR,
"GCS_SCRATCH_LOCATION": gcs_bucket,
"NUM_SLICES": num_slices,
"INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}",
Expand Down Expand Up @@ -246,6 +255,7 @@ def main(argv: Sequence[str]) -> None:
jobset_name=_JOBSET_NAME.value,
gcs_bucket=_GCS_BUCKET.value,
server_image=server_image,
sidecar_image=_SIDECAR_IMAGE.value,
template_file=_TEMPLATE_FILE.value,
dry_run=_DRY_RUN.value,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
ARG JAX_VERSION=0.10.0
# Use the JAX image with the custom-built sidecar as the base.
FROM us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_${JAX_VERSION}

ARG JAX_VERSION

# Set the working directory
WORKDIR /app

# 1. Upgrade pip and build tools
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --upgrade pip setuptools wheel

# 2. Copy ONLY requirements first to leverage Docker layer caching.
COPY maxtext/src/dependencies/requirements/base_requirements/requirements.txt ./requirements.txt

# ADD THE CACHE MOUNT HERE
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -r requirements.txt && \
uv pip install --upgrade jax==${JAX_VERSION} jaxlib==${JAX_VERSION}

# 3. Copy ONLY the actual MaxText source code
COPY maxtext/src /app/maxtext/src

# Ensure MaxText src and Orbax are in PYTHONPATH
ENV PYTHONPATH=/app/maxtext/src:/app/orbax/checkpoint:$PYTHONPATH
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,26 @@ class ProxyOptions:
use_insecure_credentials: Whether to use insecure gRPC credentials for the
proxy server.
xla_flags: A list of XLA flags to pass to the proxy server.
sidecar: Whether to use the worker sidecar or not.
"""
use_insecure_credentials: bool = False
xla_flags: list[str] = dataclasses.field(default_factory=list)
sidecar: bool = False

@classmethod
def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
"""Creates a ProxyOptions object from a list of 'key:value' strings."""
use_insecure = False
use_sidecar = False
xla_flags = []
for option in options or []:
if ":" in option:
key, value = option.split(":", 1)
key_strip = key.strip().lower()
if key_strip == "use_insecure_credentials":
use_insecure = value.strip().lower() == "true"
elif key_strip == "sidecar":
use_sidecar = value.strip().lower() == "true"
elif key_strip == "xla_flags":
val_strip = value.strip()
if (
Expand All @@ -78,7 +83,11 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
if xla_flags:
validators.validate_xla_flags(xla_flags)

return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags)
return cls(
use_insecure_credentials=use_insecure,
xla_flags=xla_flags,
sidecar=use_sidecar,
)


def _deploy_pathways_proxy_server(
Expand Down Expand Up @@ -134,6 +143,9 @@ def _deploy_pathways_proxy_server(
)
proxy_args_str = "\n" + proxy_args_str

if proxy_options.sidecar:
proxy_args_str += "\n - --sidecar_name=external"

template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ spec:
- --server_port=29005
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --cloud_pathways_sidecar_shm_directory=${SIDECAR_DIR}
env:
- name: TPU_MIN_LOG_LEVEL
value: "0"
Expand Down Expand Up @@ -133,8 +134,51 @@ spec:
limits:
google.com/tpu: "${CHIPS_PER_VM}"
volumeMounts:
- mountPath: /tmp
name: shared-tmp
- name: shared-tmp
mountPath: /tmp
- name: cache
mountPath: /tmp/checkpoints
- name: sidecar-shared-memory
mountPath: ${SIDECAR_DIR}
initContainers:
- name: colocated-python-sidecar
image: ${SIDECAR_IMAGE}
imagePullPolicy: Always
env:
- name: GRPC_SERVER_ADDRESS
value: '''0.0.0.0:50051'''
- name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY
value: ${SIDECAR_DIR}
- name: PYTHONUNBUFFERED
value: '1'
# --- High Verbosity Logging Variables ---
- name: LOGLEVEL
value: 'DEBUG'
- name: GLOG_minloglevel
value: '0' # 0 = INFO level base
- name: GLOG_v
value: '5' # Extreme verbosity for all C++ modules
- name: TF_CPP_MIN_LOG_LEVEL
value: '0'
- name: TF_CPP_MIN_VLOG_LEVEL
value: '5' # TF/XLA verbose logging
- name: TPU_MIN_LOG_LEVEL
value: '0'
- name: GLOG_vmodule
value: 'jax_array_handlers=5,type_handlers=5,tensorstore_utils=5'
# ----------------------------------------
ports:
- containerPort: 50051
protocol: TCP
resources: {}
restartPolicy: Always
volumeMounts:
- name: shared-tmp
mountPath: /tmp
- name: cache
mountPath: /tmp/checkpoints
- name: sidecar-shared-memory
mountPath: ${SIDECAR_DIR}
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
nodeSelector:
Expand All @@ -146,6 +190,12 @@ spec:
hostPath:
path: /tmp
type: DirectoryOrCreate
- name: cache
csi:
driver: multitier-checkpoint.csi.storage.gke.io
- name: sidecar-shared-memory
emptyDir:
medium: Memory
startupPolicy:
startupPolicyOrder: InOrder
successPolicy:
Expand Down
Loading