Skip to content
35 changes: 29 additions & 6 deletions src/google/adk/artifacts/gcs_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from google.genai import types
from typing_extensions import override

from . import artifact_util
from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
Expand Down Expand Up @@ -230,9 +231,21 @@ def _save_artifact(
content_type="text/plain",
)
elif artifact.file_data:
raise NotImplementedError(
"Saving artifact with file_data is not supported yet in"
" GcsArtifactService."
if not artifact.file_data.file_uri:
raise InputValidationError("Artifact file_data must have a file_uri.")
if artifact_util.is_artifact_ref(artifact):
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
raise InputValidationError(
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
)
# Store the URI as blob metadata; no content to upload.
blob.metadata = {
**(blob.metadata or {}),
"file_uri": artifact.file_data.file_uri,
}
blob.upload_from_string(
b"",
content_type=artifact.file_data.mime_type or None,
)
else:
raise InputValidationError(
Expand Down Expand Up @@ -263,15 +276,25 @@ def _load_artifact(
blob_name = self._get_blob_name(
app_name, user_id, filename, version, session_id
)
blob = self.bucket.blob(blob_name)
blob = self.bucket.get_blob(blob_name)
if blob is None:
return None

# If the artifact was saved as a file_data URI reference, restore it.
if blob.metadata and "file_uri" in blob.metadata:
return types.Part(
file_data=types.FileData(
file_uri=blob.metadata["file_uri"],
mime_type=blob.content_type or None,
)
)

artifact_bytes = blob.download_as_bytes()
if not artifact_bytes:
return None
artifact = types.Part.from_bytes(
return types.Part.from_bytes(
data=artifact_bytes, mime_type=blob.content_type
)
return artifact

def _list_artifact_keys(
self, app_name: str, user_id: str, session_id: Optional[str]
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/errors/input_validation_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class InputValidationError(ValueError):
"""Represents an error raised when user input fails validation."""

def __init__(self, message="Invalid input."):
def __init__(self, message: str = "Invalid input.") -> None:
"""Initializes the InputValidationError exception.

Args:
Expand Down
113 changes: 113 additions & 0 deletions tests/unittests/artifacts/test_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,119 @@ def test_converts_text_dict(self):
assert result.text == "hello world"


# ---------------------------------------------------------------------------
# GCS file_data (URI reference) tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio # type: ignore[untyped-decorator]
async def test_gcs_save_artifact_with_external_gcs_uri() -> None:
"""GcsArtifactService saves and loads a gs:// file_data URI reference."""
service = mock_gcs_artifact_service() # type: ignore[no-untyped-call]
artifact = types.Part(
file_data=types.FileData(
file_uri="gs://my-bucket/report.pdf",
mime_type="application/pdf",
)
)

version = await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="report.pdf",
artifact=artifact,
)
assert version == 0

loaded = await service.load_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="report.pdf",
)
assert loaded is not None
assert loaded.file_data is not None
assert loaded.file_data.file_uri == "gs://my-bucket/report.pdf"
assert loaded.file_data.mime_type == "application/pdf"


@pytest.mark.asyncio # type: ignore[untyped-decorator]
async def test_gcs_save_artifact_with_artifact_ref_uri() -> None:
"""GcsArtifactService saves and loads an internal artifact:// URI reference."""
service = mock_gcs_artifact_service() # type: ignore[no-untyped-call]
artifact_ref_uri = "artifact://apps/app/users/user1/sessions/sess1/artifacts/source.txt/versions/0"
artifact = types.Part(
file_data=types.FileData(
file_uri=artifact_ref_uri,
mime_type="text/plain",
)
)

version = await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="ref.txt",
artifact=artifact,
)
assert version == 0

loaded = await service.load_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="ref.txt",
)
assert loaded is not None
assert loaded.file_data is not None
assert loaded.file_data.file_uri == artifact_ref_uri


@pytest.mark.asyncio # type: ignore[untyped-decorator]
async def test_gcs_save_artifact_file_data_without_mime_type() -> None:
"""GcsArtifactService handles file_data with no mime_type."""
service = mock_gcs_artifact_service() # type: ignore[no-untyped-call]
artifact = types.Part(
file_data=types.FileData(file_uri="gs://my-bucket/data.bin")
)

version = await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="data.bin",
artifact=artifact,
)
assert version == 0

loaded = await service.load_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="data.bin",
)
assert loaded is not None
assert loaded.file_data is not None
assert loaded.file_data.file_uri == "gs://my-bucket/data.bin"


@pytest.mark.asyncio # type: ignore[untyped-decorator]
async def test_gcs_save_artifact_file_data_missing_uri_raises() -> None:
"""GcsArtifactService raises InputValidationError when file_uri is empty."""
service = mock_gcs_artifact_service() # type: ignore[no-untyped-call]
artifact = types.Part(file_data=types.FileData(file_uri=""))

with pytest.raises(InputValidationError):
await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="empty.bin",
artifact=artifact,
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type",
Expand Down