diff --git a/.gitattributes b/.gitattributes index 195ac9b4dd..7e20d7d98c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -171,6 +171,15 @@ clients/python/src/mr_openapi/models/serving_environment_list.py linguist-genera clients/python/src/mr_openapi/models/serving_environment_update.py linguist-generated=true clients/python/src/mr_openapi/models/sort_order.py linguist-generated=true clients/python/src/mr_openapi/rest.py linguist-generated=true +cmd/catalog-gen/templates/api/openapi_components.gotmpl linguist-generated=true +cmd/catalog-gen/templates/catalog/loader.gotmpl linguist-generated=true +cmd/catalog-gen/templates/models/artifact.gotmpl linguist-generated=true +cmd/catalog-gen/templates/models/entity.gotmpl linguist-generated=true +cmd/catalog-gen/templates/plugin/plugin.gotmpl linguist-generated=true +cmd/catalog-gen/templates/plugin/register.gotmpl linguist-generated=true +cmd/catalog-gen/templates/service/artifact_repository.gotmpl linguist-generated=true +cmd/catalog-gen/templates/service/filter_mappings.gotmpl linguist-generated=true +cmd/catalog-gen/templates/service/spec.gotmpl linguist-generated=true internal/converter/generated/embedmd_openapi_converter.gen.go linguist-generated=true internal/converter/generated/openapi_converter.gen.go linguist-generated=true internal/converter/generated/openapi_embedmd_converter.gen.go linguist-generated=true diff --git a/.github/workflows/check-openapi-spec-pr.yaml b/.github/workflows/check-openapi-spec-pr.yaml index 8bc44ab1cf..8a71a42afc 100644 --- a/.github/workflows/check-openapi-spec-pr.yaml +++ b/.github/workflows/check-openapi-spec-pr.yaml @@ -3,7 +3,10 @@ on: pull_request: paths: - ".github/workflows/**" - - "api/openapi/model-registry.yaml" + - "api/openapi/**" + - "catalog/plugins/*/api/openapi/**" + - "scripts/merge_openapi.sh" + - "scripts/merge_catalog_specs.sh" permissions: # set contents: read at top-level, per OpenSSF ScoreCard rule TokenPermissionsID contents: read diff --git a/Makefile b/Makefile index 7351720ef7..ac28eb7f6b 100644 --- a/Makefile +++ b/Makefile @@ -67,13 +67,18 @@ api/openapi/model-registry.yaml: api/openapi/src/model-registry.yaml api/openapi api/openapi/catalog.yaml: api/openapi/src/catalog.yaml api/openapi/src/lib/*.yaml bin/yq scripts/merge_openapi.sh catalog.yaml +api/openapi/catalog-spec.yaml: api/openapi/catalog.yaml $(wildcard catalog/plugins/*/api/openapi/openapi.yaml) bin/yq scripts/merge_catalog_specs.sh + scripts/merge_catalog_specs.sh catalog-spec.yaml + # validate the openapi schema .PHONY: openapi/validate openapi/validate: bin/openapi-generator-cli bin/yq @scripts/merge_openapi.sh --check model-registry.yaml || (echo "api/openapi/model-registry.yaml is incorrectly formatted. Run 'make api/openapi/model-registry.yaml' to fix it."; exit 1) @scripts/merge_openapi.sh --check catalog.yaml || (echo "$< is incorrectly formatted. Run 'make api/openapi/catalog.yaml' to fix it."; exit 1) + @scripts/merge_catalog_specs.sh --check catalog-spec.yaml || (echo "api/openapi/catalog-spec.yaml is incorrectly formatted. Run 'make api/openapi/catalog-spec.yaml' to fix it."; exit 1) $(OPENAPI_GENERATOR) validate -i api/openapi/model-registry.yaml $(OPENAPI_GENERATOR) validate -i api/openapi/catalog.yaml + $(OPENAPI_GENERATOR) validate -i api/openapi/catalog-spec.yaml # generate the openapi server implementation .PHONY: gen/openapi-server @@ -85,7 +90,7 @@ internal/server/openapi/api_model_registry_service.go: bin/openapi-generator-cli # generate the openapi schema model and client .PHONY: gen/openapi -gen/openapi: bin/openapi-generator-cli api/openapi/model-registry.yaml api/openapi/catalog.yaml openapi/validate pkg/openapi/client.go +gen/openapi: bin/openapi-generator-cli api/openapi/model-registry.yaml api/openapi/catalog.yaml api/openapi/catalog-spec.yaml openapi/validate pkg/openapi/client.go make -C catalog $@ pkg/openapi/client.go: bin/openapi-generator-cli api/openapi/model-registry.yaml clean-pkg-openapi bin/goimports diff --git a/api/openapi/catalog-spec.yaml b/api/openapi/catalog-spec.yaml new file mode 100644 index 0000000000..75a1576e15 --- /dev/null +++ b/api/openapi/catalog-spec.yaml @@ -0,0 +1,1545 @@ +openapi: 3.0.3 +info: + title: Model Catalog REST API + version: v1alpha1 + description: REST API for Model Registry to create and manage ML model metadata + license: + name: Apache 2.0 + url: "https://www.apache.org/licenses/LICENSE-2.0" +servers: + - url: "https://localhost:8080" + - url: "http://localhost:8080" +paths: + /api/model_catalog/v1alpha1/labels: + summary: Path used to get the list of catalog labels. + description: >- + The REST endpoint/path used to list zero or more `CatalogLabel` entities. + get: + summary: List All CatalogLabels + tags: + - ModelCatalogService + parameters: + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/labelOrderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" + responses: + "200": + $ref: "#/components/responses/CatalogLabelListResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: findLabels + description: Gets a list of all `CatalogLabel` entities. + /api/model_catalog/v1alpha1/models: + description: >- + The REST endpoint/path used to list zero or more `CatalogModel` entities from all `CatalogSources`. + get: + summary: Search catalog models across sources. + tags: + - ModelCatalogService + parameters: + - name: recommendations + in: query + description: Sort models by lowest recommended latency using Pareto filtering + required: false + schema: + type: boolean + default: false + - name: targetRPS + in: query + description: Target requests per second for latency calculations + required: false + schema: + type: integer + format: int32 + - name: latencyProperty + in: query + description: Property name for latency metric + required: false + schema: + type: string + default: "ttft_p90" + - name: rpsProperty + in: query + description: Property name for RPS metric + required: false + schema: + type: string + default: "requests_per_second" + - name: hardwareCountProperty + in: query + description: Property name for hardware count + required: false + schema: + type: string + default: "hardware_count" + - name: hardwareTypeProperty + in: query + description: Property name for hardware type grouping + required: false + schema: + type: string + default: "hardware_type" + - name: source + description: |- + Filter models by source. Multiple values can be separated by commas + to filter by multiple sources (OR logic). For example: + ?source=huggingface,local will return models from either + huggingface OR local sources. + schema: + type: array + items: + type: string + style: form + explode: true + in: query + required: false + - name: q + description: Free-form keyword search used to filter the response. + schema: + type: string + in: query + required: false + - name: sourceLabel + description: |- + Filter models by the label associated with the source. Multiple + values can be separated by commas. If one of the values is the + string `null`, then models from every source without a label will + be returned. + schema: + type: array + items: + type: string + in: query + required: false + - $ref: "#/components/parameters/filterQuery" + - $ref: "#/components/parameters/pageSize" + - name: orderBy + style: form + explode: true + examples: + orderBy: + value: ID + description: |- + Specifies the order by criteria for listing entities. + + Supported values are: + - CREATE_TIME + - LAST_UPDATE_TIME + - ID + - NAME + - ACCURACY + + Defaults to `NAME`. + + The `ACCURACY` sort will sort by the `overall_average` property in any linked metrics artifact. + + In addition, models can be sorted by properties. For example: + - `provider.string_value` sorts by provider name + - `artifacts.ifeval.double_value` sorts by the min/max value a property called ifeval across all associated artifacts + schema: + $ref: "#/components/schemas/OrderByField" + in: query + required: false + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" + responses: + "200": + $ref: "#/components/responses/CatalogModelListResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: findModels + /api/model_catalog/v1alpha1/models/filter_options: + description: Lists options for `filterQuery` when listing models. + get: + summary: Lists fields and available options that can be used in `filterQuery` on the list models endpoint. + tags: + - ModelCatalogService + responses: + "200": + $ref: "#/components/responses/FilterOptionsResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: findModelsFilterOptions + /api/model_catalog/v1alpha1/sources: + summary: Path used to get the list of catalog sources. + description: >- + The REST endpoint/path used to list zero or more `CatalogSource` entities. + get: + summary: List All CatalogSources + tags: + - ModelCatalogService + parameters: + - $ref: "#/components/parameters/name" + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/orderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" + responses: + "200": + $ref: "#/components/responses/CatalogSourceListResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: findSources + description: Gets a list of all `CatalogSource` entities. + /api/model_catalog/v1alpha1/sources/preview: + description: >- + The REST endpoint/path used to preview a catalog source configuration. This endpoint accepts a catalog source definition and returns a list of models with their inclusion/exclusion status based on the configured filters. + post: + summary: Preview catalog source configuration + description: |- + Accepts a catalog source configuration and returns a list of models showing + which would be included or excluded based on the configured filters. This allows + users to test and validate their source configurations before applying them. + + **Two modes of operation:** + + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files via multipart form. The models are read directly from + the uploaded `catalogData`, enabling preview of new sources before saving + anything to the server. This is ideal for testing configurations. + + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property. The models are read from the specified file path on the server. + Use this for previewing changes to existing saved sources. + tags: + - ModelCatalogService + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - config + properties: + config: + type: string + format: binary + description: |- + YAML file containing the catalog source configuration. + The file should contain a source definition with type and properties + fields, including optional includedModels and excludedModels filters. + + Model filter patterns support the `*` wildcard only and are case-insensitive. + Patterns match the entire model name (e.g., `ibm-granite/*` matches all + models starting with "ibm-granite/"). + catalogData: + type: string + format: binary + description: |- + Optional YAML file containing the catalog data (models). + + This field enables stateless preview of new sources before saving them. + When provided, the catalog data is read directly from this file instead of + from the `yamlCatalogPath` property in the config. + + **Two modes of operation:** + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files. The models are read from `catalogData`, allowing preview + without saving anything to the server. + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property pointing to a catalog file on the server filesystem. + + If both `catalogData` and `yamlCatalogPath` are provided, `catalogData` takes precedence. + examples: + statelessPreview: + summary: Stateless preview with uploaded catalog data + description: |- + Upload both config and catalogData files to preview a new source + before saving. This is the recommended approach for testing new configurations. + value: | + # config file content: + type: yaml + includedModels: + - "ibm-granite/*" + - "meta-llama/*" + excludedModels: + - "*-draft" + - "*-experimental" + + # catalogData file content (separate file): + models: + - name: ibm-granite/granite-3.0-8b-instruct + description: Granite 8B Instruct model + - name: ibm-granite/granite-3.0-2b-draft + description: Draft version + - name: meta-llama/Llama-2-7b-hf + description: Llama 2 7B + pathBasedPreview: + summary: Path-based preview using server-side catalog + description: |- + Upload only config file with yamlCatalogPath pointing to an + existing catalog file on the server. Use this for previewing + changes to existing saved sources. + value: | + type: yaml + includedModels: + - "ibm-granite/*" + - "meta-llama/*" + - "mistralai/*" + excludedModels: + - "*-draft" + - "*-experimental" + properties: + yamlCatalogPath: "models-catalog.yaml" + huggingfaceSource: + summary: HuggingFace catalog source + description: |- + Upload configuration for HuggingFace source with API credentials. + The API key is passed per-request and not persisted anywhere. + value: | + type: hf + includedModels: + - "microsoft/*" + - "google/*" + excludedModels: + - "*-gguf" + properties: + apiKey: "your-huggingface-api-key-here" # notsecret + modelLimit: 100 + parameters: + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/nextPageToken" + - name: filterStatus + description: |- + Filter the response to show specific model statuses: + - `all` (default): Show all models regardless of inclusion status + - `included`: Show only models that pass the configured filters + - `excluded`: Show only models that are filtered out + schema: + type: string + enum: + - all + - included + - excluded + default: all + in: query + required: false + responses: + "200": + $ref: "#/components/responses/CatalogSourcePreviewResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "422": + description: Unprocessable Entity - Invalid source configuration + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: previewCatalogSource + /api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name+}: + description: >- + The REST endpoint/path used to get a `CatalogModel`. + get: + summary: Get a `CatalogModel`. + tags: + - ModelCatalogService + responses: + "200": + $ref: "#/components/responses/CatalogModelResponse" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: getModel + parameters: + - name: source_id + description: A unique identifier for a `CatalogSource`. + schema: + type: string + in: path + required: true + - name: model_name+ + description: A unique identifier for the model. + schema: + type: string + in: path + required: true + /api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name}/artifacts: + description: >- + The REST endpoint/path used to list `CatalogArtifacts`. + get: + summary: List CatalogArtifacts. + tags: + - ModelCatalogService + responses: + "200": + $ref: "#/components/responses/CatalogArtifactListResponse" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: getAllModelArtifacts + parameters: + - name: source_id + description: A unique identifier for a `CatalogSource`. + schema: + type: string + in: path + required: true + - name: model_name + description: A unique identifier for the model. + schema: + type: string + in: path + required: true + - $ref: "#/components/parameters/artifactType" + - $ref: "#/components/parameters/artifact_type" + - $ref: "#/components/parameters/artifactFilterQuery" + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/artifactOrderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" + /api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name}/artifacts/performance: + description: >- + Lists performance metrics artifacts (that is, artifacts of type `CatalogMetricsArtifact` where `metricsType` is `performance-metrics`). + get: + summary: List CatalogArtifacts. + tags: + - ModelCatalogService + responses: + "200": + $ref: "#/components/responses/CatalogArtifactListResponse" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: getAllModelPerformanceArtifacts + parameters: + - name: source_id + description: A unique identifier for a `CatalogSource`. + schema: + type: string + in: path + required: true + - name: model_name + description: A unique identifier for the model. + schema: + type: string + in: path + required: true + - name: targetRPS + description: |- + Target requests per second. If specified, values for `replicas` and + `total_requests_per_second` will be calculated and returned as custom + properties. + schema: + type: integer + in: query + - name: recommendations + description: Filter records that are less optimal based on approximate cost to run and latency. + schema: + type: boolean + default: false + in: query + - name: rpsProperty + description: Custom property name for requests per second metric. + schema: + type: string + default: "requests_per_second" + in: query + - name: latencyProperty + description: Custom property name for latency metric (e.g., ttft_p90, p90_latency). + schema: + type: string + default: "ttft_p90" + in: query + - name: hardwareCountProperty + description: Custom property name for hardware count metric. + schema: + type: string + default: "hardware_count" + in: query + - name: hardwareTypeProperty + description: Custom property name for hardware type grouping. + schema: + type: string + default: "hardware_type" + in: query + - $ref: "#/components/parameters/artifactFilterQuery" + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/artifactOrderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" +components: + schemas: + ArtifactTypeQueryParam: + description: Supported artifact types for querying. + enum: + - model-artifact + - metrics-artifact + type: string + BaseModel: + type: object + properties: + description: + type: string + description: Human-readable description of the model. + readme: + type: string + description: Model documentation in Markdown. + maturity: + type: string + description: Maturity level of the model. + example: Generally Available + language: + type: array + description: List of supported languages (https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). + items: + type: string + example: + - en + - es + - cz + tasks: + type: array + description: List of tasks the model is designed for. + items: + type: string + example: + - text-generation + provider: + type: string + description: Name of the organization or entity that provides the model. + example: IBM + logo: + type: string + format: uri + description: |- + URL to the model's logo. A [data + URL](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data) + is recommended. + license: + type: string + description: Short name of the model's license. + example: apache-2.0 + licenseLink: + type: string + format: uri + description: URL to the license text. + libraryName: + type: string + example: transformers + customProperties: + description: User provided custom properties which are not defined by its type. + type: object + additionalProperties: + $ref: "#/components/schemas/MetadataValue" + BaseResource: + allOf: + - type: object + properties: + customProperties: + description: User provided custom properties which are not defined by its type. + type: object + additionalProperties: + $ref: "#/components/schemas/MetadataValue" + description: + description: |- + An optional description about the resource. + type: string + externalId: + description: |- + The external id that come from the clients’ system. This field is optional. + If set, it must be unique among all resources within a database instance. + type: string + name: + description: |- + The client provided name of the artifact. This field is optional. If set, + it must be unique among all the artifacts of the same artifact type within + a database instance and cannot be changed once set. + type: string + id: + format: int64 + description: The unique server generated id of the resource. + type: string + - $ref: "#/components/schemas/BaseResourceDates" + BaseResourceDates: + description: Common timestamp fields for resources + type: object + properties: + createTimeSinceEpoch: + format: int64 + description: Output only. Create time of the resource in millisecond since epoch. + type: string + readOnly: true + lastUpdateTimeSinceEpoch: + format: int64 + description: Output only. Last update time of the resource since epoch in millisecond since epoch. + type: string + readOnly: true + BaseResourceList: + required: + - nextPageToken + - pageSize + - size + type: object + properties: + nextPageToken: + description: Token to use to retrieve next page of results. + type: string + pageSize: + format: int32 + description: Maximum number of resources to return in the result. + type: integer + size: + format: int32 + description: Number of items in result list. + type: integer + CatalogArtifact: + description: A single artifact in the catalog API. + oneOf: + - $ref: "#/components/schemas/CatalogModelArtifact" + - $ref: "#/components/schemas/CatalogMetricsArtifact" + discriminator: + propertyName: artifactType + mapping: + model-artifact: "#/components/schemas/CatalogModelArtifact" + metrics-artifact: "#/components/schemas/CatalogMetricsArtifact" + CatalogArtifactList: + description: List of CatalogModel entities. + allOf: + - type: object + properties: + items: + description: Array of `CatalogArtifact` entities. + type: array + items: + $ref: "#/components/schemas/CatalogArtifact" + required: + - items + - $ref: "#/components/schemas/BaseResourceList" + CatalogLabel: + description: A catalog label. Labels are used to categorize catalog sources. Represented as a flexible map of string key-value pairs with a required 'name' field. + type: object + required: + - name + properties: + name: + type: string + nullable: true + description: The unique name identifier for the label. + displayName: + type: string + description: An optional human-readable name to show in place of `name`. + additionalProperties: + type: string + example: + name: huggingface + displayName: HuggingFace Hub + description: HuggingFace models with full support and legal indemnification. + CatalogLabelList: + description: List of CatalogLabel entities. + allOf: + - type: object + properties: + items: + description: Array of `CatalogLabel` entities. + type: array + items: + $ref: "#/components/schemas/CatalogLabel" + required: + - items + - $ref: "#/components/schemas/BaseResourceList" + CatalogMetricsArtifact: + description: A metadata Artifact Entity. + allOf: + - type: object + required: + - artifactType + - metricsType + properties: + artifactType: + type: string + default: metrics-artifact + metricsType: + type: string + enum: + - performance-metrics + - accuracy-metrics + customProperties: + description: User provided custom properties which are not defined by its type. + type: object + additionalProperties: + $ref: "#/components/schemas/MetadataValue" + - $ref: "#/components/schemas/BaseResource" + CatalogModel: + description: A model in the model catalog. + allOf: + - type: object + required: + - name + properties: + name: + type: string + description: Name of the model. Must be unique within a source. + example: ibm-granite/granite-3.1-8b-base + source_id: + type: string + description: ID of the source this model belongs to. + - $ref: "#/components/schemas/BaseModel" + - $ref: "#/components/schemas/BaseResource" + CatalogModelArtifact: + description: A Catalog Model Artifact Entity. + allOf: + - type: object + required: + - artifactType + - uri + properties: + artifactType: + type: string + default: model-artifact + uri: + type: string + format: uri + description: URI where the model can be retrieved. + customProperties: + description: User provided custom properties which are not defined by its type. + type: object + additionalProperties: + $ref: "#/components/schemas/MetadataValue" + - $ref: "#/components/schemas/BaseResource" + CatalogModelList: + description: List of CatalogModel entities. + allOf: + - type: object + properties: + items: + description: Array of `CatalogModel` entities. + type: array + items: + $ref: "#/components/schemas/CatalogModel" + required: + - items + - $ref: "#/components/schemas/BaseResourceList" + CatalogSource: + description: A catalog source. A catalog source has CatalogModel children. + required: + - id + - name + - labels + type: object + properties: + id: + description: A unique identifier for a `CatalogSource`. + type: string + name: + description: The name of the catalog source. + type: string + enabled: + description: Whether the catalog source is enabled. + type: boolean + default: true + labels: + description: Labels for the catalog source. + type: array + items: + type: string + status: + $ref: "#/components/schemas/CatalogSourceStatus" + description: Current operational status of the catalog source. + error: + description: |- + Detailed error information when the status is "Error". + This field is null or empty when the source is functioning normally. + type: string + nullable: true + includedModels: + description: |- + Optional list of glob patterns for models to include. If specified, only models matching + at least one pattern will be included. If omitted, all models are considered for inclusion. + + Pattern Syntax: + - Only the `*` wildcard is supported (matches zero or more characters) + - Patterns are case-insensitive (e.g., `Granite/*` matches `granite/model` and `GRANITE/model`) + - Patterns match the entire model name (anchored at start and end) + - Wildcards can appear anywhere: `Granite/*`, `*-beta`, `*deprecated*`, `*/old*` + + Examples: + - `ibm-granite/*` - matches all models starting with "ibm-granite/" + - `meta-llama/*` - matches all models in the meta-llama namespace + - `*` - matches all models + + Constraints: + - Patterns cannot be empty or whitespace-only + - A pattern cannot appear in both includedModels and excludedModels + type: array + items: + type: string + excludedModels: + description: |- + Optional list of glob patterns for models to exclude. Models matching any pattern + will be excluded even if they match an includedModels pattern. Exclusions take + precedence over inclusions. + + Pattern Syntax: + - Only the `*` wildcard is supported (matches zero or more characters) + - Patterns are case-insensitive + - Patterns match the entire model name (anchored at start and end) + - Wildcards can appear anywhere in the pattern + + Examples: + - `*-draft` - excludes all models ending with "-draft" + - `*-experimental` - excludes experimental models + - `*deprecated*` - excludes models with "deprecated" anywhere in the name + - `*/beta-*` - excludes models with "/beta-" in the path + + Constraints: + - Patterns cannot be empty or whitespace-only + - A pattern cannot appear in both includedModels and excludedModels + type: array + items: + type: string + CatalogSourceList: + description: List of CatalogSource entities. + allOf: + - type: object + properties: + items: + description: Array of `CatalogSource` entities. + type: array + items: + $ref: "#/components/schemas/CatalogSource" + required: + - items + - $ref: "#/components/schemas/BaseResourceList" + CatalogSourcePreviewResponse: + description: Response containing models and their inclusion/exclusion status. + allOf: + - type: object + properties: + items: + description: Array of model preview results. + type: array + items: + $ref: "#/components/schemas/ModelPreviewResult" + summary: + description: Summary of the preview results + type: object + properties: + totalModels: + type: integer + description: Total number of models evaluated + example: 1500 + includedModels: + type: integer + description: Number of models that would be included + example: 850 + excludedModels: + type: integer + description: Number of models that would be excluded + example: 650 + required: + - totalModels + - includedModels + - excludedModels + required: + - items + - summary + - $ref: "#/components/schemas/BaseResourceList" + CatalogSourceStatus: + description: |- + Operational status of a catalog source. + - `available`: The source is functioning correctly and models can be retrieved + - `error`: The source is experiencing issues and cannot provide models + - `disabled`: The source has been intentionally disabled + enum: + - available + - error + - disabled + type: string + Error: + description: Error code and message. + required: + - code + - message + type: object + properties: + code: + description: Error code + type: string + message: + description: Error message + type: string + FieldFilter: + type: object + required: + - operator + - value + properties: + operator: + type: string + description: Filter operator (e.g., '<', '=', '>', 'IN') + example: '<' + value: + description: Filter value (can be number, string, or array) + example: 70 + FilterOption: + type: object + required: + - type + properties: + type: + type: string + description: The data type of the filter option + enum: + - string + - number + values: + type: array + description: Known values of the property for string types with a small number of possible options. + items: {} + range: + $ref: "#/components/schemas/FilterOptionRange" + FilterOptionRange: + type: object + description: Min and max values for number types. + properties: + min: + type: number + format: double + max: + type: number + format: double + FilterOptionsList: + description: List of FilterOptions + type: object + properties: + filters: + type: object + description: A single filter option. + additionalProperties: + $ref: "#/components/schemas/FilterOption" + namedQueries: + type: object + description: Predefined named queries for common filtering scenarios + additionalProperties: + type: object + additionalProperties: + $ref: "#/components/schemas/FieldFilter" + MetadataBoolValue: + description: A bool property value. + type: object + required: + - metadataType + - bool_value + properties: + bool_value: + type: boolean + metadataType: + type: string + example: MetadataBoolValue + default: MetadataBoolValue + MetadataDoubleValue: + description: A double property value. + type: object + required: + - metadataType + - double_value + properties: + double_value: + format: double + type: number + metadataType: + type: string + example: MetadataDoubleValue + default: MetadataDoubleValue + MetadataIntValue: + description: An integer (int64) property value. + type: object + required: + - metadataType + - int_value + properties: + int_value: + format: int64 + type: string + metadataType: + type: string + example: MetadataIntValue + default: MetadataIntValue + MetadataProtoValue: + description: A proto property value. + type: object + required: + - metadataType + - type + - proto_value + properties: + type: + description: url describing proto value + type: string + proto_value: + description: Base64 encoded bytes for proto value + type: string + metadataType: + type: string + example: MetadataProtoValue + default: MetadataProtoValue + MetadataStringValue: + description: A string property value. + type: object + required: + - metadataType + - string_value + properties: + string_value: + type: string + metadataType: + type: string + example: MetadataStringValue + default: MetadataStringValue + MetadataStructValue: + description: A struct property value. + type: object + required: + - metadataType + - struct_value + properties: + struct_value: + description: Base64 encoded bytes for struct value + type: string + metadataType: + type: string + example: MetadataStructValue + default: MetadataStructValue + MetadataValue: + oneOf: + - $ref: "#/components/schemas/MetadataIntValue" + - $ref: "#/components/schemas/MetadataDoubleValue" + - $ref: "#/components/schemas/MetadataStringValue" + - $ref: "#/components/schemas/MetadataStructValue" + - $ref: "#/components/schemas/MetadataProtoValue" + - $ref: "#/components/schemas/MetadataBoolValue" + discriminator: + propertyName: metadataType + mapping: + MetadataBoolValue: "#/components/schemas/MetadataBoolValue" + MetadataDoubleValue: "#/components/schemas/MetadataDoubleValue" + MetadataIntValue: "#/components/schemas/MetadataIntValue" + MetadataProtoValue: "#/components/schemas/MetadataProtoValue" + MetadataStringValue: "#/components/schemas/MetadataStringValue" + MetadataStructValue: "#/components/schemas/MetadataStructValue" + description: A value in properties. + example: + string_value: my_value + metadataType: MetadataStringValue + ModelPreviewResult: + description: |- + A model with its inclusion/exclusion status based on the + configured catalog source filters. + type: object + required: + - name + - included + properties: + name: + type: string + description: Name of the model + example: microsoft/DialoGPT-medium + included: + type: boolean + description: Whether this model would be included based on the source configuration + OrderByField: + description: |- + Supported fields for ordering result entities. + enum: + - CREATE_TIME + - LAST_UPDATE_TIME + - ID + - NAME + type: string + SortOrder: + description: Supported sort direction for ordering result entities. + enum: + - ASC + - DESC + type: string + responses: + BadRequest: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Bad Request parameters + CatalogArtifactListResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogArtifactList" + description: A response containing a list of CatalogArtifact entities. + CatalogLabelListResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogLabelList" + description: A response containing a list of CatalogLabel entities. + CatalogModelListResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogModelList" + description: A response containing a list of CatalogModel entities. + CatalogModelResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogModel" + description: A response containing a `CatalogModel` entity. + CatalogSourceListResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogSourceList" + description: A response containing a list of CatalogSource entities. + CatalogSourcePreviewResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogSourcePreviewResponse" + examples: + allModels: + summary: All models with inclusion status + description: Response showing all models when filterStatus=all + value: + nextPageToken: "" + pageSize: 10 + size: 5 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + includedOnly: + summary: Only included models + description: Response when filterStatus=included + value: + nextPageToken: "" + pageSize: 10 + size: 3 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + excludedOnly: + summary: Only excluded models + description: Response when filterStatus=excluded + value: + nextPageToken: "" + pageSize: 10 + size: 2 + items: + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + withPagination: + summary: Paginated response + description: Response with pagination when pageSize is smaller than total + value: + nextPageToken: "eyJvZmZzZXQiOjEwfQ==" # notsecret + pageSize: 10 + size: 10 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + - name: "meta-llama/Llama-2-13b-hf" + included: true + - name: "mistralai/Mistral-7B-Instruct-v0.3" + included: true + - name: "mistralai/Mistral-7B-v0.1" + included: true + - name: "microsoft/phi-2" + included: true + - name: "google/gemma-7b" + included: true + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 150 + includedModels: 85 + excludedModels: 65 + description: |- + A response containing a list of models with their inclusion/exclusion + status based on the provided catalog source configuration. + CatalogSourceResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogSource" + description: A response containing a `CatalogSource` entity. + Conflict: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Conflict with current state of target resource + FilterOptionsResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/FilterOptionsList" + description: A response containing options for a `filterQuery` parameter. + InternalServerError: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Unexpected internal server error + NotFound: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: The specified resource was not found + ServiceUnavailable: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Service is unavailable + Unauthorized: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Unauthorized + UnprocessableEntity: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Unprocessable Entity error + parameters: + artifactFilterQuery: + examples: + artifactFilterQuery: + value: "name='my-artifact' AND uri LIKE '%s3%'" + name: filterQuery + description: | + A SQL-like query string to filter catalog artifacts. The query supports rich filtering capabilities with automatic type inference. + + **Supported Operators:** + - Comparison: `=`, `!=`, `<>`, `>`, `<`, `>=`, `<=` + - Pattern matching: `LIKE`, `ILIKE` (case-insensitive) + - Set membership: `IN` + - Logical: `AND`, `OR` + - Grouping: `()` for complex expressions + + **Data Types:** + - Strings: `"value"` or `'value'` + - Numbers: `42`, `3.14`, `1e-5` + - Booleans: `true`, `false` (case-insensitive) + + **Property Access (Artifacts):** + - Standard properties: `name`, `id`, `uri`, `artifactType`, `createTimeSinceEpoch` + - Custom properties: Any user-defined property name in `customProperties` + - Escaped properties: Use backticks for special characters: `` `custom-property` `` + - Type-specific access: `property.string_value`, `property.double_value`, `property.int_value`, `property.bool_value` + + **Examples:** + - Basic: `name = "my-artifact"` + - Comparison: `ttft_mean > 90` + - Pattern: `uri LIKE "%s3.amazonaws.com%"` + - Complex: `(artifactType = "model-artifact" OR artifactType = "metrics-artifact") AND name LIKE "%pytorch%"` + - Custom property: `format.string_value = "pytorch"` + - Escaped property: `` `custom-key` = "value" `` + schema: + type: string + in: query + required: false + artifactOrderBy: + style: form + explode: true + examples: + standardField: + value: ID + summary: Order by standard field + customPropertyDouble: + value: mmlu.double_value + summary: Order by custom double property + customPropertyString: + value: framework_type.string_value + summary: Order by custom string property + customPropertyInt: + value: hardware_count.int_value + summary: Order by custom integer property + name: orderBy + description: | + Specifies the order by criteria for listing artifacts. + + **Standard Fields:** + - `ID` - Order by artifact ID + - `NAME` - Order by artifact name + - `CREATE_TIME` - Order by creation timestamp + - `LAST_UPDATE_TIME` - Order by last update timestamp + + **Custom Property Ordering:** + + Artifacts can be ordered by custom properties using the format: `.` + + Supported value types: + - `double_value` - For numeric (floating-point) properties + - `int_value` - For integer properties + - `string_value` - For string properties + + Examples: + - `mmlu.double_value` - Order by the 'mmlu' benchmark score + - `accuracy.double_value` - Order by accuracy metric + - `framework_type.string_value` - Order by framework type + - `hardware_count.int_value` - Order by hardware count + - `ttft_mean.double_value` - Order by time-to-first-token mean + + **Behavior:** + - If an invalid value type is specified (e.g., `accuracy.invalid_type`), an error is returned + - If an invalid format is used (e.g., `accuracy` without `.value_type`), it falls back to ID ordering + - If a property doesn't exist, it falls back to ID ordering + - Artifacts with the specified property are ordered first (by the property value), followed by artifacts without the property (ordered by ID) + - Empty property names (e.g., `.double_value`) return an error + schema: + type: string + in: query + required: false + artifactType: + style: form + explode: true + examples: + artifactType: + value: model-artifact + name: artifactType + description: "Specifies the artifact type for listing artifacts." + schema: + type: array + items: + $ref: "#/components/schemas/ArtifactTypeQueryParam" + in: query + required: false + artifact_type: + deprecated: true + style: form + explode: true + examples: + artifact_type: + value: model-artifact + name: artifact_type + description: "Specifies the artifact type for listing artifacts." + schema: + type: array + items: + $ref: "#/components/schemas/ArtifactTypeQueryParam" + in: query + required: false + externalId: + examples: + externalId: + value: "10" + name: externalId + description: External ID of entity to search. + schema: + type: string + in: query + required: false + filterQuery: + examples: + filterQuery: + value: "name='my-model' AND state='LIVE'" + name: filterQuery + description: | + A SQL-like query string to filter the list of entities. The query supports rich filtering capabilities with automatic type inference. + + **Supported Operators:** + - Comparison: `=`, `!=`, `<>`, `>`, `<`, `>=`, `<=` + - Pattern matching: `LIKE`, `ILIKE` (case-insensitive) + - Set membership: `IN` + - Logical: `AND`, `OR` + - Grouping: `()` for complex expressions + + **Data Types:** + - Strings: `"value"` or `'value'` + - Numbers: `42`, `3.14`, `1e-5` + - Booleans: `true`, `false` (case-insensitive) + + **Property Access:** + - Standard properties: `name`, `id`, `state`, `createTimeSinceEpoch` + - Custom properties: Any user-defined property name + - Escaped properties: Use backticks for special characters: `` `custom-property` `` + - Type-specific access: `property.string_value`, `property.double_value`, `property.int_value`, `property.bool_value` + + **Examples:** + - Basic: `name = "my-model"` + - Comparison: `accuracy > 0.95` + - Pattern: `name LIKE "%tensorflow%"` + - Complex: `(name = "model-a" OR name = "model-b") AND state = "LIVE"` + - Custom property: `framework.string_value = "pytorch"` + - Escaped property: `` `mlflow.source.type` = "notebook" `` + schema: + type: string + in: query + required: false + id: + name: id + description: The ID of resource. + schema: + type: string + in: path + required: true + labelOrderBy: + style: form + explode: true + examples: + labelOrderBy: + value: name + name: orderBy + description: | + Specifies the key to order catalog labels by. You can provide any string key + that may exist in the label maps. Labels that contain the specified key will + be sorted by that key's value. Labels that don't contain the key will maintain + their original order and appear after labels that do contain the key. + schema: + type: string + in: query + required: false + name: + examples: + name: + value: entity-name + name: name + description: Name of entity to search. + schema: + type: string + in: query + required: false + nextPageToken: + name: nextPageToken + description: Token to use to retrieve next page of results. + schema: + type: string + in: query + required: false + orderBy: + style: form + explode: true + examples: + orderBy: + value: ID + name: orderBy + description: Specifies the order by criteria for listing entities. + schema: + $ref: "#/components/schemas/OrderByField" + in: query + required: false + pageSize: + examples: + pageSize: + value: "100" + name: pageSize + description: Number of entities in each page. + schema: + type: string + in: query + required: false + parentResourceId: + examples: + parentResourceId: + value: "10" + name: parentResourceId + description: ID of the parent resource to use for search. + schema: + type: string + in: query + required: false + sortOrder: + style: form + explode: true + examples: + sortOrder: + value: DESC + name: sortOrder + description: "Specifies the sort order for listing entities, defaults to ASC." + schema: + $ref: "#/components/schemas/SortOrder" + in: query + required: false + stepIds: + style: form + explode: true + examples: + stepIds: + value: "1,2,3" + name: stepIds + description: "Comma-separated list of step IDs to filter metrics by." + schema: + type: string + in: query + required: false + securitySchemes: + Bearer: + scheme: bearer + bearerFormat: JWT + type: http + description: Bearer JWT scheme +security: + - Bearer: [] +tags: [] +x-catalog-plugins: null diff --git a/catalog/plugins/model/plugin.go b/catalog/plugins/model/plugin.go new file mode 100644 index 0000000000..8ce0300fdd --- /dev/null +++ b/catalog/plugins/model/plugin.go @@ -0,0 +1,281 @@ +// Package model provides the model catalog plugin for the unified catalog server. +// This plugin wraps the existing catalog internals and exposes them via the plugin interface. +package model + +import ( + "context" + "fmt" + "log/slog" + "path/filepath" + "reflect" + "sync/atomic" + + "github.com/go-chi/chi/v5" + "gorm.io/gorm" + + "github.com/kubeflow/model-registry/catalog/internal/catalog" + "github.com/kubeflow/model-registry/catalog/internal/db/models" + "github.com/kubeflow/model-registry/catalog/internal/db/service" + "github.com/kubeflow/model-registry/catalog/internal/server/openapi" + "github.com/kubeflow/model-registry/internal/datastore" + "github.com/kubeflow/model-registry/internal/datastore/embedmd" + "github.com/kubeflow/model-registry/pkg/catalog/plugin" +) + +const ( + // PluginName is the identifier for this plugin. + PluginName = "model" + + // PluginVersion is the API version. + PluginVersion = "v1alpha1" +) + +// ModelCatalogPlugin implements the CatalogPlugin interface for model catalogs. +type ModelCatalogPlugin struct { + cfg plugin.Config + logger *slog.Logger + loader *catalog.Loader + dbCatalog catalog.APIProvider + services service.Services + sources *catalog.SourceCollection + labels *catalog.LabelCollection + healthy atomic.Bool + started atomic.Bool +} + +// Name returns the plugin name. +func (p *ModelCatalogPlugin) Name() string { + return PluginName +} + +// SourceKey returns the config key used in sources.yaml catalogs map. +func (p *ModelCatalogPlugin) SourceKey() string { + return "models" +} + +// Version returns the plugin API version. +func (p *ModelCatalogPlugin) Version() string { + return PluginVersion +} + +// Description returns a human-readable description. +func (p *ModelCatalogPlugin) Description() string { + return "Model catalog for ML models" +} + +// BasePath returns the API base path for this plugin. +func (p *ModelCatalogPlugin) BasePath() string { + return "/api/model_catalog/v1alpha1" +} + +// Init initializes the plugin with configuration. +func (p *ModelCatalogPlugin) Init(ctx context.Context, cfg plugin.Config) error { + p.cfg = cfg + p.logger = cfg.Logger + if p.logger == nil { + p.logger = slog.Default() + } + + p.logger.Info("initializing model catalog plugin") + + // Build paths from config sources + // The paths are the config file origins that contain the sources + paths := make([]string, 0) + originPaths := make(map[string]bool) + + for _, src := range cfg.Section.Sources { + if src.Origin != "" { + originPath := src.Origin + if !originPaths[originPath] { + paths = append(paths, originPath) + originPaths[originPath] = true + } + } + } + + // If no origins found in sources, use ConfigPaths + if len(paths) == 0 { + paths = cfg.ConfigPaths + } + + // Convert paths to absolute + absPaths := make([]string, 0, len(paths)) + for _, path := range paths { + absPath, err := filepath.Abs(path) + if err != nil { + absPath = path + } + absPaths = append(absPaths, absPath) + } + + // Initialize the services from the database connection + services, err := p.initServices(cfg.DB) + if err != nil { + return fmt.Errorf("failed to initialize services: %w", err) + } + p.services = services + + // Create the loader with existing catalog code + p.loader = catalog.NewLoader(services, absPaths) + p.sources = p.loader.Sources + p.labels = p.loader.Labels + + // Create the DB catalog provider + p.dbCatalog = catalog.NewDBCatalog(services, p.sources) + + p.logger.Info("model catalog plugin initialized", "paths", absPaths) + + return nil +} + +// initServices creates the service layer from the database connection. +func (p *ModelCatalogPlugin) initServices(db *gorm.DB) (service.Services, error) { + // Get the datastore spec for the catalog + spec := service.DatastoreSpec() + + // We need to create the RepoSet from the existing database + // This requires the types to already be registered in the database + repoSet, err := p.createRepoSet(db, spec) + if err != nil { + return service.Services{}, fmt.Errorf("failed to create repo set: %w", err) + } + + // Extract repositories from the RepoSet + catalogModelRepo, err := getRepository[models.CatalogModelRepository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get catalog model repository: %w", err) + } + + catalogArtifactRepo, err := getRepository[models.CatalogArtifactRepository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get catalog artifact repository: %w", err) + } + + catalogModelArtifactRepo, err := getRepository[models.CatalogModelArtifactRepository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get catalog model artifact repository: %w", err) + } + + catalogMetricsArtifactRepo, err := getRepository[models.CatalogMetricsArtifactRepository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get catalog metrics artifact repository: %w", err) + } + + catalogSourceRepo, err := getRepository[models.CatalogSourceRepository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get catalog source repository: %w", err) + } + + propertyOptionsRepo, err := getRepository[models.PropertyOptionsRepository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get property options repository: %w", err) + } + + return service.NewServices( + catalogModelRepo, + catalogArtifactRepo, + catalogModelArtifactRepo, + catalogMetricsArtifactRepo, + catalogSourceRepo, + propertyOptionsRepo, + ), nil +} + +// createRepoSet creates a RepoSet from the database using the spec. +// This uses the embedmd connector logic to initialize repositories. +func (p *ModelCatalogPlugin) createRepoSet(db *gorm.DB, spec *datastore.Spec) (datastore.RepoSet, error) { + // Create a connector that uses the existing database + connector, err := datastore.NewConnector("embedmd", &embedmd.EmbedMDConfig{DB: db, SkipMigrations: true}) + if err != nil { + return nil, fmt.Errorf("failed to create connector: %w", err) + } + + return connector.Connect(spec) +} + +// Start begins background operations (hot-reload, watchers). +func (p *ModelCatalogPlugin) Start(ctx context.Context) error { + p.logger.Info("starting model catalog plugin") + + if err := p.loader.Start(ctx); err != nil { + return fmt.Errorf("failed to start loader: %w", err) + } + + p.started.Store(true) + p.healthy.Store(true) + + p.logger.Info("model catalog plugin started") + return nil +} + +// Stop gracefully shuts down the plugin. +func (p *ModelCatalogPlugin) Stop(ctx context.Context) error { + p.logger.Info("stopping model catalog plugin") + p.started.Store(false) + p.healthy.Store(false) + return nil +} + +// Healthy returns true if the plugin is functioning correctly. +func (p *ModelCatalogPlugin) Healthy() bool { + return p.healthy.Load() +} + +// RegisterRoutes mounts the plugin's HTTP routes on the provided router. +func (p *ModelCatalogPlugin) RegisterRoutes(router chi.Router) error { + p.logger.Info("registering model catalog routes") + + // Create the OpenAPI service using existing handlers + apiService := openapi.NewModelCatalogServiceAPIService( + p.dbCatalog, + p.sources, + p.labels, + p.services.CatalogSourceRepository, + ) + + // Create the controller + apiController := openapi.NewModelCatalogServiceAPIController(apiService) + + // Mount routes - remove the base path prefix since chi.Router already handles that + for _, route := range apiController.OrderedRoutes() { + // Remove the /api/model_catalog/v1alpha1 prefix from the pattern + pattern := route.Pattern + basePath := "/api/model_catalog/v1alpha1" + if len(pattern) > len(basePath) && pattern[:len(basePath)] == basePath { + pattern = pattern[len(basePath):] + } + if pattern == "" { + pattern = "/" + } + + router.Method(route.Method, pattern, route.HandlerFunc) + p.logger.Debug("registered route", "method", route.Method, "pattern", pattern) + } + + return nil +} + +// Migrations returns database migrations for this plugin. +func (p *ModelCatalogPlugin) Migrations() []plugin.Migration { + // The model catalog uses the existing database schema from embedmd + // No additional migrations are needed as the schema is managed by the datastore layer + return nil +} + +// getRepository extracts a repository of type T from the RepoSet. +func getRepository[T any](rs datastore.RepoSet) (T, error) { + var zero T + t := reflect.TypeFor[T]() + + repo, err := rs.Repository(t) + if err != nil { + return zero, err + } + + result, ok := repo.(T) + if !ok { + return zero, fmt.Errorf("repository type mismatch: expected %T, got %T", zero, repo) + } + + return result, nil +} diff --git a/catalog/plugins/model/register.go b/catalog/plugins/model/register.go new file mode 100644 index 0000000000..1a0c16b34e --- /dev/null +++ b/catalog/plugins/model/register.go @@ -0,0 +1,7 @@ +package model + +import "github.com/kubeflow/model-registry/pkg/catalog/plugin" + +func init() { + plugin.Register(&ModelCatalogPlugin{}) +} diff --git a/cmd/catalog-gen/add_provider.go b/cmd/catalog-gen/add_provider.go new file mode 100644 index 0000000000..2a03301f7f --- /dev/null +++ b/cmd/catalog-gen/add_provider.go @@ -0,0 +1,149 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/spf13/cobra" +) + +func newAddProviderCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "add-provider ", + Short: "Add a new provider type to the catalog", + Long: `Add a new provider type to the catalog configuration. + +Supported provider types: + - yaml: File-based provider using YAML catalog files + - http: HTTP-based provider for remote APIs + +Example: + catalog-gen add-provider yaml + catalog-gen add-provider http`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + providerType := args[0] + + switch providerType { + case "yaml", "http": + // OK + default: + return fmt.Errorf("unsupported provider type: %s (supported: yaml, http)", providerType) + } + + return addProvider(providerType) + }, + } + + return cmd +} + +func addProvider(providerType string) error { + config, err := loadConfig() + if err != nil { + return err + } + + // Check if provider already exists + for _, p := range config.Spec.Providers { + if p.Type == providerType { + fmt.Printf("Provider '%s' already exists in catalog.yaml\n", providerType) + return nil + } + } + + // Add provider + config.Spec.Providers = append(config.Spec.Providers, ProviderConfig{Type: providerType}) + + // Save config + if err := saveConfig(config); err != nil { + return err + } + + // Generate provider file + if err := generateProviderFile(config.Spec.Entity.Name, providerType); err != nil { + return err + } + + fmt.Printf("Added provider '%s' to catalog.yaml\n", providerType) + fmt.Printf("Generated provider file at internal/catalog/providers/%s_provider.go\n", providerType) + + return nil +} + +func newAddArtifactCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "add-artifact ", + Short: "Add a new artifact type to the catalog", + Long: `Add a new artifact type to the catalog configuration. + +Example: + catalog-gen add-artifact Tool + catalog-gen add-artifact Resource`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + artifactName := args[0] + return addArtifact(artifactName) + }, + } + + return cmd +} + +func addArtifact(artifactName string) error { + config, err := loadConfig() + if err != nil { + return err + } + + // Check if artifact already exists + for _, a := range config.Spec.Artifacts { + if a.Name == artifactName { + fmt.Printf("Artifact '%s' already exists in catalog.yaml\n", artifactName) + return nil + } + } + + // Add artifact + config.Spec.Artifacts = append(config.Spec.Artifacts, ArtifactConfig{ + Name: artifactName, + Properties: []PropertyConfig{ + {Name: "uri", Type: "string"}, + }, + }) + + // Save config + if err := saveConfig(config); err != nil { + return err + } + + // Generate artifact model stub file + if err := generateArtifactModelStub(artifactName); err != nil { + return err + } + + fmt.Printf("Added artifact '%s' to catalog.yaml\n", artifactName) + fmt.Printf("Generated artifact model at internal/db/models/%s.go\n", strings.ToLower(artifactName)) + fmt.Println("\nRun 'catalog-gen generate' to generate the full artifact implementation.") + + return nil +} + +func generateArtifactModelStub(artifactName string) error { + lowerName := strings.ToLower(artifactName) + data := map[string]any{ + "EntityName": "", + "ArtifactName": artifactName, + "LowerEntityName": "", + "LowerArtifactName": lowerName, + "Properties": "\tURI *string\n", + } + + modelsDir := filepath.Join("internal", "db", "models") + if err := ensureDir(modelsDir); err != nil { + return fmt.Errorf("failed to create models directory: %w", err) + } + + return executeTemplate(TmplModelsArtifact, filepath.Join(modelsDir, fmt.Sprintf("%s.go", lowerName)), data) +} diff --git a/cmd/catalog-gen/docs/skills-and-commands.md b/cmd/catalog-gen/docs/skills-and-commands.md new file mode 100644 index 0000000000..95ff012161 --- /dev/null +++ b/cmd/catalog-gen/docs/skills-and-commands.md @@ -0,0 +1,145 @@ +# catalog-gen Skills & Commands Reference + +This document describes the AI agent skills and slash commands generated by `catalog-gen generate` for each catalog project. + +## Overview + +When you run `catalog-gen generate`, it creates two directories in the generated catalog: + +- `.claude/commands/` - Slash commands that can be invoked directly (e.g., `/run-local`) +- `.claude/skills/` - Detailed skill files with step-by-step instructions + +## Generated Commands + +### Development Workflow + +| Command | File | Description | +|---------|------|-------------| +| `/add-property` | `add-property.md` | Add a new property to the entity. Guides through updating catalog.yaml, regenerating, and updating converters. | +| `/add-artifact` | `add-artifact.md` | Add a new artifact type. Guides through updating catalog.yaml, regenerating, and running /implement-artifacts. | +| `/add-provider` | `add-provider.md` | Add a new data provider (http, grpc, s3, etc.). Creates provider file and registers it. | +| `/regenerate` | `regenerate.md` | Regenerate code from catalog.yaml and fix any resulting build issues. | +| `/implement-artifacts` | `implement-artifacts.md` | Implement missing artifact support after adding artifacts to catalog.yaml. | +| `/fix-build` | `fix-build.md` | Diagnose and fix compilation errors. | + +### Testing & Debugging + +| Command | File | Description | +|---------|------|-------------| +| `/run-local` | `run-local.md` | Start the catalog server locally for testing. Includes port, base URL, and quick test curl. | +| `/test-api` | `test-api.md` | Generate curl commands to test all API endpoints including pagination and artifacts. | +| `/seed-data` | `seed-data.md` | Populate sample-catalog.yaml and artifacts.yaml with test data. | + +## Generated Skills + +Skills provide detailed, step-by-step instructions with code snippets customized for the specific catalog. + +| Skill File | Purpose | +|------------|---------| +| `README.md` | Master reference listing all commands, workflows, and common scenarios | +| `add-property.md` | Step-by-step guide for adding entity properties with code examples | +| `add-artifact.md` | Step-by-step guide for adding artifact types | +| `add-provider.md` | Step-by-step guide for adding data providers with template code | +| `implement-artifacts.md` | Exact code snippets for implementing artifact support (NewServices, API methods, conversions) | +| `test-api.md` | Curl commands for testing all endpoints (list, get, pagination, artifacts) | +| `seed-data.md` | Sample data templates for entities and artifacts | + +## Customization + +All generated files are customized with: + +- **Entity name** (e.g., `McpServer`) +- **Artifact names** (e.g., `Tool`) +- **Catalog name** (e.g., `mcp-catalog`) +- **API port** (e.g., `8081`) +- **Property names and types** +- **File paths** specific to the catalog + +## Common Workflows + +### Adding a New Property + +``` +/add-property version string +``` + +This guides you through: +1. Updating `catalog.yaml` +2. Running `catalog-gen generate` +3. Updating repository converter +4. Updating OpenAPI service implementation +5. Updating YAML provider + +### Adding a New Artifact + +``` +/add-artifact License +``` + +This guides you through: +1. Updating `catalog.yaml` with artifact definition +2. Running `catalog-gen generate` +3. Running `/implement-artifacts` to complete integration + +### Testing the API + +``` +/seed-data +/run-local +/test-api +``` + +### Fixing Build Errors + +``` +/fix-build +``` + +Or for artifact-specific errors: + +``` +/implement-artifacts +``` + +## File Structure + +After running `catalog-gen generate`, the catalog will have: + +``` +/ +├── .claude/ +│ ├── commands/ +│ │ ├── add-property.md +│ │ ├── add-artifact.md +│ │ ├── add-provider.md +│ │ ├── regenerate.md +│ │ ├── implement-artifacts.md +│ │ ├── fix-build.md +│ │ ├── run-local.md +│ │ ├── test-api.md +│ │ └── seed-data.md +│ └── skills/ +│ ├── README.md +│ ├── add-property.md +│ ├── add-artifact.md +│ ├── add-provider.md +│ ├── implement-artifacts.md +│ ├── test-api.md +│ └── seed-data.md +├── CLAUDE.md # Main context file for AI agents +└── docs/ + └── post-artifact-generation.md +``` + +## Integration with Claude Code + +These files are automatically recognized by Claude Code: + +- **CLAUDE.md** - Loaded as project context +- **`.claude/commands/*.md`** - Available as slash commands +- **`.claude/skills/*.md`** - Referenced by commands for detailed instructions + +When an AI agent is working in the catalog directory, it can: +1. Read CLAUDE.md for overall project context +2. Use slash commands for common tasks +3. Follow skill files for detailed implementation steps diff --git a/cmd/catalog-gen/gen_api.go b/cmd/catalog-gen/gen_api.go new file mode 100644 index 0000000000..30e81da437 --- /dev/null +++ b/cmd/catalog-gen/gen_api.go @@ -0,0 +1,181 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" +) + +// generateOpenAPIComponents generates the OpenAPI components file. +// The generated components use allOf composition with BaseResource from common.yaml. +func generateOpenAPIComponents(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + + // Detect if we're in a plugin context to determine reference type + isPlugin := isPluginContext() + + // Base properties that come from BaseResource - skip these in entity definition + // These are defined in api/openapi/src/lib/common.yaml + baseResourceProperties := map[string]bool{ + "name": true, + "id": true, + "externalid": true, + "description": true, + "customproperties": true, + "createtimesinceepoch": true, + "lastupdatetimesinceepoch": true, + } + + var propDefs strings.Builder + var requiredFields strings.Builder + for _, prop := range config.Spec.Entity.Properties { + if baseResourceProperties[strings.ToLower(prop.Name)] { + continue + } + // Use 12 spaces for properties inside allOf structure + propDefs.WriteString(generateOpenAPIPropertyDef(prop, 12)) + if prop.Required { + fmt.Fprintf(&requiredFields, " - %s\n", prop.Name) + } + } + + // Build artifact schemas if artifacts are configured + // Artifacts also use allOf composition with BaseResource + var artifactSchemas strings.Builder + if len(config.Spec.Artifacts) > 0 { + // Generate individual artifact schemas using allOf composition + for _, artifact := range config.Spec.Artifacts { + fmt.Fprintf(&artifactSchemas, " %s%sArtifact:\n", entityName, artifact.Name) + artifactSchemas.WriteString(" allOf:\n") + if isPlugin { + artifactSchemas.WriteString(" - $ref: 'lib/common.yaml#/components/schemas/BaseResource'\n") + } else { + artifactSchemas.WriteString(" - $ref: '#/components/schemas/BaseResource'\n") + } + artifactSchemas.WriteString(" - type: object\n") + artifactSchemas.WriteString(" properties:\n") + artifactSchemas.WriteString(" artifactType:\n") + artifactSchemas.WriteString(" type: string\n") + for _, prop := range artifact.Properties { + artifactSchemas.WriteString(generateOpenAPIPropertyDef(prop, 12)) + } + } + + // Generate artifact list schema using allOf composition + fmt.Fprintf(&artifactSchemas, " %sArtifactList:\n", entityName) + artifactSchemas.WriteString(" allOf:\n") + if isPlugin { + artifactSchemas.WriteString(" - $ref: 'lib/common.yaml#/components/schemas/BaseResourceList'\n") + } else { + artifactSchemas.WriteString(" - $ref: '#/components/schemas/BaseResourceList'\n") + } + artifactSchemas.WriteString(" - type: object\n") + artifactSchemas.WriteString(" properties:\n") + artifactSchemas.WriteString(" items:\n") + artifactSchemas.WriteString(" type: array\n") + artifactSchemas.WriteString(" items:\n") + if len(config.Spec.Artifacts) == 1 { + fmt.Fprintf(&artifactSchemas, " $ref: '#/components/schemas/%s%sArtifact'\n", entityName, config.Spec.Artifacts[0].Name) + } else { + artifactSchemas.WriteString(" oneOf:\n") + for _, artifact := range config.Spec.Artifacts { + fmt.Fprintf(&artifactSchemas, " - $ref: '#/components/schemas/%s%sArtifact'\n", entityName, artifact.Name) + } + } + } + + data := map[string]any{ + "EntityName": entityName, + "Properties": strings.TrimSpace(propDefs.String()), + "RequiredFields": requiredFields.String(), + "ArtifactSchemas": artifactSchemas.String(), + "IsPlugin": isPlugin, + } + + generatedDir := filepath.Join("api", "openapi", "src", "generated") + if err := ensureDir(generatedDir); err != nil { + return err + } + + fmt.Printf(" Generated: api/openapi/src/generated/components.yaml\n") + return executeTemplate(TmplAPIOpenAPIComponents, filepath.Join(generatedDir, "components.yaml"), data) +} + +// generateOpenAPIMain generates the OpenAPI main spec file. +func generateOpenAPIMain(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerName := strings.ToLower(entityName) + + // Build artifact routes if artifacts are configured + artifactRoutes := "" + if len(config.Spec.Artifacts) > 0 { + artifactRoutes = ` + /` + lowerName + `s/{name}/artifacts: + get: + summary: List artifacts for a ` + entityName + ` + operationId: get` + entityName + `Artifacts + parameters: + - name: name + in: path + required: true + schema: + type: string + - name: pageSize + in: query + schema: + type: integer + default: 20 + - name: pageToken + in: query + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/` + entityName + `ArtifactList' + '404': + description: Not found +` + } + + data := map[string]any{ + "Name": config.Metadata.Name, + "EntityName": entityName, + "EntityNameLower": lowerName, + "BasePath": config.Spec.API.BasePath, + "ArtifactRoutes": artifactRoutes, + } + + srcDir := filepath.Join("api", "openapi", "src") + if err := ensureDir(srcDir); err != nil { + return err + } + + return executeTemplate(TmplAPIOpenAPIMain, filepath.Join(srcDir, "openapi.yaml"), data) +} + +// generateOpenAPIServiceImpl generates the OpenAPI service implementation stub. +func generateOpenAPIServiceImpl(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerName := strings.ToLower(entityName) + + // Build dynamic property conversion code + propConversions := buildOpenAPIPropertyConversions(config.Spec.Entity.Properties) + + data := map[string]any{ + "EntityName": entityName, + "EntityNameLower": lowerName, + "Package": config.Spec.Package, + "PropConversions": propConversions, + } + + openapiDir := filepath.Join("internal", "server", "openapi") + if err := ensureDir(openapiDir); err != nil { + return err + } + + return executeTemplate(TmplServerOpenAPIServiceImpl, filepath.Join(openapiDir, fmt.Sprintf("api_%s_service_impl.go", lowerName)), data) +} diff --git a/cmd/catalog-gen/gen_filter.go b/cmd/catalog-gen/gen_filter.go new file mode 100644 index 0000000000..2858294bbf --- /dev/null +++ b/cmd/catalog-gen/gen_filter.go @@ -0,0 +1,72 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" +) + +// filterValueType maps a catalog.yaml property type to a filter value type constant. +func filterValueType(specType string) string { + switch specType { + case "string": + return "string_value" + case "integer", "int": + return "int_value" + case "int64": + return "int_value" + case "boolean", "bool": + return "int_value" // bools stored as int (0/1) + case "number", "float", "double": + return "double_value" + default: + return "string_value" + } +} + +// generateFilterMappings generates the filter entity mappings file. +func generateFilterMappings(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + + // Built-in fields that are already handled in the template + builtinFields := map[string]bool{ + "name": true, "externalid": true, "createtimesinceepoch": true, + "lastupdatetimesinceepoch": true, "id": true, + } + + // Build property registrations for init() function + var propRegistrations strings.Builder + for _, prop := range config.Spec.Entity.Properties { + if builtinFields[strings.ToLower(prop.Name)] { + continue + } + propRegistrations.WriteString(fmt.Sprintf("\t\t\"%s\": true,\n", prop.Name)) + } + + // Build property definitions for GetPropertyDefinitionForRestEntity switch cases + var propDefinitions strings.Builder + for _, prop := range config.Spec.Entity.Properties { + if builtinFields[strings.ToLower(prop.Name)] { + continue + } + valueType := filterValueType(prop.Type) + propDefinitions.WriteString(fmt.Sprintf("\tcase \"%s\":\n", prop.Name)) + propDefinitions.WriteString(fmt.Sprintf("\t\treturn filter.PropertyDefinition{Location: filter.PropertyTable, ValueType: \"%s\", Column: \"%s\"}\n", valueType, prop.Name)) + } + + data := map[string]any{ + "EntityName": entityName, + "Package": config.Spec.Package, + "PropertyRegistrations": propRegistrations.String(), + "PropertyDefinitions": propDefinitions.String(), + } + + serviceDir := filepath.Join("internal", "db", "service") + if err := ensureDir(serviceDir); err != nil { + return err + } + + outputPath := filepath.Join(serviceDir, "filter_mappings.go") + fmt.Printf(" Generated: internal/db/service/filter_mappings.go\n") + return executeTemplate(TmplServiceFilterMappings, outputPath, data) +} diff --git a/cmd/catalog-gen/gen_manifests.go b/cmd/catalog-gen/gen_manifests.go new file mode 100644 index 0000000000..f91cd7d655 --- /dev/null +++ b/cmd/catalog-gen/gen_manifests.go @@ -0,0 +1,11 @@ +package main + +// generateGitignore generates the .gitignore file. +func generateGitignore() error { + return executeTemplate(TmplMiscGitignore, ".gitignore", nil) +} + +// generateOpenAPIGeneratorIgnore generates the .openapi-generator-ignore file. +func generateOpenAPIGeneratorIgnore() error { + return executeTemplate(TmplMiscOpenAPIGeneratorIgnore, ".openapi-generator-ignore", nil) +} diff --git a/cmd/catalog-gen/gen_models.go b/cmd/catalog-gen/gen_models.go new file mode 100644 index 0000000000..e9d41cca12 --- /dev/null +++ b/cmd/catalog-gen/gen_models.go @@ -0,0 +1,74 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" +) + +// generateEntityModel generates the entity model file. +func generateEntityModel(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerName := strings.ToLower(entityName) + + // Build properties (skip built-in fields) + builtinFields := map[string]bool{ + "name": true, "externalid": true, "createtimesinceepoch": true, + "lastupdatetimesinceepoch": true, "id": true, + } + var propDefs strings.Builder + for _, prop := range config.Spec.Entity.Properties { + if builtinFields[strings.ToLower(prop.Name)] { + continue + } + goType := goTypeFromSpec(prop.Type) + propDefs.WriteString(fmt.Sprintf("\t%s\t%s\n", capitalize(prop.Name), goType)) + } + + data := map[string]any{ + "EntityName": entityName, + "EntityNameLower": lowerName, + "Properties": propDefs.String(), + } + + modelsDir := filepath.Join("internal", "db", "models") + if err := ensureDir(modelsDir); err != nil { + return err + } + + outputPath := filepath.Join(modelsDir, fmt.Sprintf("%s.go", lowerName)) + fmt.Printf(" Generated: internal/db/models/%s.go\n", lowerName) + return executeTemplate(TmplModelsEntity, outputPath, data) +} + +// generateArtifactModel generates an artifact model file. +func generateArtifactModel(config CatalogConfig, artifact ArtifactConfig) error { + entityName := config.Spec.Entity.Name + artifactName := artifact.Name + lowerEntityName := strings.ToLower(entityName) + lowerArtifactName := strings.ToLower(artifactName) + + // Build properties + var propDefs strings.Builder + for _, prop := range artifact.Properties { + goType := goTypeFromSpec(prop.Type) + propDefs.WriteString(fmt.Sprintf("\t%s\t%s\n", capitalize(prop.Name), goType)) + } + + data := map[string]any{ + "EntityName": entityName, + "ArtifactName": artifactName, + "LowerEntityName": lowerEntityName, + "LowerArtifactName": lowerArtifactName, + "Properties": propDefs.String(), + } + + modelsDir := filepath.Join("internal", "db", "models") + if err := ensureDir(modelsDir); err != nil { + return err + } + + filename := fmt.Sprintf("%s_%s_artifact.go", lowerEntityName, lowerArtifactName) + fmt.Printf(" Generated: internal/db/models/%s\n", filename) + return executeTemplate(TmplModelsArtifact, filepath.Join(modelsDir, filename), data) +} diff --git a/cmd/catalog-gen/gen_plugin.go b/cmd/catalog-gen/gen_plugin.go new file mode 100644 index 0000000000..90669d1199 --- /dev/null +++ b/cmd/catalog-gen/gen_plugin.go @@ -0,0 +1,531 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +// sanitizeCatalogName extracts a clean slug from a catalog name. +// For example, "catalog/plugins/mcp" becomes "mcp", and "test-widgets" stays "test-widgets". +// Slashes and dots are stripped, keeping only the last segment. +func sanitizeCatalogName(name string) string { + // Take the last path segment if name contains slashes + name = filepath.Base(name) + // Replace any remaining non-alphanumeric chars (except hyphens/underscores) with underscores + name = strings.ReplaceAll(name, ".", "_") + return name +} + +// initCatalogPlugin initializes a new catalog plugin for the unified catalog server. +func initCatalogPlugin(name, entityName, packageName, outputDir string) error { + fmt.Printf("Initializing catalog: %s\n", name) + fmt.Printf(" Entity: %s\n", entityName) + fmt.Printf(" Package: %s\n", packageName) + fmt.Printf(" Output: %s\n", outputDir) + + // Create plugin-specific directory structure (no cmd/, no manifests/) + dirs := []string{ + outputDir, + filepath.Join(outputDir, "internal", "catalog", "providers"), + filepath.Join(outputDir, "internal", "db", "models"), + filepath.Join(outputDir, "internal", "db", "service"), + filepath.Join(outputDir, "internal", "server", "openapi"), + filepath.Join(outputDir, "pkg", "openapi"), + filepath.Join(outputDir, "api", "generated"), + filepath.Join(outputDir, "docs"), + } + + for _, dir := range dirs { + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + } + + // Create catalog.yaml config + config := CatalogConfig{ + APIVersion: "catalog.kubeflow.org/v1alpha1", + Kind: "CatalogConfig", + Metadata: CatalogMetadata{ + Name: sanitizeCatalogName(name), + }, + Spec: CatalogSpec{ + Package: packageName, + Entity: EntityConfig{ + Name: entityName, + Properties: []PropertyConfig{}, + }, + Providers: []ProviderConfig{ + {Type: "yaml"}, + }, + API: APIConfig{ + BasePath: fmt.Sprintf("/api/%s_catalog/v1alpha1", sanitizeCatalogName(name)), + Port: 8081, + }, + }, + } + + configPath := filepath.Join(outputDir, "catalog.yaml") + configFile, err := os.Create(configPath) + if err != nil { + return fmt.Errorf("failed to create config file: %w", err) + } + defer func() { _ = configFile.Close() }() + + encoder := yaml.NewEncoder(configFile) + encoder.SetIndent(2) + if err := encoder.Encode(config); err != nil { + return fmt.Errorf("failed to write config: %w", err) + } + // Append a comment about BaseResource fields (YAML encoder doesn't support comments) + comment := ` +# The following fields are already included from BaseResource and should NOT +# be added to properties above: +# name, id, externalId, description, customProperties, +# createTimeSinceEpoch, lastUpdateTimeSinceEpoch +` + if _, err := configFile.WriteString(comment); err != nil { + return fmt.Errorf("failed to write config comment: %w", err) + } + + // Change to output directory to use generate functions + originalDir, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current directory: %w", err) + } + if err := os.Chdir(outputDir); err != nil { + return fmt.Errorf("failed to change to output directory: %w", err) + } + defer func() { _ = os.Chdir(originalDir) }() + + fmt.Println("\n=== Generating plugin files ===") + + // Generate plugin.go and register.go (instead of cmd/main.go) + if err := generatePluginFiles(config); err != nil { + return fmt.Errorf("failed to generate plugin files: %w", err) + } + + fmt.Println("\n=== Generating editable files (created once, you can modify) ===") + + // Generate repository (same as standalone) + if err := generateRepository(config); err != nil { + return fmt.Errorf("failed to generate repository: %w", err) + } + + // Generate OpenAPI main file + if err := generateOpenAPIMain(config); err != nil { + return fmt.Errorf("failed to generate OpenAPI main: %w", err) + } + fmt.Printf(" Created: api/openapi/src/openapi.yaml\n") + + // Generate OpenAPI service implementation + if err := generateOpenAPIServiceImpl(config); err != nil { + return fmt.Errorf("failed to generate OpenAPI service impl: %w", err) + } + fmt.Printf(" Created: internal/server/openapi/api_%s_service_impl.go\n", strings.ToLower(entityName)) + + // Generate YAML provider + for _, provider := range config.Spec.Providers { + if provider.Type == "yaml" { + if err := generateYAMLProvider(config); err != nil { + return fmt.Errorf("failed to generate YAML provider: %w", err) + } + } + } + + // Generate Makefile (simplified for plugins) + if err := generatePluginMakefile(config); err != nil { + return fmt.Errorf("failed to generate Makefile: %w", err) + } + fmt.Printf(" Created: Makefile\n") + + // Generate README for plugin + if err := generatePluginREADME(config); err != nil { + return fmt.Errorf("failed to generate README: %w", err) + } + fmt.Printf(" Created: README.md\n") + + // Generate .gitignore + if err := generateGitignore(); err != nil { + return fmt.Errorf("failed to generate .gitignore: %w", err) + } + fmt.Printf(" Created: .gitignore\n") + + // Generate .openapi-generator-ignore + if err := generateOpenAPIGeneratorIgnore(); err != nil { + return fmt.Errorf("failed to generate .openapi-generator-ignore: %w", err) + } + fmt.Printf(" Created: .openapi-generator-ignore\n") + + // Generate Claude Code skills and commands + if err := generateClaudeSkills(config); err != nil { + return fmt.Errorf("failed to generate Claude skills: %w", err) + } + + // Create symlink to shared common schemas (BaseResource, etc.) + if err := ensureCommonLibSymlink(); err != nil { + return fmt.Errorf("failed to create common lib symlink: %w", err) + } + + fmt.Println("\n=== Generating auto-regenerated files ===") + + // Generate entity model + if err := generateEntityModel(config); err != nil { + return fmt.Errorf("failed to generate entity model: %w", err) + } + + // Generate datastore spec + if err := generateDatastoreSpec(config); err != nil { + return fmt.Errorf("failed to generate datastore spec: %w", err) + } + + // Generate filter mappings for filterQuery support + if err := generateFilterMappings(config); err != nil { + return fmt.Errorf("failed to generate filter mappings: %w", err) + } + + // Generate OpenAPI components + if err := generateOpenAPIComponents(config); err != nil { + return fmt.Errorf("failed to generate OpenAPI components: %w", err) + } + + // Generate loader + if err := generateLoader(config); err != nil { + return fmt.Errorf("failed to generate loader: %w", err) + } + + fmt.Printf("\nCatalog %s initialized successfully!\n", name) + fmt.Println("\nNext steps:") + fmt.Println(" 1. Run 'make gen/openapi-server' to generate OpenAPI handlers") + fmt.Println(" 2. Import the plugin in cmd/catalog-server/main.go:") + fmt.Printf(" _ \"%s\"\n", packageName) + fmt.Println(" 3. Add the plugin to sources.yaml under catalogs:") + fmt.Printf(" %s:\n", name) + fmt.Println(" sources:") + fmt.Println(" - id: \"my-source\"") + fmt.Println(" type: \"yaml\"") + + return nil +} + +// generatePluginFiles generates the plugin.go and register.go files. +func generatePluginFiles(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + + // Always use 'any' for artifact type to avoid import cycles + artifactType := "any" + + // Extract the last segment of the catalog name for the Go package name + // e.g., "catalog/plugins/mcp" -> "mcp" + packageName := filepath.Base(config.Metadata.Name) + + data := map[string]any{ + "Name": packageName, + "PackageName": packageName, + "EntityName": entityName, + "Package": config.Spec.Package, + "BasePath": config.Spec.API.BasePath, + "ArtifactType": artifactType, + } + + // Generate plugin.go + if err := executeTemplate(TmplPluginPlugin, "plugin.go", data); err != nil { + return fmt.Errorf("failed to generate plugin.go: %w", err) + } + fmt.Printf(" Created: plugin.go\n") + + // Generate register.go + if err := executeTemplate(TmplPluginRegister, "register.go", data); err != nil { + return fmt.Errorf("failed to generate register.go: %w", err) + } + fmt.Printf(" Created: register.go\n") + + return nil +} + +// generatePluginMakefile generates a Makefile for plugins with proper OpenAPI generation. +func generatePluginMakefile(config CatalogConfig) error { + content := fmt.Sprintf(`# Generated by catalog-gen - you can modify this file + +PACKAGE := %s +CATALOG_NAME := %s +PROJECT_ROOT := $(shell pwd) +REPO_ROOT := $(shell git rev-parse --show-toplevel) + +# Use local binary from repo bin/ if available, fall back to Docker +OPENAPI_GENERATOR ?= $(if $(wildcard $(REPO_ROOT)/bin/openapi-generator-cli),$(REPO_ROOT)/bin/openapi-generator-cli,docker run --rm -v $(PROJECT_ROOT):/local -w /local openapitools/openapi-generator-cli:v7.13.0) +CATALOG_GEN ?= $(if $(wildcard $(REPO_ROOT)/bin/catalog-gen),$(REPO_ROOT)/bin/catalog-gen,catalog-gen) + +.PHONY: all build test gen/catalog gen/openapi gen/openapi-server gen/openapi-client clean + +all: gen/openapi-server build + +# Regenerate code from catalog.yaml +gen/catalog: + $(CATALOG_GEN) generate + +build: + go build ./... + +test: + go test ./... + +# Merge OpenAPI spec with generated components +api/openapi/openapi.yaml: api/openapi/src/openapi.yaml api/openapi/src/generated/components.yaml + @echo "Merging OpenAPI specs..." + @mkdir -p api/openapi + @test -e api/openapi/lib || ln -s src/lib api/openapi/lib + @cat api/openapi/src/openapi.yaml > api/openapi/openapi.yaml + @echo "" >> api/openapi/openapi.yaml + @cat api/openapi/src/generated/components.yaml >> api/openapi/openapi.yaml + @echo "Merged to api/openapi/openapi.yaml" + +gen/openapi: api/openapi/openapi.yaml + +# Generate OpenAPI server code (controllers, models, routers) +# The generated controller calls your service implementation in internal/server/openapi/api_*_service_impl.go +gen/openapi-server: api/openapi/openapi.yaml + @echo "Generating OpenAPI server code..." + @mkdir -p internal/server/openapi + $(OPENAPI_GENERATOR) generate \ + -i api/openapi/openapi.yaml \ + -g go-server \ + -o internal/server/openapi \ + --package-name openapi \ + --ignore-file-override .openapi-generator-ignore \ + --additional-properties=outputAsLibrary=true,router=chi,sourceFolder=,onlyInterfaces=true,isGoSubmodule=true,enumClassPrefix=true + @echo "Running goimports..." + @command -v goimports >/dev/null 2>&1 && goimports -w internal/server/openapi || echo "goimports not found, skipping" + @echo "Done" + +# Generate OpenAPI client code (optional - for SDK generation) +gen/openapi-client: api/openapi/openapi.yaml + @echo "Generating OpenAPI client code..." + @mkdir -p pkg/openapi + $(OPENAPI_GENERATOR) generate \ + -i api/openapi/openapi.yaml \ + -g go \ + -o pkg/openapi \ + --package-name openapi \ + --ignore-file-override .openapi-generator-ignore \ + --additional-properties=isGoSubmodule=true,enumClassPrefix=true + @command -v goimports >/dev/null 2>&1 && goimports -w pkg/openapi || echo "goimports not found, skipping" + @echo "Done" + +clean: + rm -rf internal/server/openapi/.openapi-generator + rm -rf pkg/openapi/.openapi-generator + rm -f api/openapi/openapi.yaml +`, config.Spec.Package, config.Metadata.Name) + + return os.WriteFile("Makefile", []byte(content), 0644) +} + +// generatePluginREADME generates a README.md for the plugin. +func generatePluginREADME(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerEntity := strings.ToLower(entityName) + + // Build list of filterable properties + filterableProps := []string{"name", "externalId"} + for _, prop := range config.Spec.Entity.Properties { + lowerName := strings.ToLower(prop.Name) + if lowerName == "name" || lowerName == "externalid" || lowerName == "id" || + lowerName == "createtimesinceepoch" || lowerName == "lastupdatetimesinceepoch" { + continue + } + filterableProps = append(filterableProps, prop.Name) + } + + // Pick first custom property for example, or fallback + exampleProp := "name" + exampleValue := "'example'" + if len(config.Spec.Entity.Properties) > 0 { + for _, prop := range config.Spec.Entity.Properties { + if strings.ToLower(prop.Name) != "name" && strings.ToLower(prop.Name) != "externalid" { + exampleProp = prop.Name + if prop.Type == "integer" || prop.Type == "int" || prop.Type == "int64" || prop.Type == "number" { + exampleValue = "5" + } else { + exampleValue = "'example'" + } + break + } + } + } + + content := fmt.Sprintf(`# %s Catalog Plugin + +This is a catalog plugin generated by catalog-gen for the unified catalog server. + +## Overview + +- **Entity**: %s +- **Package**: %s +- **API Base Path**: %s + +## Usage + +### 1. Generate OpenAPI Handlers + +`+"```bash"+` +make gen/openapi-server +`+"```"+` + +### 2. Import the Plugin + +Add the plugin import to `+"`cmd/catalog-server/main.go`"+`: + +`+"```go"+` +import ( + // Import plugins - their init() registers them + _ "%s" +) +`+"```"+` + +### 3. Configure Sources + +Add the plugin configuration to your `+"`sources.yaml`"+`: + +`+"```yaml"+` +catalogs: + %s: + sources: + - id: "my-source" + name: "My Data Source" + type: "yaml" + properties: + yamlCatalogPath: "./data/%s.yaml" +`+"```"+` + +### 4. Build and Run + +`+"```bash"+` +go build ./cmd/catalog-server +./catalog-server --sources=./sources.yaml --listen=:8080 +`+"```"+` + +## Filtering + +List endpoints support advanced filtering via `+"`filterQuery`"+`: + +`+"```bash"+` +# Filter by property +curl "http://localhost:8080%s/%ss?filterQuery=%s=%s" + +# Multiple conditions +curl "http://localhost:8080%s/%ss?filterQuery=%s=%s AND name LIKE '%%25server%%25'" + +# Pattern matching +curl "http://localhost:8080%s/%ss?filterQuery=name LIKE '%%25server%%25'" + +# Ordering +curl "http://localhost:8080%s/%ss?orderBy=name&sortOrder=DESC" +`+"```"+` + +Supported operators: `+"`` = ``"+`, `+"`` != ``"+`, `+"`` > ``"+`, `+"`` < ``"+`, `+"`` >= ``"+`, `+"`` <= ``"+`, `+"`` LIKE ``"+`, `+"`` ILIKE ``"+`, `+"`` IN ``"+`, `+"`` AND ``"+`, `+"`` OR ``"+` + +Filterable properties: %s + +## Development + +### Adding Properties + +Edit `+"`catalog.yaml`"+` and run: + +`+"```bash"+` +catalog-gen generate +`+"```"+` + +### Adding Artifacts + +`+"```bash"+` +catalog-gen add-artifact MyArtifact +`+"```"+` + +## Files + +| File | Description | +|------|-------------| +| `+"`plugin.go`"+` | Plugin implementation (auto-generated) | +| `+"`register.go`"+` | Plugin registration (auto-generated) | +| `+"`internal/db/models/`"+` | Entity models | +| `+"`internal/db/service/`"+` | Repository implementations | +| `+"`internal/catalog/`"+` | Data loader and providers | +| `+"`internal/server/openapi/`"+` | API handler implementations | +`, config.Metadata.Name, entityName, config.Spec.Package, + config.Spec.API.BasePath, config.Spec.Package, config.Metadata.Name, + lowerEntity, + config.Spec.API.BasePath, lowerEntity, exampleProp, exampleValue, + config.Spec.API.BasePath, lowerEntity, exampleProp, exampleValue, + config.Spec.API.BasePath, lowerEntity, + config.Spec.API.BasePath, lowerEntity, + strings.Join(filterableProps, ", ")) + + return os.WriteFile("README.md", []byte(content), 0644) +} + +// generateClaudeSkills generates Claude Code skills and commands for the plugin. +func generateClaudeSkills(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + entityNameLower := strings.ToLower(entityName) + + data := map[string]any{ + "Name": config.Metadata.Name, + "EntityName": entityName, + "EntityNameLower": entityNameLower, + "Package": config.Spec.Package, + "BasePath": config.Spec.API.BasePath, + "HasArtifacts": len(config.Spec.Artifacts) > 0, + } + + // Create directories + if err := os.MkdirAll(".claude/commands", 0755); err != nil { + return err + } + if err := os.MkdirAll(".claude/skills", 0755); err != nil { + return err + } + + // Generate CLAUDE.md + if err := executeTemplate(TmplAgentClaudeMD, "CLAUDE.md", data); err != nil { + return err + } + fmt.Printf(" Created: CLAUDE.md\n") + + // Generate commands + commands := []struct{ tmpl, file string }{ + {TmplAgentCmdAddProperty, ".claude/commands/add-property.md"}, + {TmplAgentCmdAddArtifact, ".claude/commands/add-artifact.md"}, + {TmplAgentCmdAddArtifactProp, ".claude/commands/add-artifact-property.md"}, + {TmplAgentCmdRegenerate, ".claude/commands/regenerate.md"}, + {TmplAgentCmdFixBuild, ".claude/commands/fix-build.md"}, + {TmplAgentCmdGenTestdata, ".claude/commands/gen-testdata.md"}, + } + for _, cmd := range commands { + if err := executeTemplate(cmd.tmpl, cmd.file, data); err != nil { + return err + } + } + fmt.Printf(" Created: .claude/commands/ (%d commands)\n", len(commands)) + + // Generate skills + skills := []struct{ tmpl, file string }{ + {TmplAgentSkillAddProperty, ".claude/skills/add-property.md"}, + {TmplAgentSkillAddArtifact, ".claude/skills/add-artifact.md"}, + {TmplAgentSkillAddArtifactProp, ".claude/skills/add-artifact-property.md"}, + {TmplAgentSkillRegenerate, ".claude/skills/regenerate.md"}, + {TmplAgentSkillGenTestdata, ".claude/skills/gen-testdata.md"}, + } + for _, skill := range skills { + if err := executeTemplate(skill.tmpl, skill.file, data); err != nil { + return err + } + } + fmt.Printf(" Created: .claude/skills/ (%d skills)\n", len(skills)) + + return nil +} diff --git a/cmd/catalog-gen/gen_providers.go b/cmd/catalog-gen/gen_providers.go new file mode 100644 index 0000000000..f1c827ce24 --- /dev/null +++ b/cmd/catalog-gen/gen_providers.go @@ -0,0 +1,237 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" +) + +// generateLoader generates the catalog loader file. +func generateLoader(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + hasArtifacts := len(config.Spec.Artifacts) > 0 + + // Build artifact type switch cases for SaveArtifact + var artifactSaveCases strings.Builder + var artifactDeleteCalls strings.Builder + if hasArtifacts { + for _, artifact := range config.Spec.Artifacts { + artifactSaveCases.WriteString(fmt.Sprintf(` case models.%s%sArtifact: + _, err := services.%s%sArtifactRepository.Save(a, &entityID) + return err +`, entityName, artifact.Name, entityName, artifact.Name)) + artifactDeleteCalls.WriteString(fmt.Sprintf(` if err := services.%s%sArtifactRepository.DeleteByParentID(entityID); err != nil { + return err + } +`, entityName, artifact.Name)) + } + } + + // Always use 'any' for artifact type to avoid import cycles with providers + artifactType := "any" + + data := map[string]any{ + "EntityName": entityName, + "Package": config.Spec.Package, + "HasArtifacts": hasArtifacts, + "ArtifactType": artifactType, + "ArtifactSaveCases": artifactSaveCases.String(), + "ArtifactDeleteCalls": artifactDeleteCalls.String(), + } + + catalogDir := filepath.Join("internal", "catalog") + if err := ensureDir(catalogDir); err != nil { + return err + } + + fmt.Printf(" Generated: internal/catalog/loader.go\n") + return executeTemplate(TmplCatalogLoader, filepath.Join(catalogDir, "loader.go"), data) +} + +// generateYAMLProvider generates the YAML provider file. +func generateYAMLProvider(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerName := strings.ToLower(entityName) + hasArtifacts := len(config.Spec.Artifacts) > 0 + + // Create providers directory + providersDir := filepath.Join("internal", "catalog", "providers") + if err := ensureDir(providersDir); err != nil { + return err + } + + // Determine artifact type for the provider + artifactType := "any" + if hasArtifacts { + artifactType = "catalog.Artifact" + } + + // Build artifact structs and parsing code + var artifactStructs strings.Builder + var artifactParseCode strings.Builder + var artifactMatchCode strings.Builder + if hasArtifacts { + // Generate struct for each artifact type + for _, artifact := range config.Spec.Artifacts { + lowerArtifactName := strings.ToLower(artifact.Name) + artifactStructs.WriteString(fmt.Sprintf(` +// yaml%s%s represents a %s entry in the artifacts YAML file. +type yaml%s%s struct { + %sName string %s + Name string %s +`, entityName, artifact.Name, artifact.Name, entityName, artifact.Name, entityName, + "`json:\""+lowerName+"Name\" yaml:\""+lowerName+"Name\"`", + "`json:\"name\" yaml:\"name\"`")) + for _, prop := range artifact.Properties { + goType := goTypeFromSpec(prop.Type) + // Remove pointer for optional fields in yaml struct + yamlGoType := strings.TrimPrefix(goType, "*") + artifactStructs.WriteString(fmt.Sprintf("\t%s %s `json:\"%s,omitempty\" yaml:\"%s,omitempty\"`\n", + capitalize(prop.Name), yamlGoType, prop.Name, prop.Name)) + } + artifactStructs.WriteString("}\n") + + // Add to artifacts catalog struct + artifactStructs.WriteString(fmt.Sprintf(` +// yaml%sArtifactsCatalog is the structure of the artifacts YAML file. +type yaml%sArtifactsCatalog struct { + %ss []yaml%s%s %s +} +`, entityName, entityName, artifact.Name, entityName, artifact.Name, + "`json:\""+lowerArtifactName+"s\" yaml:\""+lowerArtifactName+"s\"`")) + } + + // Generate artifact parsing code + artifactParseCode.WriteString(` + // Parse artifacts file if provided + artifactsByEntity := make(map[string][]catalog.Artifact) + if artifactsData != nil { +`) + for _, artifact := range config.Spec.Artifacts { + lowerArtifactName := strings.ToLower(artifact.Name) + artifactParseCode.WriteString(fmt.Sprintf(` var %sArtifacts yaml%sArtifactsCatalog + if err := k8syaml.UnmarshalStrict(artifactsData, &%sArtifacts); err == nil { + for _, a := range %sArtifacts.%ss { + entityName := a.%sName + artifactName := a.Name +`, lowerArtifactName, entityName, lowerArtifactName, lowerArtifactName, artifact.Name, entityName)) + // Build property assignments + var propAssignments strings.Builder + for _, prop := range artifact.Properties { + propName := capitalize(prop.Name) + goType := goTypeFromSpec(prop.Type) + if strings.HasPrefix(goType, "*") { + // For pointer types, take address + propAssignments.WriteString(fmt.Sprintf("\t\t\t\t\t%s: &a.%s,\n", propName, propName)) + } else { + propAssignments.WriteString(fmt.Sprintf("\t\t\t\t\t%s: a.%s,\n", propName, propName)) + } + } + artifactParseCode.WriteString(fmt.Sprintf(` artifact := models.New%s%sArtifact(&models.%s%sArtifactAttributes{ + Name: &artifactName, +%s }) + artifactsByEntity[entityName] = append(artifactsByEntity[entityName], artifact) + } + } +`, entityName, artifact.Name, entityName, artifact.Name, propAssignments.String())) + } + artifactParseCode.WriteString(` } +`) + + // Generate code to match artifacts to entities + artifactMatchCode.WriteString(` + // Attach artifacts to this entity + if artifacts, ok := artifactsByEntity[name]; ok { + record.Artifacts = artifacts + } +`) + } + + // Build entity property struct fields for yaml struct + var entityPropertyFields strings.Builder + var entityPropertyAssignments strings.Builder + for _, prop := range config.Spec.Entity.Properties { + if prop.Name == "description" || prop.Name == "externalId" || prop.Name == "name" || prop.Name == "customProperties" { + continue // Skip base fields already in template + } + goType := goTypeFromSpec(prop.Type) + // For yaml struct, use non-pointer types with omitempty + yamlGoType := strings.TrimPrefix(goType, "*") + if prop.Type == "array" { + // Determine array item type from items field + itemType := "string" // default + if prop.Items != nil && prop.Items.Type != "" { + itemType = strings.TrimPrefix(goTypeFromSpec(prop.Items.Type), "*") + } + yamlGoType = "[]" + itemType + } + propName := capitalize(prop.Name) + entityPropertyFields.WriteString(fmt.Sprintf("\t%s %s `json:\"%s,omitempty\" yaml:\"%s,omitempty\"`\n", + propName, yamlGoType, prop.Name, prop.Name)) + // Generate assignment code - handle pointers and arrays + if prop.Type == "array" { + entityPropertyAssignments.WriteString(fmt.Sprintf(` // Handle %s (array -> comma-separated string) + if len(item.%s) > 0 { + %sStr := "" + for i, v := range item.%s { + if i > 0 { + %sStr += "," + } + %sStr += v + } + entity.GetAttributes().%s = &%sStr + } +`, propName, propName, prop.Name, propName, prop.Name, prop.Name, propName, prop.Name)) + } else if strings.HasPrefix(goType, "*") { + // Pointer type - take address + entityPropertyAssignments.WriteString(fmt.Sprintf("\t\tentity.GetAttributes().%s = &item.%s\n", propName, propName)) + } else { + entityPropertyAssignments.WriteString(fmt.Sprintf("\t\tentity.GetAttributes().%s = item.%s\n", propName, propName)) + } + } + + data := map[string]any{ + "Package": config.Spec.Package, + "EntityName": entityName, + "EntityNameLower": lowerName, + "HasArtifacts": hasArtifacts, + "ArtifactType": artifactType, + "ArtifactStructs": artifactStructs.String(), + "ArtifactParseCode": artifactParseCode.String(), + "ArtifactMatchCode": artifactMatchCode.String(), + "EntityPropertyFields": entityPropertyFields.String(), + "EntityPropertyAssignments": entityPropertyAssignments.String(), + } + + providerPath := filepath.Join(providersDir, "yaml_provider.go") + if err := executeTemplate(TmplProvidersYAML, providerPath, data); err != nil { + return err + } + fmt.Printf(" Generated: %s\n", providerPath) + + return nil +} + +// generateProviderFile generates a provider file for the specified type. +func generateProviderFile(entityName, providerType string) error { + data := map[string]any{ + "EntityName": entityName, + } + + providersDir := filepath.Join("internal", "catalog", "providers") + if err := ensureDir(providersDir); err != nil { + return fmt.Errorf("failed to create providers directory: %w", err) + } + + var templatePath string + switch providerType { + case "yaml": + templatePath = TmplProvidersYAML + case "http": + templatePath = TmplProvidersHTTP + default: + return fmt.Errorf("unknown provider type: %s", providerType) + } + + return executeTemplate(templatePath, filepath.Join(providersDir, fmt.Sprintf("%s_provider.go", providerType)), data) +} diff --git a/cmd/catalog-gen/gen_service.go b/cmd/catalog-gen/gen_service.go new file mode 100644 index 0000000000..bab903c5a5 --- /dev/null +++ b/cmd/catalog-gen/gen_service.go @@ -0,0 +1,166 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" +) + +// generateRepository generates the entity repository file. +func generateRepository(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerName := strings.ToLower(entityName) + + // Build dynamic property mapping code + propVarDecls := buildPropertyVarDeclarations(config.Spec.Entity.Properties) + propReadCases := buildPropertyReadCases(config.Spec.Entity.Properties) + propAttrAssignments := buildPropertyAttrAssignments(config.Spec.Entity.Properties) + propWriteStatements := buildPropertyWriteStatements(config.Spec.Entity.Properties) + + data := map[string]any{ + "EntityName": entityName, + "EntityNameLower": lowerName, + "Package": config.Spec.Package, + "PropVarDecls": propVarDecls, + "PropReadCases": propReadCases, + "PropAttrAssignments": propAttrAssignments, + "PropWriteStatements": propWriteStatements, + } + + serviceDir := filepath.Join("internal", "db", "service") + if err := ensureDir(serviceDir); err != nil { + return err + } + + outputPath := filepath.Join(serviceDir, fmt.Sprintf("%s.go", lowerName)) + fmt.Printf(" Generated: internal/db/service/%s.go\n", lowerName) + return executeTemplate(TmplServiceRepository, outputPath, data) +} + +// generateArtifactRepository generates an artifact repository file. +func generateArtifactRepository(config CatalogConfig, artifact ArtifactConfig) error { + entityName := config.Spec.Entity.Name + artifactName := artifact.Name + lowerEntityName := strings.ToLower(entityName) + lowerArtifactName := strings.ToLower(artifactName) + + // Build property mapping code for artifact properties + propVarDecls := buildPropertyVarDeclarations(artifact.Properties) + propReadCases := buildPropertyReadCases(artifact.Properties) + propAttrAssignments := buildPropertyAttrAssignments(artifact.Properties) + propWriteStatements := buildArtifactPropertyWriteStatements(artifact.Properties) + + data := map[string]any{ + "Package": config.Spec.Package, + "EntityName": entityName, + "ArtifactName": artifactName, + "LowerEntityName": lowerEntityName, + "LowerArtifactName": lowerArtifactName, + "PropVarDecls": propVarDecls, + "PropReadCases": propReadCases, + "PropAttrAssignments": propAttrAssignments, + "PropWriteStatements": propWriteStatements, + } + + serviceDir := filepath.Join("internal", "db", "service") + if err := ensureDir(serviceDir); err != nil { + return err + } + + filename := fmt.Sprintf("%s_%s_artifact.go", lowerEntityName, lowerArtifactName) + fmt.Printf(" Generated: internal/db/service/%s\n", filename) + return executeTemplate(TmplServiceArtifactRepository, filepath.Join(serviceDir, filename), data) +} + +// generateDatastoreSpec generates the datastore spec file. +func generateDatastoreSpec(config CatalogConfig) error { + entityName := config.Spec.Entity.Name + lowerEntityName := strings.ToLower(entityName) + + // Build property definitions from config + var propDefs []string + for _, prop := range config.Spec.Entity.Properties { + propMethod := datastorePropertyMethod(prop.Type) + propDefs = append(propDefs, fmt.Sprintf("\t\t\t%s(\"%s\")", propMethod, prop.Name)) + } + + // Join properties with ".\n" but don't add trailing dot + propDefsStr := "" + if len(propDefs) > 0 { + propDefsStr = strings.Join(propDefs, ".\n") + "," + } + + // Build artifact type constants + var artifactConstants strings.Builder + for _, artifact := range config.Spec.Artifacts { + artifactConstants.WriteString(fmt.Sprintf("\t%s%sArtifactTypeName = \"kf.%s%sArtifact\"\n", + entityName, artifact.Name, entityName, artifact.Name)) + } + + // Build artifact spec registrations + var artifactSpecs strings.Builder + for i, artifact := range config.Spec.Artifacts { + // Build artifact property definitions + var artifactPropDefs []string + for _, prop := range artifact.Properties { + propMethod := datastorePropertyMethod(prop.Type) + artifactPropDefs = append(artifactPropDefs, fmt.Sprintf("\t\t\t%s(\"%s\")", propMethod, prop.Name)) + } + artifactPropDefsStr := "" + if len(artifactPropDefs) > 0 { + artifactPropDefsStr = strings.Join(artifactPropDefs, ".\n") + "," + } + + // Add trailing dot only if not the last artifact + trailingDot := "" + if i < len(config.Spec.Artifacts)-1 { + trailingDot = "." + } + + artifactSpecs.WriteString(fmt.Sprintf(` AddArtifact(%s%sArtifactTypeName, datastore.NewSpecType(New%s%sArtifactRepository). +%s + )%s +`, entityName, artifact.Name, entityName, artifact.Name, artifactPropDefsStr, trailingDot)) + } + + // Build Services struct fields for artifacts + var artifactServiceFields strings.Builder + var artifactServiceParams strings.Builder + var artifactServiceAssignments strings.Builder + for _, artifact := range config.Spec.Artifacts { + artifactServiceFields.WriteString(fmt.Sprintf("\t%s%sArtifactRepository models.%s%sArtifactRepository\n", + entityName, artifact.Name, entityName, artifact.Name)) + artifactServiceParams.WriteString(fmt.Sprintf("\t%s%sArtifactRepository models.%s%sArtifactRepository,\n", + lowerEntityName, artifact.Name, entityName, artifact.Name)) + artifactServiceAssignments.WriteString(fmt.Sprintf("\t\t%s%sArtifactRepository: %s%sArtifactRepository,\n", + entityName, artifact.Name, lowerEntityName, artifact.Name)) + } + + // If there are artifacts, the context registration needs to chain to them + hasArtifacts := len(config.Spec.Artifacts) > 0 + contextTrailingDot := "" + if hasArtifacts { + contextTrailingDot = "." + } + + data := map[string]any{ + "EntityName": entityName, + "EntityNameLower": lowerEntityName, + "Package": config.Spec.Package, + "PropertyDefs": propDefsStr, + "ArtifactConstants": artifactConstants.String(), + "ArtifactSpecs": artifactSpecs.String(), + "ArtifactServiceFields": artifactServiceFields.String(), + "ArtifactServiceParams": artifactServiceParams.String(), + "ArtifactServiceAssignments": artifactServiceAssignments.String(), + "ContextTrailingDot": contextTrailingDot, + } + + serviceDir := filepath.Join("internal", "db", "service") + if err := ensureDir(serviceDir); err != nil { + return err + } + + fmt.Printf(" Generated: internal/db/service/spec.go\n") + return executeTemplate(TmplServiceSpec, filepath.Join(serviceDir, "spec.go"), data) +} diff --git a/cmd/catalog-gen/gen_testdata.go b/cmd/catalog-gen/gen_testdata.go new file mode 100644 index 0000000000..4bc0c9372d --- /dev/null +++ b/cmd/catalog-gen/gen_testdata.go @@ -0,0 +1,160 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/spf13/cobra" +) + +func newGenTestdataCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "gen-testdata", + Short: "Generate testdata files for testing the catalog plugin", + Long: `Generate testdata files for testing the catalog plugin locally. + +This creates: + - testdata/test-sources.yaml (catalog server config) + - testdata/-sources.yaml (loader config) + - testdata/s.yaml (sample entity data) + +Example: + catalog-gen gen-testdata`, + RunE: func(cmd *cobra.Command, args []string) error { + return generateTestdata() + }, + } + + return cmd +} + +func generateTestdata() error { + config, err := loadConfig() + if err != nil { + return err + } + + entityName := config.Spec.Entity.Name + entityNameLower := strings.ToLower(entityName) + catalogName := filepath.Base(config.Metadata.Name) + + // Create testdata directory + testdataDir := "testdata" + if err := os.MkdirAll(testdataDir, 0755); err != nil { + return fmt.Errorf("failed to create testdata directory: %w", err) + } + + // Generate sample entity data file + entityDataPath := filepath.Join(testdataDir, fmt.Sprintf("%ss.yaml", entityNameLower)) + if err := generateSampleEntityData(config, entityDataPath); err != nil { + return fmt.Errorf("failed to generate sample entity data: %w", err) + } + fmt.Printf(" Generated: %s\n", entityDataPath) + + // Generate loader sources config + loaderSourcesPath := filepath.Join(testdataDir, fmt.Sprintf("%s-sources.yaml", entityNameLower)) + if err := generateLoaderSourcesConfig(config, loaderSourcesPath, entityDataPath); err != nil { + return fmt.Errorf("failed to generate loader sources config: %w", err) + } + fmt.Printf(" Generated: %s\n", loaderSourcesPath) + + // Generate catalog server config + serverConfigPath := filepath.Join(testdataDir, "test-sources.yaml") + if err := generateServerSourcesConfig(config, serverConfigPath, loaderSourcesPath, entityDataPath); err != nil { + return fmt.Errorf("failed to generate server sources config: %w", err) + } + fmt.Printf(" Generated: %s\n", serverConfigPath) + + fmt.Println("\nTestdata generated successfully!") + fmt.Println("\nTo test the plugin, run:") + fmt.Printf(" ./catalog-server --sources=%s --listen=:8080 --db-dsn=\"host=localhost port=5432 user=postgres password=postgres dbname=model_registry sslmode=disable\"\n", serverConfigPath) + fmt.Println("\nThen test with:") + fmt.Printf(" curl -s http://localhost:8080/api/%s/v1alpha1/%ss | jq\n", catalogName, entityNameLower) + + return nil +} + +func generateSampleEntityData(config CatalogConfig, outputPath string) error { + entityName := config.Spec.Entity.Name + entityNameLower := strings.ToLower(entityName) + + var content strings.Builder + content.WriteString(fmt.Sprintf("%ss:\n", entityNameLower)) + + // Track which shared fields the user already declared as custom properties + declaredProps := make(map[string]bool) + for _, prop := range config.Spec.Entity.Properties { + declaredProps[strings.ToLower(prop.Name)] = true + } + + // Generate 3 sample entities + for i := 1; i <= 3; i++ { + // BaseResource shared fields + content.WriteString(fmt.Sprintf(" - name: \"sample-%s-%d\"\n", entityNameLower, i)) + content.WriteString(fmt.Sprintf(" externalId: \"ext-%s-%d\"\n", entityNameLower, i)) + content.WriteString(fmt.Sprintf(" description: \"Sample %s %d description\"\n", entityName, i)) + + // Add custom properties from catalog.yaml + for _, prop := range config.Spec.Entity.Properties { + lower := strings.ToLower(prop.Name) + // Skip fields already emitted as BaseResource shared fields + if lower == "name" || lower == "externalid" || lower == "description" { + continue + } + + switch prop.Type { + case "string": + content.WriteString(fmt.Sprintf(" %s: \"sample-%s-value-%d\"\n", prop.Name, prop.Name, i)) + case "integer", "int": + content.WriteString(fmt.Sprintf(" %s: %d\n", prop.Name, i*10)) + case "int64": + content.WriteString(fmt.Sprintf(" %s: %d\n", prop.Name, i*1000)) + case "boolean", "bool": + content.WriteString(fmt.Sprintf(" %s: %t\n", prop.Name, i%2 == 0)) + case "number", "float", "double": + content.WriteString(fmt.Sprintf(" %s: %d.%d\n", prop.Name, i, i)) + case "array": + content.WriteString(fmt.Sprintf(" %s:\n", prop.Name)) + content.WriteString(fmt.Sprintf(" - \"item-%d-a\"\n", i)) + content.WriteString(fmt.Sprintf(" - \"item-%d-b\"\n", i)) + } + } + } + + return os.WriteFile(outputPath, []byte(content.String()), 0644) +} + +func generateLoaderSourcesConfig(_ CatalogConfig, outputPath, entityDataPath string) error { + entityDataFile := filepath.Base(entityDataPath) + + content := fmt.Sprintf(`catalogs: + - id: "test-source" + name: "Test Data Source" + type: "yaml" + properties: + yamlCatalogPath: "./%s" +`, entityDataFile) + + return os.WriteFile(outputPath, []byte(content), 0644) +} + +func generateServerSourcesConfig(config CatalogConfig, outputPath, loaderSourcesPath, entityDataPath string) error { + catalogName := filepath.Base(config.Metadata.Name) + loaderSourcesFile := filepath.Base(loaderSourcesPath) + entityDataFile := filepath.Base(entityDataPath) + + content := fmt.Sprintf(`catalogs: + %s: + sources: + - id: "test-source" + name: "Test Data Source" + type: "yaml" + properties: + loaderConfigPath: "./%s" + yamlCatalogPath: "./%s" +`, catalogName, loaderSourcesFile, entityDataFile) + + return os.WriteFile(outputPath, []byte(content), 0644) +} diff --git a/cmd/catalog-gen/generate.go b/cmd/catalog-gen/generate.go new file mode 100644 index 0000000000..0c7cf17be0 --- /dev/null +++ b/cmd/catalog-gen/generate.go @@ -0,0 +1,243 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/spf13/cobra" +) + +func newGenerateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "generate", + Short: "Generate or regenerate code from catalog.yaml", + Long: `Generate or regenerate the catalog plugin code based on the catalog.yaml configuration. + +This command reads the catalog.yaml file and generates: + - Entity models and repositories + - Artifact models and repositories + - Provider implementations + - OpenAPI specification + - Plugin registration files + +Example: + catalog-gen generate`, + RunE: func(cmd *cobra.Command, args []string) error { + return generate() + }, + } + + return cmd +} + +func generate() error { + config, err := loadConfig() + if err != nil { + return err + } + + fmt.Printf("Regenerating code for catalog plugin: %s\n", config.Metadata.Name) + fmt.Println("Note: Only non-editable files are regenerated. Editable files are created by 'catalog-gen init'.") + fmt.Println() + + // === PLUGIN FILES === + + if err := generatePluginFiles(config); err != nil { + return fmt.Errorf("failed to generate plugin files: %w", err) + } + + // === ALWAYS REGENERATED (not meant to be edited) === + + // Generate entity model + if err := generateEntityModel(config); err != nil { + return fmt.Errorf("failed to generate entity model: %w", err) + } + + // Generate artifact models (if artifacts are configured) + for _, artifact := range config.Spec.Artifacts { + if err := generateArtifactModel(config, artifact); err != nil { + return fmt.Errorf("failed to generate artifact model for %s: %w", artifact.Name, err) + } + if err := generateArtifactRepository(config, artifact); err != nil { + return fmt.Errorf("failed to generate artifact repository for %s: %w", artifact.Name, err) + } + } + + // Generate datastore spec + if err := generateDatastoreSpec(config); err != nil { + return fmt.Errorf("failed to generate datastore spec: %w", err) + } + + // Generate filter mappings for filterQuery support + if err := generateFilterMappings(config); err != nil { + return fmt.Errorf("failed to generate filter mappings: %w", err) + } + + // Ensure symlink to shared common schemas exists (for plugin contexts) + if err := ensureCommonLibSymlink(); err != nil { + return fmt.Errorf("failed to create common lib symlink: %w", err) + } + + // Generate OpenAPI components + if err := generateOpenAPIComponents(config); err != nil { + return fmt.Errorf("failed to generate OpenAPI components: %w", err) + } + + // Generate OpenAPI main file + if err := generateOpenAPIMain(config); err != nil { + return fmt.Errorf("failed to generate OpenAPI main: %w", err) + } + fmt.Printf(" Generated: api/openapi/src/openapi.yaml\n") + + // Generate loader + if err := generateLoader(config); err != nil { + return fmt.Errorf("failed to generate loader: %w", err) + } + + // Regenerate YAML provider if it doesn't exist + yamlProviderPath := filepath.Join("internal", "catalog", "providers", "yaml_provider.go") + if _, err := os.Stat(yamlProviderPath); os.IsNotExist(err) { + for _, provider := range config.Spec.Providers { + if provider.Type == "yaml" { + if err := generateYAMLProvider(config); err != nil { + return fmt.Errorf("failed to generate YAML provider: %w", err) + } + break + } + } + } + + // Generate post-generation documentation (if artifacts are configured) + if len(config.Spec.Artifacts) > 0 { + if err := generatePostGenerationDocs(config); err != nil { + return fmt.Errorf("failed to generate post-generation docs: %w", err) + } + } + + fmt.Println("\nGeneration complete!") + fmt.Println("\nIf you added new properties or artifacts to catalog.yaml, you may need to manually update:") + fmt.Println(" - internal/db/service/.go (property mapping in converters)") + fmt.Println(" - internal/server/openapi/api_*_service_impl.go (OpenAPI conversion)") + fmt.Println(" - internal/catalog/providers/*.go (provider parsing)") + fmt.Println("\nThen run 'make gen/openapi-server' to regenerate OpenAPI handlers.") + + return nil +} + +// ============================================================================= +// Post-Generation Documentation +// ============================================================================= + +func generatePostGenerationDocs(config CatalogConfig) error { + docsDir := "docs" + if err := ensureDir(docsDir); err != nil { + return fmt.Errorf("failed to create docs directory: %w", err) + } + + entityName := config.Spec.Entity.Name + lowerEntity := strings.ToLower(entityName) + catalogName := config.Metadata.Name + + // Build artifact-specific sections + var artifactRepoLines strings.Builder + var artifactMethodsDoc strings.Builder + var artifactConversionsDoc strings.Builder + var artifactChecklist strings.Builder + var artifactYAMLFields strings.Builder + + for _, artifact := range config.Spec.Artifacts { + artifactName := artifact.Name + lowerArtifact := strings.ToLower(artifactName) + fullArtifactName := entityName + artifactName + "Artifact" + repoName := fullArtifactName + "Repository" + + artifactRepoLines.WriteString(fmt.Sprintf(" getRepo[models.%s](repoSet),\n", repoName)) + + artifactMethodsDoc.WriteString(fmt.Sprintf(` +### Get%s%ss Endpoint + +Add this method to `+"`internal/server/openapi/api_%s_service_impl.go`:\n\n", entityName, artifactName, lowerEntity)) + artifactMethodsDoc.WriteString("```go\n") + artifactMethodsDoc.WriteString(fmt.Sprintf(`// Get%s%ss implements DefaultAPIServicer.Get%s%ss +func (s *%sCatalogServiceAPIService) Get%s%ss( + ctx context.Context, + name string, + pageSize int32, + pageToken string, +) (ImplResponse, error) { + // Implementation here +} +`, entityName, artifactName, entityName, artifactName, entityName, entityName, artifactName)) + artifactMethodsDoc.WriteString("```\n") + + artifactConversionsDoc.WriteString(fmt.Sprintf("\n### convert%sToOpenAPIModel\n\n", fullArtifactName)) + artifactConversionsDoc.WriteString("```go\n") + artifactConversionsDoc.WriteString(fmt.Sprintf(`func convert%sToOpenAPIModel(artifact models.%s) %s { + // Implementation here +} +`, fullArtifactName, fullArtifactName, fullArtifactName)) + artifactConversionsDoc.WriteString("```\n") + + artifactChecklist.WriteString(fmt.Sprintf("| `internal/server/openapi/api_%s_service_impl.go` | Add Get%s%ss method | ☐ Manual |\n", lowerEntity, entityName, artifactName)) + artifactChecklist.WriteString(fmt.Sprintf("| `internal/server/openapi/api_%s_service_impl.go` | Add convert%sToOpenAPIModel | ☐ Manual |\n", lowerEntity, fullArtifactName)) + + artifactYAMLFields.WriteString(fmt.Sprintf(" - %sName: \"example-%s\"\n", lowerEntity, lowerEntity)) + artifactYAMLFields.WriteString(fmt.Sprintf(" name: \"example-%s\"\n", lowerArtifact)) + for _, prop := range artifact.Properties { + artifactYAMLFields.WriteString(fmt.Sprintf(" %s: \"example-value\"\n", prop.Name)) + } + } + + content := fmt.Sprintf(`# Post-Artifact Generation Manual Steps + +> **Auto-generated by catalog-gen** - Regenerate with `+"`catalog-gen generate`"+` + +After adding an artifact to `+"`catalog.yaml`"+` and running `+"`catalog-gen generate`"+`, complete these manual steps. + +## Manual Steps Required + +### 1. Update plugin.go + +Add the artifact repository to initServices: + +`+"```go"+` +services := service.NewServices( + getRepo[models.%[3]sRepository](repoSet), +%[4]s) +`+"```"+` + +### 2. Regenerate OpenAPI Server Code + +`+"```bash"+` +make gen/openapi-server +`+"```"+` + +### 3. Implement the Artifact List Endpoint(s) + +%[5]s + +### 4. Add the Artifact Conversion Function(s) + +%[6]s + +## File Checklist + +| File | Action | Status | +|------|--------|--------| +| `+"`plugin.go`"+` | Add artifact repository to initServices | ☐ Manual | +%[9]s +`, lowerEntity, catalogName, entityName, artifactRepoLines.String(), + artifactMethodsDoc.String(), artifactConversionsDoc.String(), + artifactYAMLFields.String(), config.Spec.API.Port, + artifactChecklist.String()) + + docsPath := filepath.Join(docsDir, "post-artifact-generation.md") + if err := os.WriteFile(docsPath, []byte(content), 0644); err != nil { + return fmt.Errorf("failed to write docs file: %w", err) + } + + fmt.Printf(" Generated: %s\n", docsPath) + return nil +} diff --git a/cmd/catalog-gen/helpers.go b/cmd/catalog-gen/helpers.go new file mode 100644 index 0000000000..5d5c85cff1 --- /dev/null +++ b/cmd/catalog-gen/helpers.go @@ -0,0 +1,361 @@ +package main + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +// loadConfig reads and parses catalog.yaml from the current directory. +func loadConfig() (CatalogConfig, error) { + configPath := "catalog.yaml" + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return CatalogConfig{}, fmt.Errorf("catalog.yaml not found in current directory. Run 'catalog-gen init' first") + } + + configData, err := os.ReadFile(configPath) + if err != nil { + return CatalogConfig{}, fmt.Errorf("failed to read catalog.yaml: %w", err) + } + + var config CatalogConfig + if err := yaml.Unmarshal(configData, &config); err != nil { + return CatalogConfig{}, fmt.Errorf("failed to parse catalog.yaml: %w", err) + } + + return config, nil +} + +// saveConfig writes the config back to catalog.yaml. +func saveConfig(config CatalogConfig) error { + configFile, err := os.Create("catalog.yaml") + if err != nil { + return fmt.Errorf("failed to open catalog.yaml: %w", err) + } + defer func() { _ = configFile.Close() }() + + encoder := yaml.NewEncoder(configFile) + encoder.SetIndent(2) + if err := encoder.Encode(config); err != nil { + return fmt.Errorf("failed to write catalog.yaml: %w", err) + } + + return nil +} + +// ensureDir creates a directory if it doesn't exist. +func ensureDir(path string) error { + return os.MkdirAll(path, 0755) +} + +// isPluginContext detects if we're running in a plugin context by checking the current working directory. +func isPluginContext() bool { + wd, err := os.Getwd() + if err != nil { + return false + } + // We're in a plugin if the working directory contains "/catalog/plugins/" + return strings.Contains(wd, "/catalog/plugins/") +} + +// ensureCommonLibSymlink creates a symlink at api/openapi/src/lib pointing to the +// shared common schemas in the repo root. This allows plugin OpenAPI specs to reference +// shared schemas (BaseResource, etc.) via relative paths like 'lib/common.yaml'. +// The symlink is only created in plugin contexts and is a no-op otherwise. +func ensureCommonLibSymlink() error { + if !isPluginContext() { + return nil + } + + symlinkPath := filepath.Join("api", "openapi", "src", "lib") + + // If the symlink (or directory) already exists, nothing to do + if _, err := os.Lstat(symlinkPath); err == nil { + return nil + } + + // Find the repo root via git + out, err := exec.Command("git", "rev-parse", "--show-toplevel").Output() + if err != nil { + return fmt.Errorf("failed to find git repo root: %w", err) + } + repoRoot := strings.TrimSpace(string(out)) + + // Target is the shared lib directory at the repo root + target := filepath.Join(repoRoot, "api", "openapi", "src", "lib") + if _, err := os.Stat(target); os.IsNotExist(err) { + return fmt.Errorf("common lib directory not found at %s", target) + } + + // Ensure parent directory exists + if err := ensureDir(filepath.Dir(symlinkPath)); err != nil { + return err + } + + // Compute relative path from the symlink's parent to the target + symlinkParent, err := filepath.Abs(filepath.Dir(symlinkPath)) + if err != nil { + return fmt.Errorf("failed to resolve symlink parent: %w", err) + } + relTarget, err := filepath.Rel(symlinkParent, target) + if err != nil { + return fmt.Errorf("failed to compute relative path: %w", err) + } + + if err := os.Symlink(relTarget, symlinkPath); err != nil { + return fmt.Errorf("failed to create lib symlink: %w", err) + } + + fmt.Printf(" Created symlink: %s -> %s\n", symlinkPath, relTarget) + return nil +} + +// capitalize returns the string with the first letter uppercased. +func capitalize(s string) string { + if s == "" { + return s + } + return strings.ToUpper(s[:1]) + s[1:] +} + +// goTypeFromSpec converts a spec type to a Go type. +func goTypeFromSpec(specType string) string { + switch specType { + case "string": + return "*string" + case "integer", "int": + return "*int32" + case "int64": + return "*int64" + case "boolean", "bool": + return "*bool" + case "number", "float", "double": + return "*float64" + case "array": + return "[]string" + default: + return "*string" + } +} + +// openAPIType converts a spec type to an OpenAPI type. +func openAPIType(specType string) string { + switch specType { + case "string": + return "string" + case "integer", "int", "int64": + return "integer" + case "boolean", "bool": + return "boolean" + case "number", "float", "double": + return "number" + case "array": + return "array" + default: + return "string" + } +} + +// datastorePropertyMethod returns the datastore method for a property type. +func datastorePropertyMethod(specType string) string { + switch specType { + case "string": + return "AddString" + case "integer", "int", "int64": + return "AddInt" + case "boolean", "bool": + return "AddBool" + case "number", "float", "double": + return "AddDouble" + case "struct", "object": + return "AddStruct" + default: + return "AddString" + } +} + +// mlmdValueField returns the MLMD value field name for a property type. +func mlmdValueField(specType string) string { + switch specType { + case "string": + return "StringValue" + case "integer", "int": + return "IntValue" + case "int64": + return "IntValue" + case "boolean", "bool": + return "IntValue" // Bools are stored as int (0/1) + case "number", "float", "double": + return "DoubleValue" + default: + return "StringValue" + } +} + +// isStructType returns true if the type is a struct/object type. +func isStructType(specType string) bool { + switch specType { + case "struct", "object": + return true + default: + return false + } +} + +// buildPropertyVarDeclarations generates variable declarations for reading properties. +func buildPropertyVarDeclarations(props []PropertyConfig) string { + var sb strings.Builder + for _, prop := range props { + goType := goTypeFromSpec(prop.Type) + varName := strings.ToLower(prop.Name[:1]) + prop.Name[1:] // camelCase + if isStructType(prop.Type) { + sb.WriteString(fmt.Sprintf("\tvar %s %s // TODO: struct type requires manual handling\n", varName, goType)) + } else { + sb.WriteString(fmt.Sprintf("\tvar %s %s\n", varName, goType)) + } + } + return sb.String() +} + +// buildPropertyReadCases generates switch cases for reading properties from ContextProperty. +func buildPropertyReadCases(props []PropertyConfig) string { + var sb strings.Builder + for _, prop := range props { + varName := strings.ToLower(prop.Name[:1]) + prop.Name[1:] + if isStructType(prop.Type) { + sb.WriteString(fmt.Sprintf("\t\tcase \"%s\":\n", prop.Name)) + sb.WriteString(fmt.Sprintf("\t\t\t// TODO: Struct property '%s' requires manual conversion.\n", prop.Name)) + sb.WriteString("\t\t\t// Example: json.Unmarshal([]byte(*p.StringValue), &" + varName + ")\n") + } else { + valueField := mlmdValueField(prop.Type) + sb.WriteString(fmt.Sprintf("\t\tcase \"%s\":\n", prop.Name)) + sb.WriteString(fmt.Sprintf("\t\t\tif p.%s != nil {\n", valueField)) + sb.WriteString(fmt.Sprintf("\t\t\t\t%s = p.%s\n", varName, valueField)) + sb.WriteString("\t\t\t}\n") + } + } + return sb.String() +} + +// buildPropertyAttrAssignments generates attribute assignments for properties. +func buildPropertyAttrAssignments(props []PropertyConfig) string { + var sb strings.Builder + for _, prop := range props { + varName := strings.ToLower(prop.Name[:1]) + prop.Name[1:] + fieldName := capitalize(prop.Name) + sb.WriteString(fmt.Sprintf("\t\t\t%s: %s,\n", fieldName, varName)) + } + return sb.String() +} + +// buildPropertyWriteStatements generates property write statements for entity properties. +func buildPropertyWriteStatements(props []PropertyConfig) string { + var sb strings.Builder + for _, prop := range props { + propName := capitalize(prop.Name) + switch prop.Type { + case "string": + sb.WriteString(fmt.Sprintf(` if attrs.%s != nil { + props = append(props, schema.ContextProperty{ + ContextID: entityID, + Name: "%s", + StringValue: attrs.%s, + }) + } +`, propName, prop.Name, propName)) + case "integer", "int": + sb.WriteString(fmt.Sprintf(` if attrs.%s != nil { + props = append(props, schema.ContextProperty{ + ContextID: entityID, + Name: "%s", + IntValue: attrs.%s, + }) + } +`, propName, prop.Name, propName)) + case "boolean", "bool": + sb.WriteString(fmt.Sprintf(` if attrs.%s != nil { + boolVal := int64(0) + if *attrs.%s { + boolVal = 1 + } + props = append(props, schema.ContextProperty{ + ContextID: entityID, + Name: "%s", + IntValue: &boolVal, + }) + } +`, propName, propName, prop.Name)) + } + } + return sb.String() +} + +// buildArtifactPropertyWriteStatements generates property write statements for artifact properties. +func buildArtifactPropertyWriteStatements(properties []PropertyConfig) string { + var sb strings.Builder + for _, prop := range properties { + propName := capitalize(prop.Name) + switch prop.Type { + case "string": + sb.WriteString(fmt.Sprintf(` if attr.%s != nil { + properties = append(properties, schema.ArtifactProperty{ + ArtifactID: artifactID, + Name: "%s", + StringValue: attr.%s, + }) + } +`, propName, prop.Name, propName)) + case "integer", "int": + sb.WriteString(fmt.Sprintf(` if attr.%s != nil { + properties = append(properties, schema.ArtifactProperty{ + ArtifactID: artifactID, + Name: "%s", + IntValue: attr.%s, + }) + } +`, propName, prop.Name, propName)) + case "boolean", "bool": + sb.WriteString(fmt.Sprintf(` if attr.%s != nil { + boolVal := int64(0) + if *attr.%s { + boolVal = 1 + } + properties = append(properties, schema.ArtifactProperty{ + ArtifactID: artifactID, + Name: "%s", + IntValue: &boolVal, + }) + } +`, propName, propName, prop.Name)) + } + } + return sb.String() +} + +// buildOpenAPIPropertyConversions generates OpenAPI property conversion code. +func buildOpenAPIPropertyConversions(props []PropertyConfig) string { + var sb strings.Builder + for _, prop := range props { + fieldName := capitalize(prop.Name) + sb.WriteString(fmt.Sprintf(` if attrs.%s != nil { + result.%s = *attrs.%s + } +`, fieldName, fieldName, fieldName)) + } + return sb.String() +} + +// generateOpenAPIPropertyDef generates an OpenAPI property definition. +func generateOpenAPIPropertyDef(prop PropertyConfig, indent int) string { + spaces := strings.Repeat(" ", indent) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("%s%s:\n%s type: %s\n", spaces, prop.Name, spaces, openAPIType(prop.Type))) + if prop.Type == "array" && prop.Items != nil { + sb.WriteString(fmt.Sprintf("%s items:\n%s type: %s\n", spaces, spaces, openAPIType(prop.Items.Type))) + } + return sb.String() +} diff --git a/cmd/catalog-gen/init.go b/cmd/catalog-gen/init.go new file mode 100644 index 0000000000..50c0589c23 --- /dev/null +++ b/cmd/catalog-gen/init.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func newInitCmd() *cobra.Command { + var ( + entityName string + packageName string + outputDir string + ) + + cmd := &cobra.Command{ + Use: "init ", + Short: "Initialize a new catalog plugin", + Long: `Initialize a new catalog plugin for the unified catalog server. + +This creates the basic directory structure and configuration file +for a new catalog plugin. + +Example: + catalog-gen init mcp-catalog --entity=MCPServer --package=github.com/kubeflow/model-registry/catalog/plugins/mcp`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + + if entityName == "" { + return fmt.Errorf("--entity is required") + } + if packageName == "" { + return fmt.Errorf("--package is required") + } + + if outputDir == "" { + outputDir = name + } + + return initCatalogPlugin(name, entityName, packageName, outputDir) + }, + } + + cmd.Flags().StringVar(&entityName, "entity", "", "Name of the main entity (e.g., MCPServer)") + cmd.Flags().StringVar(&packageName, "package", "", "Go package path (e.g., github.com/kubeflow/model-registry/catalog/plugins/mcp)") + cmd.Flags().StringVar(&outputDir, "output", "", "Output directory (defaults to catalog name)") + + return cmd +} diff --git a/cmd/catalog-gen/main.go b/cmd/catalog-gen/main.go new file mode 100644 index 0000000000..d5ac2598a7 --- /dev/null +++ b/cmd/catalog-gen/main.go @@ -0,0 +1,46 @@ +// catalog-gen is a scaffolding tool for creating new catalog components. +// It generates the boilerplate code needed for a new catalog service, +// similar to how kubebuilder scaffolds Kubernetes controllers. +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var version = "dev" + +func main() { + rootCmd := &cobra.Command{ + Use: "catalog-gen", + Short: "Scaffolding tool for creating catalog components", + Long: `catalog-gen is a scaffolding tool for creating new catalog components +in the Model Registry project. It generates the boilerplate code needed +for a new catalog service, including: + +- Entity models and repositories +- REST API handlers (OpenAPI-based) +- Data providers (YAML, HTTP) +- Kustomize manifests for deployment + +Usage: + catalog-gen init --entity= --package= + catalog-gen add-provider + catalog-gen add-artifact + catalog-gen generate`, + Version: version, + } + + rootCmd.AddCommand(newInitCmd()) + rootCmd.AddCommand(newAddProviderCmd()) + rootCmd.AddCommand(newAddArtifactCmd()) + rootCmd.AddCommand(newGenerateCmd()) + rootCmd.AddCommand(newGenTestdataCmd()) + + if err := rootCmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} diff --git a/cmd/catalog-gen/templates.go b/cmd/catalog-gen/templates.go new file mode 100644 index 0000000000..257e182a13 --- /dev/null +++ b/cmd/catalog-gen/templates.go @@ -0,0 +1,86 @@ +package main + +import ( + "embed" + "fmt" + "os" + "text/template" +) + +//go:embed templates/* +var templateFS embed.FS + +// Template path constants +const ( + // models templates + TmplModelsEntity = "templates/models/entity.gotmpl" + TmplModelsArtifact = "templates/models/artifact.gotmpl" + TmplModelsBase = "templates/models/base.gotmpl" + + // service templates + TmplServiceRepository = "templates/service/repository.gotmpl" + TmplServiceArtifactRepository = "templates/service/artifact_repository.gotmpl" + TmplServiceSpec = "templates/service/spec.gotmpl" + TmplServiceFilterMappings = "templates/service/filter_mappings.gotmpl" + + // server templates + TmplServerOpenAPIServiceImpl = "templates/server/openapi_service_impl.gotmpl" + + // catalog templates + TmplCatalogLoader = "templates/catalog/loader.gotmpl" + + // providers templates + TmplProvidersYAML = "templates/providers/yaml.gotmpl" + TmplProvidersHTTP = "templates/providers/http.gotmpl" + + // api templates + TmplAPIOpenAPIMain = "templates/api/openapi_main.gotmpl" + TmplAPIOpenAPIComponents = "templates/api/openapi_components.gotmpl" + + // misc templates + TmplMiscGitignore = "templates/misc/gitignore.gotmpl" + TmplMiscOpenAPIGeneratorIgnore = "templates/misc/openapi_generator_ignore.gotmpl" + + // plugin templates + TmplPluginPlugin = "templates/plugin/plugin.gotmpl" + TmplPluginRegister = "templates/plugin/register.gotmpl" + + // agent templates + TmplAgentClaudeMD = "templates/agent/claude_md.gotmpl" + TmplAgentCmdAddProperty = "templates/agent/commands/add_property.gotmpl" + TmplAgentCmdAddArtifact = "templates/agent/commands/add_artifact.gotmpl" + TmplAgentCmdAddArtifactProp = "templates/agent/commands/add_artifact_property.gotmpl" + TmplAgentCmdRegenerate = "templates/agent/commands/regenerate.gotmpl" + TmplAgentCmdFixBuild = "templates/agent/commands/fix_build.gotmpl" + TmplAgentSkillAddProperty = "templates/agent/skills/add_property.gotmpl" + TmplAgentSkillAddArtifact = "templates/agent/skills/add_artifact.gotmpl" + TmplAgentSkillAddArtifactProp = "templates/agent/skills/add_artifact_property.gotmpl" + TmplAgentSkillRegenerate = "templates/agent/skills/regenerate.gotmpl" + TmplAgentCmdGenTestdata = "templates/agent/commands/gen_testdata.gotmpl" + TmplAgentSkillGenTestdata = "templates/agent/skills/gen_testdata.gotmpl" +) + +// executeTemplate reads a template from the embedded filesystem and executes it to a file. +func executeTemplate(templatePath, outputPath string, data any) error { + content, err := templateFS.ReadFile(templatePath) + if err != nil { + return fmt.Errorf("failed to read template %s: %w", templatePath, err) + } + + tmpl, err := template.New(templatePath).Parse(string(content)) + if err != nil { + return fmt.Errorf("failed to parse template %s: %w", templatePath, err) + } + + file, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", outputPath, err) + } + defer func() { _ = file.Close() }() + + if err := tmpl.Execute(file, data); err != nil { + return fmt.Errorf("failed to execute template %s: %w", templatePath, err) + } + + return nil +} diff --git a/cmd/catalog-gen/templates/agent/claude_md.gotmpl b/cmd/catalog-gen/templates/agent/claude_md.gotmpl new file mode 100644 index 0000000000..8da65c0ec3 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/claude_md.gotmpl @@ -0,0 +1,52 @@ +# {{.EntityName}} Catalog Plugin + +This is a catalog plugin for the unified catalog server. + +## Structure + +- `catalog.yaml` - Plugin configuration (entity, properties, artifacts) +- `plugin.go` - Plugin implementation +- `internal/db/models/` - Entity and artifact models +- `internal/db/service/` - Repository implementations +- `internal/server/openapi/` - API handlers +- `internal/catalog/providers/` - Data providers + +## Development Commands + +Use these slash commands for common tasks: + +- `/add-property ` - Add a property to {{.EntityName}} +- `/add-artifact ` - Add a new artifact type +- `/add-artifact-property ` - Add property to artifact +- `/regenerate` - Regenerate code from catalog.yaml +- `/gen-testdata` - Generate sample test data in testdata/ +- `/fix-build` - Fix compilation errors + +## Workflow + +1. Edit `catalog.yaml` to define schema +2. Run `/regenerate` or `catalog-gen generate` +3. Run `make gen/openapi-server` to generate OpenAPI handlers +4. Implement business logic in `internal/server/openapi/api_*_service_impl.go` + +## Property Types + +| YAML Type | Go Type | OpenAPI Type | +|-----------|---------|--------------| +| string | *string | string | +| integer | *int32 | integer | +| int64 | *int64 | integer (format: int64) | +| boolean | *bool | boolean | +| number | *float64 | number | +| array | []string | array | + +## Filtering + +All list endpoints support `filterQuery` parameter with SQL-like syntax: +- `?filterQuery=protocol='http'` +- `?filterQuery=toolCount>5 AND protocol='stdio'` +- `?filterQuery=name LIKE '%server%'` +- Supported operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IN, AND, OR + +Results can be ordered with `orderBy` and `sortOrder` parameters: +- `?orderBy=name&sortOrder=DESC` diff --git a/cmd/catalog-gen/templates/agent/commands/add_artifact.gotmpl b/cmd/catalog-gen/templates/agent/commands/add_artifact.gotmpl new file mode 100644 index 0000000000..5a6b764baa --- /dev/null +++ b/cmd/catalog-gen/templates/agent/commands/add_artifact.gotmpl @@ -0,0 +1,7 @@ +Add a new artifact type to {{.EntityName}}. + +Usage: /add-artifact + +Example: /add-artifact Weights + +Follow the instructions in .claude/skills/add-artifact.md diff --git a/cmd/catalog-gen/templates/agent/commands/add_artifact_property.gotmpl b/cmd/catalog-gen/templates/agent/commands/add_artifact_property.gotmpl new file mode 100644 index 0000000000..4fd190fa42 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/commands/add_artifact_property.gotmpl @@ -0,0 +1,7 @@ +Add a property to an existing artifact. + +Usage: /add-artifact-property + +Example: /add-artifact-property Weights format string + +Follow the instructions in .claude/skills/add-artifact-property.md diff --git a/cmd/catalog-gen/templates/agent/commands/add_property.gotmpl b/cmd/catalog-gen/templates/agent/commands/add_property.gotmpl new file mode 100644 index 0000000000..4cb6ac4416 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/commands/add_property.gotmpl @@ -0,0 +1,9 @@ +Add a new property to {{.EntityName}}. + +Usage: /add-property + +Types: string, integer, int64, boolean, number, array + +Example: /add-property version string + +Follow the instructions in .claude/skills/add-property.md diff --git a/cmd/catalog-gen/templates/agent/commands/fix_build.gotmpl b/cmd/catalog-gen/templates/agent/commands/fix_build.gotmpl new file mode 100644 index 0000000000..35dd6e3223 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/commands/fix_build.gotmpl @@ -0,0 +1,10 @@ +Diagnose and fix compilation errors. + +Usage: /fix-build + +This will: +1. Run `go build ./...` and capture errors +2. Analyze error messages +3. Suggest or apply fixes + +Run `go build ./...` to see current errors, then analyze and fix them. diff --git a/cmd/catalog-gen/templates/agent/commands/gen_testdata.gotmpl b/cmd/catalog-gen/templates/agent/commands/gen_testdata.gotmpl new file mode 100644 index 0000000000..cd05945d83 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/commands/gen_testdata.gotmpl @@ -0,0 +1,10 @@ +Generate sample test data for {{.EntityName}}. + +Usage: /gen-testdata + +This will run `catalog-gen gen-testdata` to create: +- testdata/{{.EntityNameLower}}s.yaml (sample entity data) +- testdata/{{.EntityNameLower}}-sources.yaml (loader config) +- testdata/test-sources.yaml (catalog server config) + +Follow the instructions in .claude/skills/gen-testdata.md diff --git a/cmd/catalog-gen/templates/agent/commands/regenerate.gotmpl b/cmd/catalog-gen/templates/agent/commands/regenerate.gotmpl new file mode 100644 index 0000000000..709be72638 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/commands/regenerate.gotmpl @@ -0,0 +1,10 @@ +Regenerate all code from catalog.yaml. + +Usage: /regenerate + +This will: +1. Run `catalog-gen generate` +2. Run `make gen/openapi-server` +3. Report any issues + +Follow the instructions in .claude/skills/regenerate.md diff --git a/cmd/catalog-gen/templates/agent/skills/add_artifact.gotmpl b/cmd/catalog-gen/templates/agent/skills/add_artifact.gotmpl new file mode 100644 index 0000000000..7f029c4dfb --- /dev/null +++ b/cmd/catalog-gen/templates/agent/skills/add_artifact.gotmpl @@ -0,0 +1,51 @@ +# Add Artifact to {{.EntityName}} + +## Arguments +- `ArtifactName`: Name of the artifact (PascalCase, e.g., Weights, Metrics) + +## Steps + +### Step 1: Update catalog.yaml + +Add the artifact under `spec.artifacts`: + +```yaml +spec: + artifacts: + - name: + properties: + - name: uri + type: string + - name: format + type: string +``` + +### Step 2: Regenerate Code + +```bash +catalog-gen generate +make gen/openapi-server +``` + +### Step 3: Update OpenAPI Spec + +Edit `api/openapi/src/openapi.yaml` to add artifact endpoints: + +```yaml +paths: + /{{.EntityNameLower}}s/{name}/artifacts: + get: + summary: List artifacts for a {{.EntityName}} + ... +``` + +### Step 4: Implement Artifact Service + +Create or update artifact handling in service implementation. + +### Step 5: Verify + +```bash +go build ./... +go test ./... +``` diff --git a/cmd/catalog-gen/templates/agent/skills/add_artifact_property.gotmpl b/cmd/catalog-gen/templates/agent/skills/add_artifact_property.gotmpl new file mode 100644 index 0000000000..6239d9e00f --- /dev/null +++ b/cmd/catalog-gen/templates/agent/skills/add_artifact_property.gotmpl @@ -0,0 +1,39 @@ +# Add Property to Artifact + +## Arguments +- `ArtifactName`: Name of the artifact (PascalCase) +- `property-name`: Name of the property (camelCase) +- `type`: One of: string, integer, int64, boolean, number, array + +## Steps + +### Step 1: Update catalog.yaml + +Add the property under the artifact's properties: + +```yaml +spec: + artifacts: + - name: + properties: + # ... existing properties ... + - name: + type: +``` + +### Step 2: Regenerate Code + +```bash +catalog-gen generate +make gen/openapi-server +``` + +### Step 3: Update YAML Provider + +If using YAML source, update the artifact struct and parsing. + +### Step 4: Verify + +```bash +go build ./... +``` diff --git a/cmd/catalog-gen/templates/agent/skills/add_property.gotmpl b/cmd/catalog-gen/templates/agent/skills/add_property.gotmpl new file mode 100644 index 0000000000..7379cefeb3 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/skills/add_property.gotmpl @@ -0,0 +1,58 @@ +# Add Property to {{.EntityName}} + +## Arguments +- `property-name`: Name of the property (camelCase) +- `type`: One of: string, integer, int64, boolean, number, array + +## Steps + +### Step 1: Update catalog.yaml + +Add the property under `spec.entity.properties`: + +```yaml +spec: + entity: + name: {{.EntityName}} + properties: + # ... existing properties ... + - name: + type: + required: false # optional +``` + +### Step 2: Regenerate Code + +```bash +catalog-gen generate +make gen/openapi-server +``` + +The regeneration automatically updates filter mappings so the new property +is available in filterQuery expressions (e.g., `?filterQuery=newProp='value'`). + +### Step 3: Update YAML Provider (if using YAML source) + +Edit `internal/catalog/providers/yaml_provider.go`: + +1. Add field to `yaml{{.EntityName}}` struct +2. Add mapping in `parse{{.EntityName}}Catalog()` function + +### Step 4: Update Service Implementation + +Edit `internal/server/openapi/api_{{.EntityNameLower}}_service_impl.go`: + +Add the property mapping in `convertToOpenAPIModel()`: + +```go +if attrs.PropertyName != nil { + result.PropertyName = *attrs.PropertyName +} +``` + +### Step 5: Verify + +```bash +go build ./... +go test ./... +``` diff --git a/cmd/catalog-gen/templates/agent/skills/gen_testdata.gotmpl b/cmd/catalog-gen/templates/agent/skills/gen_testdata.gotmpl new file mode 100644 index 0000000000..2ddf856078 --- /dev/null +++ b/cmd/catalog-gen/templates/agent/skills/gen_testdata.gotmpl @@ -0,0 +1,41 @@ +# Generate Test Data for {{.EntityName}} + +## Usage + +Run `catalog-gen gen-testdata` from the plugin root directory. + +## What It Does + +Generates sample data files in `testdata/` based on the current `catalog.yaml`: + +1. **`testdata/{{.EntityNameLower}}s.yaml`** - Sample entities with all fields populated + (includes BaseResource fields: name, externalId, description, plus all custom properties) +2. **`testdata/{{.EntityNameLower}}-sources.yaml`** - Loader configuration pointing to the sample data +3. **`testdata/test-sources.yaml`** - Catalog server configuration for local testing + +## Steps + +### Step 1: Generate test data + +```bash +catalog-gen gen-testdata +``` + +### Step 2: Run the catalog server locally + +```bash +./catalog-server --sources=testdata/test-sources.yaml --listen=:8080 --db-dsn="" +``` + +### Step 3: Test the API + +```bash +curl -s http://localhost:8080{{.BasePath}}/{{.EntityNameLower}}s | jq +``` + +## Notes + +- Regenerate test data after adding properties: `catalog-gen gen-testdata` +- The generated YAML includes all BaseResource fields (name, externalId, description) + plus all custom properties defined in `catalog.yaml` +- Sample data contains 3 entities with varied values for testing diff --git a/cmd/catalog-gen/templates/agent/skills/regenerate.gotmpl b/cmd/catalog-gen/templates/agent/skills/regenerate.gotmpl new file mode 100644 index 0000000000..bf885d6d9f --- /dev/null +++ b/cmd/catalog-gen/templates/agent/skills/regenerate.gotmpl @@ -0,0 +1,40 @@ +# Regenerate {{.EntityName}} Catalog + +## Binary locations + +Both `catalog-gen` and `openapi-generator-cli` are in the repo's `bin/` directory. +Resolve the repo root with `git rev-parse --show-toplevel` and use that to build paths. + +## Steps + +### Step 1: Regenerate from catalog.yaml + +```bash +"$(git rev-parse --show-toplevel)/bin/catalog-gen" generate +``` + +This regenerates: +- `internal/db/models/{{.EntityNameLower}}.go` +- `internal/db/service/spec.go` +- `api/openapi/src/generated/components.yaml` +- `internal/catalog/loader.go` +- `plugin.go`, `register.go` + +### Step 2: Generate OpenAPI Handlers + +```bash +make gen/openapi-server +``` + +### Step 3: Verify Build + +```bash +go build ./... +``` + +### Step 4: Fix Issues + +If build fails, check: +1. Service implementation matches new model fields +2. YAML provider matches new properties +3. OpenAPI service impl has correct type conversions diff --git a/cmd/catalog-gen/templates/api/openapi_components.gotmpl b/cmd/catalog-gen/templates/api/openapi_components.gotmpl new file mode 100644 index 0000000000..3cc5d82c21 --- /dev/null +++ b/cmd/catalog-gen/templates/api/openapi_components.gotmpl @@ -0,0 +1,65 @@ +# Code generated by catalog-gen. DO NOT EDIT. +# To regenerate: catalog-gen generate +# Source: catalog.yaml +# +# Entity schemas are composed using allOf with BaseResource. +components: + schemas:{{if not .IsPlugin}} + # Base schemas shared across all catalog entities + BaseResource: + type: object + properties: + id: + type: string + description: Unique identifier + name: + type: string + description: Resource name + externalId: + type: string + description: External identifier + description: + type: string + description: Resource description + customProperties: + type: object + additionalProperties: true + description: Custom properties + createTimeSinceEpoch: + type: string + description: Creation timestamp + lastUpdateTimeSinceEpoch: + type: string + description: Last update timestamp + + BaseResourceList: + type: object + properties: + nextPageToken: + type: string + description: Token for next page + size: + type: integer + description: Total size +{{end}} + {{.EntityName}}: + allOf:{{if .IsPlugin}} + - $ref: 'lib/common.yaml#/components/schemas/BaseResource'{{else}} + - $ref: '#/components/schemas/BaseResource'{{end}} + - type: object + required: + - name +{{.RequiredFields}}{{if gt (len .Properties) 0}} + properties: +{{.Properties}}{{end}} + {{.EntityName}}List: + allOf:{{if .IsPlugin}} + - $ref: 'lib/common.yaml#/components/schemas/BaseResourceList'{{else}} + - $ref: '#/components/schemas/BaseResourceList'{{end}} + - type: object + properties: + items: + type: array + items: + $ref: '#/components/schemas/{{.EntityName}}' +{{.ArtifactSchemas}} diff --git a/cmd/catalog-gen/templates/api/openapi_main.gotmpl b/cmd/catalog-gen/templates/api/openapi_main.gotmpl new file mode 100644 index 0000000000..d092394f9d --- /dev/null +++ b/cmd/catalog-gen/templates/api/openapi_main.gotmpl @@ -0,0 +1,77 @@ +openapi: 3.0.3 +info: + title: {{.Name}} Catalog API + version: v1alpha1 + description: API for the {{.Name}} catalog service + +servers: + - url: {{.BasePath}} + +paths: + /{{.EntityNameLower}}s: + get: + summary: List {{.EntityName}} entities + operationId: list{{.EntityName}}s + parameters: + - name: pageSize + in: query + schema: + type: integer + default: 20 + - name: pageToken + in: query + schema: + type: string + - name: q + in: query + schema: + type: string + description: Search query + - name: filterQuery + in: query + schema: + type: string + description: SQL-like filter expression (e.g., "protocol = 'http' AND toolCount > 5") + - name: orderBy + in: query + schema: + type: string + description: Field to order results by + - name: sortOrder + in: query + schema: + type: string + enum: [ASC, DESC] + default: ASC + description: Sort direction + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/{{.EntityName}}List' + + /{{.EntityNameLower}}s/{name}: + get: + summary: Get a {{.EntityName}} by name + operationId: get{{.EntityName}} + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/{{.EntityName}}' + '404': + description: Not found +{{.ArtifactRoutes}} +# You can add custom paths, parameters, and responses below. +# Schemas are defined in ./generated/components.yaml and merged via 'make api/openapi'. +# Use local $refs like '#/components/schemas/YourSchema' - they resolve after merge. diff --git a/cmd/catalog-gen/templates/catalog/loader.gotmpl b/cmd/catalog-gen/templates/catalog/loader.gotmpl new file mode 100644 index 0000000000..a3776b0041 --- /dev/null +++ b/cmd/catalog-gen/templates/catalog/loader.gotmpl @@ -0,0 +1,121 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package catalog + +import ( + "context" +{{if .HasArtifacts}} "fmt" +{{end}} + "github.com/golang/glog" + "github.com/kubeflow/model-registry/pkg/catalog" + sharedmodels "github.com/kubeflow/model-registry/internal/db/models" + "{{.Package}}/internal/db/models" + "{{.Package}}/internal/db/service" +) + +// glogLogger implements catalog.LoaderLogger using glog. +type glogLogger struct{} + +func (glogLogger) Infof(format string, args ...any) { glog.Infof(format, args...) } +func (glogLogger) Errorf(format string, args ...any) { glog.Errorf(format, args...) } +{{if .HasArtifacts}} +// Artifact is a union interface for all artifact types in this catalog. +type Artifact interface { + GetID() *int32 +} +{{end}} +// Loader wraps the generic catalog loader with {{.EntityName}}-specific types. +type Loader struct { + *catalog.Loader[models.{{.EntityName}}, {{.ArtifactType}}] + services service.Services +} + +// NewLoader creates a new catalog loader. +func NewLoader(services service.Services, paths []string, registry *catalog.ProviderRegistry[models.{{.EntityName}}, {{.ArtifactType}}]) *Loader { + cfg := catalog.LoaderConfig[models.{{.EntityName}}, {{.ArtifactType}}]{ + Paths: paths, + ProviderRegistry: registry, + Logger: glogLogger{}, + SaveEntity: func(entity models.{{.EntityName}}) (models.{{.EntityName}}, error) { + return services.{{.EntityName}}Repository.Save(entity) + }, + SaveArtifact: saveArtifact(services), + GetEntityID: func(entity models.{{.EntityName}}) *int32 { + return entity.GetID() + }, + GetEntityName: func(entity models.{{.EntityName}}) string { + if attrs := entity.GetAttributes(); attrs != nil && attrs.Name != nil { + return *attrs.Name + } + return "" + }, + DeleteArtifactsByEntity: deleteArtifactsByEntity(services), + DeleteEntitiesBySource: func(sourceID string) error { + return services.{{.EntityName}}Repository.DeleteBySource(sourceID) + }, + GetDistinctSourceIDs: func() ([]string, error) { + return services.{{.EntityName}}Repository.GetDistinctSourceIDs() + }, + SetEntitySourceID: func(entity models.{{.EntityName}}, sourceID string) { + setEntitySourceID(entity, sourceID) + }, + IsEntityNil: func(entity models.{{.EntityName}}) bool { + return entity == nil + }, + } + + return &Loader{ + Loader: catalog.NewLoader(cfg), + services: services, + } +} + +// saveArtifact returns a function that saves an artifact to the appropriate repository. +func saveArtifact(services service.Services) func(artifact {{.ArtifactType}}, entityID int32) error { + return func(artifact {{.ArtifactType}}, entityID int32) error { +{{if .HasArtifacts}} switch a := artifact.(type) { +{{.ArtifactSaveCases}} default: + return fmt.Errorf("unknown artifact type: %T", artifact) + } +{{else}} // No artifacts configured + return nil +{{end}} } +} + +// deleteArtifactsByEntity returns a function that deletes all artifacts for an entity. +func deleteArtifactsByEntity(services service.Services) func(entityID int32) error { + return func(entityID int32) error { +{{.ArtifactDeleteCalls}} return nil + } +} + +// Start begins loading catalog data. +func (l *Loader) Start(ctx context.Context) error { + return l.Loader.Start(ctx) +} + +// setEntitySourceID sets the source_id as a property on the entity. +// This follows the MLMD pattern where source_id is stored as a ContextProperty. +func setEntitySourceID(entity models.{{.EntityName}}, sourceID string) { + props := entity.GetProperties() + if props == nil { + newProps := []sharedmodels.Properties{} + props = &newProps + } + + // Check if source_id already exists + for i := range *props { + if (*props)[i].Name == "source_id" { + (*props)[i].StringValue = &sourceID + return + } + } + + // Add new source_id property + *props = append(*props, sharedmodels.Properties{ + Name: "source_id", + StringValue: &sourceID, + }) +} diff --git a/cmd/catalog-gen/templates/misc/gitignore.gotmpl b/cmd/catalog-gen/templates/misc/gitignore.gotmpl new file mode 100644 index 0000000000..0ff8f1d4d7 --- /dev/null +++ b/cmd/catalog-gen/templates/misc/gitignore.gotmpl @@ -0,0 +1,10 @@ +# Build artifacts +bin/ +coverage.txt +coverage.html + +# IDE +.idea/ +.vscode/ +*.swp +*.swo diff --git a/cmd/catalog-gen/templates/misc/openapi_generator_ignore.gotmpl b/cmd/catalog-gen/templates/misc/openapi_generator_ignore.gotmpl new file mode 100644 index 0000000000..d647cf082b --- /dev/null +++ b/cmd/catalog-gen/templates/misc/openapi_generator_ignore.gotmpl @@ -0,0 +1,34 @@ +# OpenAPI Generator Ignore +# Use this file to prevent files from being overwritten by the generator. + +# Service implementation (hand-written, created by catalog-gen) +internal/server/openapi/api_*_service_impl.go + +# Test files (hand-written) +internal/server/openapi/*_test.go + +# Client files to ignore +pkg/openapi/api +pkg/openapi/api/** +pkg/openapi/git_push.sh +pkg/openapi/.gitignore +pkg/openapi/.travis.yml +pkg/openapi/.openapi-generator-ignore +pkg/openapi/README.md +pkg/openapi/docs +pkg/openapi/docs/** +pkg/openapi/test +pkg/openapi/test/** +pkg/openapi/**all_of.go +pkg/openapi/go.mod +pkg/openapi/go.sum + +# Server files to ignore +internal/server/openapi/api +internal/server/openapi/api/** +internal/server/openapi/.openapi-generator-ignore +internal/server/openapi/README.md +internal/server/openapi/main.go +internal/server/openapi/go.mod +internal/server/openapi/go.sum +internal/server/openapi/Dockerfile diff --git a/cmd/catalog-gen/templates/models/artifact.gotmpl b/cmd/catalog-gen/templates/models/artifact.gotmpl new file mode 100644 index 0000000000..c1a355b85a --- /dev/null +++ b/cmd/catalog-gen/templates/models/artifact.gotmpl @@ -0,0 +1,50 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package models + +import ( + "github.com/kubeflow/model-registry/internal/db/models" +) + +const {{.EntityName}}{{.ArtifactName}}ArtifactType = "{{.LowerArtifactName}}-artifact" + +// {{.EntityName}}{{.ArtifactName}}ArtifactAttributes contains the attributes for a {{.ArtifactName}} artifact. +type {{.EntityName}}{{.ArtifactName}}ArtifactAttributes struct { + Name *string + ExternalID *string + CreateTimeSinceEpoch *int64 + LastUpdateTimeSinceEpoch *int64 +{{.Properties}}} + +// {{.EntityName}}{{.ArtifactName}}Artifact is the interface for {{.ArtifactName}} artifacts. +type {{.EntityName}}{{.ArtifactName}}Artifact interface { + models.Entity[{{.EntityName}}{{.ArtifactName}}ArtifactAttributes] +} + +// {{.EntityName}}{{.ArtifactName}}ArtifactImpl is the concrete implementation. +type {{.EntityName}}{{.ArtifactName}}ArtifactImpl = models.BaseEntity[{{.EntityName}}{{.ArtifactName}}ArtifactAttributes] + +// New{{.EntityName}}{{.ArtifactName}}Artifact creates a new {{.ArtifactName}} artifact. +func New{{.EntityName}}{{.ArtifactName}}Artifact(attrs *{{.EntityName}}{{.ArtifactName}}ArtifactAttributes) {{.EntityName}}{{.ArtifactName}}Artifact { + return &{{.EntityName}}{{.ArtifactName}}ArtifactImpl{ + Attributes: attrs, + } +} + +// {{.EntityName}}{{.ArtifactName}}ArtifactListOptions contains options for listing {{.ArtifactName}} artifacts. +type {{.EntityName}}{{.ArtifactName}}ArtifactListOptions struct { + models.Pagination + Name *string + ExternalID *string + ParentResourceID *int32 +} + +// {{.EntityName}}{{.ArtifactName}}ArtifactRepository is the interface for {{.ArtifactName}} artifact data access. +type {{.EntityName}}{{.ArtifactName}}ArtifactRepository interface { + GetByID(id int32) ({{.EntityName}}{{.ArtifactName}}Artifact, error) + List(options {{.EntityName}}{{.ArtifactName}}ArtifactListOptions) (*models.ListWrapper[{{.EntityName}}{{.ArtifactName}}Artifact], error) + Save(artifact {{.EntityName}}{{.ArtifactName}}Artifact, parentResourceID *int32) ({{.EntityName}}{{.ArtifactName}}Artifact, error) + DeleteByParentID(parentID int32) error +} diff --git a/cmd/catalog-gen/templates/models/base.gotmpl b/cmd/catalog-gen/templates/models/base.gotmpl new file mode 100644 index 0000000000..3c5961c05e --- /dev/null +++ b/cmd/catalog-gen/templates/models/base.gotmpl @@ -0,0 +1,32 @@ +package models + +// Property represents a key-value property. +type Property struct { + Name string + ValueType string + Value any +} + +// Pagination contains pagination parameters for list operations. +type Pagination struct { + PageSize int32 + PageToken string + OrderBy string + SortOrder string +} + +// ListWrapper wraps a list of items with pagination info. +type ListWrapper[T any] struct { + Items []T + NextPageToken string + Size int32 +} + +// NewListWrapper creates a new ListWrapper. +func NewListWrapper[T any](items []T, nextPageToken string) *ListWrapper[T] { + return &ListWrapper[T]{ + Items: items, + NextPageToken: nextPageToken, + Size: int32(len(items)), + } +} diff --git a/cmd/catalog-gen/templates/models/entity.gotmpl b/cmd/catalog-gen/templates/models/entity.gotmpl new file mode 100644 index 0000000000..2f86a9143e --- /dev/null +++ b/cmd/catalog-gen/templates/models/entity.gotmpl @@ -0,0 +1,63 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package models + +import ( + "github.com/kubeflow/model-registry/internal/db/filter" + "github.com/kubeflow/model-registry/internal/db/models" +) + +// RestEntity{{.EntityName}} is the filter.RestEntityType constant for {{.EntityName}} entities. +const RestEntity{{.EntityName}} filter.RestEntityType = "{{.EntityName}}" + +// {{.EntityName}}Attributes contains the attributes for a {{.EntityName}} entity. +type {{.EntityName}}Attributes struct { + Name *string + ExternalID *string + CreateTimeSinceEpoch *int64 + LastUpdateTimeSinceEpoch *int64 +{{.Properties}}} + +// {{.EntityName}} is the interface for {{.EntityName}} entities. +// It extends the shared Entity interface with {{.EntityName}}-specific attributes. +type {{.EntityName}} interface { + models.Entity[{{.EntityName}}Attributes] +} + +// {{.EntityName}}Impl is the concrete implementation of the {{.EntityName}} interface. +// It uses the shared BaseEntity implementation. +type {{.EntityName}}Impl = models.BaseEntity[{{.EntityName}}Attributes] + +// New{{.EntityName}} creates a new {{.EntityName}} entity. +func New{{.EntityName}}(attrs *{{.EntityName}}Attributes) {{.EntityName}} { + return &{{.EntityName}}Impl{ + Attributes: attrs, + } +} + +// {{.EntityName}}ListOptions contains options for listing {{.EntityName}} entities. +type {{.EntityName}}ListOptions struct { + models.Pagination + Name *string + ExternalID *string + SourceIDs *[]string + Query *string +} + +// GetRestEntityType implements the FilterApplier interface for advanced filtering. +func (o *{{.EntityName}}ListOptions) GetRestEntityType() filter.RestEntityType { + return RestEntity{{.EntityName}} +} + +// {{.EntityName}}Repository is the interface for {{.EntityName}} data access. +type {{.EntityName}}Repository interface { + GetByID(id int32) ({{.EntityName}}, error) + GetByName(name string) ({{.EntityName}}, error) + List(options {{.EntityName}}ListOptions) (*models.ListWrapper[{{.EntityName}}], error) + Save(entity {{.EntityName}}) ({{.EntityName}}, error) + DeleteBySource(sourceID string) error + DeleteByID(id int32) error + GetDistinctSourceIDs() ([]string, error) +} diff --git a/cmd/catalog-gen/templates/plugin/hooks.gotmpl b/cmd/catalog-gen/templates/plugin/hooks.gotmpl new file mode 100644 index 0000000000..a8b5da2118 --- /dev/null +++ b/cmd/catalog-gen/templates/plugin/hooks.gotmpl @@ -0,0 +1,32 @@ +// Plugin hooks - created once by catalog-gen, you can modify this file. +// These hooks are called by the generated plugin.go at specific lifecycle points. +// Use them to register event handlers, custom loaders, or any plugin-specific initialization. + +package {{.PackageName}} + +import ( + "github.com/kubeflow/model-registry/pkg/catalog/plugin" +) + +// initHooks is called at the end of Init(), after the loader and services are created. +// Use this to register event handlers, set up custom loaders, or read plugin config. +// +// Available fields on the plugin struct: +// - p.loader — the catalog loader (call p.loader.RegisterEventHandler(...) to add hooks) +// - p.services — the service layer with repositories +// - p.cfg — the plugin config (p.cfg.Section.Config for plugin-specific settings) +// - p.logger — structured logger +func (p *{{.EntityName}}CatalogPlugin) initHooks(cfg plugin.Config) error { + // Example: register an event handler + // p.loader.RegisterEventHandler(func(ctx context.Context, record catalog.Record[models.{{.EntityName}}, any]) error { + // p.logger.Info("entity loaded", "name", record.Entity.GetAttributes().Name) + // return nil + // }) + + // Example: read plugin-specific config from sources.yaml + // if val, ok := cfg.Section.Config["myCustomSetting"]; ok { + // p.logger.Info("custom setting", "value", val) + // } + + return nil +} diff --git a/cmd/catalog-gen/templates/plugin/plugin.gotmpl b/cmd/catalog-gen/templates/plugin/plugin.gotmpl new file mode 100644 index 0000000000..587c9193e0 --- /dev/null +++ b/cmd/catalog-gen/templates/plugin/plugin.gotmpl @@ -0,0 +1,246 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml +// +// Package {{.PackageName}} provides the {{.EntityName}} catalog plugin for the unified catalog server. +package {{.PackageName}} + +import ( + "context" + "flag" + "fmt" + "log/slog" + "path/filepath" + "reflect" + "sync/atomic" + + "github.com/go-chi/chi/v5" + "gorm.io/gorm" + + "{{.Package}}/internal/catalog" + "{{.Package}}/internal/catalog/providers" + "{{.Package}}/internal/db/models" + "{{.Package}}/internal/db/service" + "{{.Package}}/internal/server/openapi" + "github.com/kubeflow/model-registry/internal/datastore" + "github.com/kubeflow/model-registry/internal/datastore/embedmd" + pkgcatalog "github.com/kubeflow/model-registry/pkg/catalog" + "github.com/kubeflow/model-registry/pkg/catalog/plugin" +) + +const ( + // PluginName is the identifier for this plugin. + PluginName = "{{.Name}}" + + // PluginVersion is the API version. + PluginVersion = "v1alpha1" +) + +// {{.EntityName}}CatalogPlugin implements the CatalogPlugin interface for {{.EntityName}} catalogs. +type {{.EntityName}}CatalogPlugin struct { + cfg plugin.Config + logger *slog.Logger + loaders []plugin.CatalogLoader + services service.Services + healthy atomic.Bool + started atomic.Bool +} + +// Name returns the plugin name. +func (p *{{.EntityName}}CatalogPlugin) Name() string { + return PluginName +} + +// Version returns the plugin API version. +func (p *{{.EntityName}}CatalogPlugin) Version() string { + return PluginVersion +} + +// Description returns a human-readable description. +func (p *{{.EntityName}}CatalogPlugin) Description() string { + return "{{.EntityName}} catalog" +} + +// BasePath returns the API base path for this plugin. +func (p *{{.EntityName}}CatalogPlugin) BasePath() string { + return "{{.BasePath}}" +} + +// RegisterFlags registers custom CLI flags for this plugin. +// Add your plugin-specific flags here. They will be available after flag.Parse(). +func (p *{{.EntityName}}CatalogPlugin) RegisterFlags(fs *flag.FlagSet) { + // Register plugin-specific CLI flags here. Example: + // fs.StringVar(&p.myFlag, "{{.Name}}-my-flag", "default", "Description") +} + +// Init initializes the plugin with configuration. +func (p *{{.EntityName}}CatalogPlugin) Init(ctx context.Context, cfg plugin.Config) error { + p.cfg = cfg + p.logger = cfg.Logger + if p.logger == nil { + p.logger = slog.Default() + } + + p.logger.Info("initializing {{.Name}} plugin") + + // Build config paths from source properties or origins + paths := make([]string, 0) + pathsSet := make(map[string]bool) + for _, src := range cfg.Section.Sources { + // Check for loaderConfigPath property first (allows separate loader config) + if loaderPath, ok := src.Properties["loaderConfigPath"].(string); ok && loaderPath != "" { + // Resolve relative to source origin directory + if !filepath.IsAbs(loaderPath) && src.Origin != "" { + loaderPath = filepath.Join(filepath.Dir(src.Origin), loaderPath) + } + if !pathsSet[loaderPath] { + paths = append(paths, loaderPath) + pathsSet[loaderPath] = true + } + } else if src.Origin != "" && !pathsSet[src.Origin] { + paths = append(paths, src.Origin) + pathsSet[src.Origin] = true + } + } + if len(paths) == 0 { + paths = cfg.ConfigPaths + } + + // Convert to absolute paths + absPaths := make([]string, 0, len(paths)) + for _, path := range paths { + if absPath, err := filepath.Abs(path); err == nil { + absPaths = append(absPaths, absPath) + } else { + absPaths = append(absPaths, path) + } + } + + // Initialize services from the database connection + services, err := p.initServices(cfg.DB) + if err != nil { + return fmt.Errorf("failed to initialize services: %w", err) + } + p.services = services + + // Set up provider registry + registry := pkgcatalog.NewProviderRegistry[models.{{.EntityName}}, {{.ArtifactType}}]() + if err := registry.Register("yaml", providers.New{{.EntityName}}YAMLProvider()); err != nil { + return fmt.Errorf("failed to register YAML provider: %w", err) + } + + // Create the core loader and add it to the loaders list. + // Additional custom loaders can be appended here. + p.loaders = append(p.loaders, catalog.NewLoader(services, absPaths, registry)) + + p.logger.Info("{{.Name}} plugin initialized", "paths", absPaths) + return nil +} + +// initServices creates the service layer from the database connection. +func (p *{{.EntityName}}CatalogPlugin) initServices(db *gorm.DB) (service.Services, error) { + spec := service.DatastoreSpec() + + connector, err := datastore.NewConnector("embedmd", &embedmd.EmbedMDConfig{DB: db, SkipMigrations: true}) + if err != nil { + return service.Services{}, fmt.Errorf("failed to create connector: %w", err) + } + + repoSet, err := connector.Connect(spec) + if err != nil { + return service.Services{}, fmt.Errorf("failed to connect: %w", err) + } + + repo, err := getRepository[models.{{.EntityName}}Repository](repoSet) + if err != nil { + return service.Services{}, fmt.Errorf("failed to get {{.EntityName}} repository: %w", err) + } + + return service.NewServices(repo), nil +} + +// Start begins background operations (hot-reload, watchers). +func (p *{{.EntityName}}CatalogPlugin) Start(ctx context.Context) error { + p.logger.Info("starting {{.Name}} plugin") + + for _, loader := range p.loaders { + if err := loader.Start(ctx); err != nil { + return fmt.Errorf("failed to start loader: %w", err) + } + } + + p.started.Store(true) + p.healthy.Store(true) + + p.logger.Info("{{.Name}} plugin started") + return nil +} + +// Stop gracefully shuts down the plugin. +func (p *{{.EntityName}}CatalogPlugin) Stop(ctx context.Context) error { + p.logger.Info("stopping {{.Name}} plugin") + + for i := len(p.loaders) - 1; i >= 0; i-- { + if err := p.loaders[i].Stop(ctx); err != nil { + p.logger.Error("loader stop failed", "error", err) + } + } + + p.started.Store(false) + p.healthy.Store(false) + return nil +} + +// Healthy returns true if the plugin is functioning correctly. +func (p *{{.EntityName}}CatalogPlugin) Healthy() bool { + return p.healthy.Load() +} + +// RegisterRoutes mounts the plugin's HTTP routes on the provided router. +func (p *{{.EntityName}}CatalogPlugin) RegisterRoutes(router chi.Router) error { + p.logger.Info("registering {{.Name}} routes") + + // Create the OpenAPI service + apiService := openapi.New{{.EntityName}}CatalogServiceAPIService(p.services) + apiController := openapi.NewDefaultAPIController(apiService) + + // Mount routes - remove the base path prefix since chi.Router already handles that + basePath := "{{.BasePath}}" + for _, route := range apiController.OrderedRoutes() { + pattern := route.Pattern + if len(pattern) > len(basePath) && pattern[:len(basePath)] == basePath { + pattern = pattern[len(basePath):] + } + if pattern == "" { + pattern = "/" + } + router.Method(route.Method, pattern, route.HandlerFunc) + p.logger.Debug("registered route", "method", route.Method, "pattern", pattern) + } + + return nil +} + +// Migrations returns database migrations for this plugin. +func (p *{{.EntityName}}CatalogPlugin) Migrations() []plugin.Migration { + // Migrations are handled by the datastore layer + return nil +} + +// getRepository extracts a repository of type T from the RepoSet. +func getRepository[T any](rs datastore.RepoSet) (T, error) { + var zero T + t := reflect.TypeFor[T]() + + repo, err := rs.Repository(t) + if err != nil { + return zero, err + } + + result, ok := repo.(T) + if !ok { + return zero, fmt.Errorf("repository type mismatch: expected %T, got %T", zero, repo) + } + + return result, nil +} diff --git a/cmd/catalog-gen/templates/plugin/register.gotmpl b/cmd/catalog-gen/templates/plugin/register.gotmpl new file mode 100644 index 0000000000..af16c45ba7 --- /dev/null +++ b/cmd/catalog-gen/templates/plugin/register.gotmpl @@ -0,0 +1,11 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package {{.PackageName}} + +import "github.com/kubeflow/model-registry/pkg/catalog/plugin" + +func init() { + plugin.Register(&{{.EntityName}}CatalogPlugin{}) +} diff --git a/cmd/catalog-gen/templates/providers/http.gotmpl b/cmd/catalog-gen/templates/providers/http.gotmpl new file mode 100644 index 0000000000..9c537032a7 --- /dev/null +++ b/cmd/catalog-gen/templates/providers/http.gotmpl @@ -0,0 +1,29 @@ +package providers + +import ( + "context" + "net/http" + + "github.com/kubeflow/model-registry/pkg/catalog" + httpprovider "github.com/kubeflow/model-registry/pkg/catalog/providers/http" +) + +// New{{.EntityName}}HTTPProvider creates an HTTP provider for {{.EntityName}} entities. +func New{{.EntityName}}HTTPProvider() catalog.ProviderFunc[*{{.EntityName}}, *{{.EntityName}}Artifact] { + config := httpprovider.Config[*{{.EntityName}}, *{{.EntityName}}Artifact]{ + BaseURLKey: "url", + DefaultBaseURL: "", // Set a default URL if applicable + FetchRecords: fetch{{.EntityName}}Records, + } + return httpprovider.NewProviderFunc(config) +} + +func fetch{{.EntityName}}Records(ctx context.Context, client *http.Client, baseURL string, source *catalog.Source) ([]catalog.Record[*{{.EntityName}}, *{{.EntityName}}Artifact], error) { + // TODO: Implement HTTP fetching logic + // 1. Build API URL from baseURL and source.Properties + // 2. Make HTTP request + // 3. Parse response + // 4. Convert to catalog.Record + + return nil, nil +} diff --git a/cmd/catalog-gen/templates/providers/yaml.gotmpl b/cmd/catalog-gen/templates/providers/yaml.gotmpl new file mode 100644 index 0000000000..451dcf650b --- /dev/null +++ b/cmd/catalog-gen/templates/providers/yaml.gotmpl @@ -0,0 +1,128 @@ +package providers + +import ( + "context" + "os" + "path/filepath" + + "fmt" + + "github.com/golang/glog" + k8syaml "k8s.io/apimachinery/pkg/util/yaml" +{{if .HasArtifacts}} + "{{.Package}}/internal/catalog" +{{end}} "{{.Package}}/internal/db/models" + sharedmodels "github.com/kubeflow/model-registry/internal/db/models" + pkgcatalog "github.com/kubeflow/model-registry/pkg/catalog" + yamlprovider "github.com/kubeflow/model-registry/pkg/catalog/providers/yaml" +) + +// yaml{{.EntityName}} represents a {{.EntityName}} entry in the YAML catalog file. +type yaml{{.EntityName}} struct { + Name string `json:"name" yaml:"name"` + ExternalId string `json:"externalId,omitempty" yaml:"externalId,omitempty"` + Description *string `json:"description,omitempty" yaml:"description,omitempty"` + CustomProperties map[string]any `json:"customProperties,omitempty" yaml:"customProperties,omitempty"` +{{.EntityPropertyFields}}} + +// yaml{{.EntityName}}Catalog is the structure of the YAML catalog file. +type yaml{{.EntityName}}Catalog struct { + {{.EntityName}}s []yaml{{.EntityName}} `json:"{{.EntityNameLower}}s" yaml:"{{.EntityNameLower}}s"` +} +{{.ArtifactStructs}} +// glogLogger implements yaml.Logger using glog. +type glogLogger struct{} + +func (glogLogger) Infof(format string, args ...any) { glog.Infof(format, args...) } +func (glogLogger) Errorf(format string, args ...any) { glog.Errorf(format, args...) } + +// New{{.EntityName}}YAMLProvider creates a new YAML provider for {{.EntityName}} entities. +// It uses the reusable yaml.NewProviderFunc which includes automatic hot-reload +// via file watching (polling every 5 seconds for file changes). +func New{{.EntityName}}YAMLProvider() pkgcatalog.ProviderFunc[models.{{.EntityName}}, {{.ArtifactType}}] { + return func(ctx context.Context, source *pkgcatalog.Source, reldir string) (<-chan pkgcatalog.Record[models.{{.EntityName}}, {{.ArtifactType}}], error) { + // Resolve artifacts path from source properties (captured in parse closure) + var artifactsPath string + if ap, ok := source.Properties["yamlArtifactsPath"].(string); ok && ap != "" { + if !filepath.IsAbs(ap) { + ap = filepath.Join(reldir, ap) + } + artifactsPath = ap + } + + config := yamlprovider.Config[models.{{.EntityName}}, {{.ArtifactType}}]{ + Parse: func(data []byte) ([]pkgcatalog.Record[models.{{.EntityName}}, {{.ArtifactType}}], error) { + var artifactsData []byte + if artifactsPath != "" { + var err error + artifactsData, err = os.ReadFile(artifactsPath) + if err != nil { + glog.Warningf("failed to read artifacts file %s: %v", artifactsPath, err) + } + } + return parse{{.EntityName}}Catalog(data, artifactsData) + }, + Logger: glogLogger{}, + } + + provider, err := yamlprovider.NewProvider(config, source, reldir) + if err != nil { + return nil, err + } + return provider.Records(ctx) + } +} + +// parse{{.EntityName}}Catalog parses the YAML catalog files into records. +func parse{{.EntityName}}Catalog(catalogData, artifactsData []byte) ([]pkgcatalog.Record[models.{{.EntityName}}, {{.ArtifactType}}], error) { + var entityCatalog yaml{{.EntityName}}Catalog + if err := k8syaml.UnmarshalStrict(catalogData, &entityCatalog); err != nil { + return nil, err + } +{{.ArtifactParseCode}} + records := make([]pkgcatalog.Record[models.{{.EntityName}}, {{.ArtifactType}}], 0, len(entityCatalog.{{.EntityName}}s)) + for _, item := range entityCatalog.{{.EntityName}}s { + name := item.Name + var externalID *string + if item.ExternalId != "" { + externalID = &item.ExternalId + } + entity := models.New{{.EntityName}}(&models.{{.EntityName}}Attributes{ + Name: &name, + ExternalID: externalID, + }) +{{.EntityPropertyAssignments}} + // Set base properties (description) + var props []sharedmodels.Properties + if item.Description != nil { + props = append(props, sharedmodels.NewStringProperty("description", *item.Description, false)) + } + if len(props) > 0 { + entity.(*models.{{.EntityName}}Impl).Properties = &props + } + + // Set custom properties + if len(item.CustomProperties) > 0 { + var customProps []sharedmodels.Properties + for k, v := range item.CustomProperties { + switch val := v.(type) { + case string: + customProps = append(customProps, sharedmodels.NewStringProperty(k, val, true)) + case float64: + customProps = append(customProps, sharedmodels.NewDoubleProperty(k, val, true)) + case bool: + customProps = append(customProps, sharedmodels.NewBoolProperty(k, val, true)) + default: + customProps = append(customProps, sharedmodels.NewStringProperty(k, fmt.Sprintf("%v", val), true)) + } + } + entity.(*models.{{.EntityName}}Impl).CustomProperties = &customProps + } + + record := pkgcatalog.Record[models.{{.EntityName}}, {{.ArtifactType}}]{Entity: entity} +{{.ArtifactMatchCode}} + records = append(records, record) + } + + return records, nil +} diff --git a/cmd/catalog-gen/templates/server/openapi_service_impl.gotmpl b/cmd/catalog-gen/templates/server/openapi_service_impl.gotmpl new file mode 100644 index 0000000000..bbbba5772f --- /dev/null +++ b/cmd/catalog-gen/templates/server/openapi_service_impl.gotmpl @@ -0,0 +1,128 @@ +package openapi + +import ( + "context" + "errors" + "fmt" + "net/http" + + "{{.Package}}/internal/db/models" + "{{.Package}}/internal/db/service" +) + +// {{.EntityName}}CatalogServiceAPIService implements the business logic for the {{.EntityName}} Catalog API. +type {{.EntityName}}CatalogServiceAPIService struct { + services service.Services +} + +// New{{.EntityName}}CatalogServiceAPIService creates a new service instance. +func New{{.EntityName}}CatalogServiceAPIService(services service.Services) *{{.EntityName}}CatalogServiceAPIService { + return &{{.EntityName}}CatalogServiceAPIService{ + services: services, + } +} + +// Ensure we implement the DefaultAPIServicer interface. +var _ DefaultAPIServicer = &{{.EntityName}}CatalogServiceAPIService{} + +// List{{.EntityName}}s implements DefaultAPIServicer.List{{.EntityName}}s +// This returns a paginated list of {{.EntityName}} entities. +func (s *{{.EntityName}}CatalogServiceAPIService) List{{.EntityName}}s(ctx context.Context, pageSize int32, pageToken string, q string, filterQuery string, orderBy string, sortOrder string) (ImplResponse, error) { + listOptions := models.{{.EntityName}}ListOptions{ + Query: &q, + } + listOptions.PageSize = &pageSize + listOptions.NextPageToken = &pageToken + listOptions.FilterQuery = &filterQuery + listOptions.OrderBy = &orderBy + listOptions.SortOrder = &sortOrder + + result, err := s.services.{{.EntityName}}Repository.List(listOptions) + if err != nil { + return Response(http.StatusInternalServerError, nil), err + } + + // Convert to OpenAPI model types + items := make([]{{.EntityName}}, len(result.Items)) + for i, item := range result.Items { + items[i] = convertToOpenAPIModel(item) + } + + response := {{.EntityName}}List{ + Items: items, + NextPageToken: result.NextPageToken, + Size: int32(len(items)), + } + + return Response(http.StatusOK, response), nil +} + +// Get{{.EntityName}} implements DefaultAPIServicer.Get{{.EntityName}} +// This returns a single {{.EntityName}} by name. +func (s *{{.EntityName}}CatalogServiceAPIService) Get{{.EntityName}}(ctx context.Context, name string) (ImplResponse, error) { + entity, err := s.services.{{.EntityName}}Repository.GetByName(name) + if err != nil { + if errors.Is(err, service.Err{{.EntityName}}NotFound) { + return Response(http.StatusNotFound, nil), err + } + return Response(http.StatusInternalServerError, nil), err + } + + return Response(http.StatusOK, convertToOpenAPIModel(entity)), nil +} + +// convertToOpenAPIModel converts a database entity to the OpenAPI model type. +func convertToOpenAPIModel(entity models.{{.EntityName}}) {{.EntityName}} { + attrs := entity.GetAttributes() + result := {{.EntityName}}{} + + if attrs.Name != nil { + result.Name = *attrs.Name + } + if attrs.ExternalID != nil { + result.ExternalId = *attrs.ExternalID + } + if attrs.CreateTimeSinceEpoch != nil { + result.CreateTimeSinceEpoch = fmt.Sprintf("%d", *attrs.CreateTimeSinceEpoch) + } + if attrs.LastUpdateTimeSinceEpoch != nil { + result.LastUpdateTimeSinceEpoch = fmt.Sprintf("%d", *attrs.LastUpdateTimeSinceEpoch) + } + if entity.GetID() != nil { + result.Id = fmt.Sprintf("%d", *entity.GetID()) + } + + // Extract base properties (description) + if entity.GetProperties() != nil { + for _, prop := range *entity.GetProperties() { + switch prop.Name { + case "description": + if prop.StringValue != nil { + result.Description = *prop.StringValue + } + } + } + } + + // Extract custom properties + if entity.GetCustomProperties() != nil { + customProps := make(map[string]MetadataValue) + for _, prop := range *entity.GetCustomProperties() { + switch { + case prop.StringValue != nil: + customProps[prop.Name] = MetadataValue{StringValue: *prop.StringValue, MetadataType: "MetadataStringValue"} + case prop.IntValue != nil: + customProps[prop.Name] = MetadataValue{IntValue: fmt.Sprintf("%d", *prop.IntValue), MetadataType: "MetadataIntValue"} + case prop.DoubleValue != nil: + customProps[prop.Name] = MetadataValue{DoubleValue: *prop.DoubleValue, MetadataType: "MetadataDoubleValue"} + case prop.BoolValue != nil: + customProps[prop.Name] = MetadataValue{BoolValue: *prop.BoolValue, MetadataType: "MetadataBoolValue"} + } + } + if len(customProps) > 0 { + result.CustomProperties = customProps + } + } +{{.PropConversions}} + return result +} diff --git a/cmd/catalog-gen/templates/service/artifact_repository.gotmpl b/cmd/catalog-gen/templates/service/artifact_repository.gotmpl new file mode 100644 index 0000000000..4a967d9d41 --- /dev/null +++ b/cmd/catalog-gen/templates/service/artifact_repository.gotmpl @@ -0,0 +1,222 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package service + +import ( + "errors" + "fmt" + + "gorm.io/gorm" + + "{{.Package}}/internal/db/models" + "github.com/kubeflow/model-registry/internal/db/schema" + sharedmodels "github.com/kubeflow/model-registry/internal/db/models" + "github.com/kubeflow/model-registry/internal/db/service" + "github.com/kubeflow/model-registry/internal/db/utils" + "github.com/kubeflow/model-registry/internal/apiutils" +) + +var Err{{.EntityName}}{{.ArtifactName}}ArtifactNotFound = errors.New("{{.LowerEntityName}} {{.LowerArtifactName}} artifact not found") + +// {{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl uses the shared GenericRepository for MLMD-based storage. +type {{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl struct { + *service.GenericRepository[models.{{.EntityName}}{{.ArtifactName}}Artifact, schema.Artifact, schema.ArtifactProperty, *models.{{.EntityName}}{{.ArtifactName}}ArtifactListOptions] +} + +// New{{.EntityName}}{{.ArtifactName}}ArtifactRepository creates a new {{.ArtifactName}} artifact repository. +func New{{.EntityName}}{{.ArtifactName}}ArtifactRepository(db *gorm.DB, typeID int32) models.{{.EntityName}}{{.ArtifactName}}ArtifactRepository { + r := &{{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl{} + + r.GenericRepository = service.NewGenericRepository(service.GenericRepositoryConfig[ + models.{{.EntityName}}{{.ArtifactName}}Artifact, + schema.Artifact, + schema.ArtifactProperty, + *models.{{.EntityName}}{{.ArtifactName}}ArtifactListOptions, + ]{ + DB: db, + TypeID: typeID, + EntityToSchema: map{{.EntityName}}{{.ArtifactName}}ArtifactToSchema, + SchemaToEntity: mapSchemaTo{{.EntityName}}{{.ArtifactName}}Artifact, + EntityToProperties: map{{.EntityName}}{{.ArtifactName}}ArtifactToProperties, + NotFoundError: Err{{.EntityName}}{{.ArtifactName}}ArtifactNotFound, + EntityName: "{{.LowerEntityName}} {{.LowerArtifactName}} artifact", + PropertyFieldName: "artifact_id", + ApplyListFilters: apply{{.EntityName}}{{.ArtifactName}}ArtifactListFilters, + IsNewEntity: func(entity models.{{.EntityName}}{{.ArtifactName}}Artifact) bool { return entity.GetID() == nil }, + HasCustomProperties: func(entity models.{{.EntityName}}{{.ArtifactName}}Artifact) bool { return entity.GetCustomProperties() != nil }, + PreserveHistoricalTimes: true, + }) + + return r +} + +func (r *{{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl) Save(artifact models.{{.EntityName}}{{.ArtifactName}}Artifact, parentResourceID *int32) (models.{{.EntityName}}{{.ArtifactName}}Artifact, error) { + config := r.GetConfig() + if artifact.GetTypeID() == nil { + if config.TypeID > 0 { + artifact.SetTypeID(config.TypeID) + } + } + + attr := artifact.GetAttributes() + if artifact.GetID() == nil && attr != nil && attr.Name != nil { + existing, err := r.lookupByName(*attr.Name) + if err != nil { + if !errors.Is(err, Err{{.EntityName}}{{.ArtifactName}}ArtifactNotFound) { + return nil, fmt.Errorf("error finding existing artifact named %s: %w", *attr.Name, err) + } + } else { + artifact.SetID(existing.ID) + } + } + + return r.GenericRepository.Save(artifact, parentResourceID) +} + +func (r *{{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl) List(listOptions models.{{.EntityName}}{{.ArtifactName}}ArtifactListOptions) (*sharedmodels.ListWrapper[models.{{.EntityName}}{{.ArtifactName}}Artifact], error) { + return r.GenericRepository.List(&listOptions) +} + +func (r *{{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl) DeleteByParentID(parentID int32) error { + config := r.GetConfig() + + // Find all artifact IDs linked to this parent via Attribution + var artifactIDs []int32 + err := config.DB.Model(&schema.Attribution{}). + Select("artifact_id"). + Where("context_id = ?", parentID). + Joins("JOIN "+utils.GetTableName(config.DB, &schema.Artifact{})+" ON "+utils.GetColumnRef(config.DB, &schema.Artifact{}, "id")+" = "+utils.GetColumnRef(config.DB, &schema.Attribution{}, "artifact_id")). + Where(utils.GetColumnRef(config.DB, &schema.Artifact{}, "type_id")+" = ?", config.TypeID). + Pluck("artifact_id", &artifactIDs).Error + if err != nil { + return fmt.Errorf("failed to find artifacts for parent %d: %w", parentID, err) + } + + if len(artifactIDs) == 0 { + return nil + } + + // Delete properties + if err := config.DB.Where("artifact_id IN ?", artifactIDs).Delete(&schema.ArtifactProperty{}).Error; err != nil { + return fmt.Errorf("failed to delete artifact properties: %w", err) + } + + // Delete attributions + if err := config.DB.Where("artifact_id IN ?", artifactIDs).Delete(&schema.Attribution{}).Error; err != nil { + return fmt.Errorf("failed to delete attributions: %w", err) + } + + // Delete artifacts + if err := config.DB.Where("id IN ?", artifactIDs).Delete(&schema.Artifact{}).Error; err != nil { + return fmt.Errorf("failed to delete artifacts: %w", err) + } + + return nil +} + +func (r *{{.EntityName}}{{.ArtifactName}}ArtifactRepositoryImpl) lookupByName(name string) (*schema.Artifact, error) { + var entity schema.Artifact + config := r.GetConfig() + + if err := config.DB.Where("name = ? AND type_id = ?", name, config.TypeID).First(&entity).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("%w: %v", config.NotFoundError, err) + } + return nil, fmt.Errorf("error getting %s by name: %w", config.EntityName, err) + } + + return &entity, nil +} + +func apply{{.EntityName}}{{.ArtifactName}}ArtifactListFilters(query *gorm.DB, listOptions *models.{{.EntityName}}{{.ArtifactName}}ArtifactListOptions) *gorm.DB { + if listOptions.Name != nil { + query = query.Where("name LIKE ?", fmt.Sprintf("%%:%s", *listOptions.Name)) + } else if listOptions.ExternalID != nil { + query = query.Where("external_id = ?", listOptions.ExternalID) + } + + if listOptions.ParentResourceID != nil { + query = query.Joins(utils.BuildAttributionJoin(query)). + Where(utils.GetColumnRef(query, &schema.Attribution{}, "context_id")+" = ?", listOptions.ParentResourceID) + } + + return query +} + +func map{{.EntityName}}{{.ArtifactName}}ArtifactToSchema(artifact models.{{.EntityName}}{{.ArtifactName}}Artifact) schema.Artifact { + if artifact == nil { + return schema.Artifact{} + } + + schemaArtifact := schema.Artifact{ + ID: apiutils.ZeroIfNil(artifact.GetID()), + TypeID: apiutils.ZeroIfNil(artifact.GetTypeID()), + } + + if artifact.GetAttributes() != nil { + schemaArtifact.Name = artifact.GetAttributes().Name + schemaArtifact.ExternalID = artifact.GetAttributes().ExternalID + schemaArtifact.CreateTimeSinceEpoch = apiutils.ZeroIfNil(artifact.GetAttributes().CreateTimeSinceEpoch) + schemaArtifact.LastUpdateTimeSinceEpoch = apiutils.ZeroIfNil(artifact.GetAttributes().LastUpdateTimeSinceEpoch) + } + + return schemaArtifact +} + +func map{{.EntityName}}{{.ArtifactName}}ArtifactToProperties(artifact models.{{.EntityName}}{{.ArtifactName}}Artifact, artifactID int32) []schema.ArtifactProperty { + if artifact == nil { + return []schema.ArtifactProperty{} + } + + properties := []schema.ArtifactProperty{} + + // Map custom properties from attributes + if artifact.GetAttributes() != nil { + attr := artifact.GetAttributes() +{{.PropWriteStatements}} + } + + if artifact.GetCustomProperties() != nil { + for _, prop := range *artifact.GetCustomProperties() { + properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, true)) + } + } + + return properties +} + +func mapSchemaTo{{.EntityName}}{{.ArtifactName}}Artifact(artifact schema.Artifact, artProperties []schema.ArtifactProperty) models.{{.EntityName}}{{.ArtifactName}}Artifact { +{{.PropVarDecls}} + + for _, p := range artProperties { + if !p.IsCustomProperty { + switch p.Name { +{{.PropReadCases}} + } + } + } + + result := models.{{.EntityName}}{{.ArtifactName}}ArtifactImpl{ + ID: &artifact.ID, + TypeID: &artifact.TypeID, + Attributes: &models.{{.EntityName}}{{.ArtifactName}}ArtifactAttributes{ + Name: artifact.Name, + ExternalID: artifact.ExternalID, + CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch, + LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch, +{{.PropAttrAssignments}} + }, + } + + customProperties := []sharedmodels.Properties{} + for _, prop := range artProperties { + if prop.IsCustomProperty { + customProperties = append(customProperties, service.MapArtifactPropertyToProperties(prop)) + } + } + result.CustomProperties = &customProperties + + return &result +} diff --git a/cmd/catalog-gen/templates/service/filter_mappings.gotmpl b/cmd/catalog-gen/templates/service/filter_mappings.gotmpl new file mode 100644 index 0000000000..115aa866e1 --- /dev/null +++ b/cmd/catalog-gen/templates/service/filter_mappings.gotmpl @@ -0,0 +1,62 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package service + +import ( + "github.com/kubeflow/model-registry/internal/db/filter" + + "{{.Package}}/internal/db/models" +) + +func init() { + // Register {{.EntityName}} entity properties in the global filter property map. + // This allows filterQuery validation and property resolution for {{.EntityName}} entities. + filter.RestEntityPropertyMap[models.RestEntity{{.EntityName}}] = map[string]bool{ + // Built-in entity table properties + "id": true, "name": true, "externalId": true, + "createTimeSinceEpoch": true, "lastUpdateTimeSinceEpoch": true, +{{.PropertyRegistrations}} } +} + +// entityMappings implements filter.EntityMappingFunctions for {{.EntityName}} entities. +type entityMappings struct{} + +// GetMLMDEntityType maps a REST entity type to its underlying MLMD entity type. +// {{.EntityName}} entities are stored as MLMD Contexts. +func (m *entityMappings) GetMLMDEntityType(t filter.RestEntityType) filter.EntityType { + return filter.EntityTypeContext +} + +// GetPropertyDefinitionForRestEntity returns the property definition for a REST entity property. +// It maps well-known properties to their storage locations and types. +func (m *entityMappings) GetPropertyDefinitionForRestEntity(t filter.RestEntityType, prop string) filter.PropertyDefinition { + switch prop { + // Entity table columns (Context table) + case "id": + return filter.PropertyDefinition{Location: filter.EntityTable, ValueType: "int_value", Column: "id"} + case "name": + return filter.PropertyDefinition{Location: filter.EntityTable, ValueType: "string_value", Column: "name"} + case "externalId": + return filter.PropertyDefinition{Location: filter.EntityTable, ValueType: "string_value", Column: "external_id"} + case "createTimeSinceEpoch": + return filter.PropertyDefinition{Location: filter.EntityTable, ValueType: "int_value", Column: "create_time_since_epoch"} + case "lastUpdateTimeSinceEpoch": + return filter.PropertyDefinition{Location: filter.EntityTable, ValueType: "int_value", Column: "last_update_time_since_epoch"} +{{.PropertyDefinitions}} + default: + // Unknown properties are treated as custom properties in the ContextProperty table + return filter.PropertyDefinition{ + Location: filter.Custom, + ValueType: "string_value", + Column: prop, + } + } +} + +// IsChildEntity returns true if the REST entity type uses prefixed names (parentId:name). +// Catalog entities are top-level and do not use prefixed names. +func (m *entityMappings) IsChildEntity(t filter.RestEntityType) bool { + return false +} diff --git a/cmd/catalog-gen/templates/service/repository.gotmpl b/cmd/catalog-gen/templates/service/repository.gotmpl new file mode 100644 index 0000000000..4a7d4d73a8 --- /dev/null +++ b/cmd/catalog-gen/templates/service/repository.gotmpl @@ -0,0 +1,215 @@ +package service + +import ( + "errors" + "fmt" + + "gorm.io/gorm" + + "{{.Package}}/internal/db/models" + "github.com/kubeflow/model-registry/internal/db/schema" + sharedmodels "github.com/kubeflow/model-registry/internal/db/models" + "github.com/kubeflow/model-registry/internal/db/service" +) + +var Err{{.EntityName}}NotFound = errors.New("{{.EntityNameLower}} not found") + +// {{.EntityName}}RepositoryImpl uses the shared GenericRepository for MLMD-based storage. +type {{.EntityName}}RepositoryImpl struct { + *service.GenericRepository[models.{{.EntityName}}, schema.Context, schema.ContextProperty, *models.{{.EntityName}}ListOptions] +} + +// New{{.EntityName}}Repository creates a new {{.EntityName}} repository. +func New{{.EntityName}}Repository(db *gorm.DB, typeID int32) models.{{.EntityName}}Repository { + r := &{{.EntityName}}RepositoryImpl{} + + r.GenericRepository = service.NewGenericRepository(service.GenericRepositoryConfig[ + models.{{.EntityName}}, + schema.Context, + schema.ContextProperty, + *models.{{.EntityName}}ListOptions, + ]{ + DB: db, + TypeID: typeID, + EntityToSchema: map{{.EntityName}}ToContext, + SchemaToEntity: mapContextTo{{.EntityName}}, + EntityToProperties: map{{.EntityName}}ToContextProperties, + NotFoundError: Err{{.EntityName}}NotFound, + EntityName: "{{.EntityNameLower}}", + PropertyFieldName: "context_id", + ApplyListFilters: apply{{.EntityName}}ListFilters, + IsNewEntity: func(entity models.{{.EntityName}}) bool { return entity.GetID() == nil }, + HasCustomProperties: func(entity models.{{.EntityName}}) bool { return entity.GetCustomProperties() != nil }, + PreserveHistoricalTimes: true, + EntityMappingFuncs: &entityMappings{}, + }) + + return r +} + +// Save saves or updates an entity. +func (r *{{.EntityName}}RepositoryImpl) Save(entity models.{{.EntityName}}) (models.{{.EntityName}}, error) { + config := r.GetConfig() + if entity.GetTypeID() == nil { + if config.TypeID > 0 { + entity.SetTypeID(config.TypeID) + } + } + + // Check for existing entity by name if this is a new entity + attr := entity.GetAttributes() + if entity.GetID() == nil && attr != nil && attr.Name != nil { + existing, err := r.GenericRepository.GetByName(*attr.Name) + if err == nil { + entity.SetID(*existing.GetID()) + } else if !errors.Is(err, Err{{.EntityName}}NotFound) { + return nil, fmt.Errorf("error finding existing entity: %w", err) + } + } + + return r.GenericRepository.Save(entity, nil) +} + +// List returns entities matching the options. +func (r *{{.EntityName}}RepositoryImpl) List(options models.{{.EntityName}}ListOptions) (*sharedmodels.ListWrapper[models.{{.EntityName}}], error) { + return r.GenericRepository.List(&options) +} + +// DeleteBySource deletes all entities with the given source ID. +func (r *{{.EntityName}}RepositoryImpl) DeleteBySource(sourceID string) error { + config := r.GetConfig() + + return config.DB.Transaction(func(tx *gorm.DB) error { + // Delete Context records where there's a ContextProperty with name='source_id' and matching value + deleteContextQuery := `DELETE FROM "Context" WHERE id IN ( + SELECT "Context".id + FROM "Context" + INNER JOIN "ContextProperty" ON "Context".id="ContextProperty".context_id + AND "ContextProperty".name='source_id' + AND "ContextProperty".string_value=? + WHERE "Context".type_id=? + )` + if err := tx.Exec(deleteContextQuery, sourceID, config.TypeID).Error; err != nil { + return fmt.Errorf("error deleting {{.EntityNameLower}}s by source: %w", err) + } + return nil + }) +} + +// DeleteByID deletes an entity by ID. +func (r *{{.EntityName}}RepositoryImpl) DeleteByID(id int32) error { + config := r.GetConfig() + return config.DB.Where("id = ? AND type_id = ?", id, config.TypeID).Delete(&schema.Context{}).Error +} + +// GetDistinctSourceIDs retrieves all unique source_id values. +func (r *{{.EntityName}}RepositoryImpl) GetDistinctSourceIDs() ([]string, error) { + config := r.GetConfig() + var sourceIDs []string + + query := `SELECT DISTINCT string_value FROM "ContextProperty" + WHERE name='source_id' + AND context_id IN (SELECT id FROM "Context" WHERE type_id=?)` + if err := config.DB.Raw(query, config.TypeID).Scan(&sourceIDs).Error; err != nil { + return nil, fmt.Errorf("error getting distinct source IDs: %w", err) + } + return sourceIDs, nil +} + +func map{{.EntityName}}ToContext(entity models.{{.EntityName}}) schema.Context { + ctx := schema.Context{} + if entity.GetID() != nil { + ctx.ID = *entity.GetID() + } + if entity.GetTypeID() != nil { + ctx.TypeID = *entity.GetTypeID() + } + if attrs := entity.GetAttributes(); attrs != nil { + if attrs.Name != nil { + ctx.Name = *attrs.Name + } + if attrs.ExternalID != nil { + ctx.ExternalID = attrs.ExternalID + } + if attrs.CreateTimeSinceEpoch != nil { + ctx.CreateTimeSinceEpoch = *attrs.CreateTimeSinceEpoch + } + if attrs.LastUpdateTimeSinceEpoch != nil { + ctx.LastUpdateTimeSinceEpoch = *attrs.LastUpdateTimeSinceEpoch + } + } + return ctx +} + +func mapContextTo{{.EntityName}}(ctx schema.Context, props []schema.ContextProperty) models.{{.EntityName}} { + // Convert schema properties to model properties and extract known attributes + var modelProps []sharedmodels.Properties +{{.PropVarDecls}} + for _, p := range props { + switch p.Name { +{{.PropReadCases}} + } + modelProps = append(modelProps, service.MapContextPropertyToProperties(p)) + } + + entity := &models.{{.EntityName}}Impl{ + ID: &ctx.ID, + TypeID: &ctx.TypeID, + Attributes: &models.{{.EntityName}}Attributes{ + Name: &ctx.Name, + ExternalID: ctx.ExternalID, + CreateTimeSinceEpoch: &ctx.CreateTimeSinceEpoch, + LastUpdateTimeSinceEpoch: &ctx.LastUpdateTimeSinceEpoch, +{{.PropAttrAssignments}} + }, + Properties: &modelProps, + } + + return entity +} + +func map{{.EntityName}}ToContextProperties(entity models.{{.EntityName}}, entityID int32) []schema.ContextProperty { + var props []schema.ContextProperty +{{if .PropWriteStatements}} attrs := entity.GetAttributes() + +{{.PropWriteStatements}}{{end}} + // Add other properties + if entity.GetProperties() != nil { + for _, p := range *entity.GetProperties() { + props = append(props, service.MapPropertiesToContextProperty(p, entityID, p.IsCustomProperty)) + } + } + + return props +} + +func apply{{.EntityName}}ListFilters(db *gorm.DB, opts *models.{{.EntityName}}ListOptions) *gorm.DB { + if opts == nil { + return db + } + + if opts.Name != nil && *opts.Name != "" { + db = db.Where("name LIKE ?", "%"+*opts.Name+"%") + } + + if opts.ExternalID != nil && *opts.ExternalID != "" { + db = db.Where("external_id = ?", *opts.ExternalID) + } + + // Filter by source IDs using context properties + var nonEmptySourceIDs []string + if opts.SourceIDs != nil { + for _, sourceID := range *opts.SourceIDs { + if sourceID != "" { + nonEmptySourceIDs = append(nonEmptySourceIDs, sourceID) + } + } + } + + if len(nonEmptySourceIDs) > 0 { + db = db.Joins("JOIN \"ContextProperty\" cp ON \"Context\".id = cp.context_id"). + Where("cp.name = ? AND cp.string_value IN ?", "source_id", nonEmptySourceIDs) + } + + return db +} diff --git a/cmd/catalog-gen/templates/service/spec.gotmpl b/cmd/catalog-gen/templates/service/spec.gotmpl new file mode 100644 index 0000000000..5f24100da1 --- /dev/null +++ b/cmd/catalog-gen/templates/service/spec.gotmpl @@ -0,0 +1,38 @@ +// Code generated by catalog-gen. DO NOT EDIT. +// To regenerate: catalog-gen generate +// Source: catalog.yaml + +package service + +import ( + "{{.Package}}/internal/db/models" + "github.com/kubeflow/model-registry/internal/datastore" +) + +const ( + {{.EntityName}}TypeName = "kf.{{.EntityName}}" +{{.ArtifactConstants}}) + +// DatastoreSpec returns the datastore specification for this catalog. +// This defines the types and properties that will be stored in the MLMD database. +func DatastoreSpec() *datastore.Spec { + return datastore.NewSpec(). + AddContext({{.EntityName}}TypeName, datastore.NewSpecType(New{{.EntityName}}Repository). + AddString("source_id"){{if .PropertyDefs}}. +{{.PropertyDefs}}{{else}},{{end}} + ){{.ContextTrailingDot}} +{{.ArtifactSpecs}}} + +// Services holds all repository instances for this catalog. +type Services struct { + {{.EntityName}}Repository models.{{.EntityName}}Repository +{{.ArtifactServiceFields}}} + +// NewServices creates a new Services instance from repository instances. +func NewServices( + {{.EntityNameLower}}Repository models.{{.EntityName}}Repository, +{{.ArtifactServiceParams}}) Services { + return Services{ + {{.EntityName}}Repository: {{.EntityNameLower}}Repository, +{{.ArtifactServiceAssignments}} } +} diff --git a/cmd/catalog-gen/types.go b/cmd/catalog-gen/types.go new file mode 100644 index 0000000000..07f76e5f99 --- /dev/null +++ b/cmd/catalog-gen/types.go @@ -0,0 +1,54 @@ +package main + +// CatalogConfig is the configuration structure for a catalog. +type CatalogConfig struct { + APIVersion string `yaml:"apiVersion"` + Kind string `yaml:"kind"` + Metadata CatalogMetadata `yaml:"metadata"` + Spec CatalogSpec `yaml:"spec"` +} + +// CatalogMetadata contains catalog metadata. +type CatalogMetadata struct { + Name string `yaml:"name"` +} + +// CatalogSpec contains the catalog specification. +type CatalogSpec struct { + Package string `yaml:"package"` + Entity EntityConfig `yaml:"entity"` + Artifacts []ArtifactConfig `yaml:"artifacts,omitempty"` + Providers []ProviderConfig `yaml:"providers,omitempty"` + API APIConfig `yaml:"api"` +} + +// EntityConfig defines the main entity type. +type EntityConfig struct { + Name string `yaml:"name"` + Properties []PropertyConfig `yaml:"properties,omitempty"` +} + +// ArtifactConfig defines an artifact type linked to the entity. +type ArtifactConfig struct { + Name string `yaml:"name"` + Properties []PropertyConfig `yaml:"properties,omitempty"` +} + +// PropertyConfig defines a property on an entity or artifact. +type PropertyConfig struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + Required bool `yaml:"required,omitempty"` + Items *PropertyConfig `yaml:"items,omitempty"` // For array types +} + +// ProviderConfig defines a data provider. +type ProviderConfig struct { + Type string `yaml:"type"` +} + +// APIConfig defines API settings. +type APIConfig struct { + BasePath string `yaml:"basePath"` + Port int `yaml:"port"` +} diff --git a/cmd/catalog-server/main.go b/cmd/catalog-server/main.go new file mode 100644 index 0000000000..7a71fe0945 --- /dev/null +++ b/cmd/catalog-server/main.go @@ -0,0 +1,189 @@ +// Package main provides the unified catalog server entry point. +// This server hosts all registered catalog plugins under a single process. +package main + +import ( + "context" + "flag" + "fmt" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/golang/glog" + "gorm.io/gorm" + + "github.com/kubeflow/model-registry/internal/datastore" + "github.com/kubeflow/model-registry/internal/datastore/embedmd" + "github.com/kubeflow/model-registry/internal/db" + "github.com/kubeflow/model-registry/pkg/catalog/plugin" + + // Import plugins - their init() registers them + _ "github.com/kubeflow/model-registry/catalog/plugins/model" + // _ "github.com/kubeflow/model-registry/catalog/plugins/mcp" // generated via catalog-gen + // _ "github.com/kubeflow/model-registry/catalog/plugins/datasets" // future +) + +func main() { + var ( + listenAddr string + sourcesPath string + databaseType string + databaseDSN string + ) + + flag.StringVar(&listenAddr, "listen", ":8080", "Address to listen on") + flag.StringVar(&sourcesPath, "sources", "/config/sources.yaml", "Path to catalog sources config") + flag.StringVar(&databaseType, "db-type", "postgres", "Database type (postgres or mysql)") + flag.StringVar(&databaseDSN, "db-dsn", "", "Database connection string") + + // Let plugins register their custom flags before parsing + plugin.RegisterAllFlags(flag.CommandLine) + flag.Parse() + + // Initialize glog for backwards compatibility + _ = flag.Set("logtostderr", "true") + + // Set up structured logger + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + slog.SetDefault(logger) + + logger.Info("starting catalog server", + "listen", listenAddr, + "sources", sourcesPath, + "plugins", plugin.Names(), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Handle shutdown signals + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigCh + logger.Info("received shutdown signal", "signal", sig) + cancel() + }() + + // Load config + cfg, err := plugin.LoadConfig(sourcesPath) + if err != nil { + glog.Fatalf("Failed to load config: %v", err) + } + + logger.Info("loaded config", + "apiVersion", cfg.APIVersion, + "kind", cfg.Kind, + "catalogs", len(cfg.Catalogs), + ) + + // Setup database + gormDB, err := setupDatabase(databaseType, databaseDSN) + if err != nil { + glog.Fatalf("Failed to connect to database: %v", err) + } + + // Create and initialize server + server := plugin.NewServer(cfg, []string{sourcesPath}, gormDB, logger) + if err := server.Init(ctx); err != nil { + glog.Fatalf("Failed to initialize plugins: %v", err) + } + + // Mount routes and start + router := server.MountRoutes() + + if err := server.Start(ctx); err != nil { + glog.Fatalf("Failed to start plugins: %v", err) + } + + logger.Info("catalog server ready", + "listen", listenAddr, + "plugins", plugin.Names(), + ) + + // Create HTTP server with graceful shutdown + httpServer := &http.Server{ + Addr: listenAddr, + Handler: router, + } + + // Start HTTP server in goroutine + go func() { + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + glog.Fatalf("HTTP server error: %v", err) + } + }() + + // Wait for shutdown signal + <-ctx.Done() + + logger.Info("shutting down...") + + // Graceful shutdown with timeout + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + logger.Error("HTTP server shutdown error", "error", err) + } + + if err := server.Stop(shutdownCtx); err != nil { + logger.Error("plugin shutdown error", "error", err) + } + + logger.Info("catalog server stopped") +} + +func setupDatabase(dbType, dsn string) (*gorm.DB, error) { + if dsn == "" { + // Try to get from environment + dsn = os.Getenv("DATABASE_DSN") + if dsn == "" { + return nil, fmt.Errorf("database DSN is required (use -db-dsn flag or DATABASE_DSN environment variable)") + } + } + + if dbType == "" { + dbType = os.Getenv("DATABASE_TYPE") + if dbType == "" { + dbType = "postgres" + } + } + + // Create embedmd connector + cfg := &embedmd.EmbedMDConfig{ + DatabaseType: dbType, + DatabaseDSN: dsn, + } + + connector, err := datastore.NewConnector("embedmd", cfg) + if err != nil { + return nil, fmt.Errorf("failed to create database connector: %w", err) + } + + // Connect to initialize the database + // We need a minimal spec just to establish the connection + _, err = connector.Connect(datastore.NewSpec()) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + // Get the GORM DB from the db package + dbConnector, ok := db.GetConnector() + if !ok { + return nil, fmt.Errorf("database connector not available") + } + + gormDB, err := dbConnector.Connect() + if err != nil { + return nil, fmt.Errorf("failed to get GORM connection: %w", err) + } + + return gormDB, nil +} diff --git a/go.mod b/go.mod index cb23d5ed94..828d9f6cfb 100644 --- a/go.mod +++ b/go.mod @@ -225,7 +225,7 @@ require ( golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250528174236-200df99c418a // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 ) replace github.com/kubeflow/model-registry/pkg/openapi => ./pkg/openapi diff --git a/internal/datastore/embedmd/service.go b/internal/datastore/embedmd/service.go index a09de0e569..f91f9b6163 100644 --- a/internal/datastore/embedmd/service.go +++ b/internal/datastore/embedmd/service.go @@ -43,6 +43,10 @@ type EmbedMDConfig struct { // DB is an already connected database instance that, if provided, will // be used instead of making a new connection. DB *gorm.DB + + // SkipMigrations skips running database migrations during Connect. + // Use this when migrations have already been run (e.g., by the server at startup). + SkipMigrations bool } func (c *EmbedMDConfig) Validate() error { @@ -97,7 +101,8 @@ func (c *EmbedMDConfig) Validate() error { } type EmbedMDService struct { - dbConnector db.Connector + dbConnector db.Connector + skipMigrations bool } func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) { @@ -116,7 +121,8 @@ func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) { } return &EmbedMDService{ - dbConnector: dbConnector, + dbConnector: dbConnector, + skipMigrations: cfg.SkipMigrations, }, nil } @@ -130,19 +136,21 @@ func (s *EmbedMDService) Connect(spec *datastore.Spec) (datastore.RepoSet, error glog.Infof("Connected to EmbedMD service") - migrator, err := db.NewDBMigrator(connectedDB) - if err != nil { - return nil, err - } + if !s.skipMigrations { + migrator, err := db.NewDBMigrator(connectedDB) + if err != nil { + return nil, err + } - glog.Infof("Running migrations...") + glog.Infof("Running migrations...") - err = migrator.Migrate() - if err != nil { - return nil, err - } + err = migrator.Migrate() + if err != nil { + return nil, err + } - glog.Infof("Migrations completed") + glog.Infof("Migrations completed") + } glog.Infof("Syncing types...") err = s.syncTypes(connectedDB, spec) diff --git a/pkg/catalog/filter.go b/pkg/catalog/filter.go new file mode 100644 index 0000000000..6f854a7b59 --- /dev/null +++ b/pkg/catalog/filter.go @@ -0,0 +1,179 @@ +package catalog + +import ( + "fmt" + "regexp" + "strings" +) + +// ItemFilter provides include/exclude pattern matching for item names. +// It uses glob-style patterns with '*' as the only supported wildcard. +type ItemFilter struct { + included []*compiledPattern + excluded []*compiledPattern +} + +type compiledPattern struct { + raw string + re *regexp.Regexp +} + +func newCompiledPattern(field string, idx int, raw string) (*compiledPattern, error) { + value := strings.TrimSpace(raw) + if value == "" { + return nil, fmt.Errorf("%s[%d]: pattern cannot be empty", field, idx) + } + + // Convert a simple glob (only supporting '*') into a regexp. + var b strings.Builder + b.WriteString("(?i)^") // case insensitive + for _, r := range value { + if r == '*' { + b.WriteString(".*") + continue + } + b.WriteString(regexp.QuoteMeta(string(r))) + } + b.WriteString("$") + + re, err := regexp.Compile(b.String()) + if err != nil { + return nil, fmt.Errorf("%s[%d]: invalid pattern %q: %w", field, idx, value, err) + } + + return &compiledPattern{ + raw: value, + re: re, + }, nil +} + +func compilePatterns(field string, patterns []string) ([]*compiledPattern, error) { + if len(patterns) == 0 { + return nil, nil + } + + compiled := make([]*compiledPattern, 0, len(patterns)) + for i, pattern := range patterns { + cp, err := newCompiledPattern(field, i, pattern) + if err != nil { + return nil, err + } + compiled = append(compiled, cp) + } + return compiled, nil +} + +// ValidatePatterns validates that the included and excluded patterns +// are valid (non-empty, compilable). This is useful for early validation +// at configuration load time without constructing the full ItemFilter. +func ValidatePatterns(included, excluded []string) error { + if _, err := compilePatterns("included", included); err != nil { + return err + } + + if _, err := compilePatterns("excluded", excluded); err != nil { + return err + } + + return nil +} + +// NewItemFilter builds an ItemFilter from the provided include/exclude pattern lists. +// Patterns support glob-style wildcards ('*'). +// +// Include logic: +// - If included is non-empty, items must match at least one pattern to be allowed. +// - If included is empty, all items are allowed (subject to exclusions). +// +// Exclude logic: +// - Items matching any excluded pattern are rejected, even if they match an include. +// +// Returns nil if both lists are empty (no filtering needed). +func NewItemFilter(included, excluded []string) (*ItemFilter, error) { + if err := ValidatePatterns(included, excluded); err != nil { + return nil, err + } + + inc, err := compilePatterns("included", included) + if err != nil { + return nil, err + } + + exc, err := compilePatterns("excluded", excluded) + if err != nil { + return nil, err + } + + if len(inc) == 0 && len(exc) == 0 { + return nil, nil + } + + return &ItemFilter{ + included: inc, + excluded: exc, + }, nil +} + +// NewItemFilterFromSource creates an ItemFilter from a Source's configuration. +// Additional patterns can be appended via extraIncluded and extraExcluded. +func NewItemFilterFromSource(source *Source, extraIncluded, extraExcluded []string) (*ItemFilter, error) { + if source == nil { + return nil, fmt.Errorf("source cannot be nil when building filters") + } + + included := append([]string{}, source.IncludedItems...) + if len(extraIncluded) > 0 { + included = append(included, extraIncluded...) + } + + excluded := append([]string{}, source.ExcludedItems...) + if len(extraExcluded) > 0 { + excluded = append(excluded, extraExcluded...) + } + + filter, err := NewItemFilter(included, excluded) + if err != nil { + return nil, fmt.Errorf("invalid include/exclude configuration for source %s: %w", source.ID, err) + } + + return filter, nil +} + +// Allows returns true if the provided item name passes the include/exclude rules. +// A nil filter allows everything. +func (f *ItemFilter) Allows(name string) bool { + if f == nil { + return true + } + + // Check include patterns first + if len(f.included) > 0 { + matched := false + for _, pattern := range f.included { + if pattern.re.MatchString(name) { + matched = true + break + } + } + if !matched { + return false + } + } + + // Check exclude patterns + for _, pattern := range f.excluded { + if pattern.re.MatchString(name) { + return false + } + } + + return true +} + +// HasPatterns returns true if the filter has any include or exclude patterns. +func (f *ItemFilter) HasPatterns() bool { + if f == nil { + return false + } + return len(f.included) > 0 || len(f.excluded) > 0 +} diff --git a/pkg/catalog/filter_test.go b/pkg/catalog/filter_test.go new file mode 100644 index 0000000000..84af950262 --- /dev/null +++ b/pkg/catalog/filter_test.go @@ -0,0 +1,206 @@ +package catalog + +import ( + "testing" +) + +func TestItemFilter(t *testing.T) { + tests := []struct { + name string + included []string + excluded []string + items map[string]bool // item -> expected result + }{ + { + name: "no patterns allows everything", + included: nil, + excluded: nil, + items: map[string]bool{ + "anything": true, + "foo": true, + }, + }, + { + name: "include pattern only", + included: []string{"foo*"}, + excluded: nil, + items: map[string]bool{ + "foo": true, + "foobar": true, + "bar": false, + "barfoo": false, + "FOO": true, // case insensitive + "FOOBAR": true, + }, + }, + { + name: "exclude pattern only", + included: nil, + excluded: []string{"*test*"}, + items: map[string]bool{ + "foo": true, + "test": false, + "testing": false, + "mytest": false, + "mytesting": false, + "production": true, + }, + }, + { + name: "include and exclude combined", + included: []string{"model-*"}, + excluded: []string{"*-test"}, + items: map[string]bool{ + "model-a": true, + "model-b": true, + "model-test": false, // excluded + "other-model": false, // not included + "model-a-test": false, // excluded + }, + }, + { + name: "exact match", + included: []string{"mymodel"}, + excluded: nil, + items: map[string]bool{ + "mymodel": true, + "mymodel2": false, + "themymodel": false, + "MYMODEL": true, // case insensitive + }, + }, + { + name: "wildcard at start", + included: []string{"*-v1"}, + excluded: nil, + items: map[string]bool{ + "model-v1": true, + "other-v1": true, + "model-v2": false, + "v1": false, // doesn't end with -v1 + "-v1": true, + }, + }, + { + name: "multiple includes", + included: []string{"foo*", "bar*"}, + excluded: nil, + items: map[string]bool{ + "foo": true, + "fooX": true, + "bar": true, + "barY": true, + "baz": false, + "other": false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filter, err := NewItemFilter(tt.included, tt.excluded) + if err != nil { + t.Fatalf("NewItemFilter failed: %v", err) + } + + for item, expected := range tt.items { + result := filter.Allows(item) + if result != expected { + t.Errorf("Allows(%q) = %v, want %v", item, result, expected) + } + } + }) + } +} + +func TestItemFilterNil(t *testing.T) { + // nil filter allows everything + var filter *ItemFilter + if !filter.Allows("anything") { + t.Error("nil filter should allow everything") + } +} + +func TestItemFilterEmptyReturnNil(t *testing.T) { + filter, err := NewItemFilter(nil, nil) + if err != nil { + t.Fatalf("NewItemFilter failed: %v", err) + } + if filter != nil { + t.Error("Expected nil filter when no patterns provided") + } +} + +func TestValidatePatterns(t *testing.T) { + // Valid patterns + if err := ValidatePatterns([]string{"foo*"}, []string{"bar*"}); err != nil { + t.Errorf("ValidatePatterns failed for valid patterns: %v", err) + } + + // Empty pattern should fail + if err := ValidatePatterns([]string{""}, nil); err == nil { + t.Error("Expected error for empty pattern") + } + + // Whitespace-only pattern should fail + if err := ValidatePatterns([]string{" "}, nil); err == nil { + t.Error("Expected error for whitespace-only pattern") + } +} + +func TestNewItemFilterFromSource(t *testing.T) { + source := &Source{ + ID: "test", + IncludedItems: []string{"model-*"}, + ExcludedItems: []string{"*-test"}, + } + + filter, err := NewItemFilterFromSource(source, nil, nil) + if err != nil { + t.Fatalf("NewItemFilterFromSource failed: %v", err) + } + + if !filter.Allows("model-a") { + t.Error("Expected model-a to be allowed") + } + if filter.Allows("model-test") { + t.Error("Expected model-test to be excluded") + } +} + +func TestNewItemFilterFromSourceWithExtras(t *testing.T) { + source := &Source{ + ID: "test", + IncludedItems: []string{"model-*"}, + } + + // Add extra excluded patterns + filter, err := NewItemFilterFromSource(source, nil, []string{"*-deprecated"}) + if err != nil { + t.Fatalf("NewItemFilterFromSource failed: %v", err) + } + + if !filter.Allows("model-a") { + t.Error("Expected model-a to be allowed") + } + if filter.Allows("model-deprecated") { + t.Error("Expected model-deprecated to be excluded by extra pattern") + } +} + +func TestItemFilterHasPatterns(t *testing.T) { + var nilFilter *ItemFilter + if nilFilter.HasPatterns() { + t.Error("nil filter should not have patterns") + } + + filter, _ := NewItemFilter([]string{"foo*"}, nil) + if !filter.HasPatterns() { + t.Error("filter with include patterns should have patterns") + } + + filter, _ = NewItemFilter(nil, []string{"bar*"}) + if !filter.HasPatterns() { + t.Error("filter with exclude patterns should have patterns") + } +} diff --git a/pkg/catalog/loader.go b/pkg/catalog/loader.go new file mode 100644 index 0000000000..d977675ee2 --- /dev/null +++ b/pkg/catalog/loader.go @@ -0,0 +1,460 @@ +package catalog + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "k8s.io/apimachinery/pkg/util/yaml" +) + +// LoaderEventHandler is called after each record is successfully processed. +// Use this to trigger side effects like cache invalidation or notifications. +type LoaderEventHandler[E any, A any] func(ctx context.Context, record Record[E, A]) error + +// EntitySaver persists an entity and returns the saved entity (with ID populated). +type EntitySaver[E any] func(entity E) (E, error) + +// ArtifactSaver persists an artifact associated with an entity. +type ArtifactSaver[A any] func(artifact A, entityID int32) error + +// EntityIDGetter extracts the ID from an entity (returns nil if not set). +type EntityIDGetter[E any] func(entity E) *int32 + +// EntityNameGetter extracts the name from an entity for logging. +type EntityNameGetter[E any] func(entity E) string + +// SourceStatusSaver persists the status of a source (available/error/disabled). +type SourceStatusSaver func(sourceID, status, errorMsg string) + +// Source status constants +const ( + SourceStatusAvailable = "available" + SourceStatusError = "error" + SourceStatusDisabled = "disabled" +) + +// LoaderConfig configures a Loader instance. +type LoaderConfig[E any, A any] struct { + // Paths are the config file paths to load sources from. + Paths []string + + // ProviderRegistry contains registered provider implementations. + ProviderRegistry *ProviderRegistry[E, A] + + // SaveEntity persists an entity to the database. + SaveEntity EntitySaver[E] + + // SaveArtifact persists an artifact associated with an entity. + SaveArtifact ArtifactSaver[A] + + // GetEntityID extracts the ID from an entity. + GetEntityID EntityIDGetter[E] + + // GetEntityName extracts a name from an entity for logging. + GetEntityName EntityNameGetter[E] + + // SaveSourceStatus persists source status to the database. + // Optional - if nil, status is not persisted. + SaveSourceStatus SourceStatusSaver + + // DeleteArtifactsByEntity removes all artifacts for an entity before re-adding. + // Optional - if nil, artifacts are not cleaned up before save. + DeleteArtifactsByEntity func(entityID int32) error + + // DeleteEntitiesBySource removes all entities for a source ID. + // Called when a source is removed or disabled. + DeleteEntitiesBySource func(sourceID string) error + + // GetDistinctSourceIDs returns all source IDs that have entities in the DB. + // Used for cleanup of orphaned sources. + GetDistinctSourceIDs func() ([]string, error) + + // SetEntitySourceID sets the source ID on an entity. + SetEntitySourceID func(entity E, sourceID string) + + // IsEntityNil checks if an entity is nil (for batch completion detection). + IsEntityNil func(entity E) bool + + // Logger for logging messages (optional, defaults to no-op). + Logger LoaderLogger +} + +// LoaderLogger is an interface for logging within the loader. +type LoaderLogger interface { + Infof(format string, args ...any) + Errorf(format string, args ...any) +} + +// noopLogger is a no-op logger implementation. +type noopLogger struct{} + +func (noopLogger) Infof(format string, args ...any) {} +func (noopLogger) Errorf(format string, args ...any) {} + +// SourceConfig is the structure for catalog sources YAML files. +type SourceConfig struct { + Catalogs []SourceConfigEntry `json:"catalogs" yaml:"catalogs"` +} + +// SourceConfigEntry is a single entry in the sources YAML file. +type SourceConfigEntry struct { + ID string `json:"id" yaml:"id"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Type string `json:"type" yaml:"type"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + Labels []string `json:"labels,omitempty" yaml:"labels,omitempty"` + Properties map[string]any `json:"properties,omitempty" yaml:"properties,omitempty"` + IncludedModels []string `json:"includedModels,omitempty" yaml:"includedModels,omitempty"` + ExcludedModels []string `json:"excludedModels,omitempty" yaml:"excludedModels,omitempty"` +} + +// ToSource converts a config entry to a Source. +func (e SourceConfigEntry) ToSource(origin string) Source { + return Source{ + ID: e.ID, + Name: e.Name, + Type: e.Type, + Enabled: e.Enabled, + Labels: e.Labels, + Properties: e.Properties, + IncludedItems: e.IncludedModels, + ExcludedItems: e.ExcludedModels, + Origin: origin, + } +} + +// Loader manages loading data from sources into repositories. +type Loader[E any, A any] struct { + config LoaderConfig[E, A] + + // Sources contains current source information loaded from the configuration files. + Sources *SourceCollection + + closersMu sync.Mutex + closer func() // cancels the current loading goroutines + handlers []LoaderEventHandler[E, A] + loadedSources map[string]bool // tracks which source IDs have been loaded + logger LoaderLogger +} + +// NewLoader creates a new Loader with the given configuration. +func NewLoader[E any, A any](config LoaderConfig[E, A]) *Loader[E, A] { + // Convert paths to absolute for consistent origin ordering. + absPaths := make([]string, 0, len(config.Paths)) + for _, p := range config.Paths { + absPath, err := filepath.Abs(p) + if err != nil { + absPath = p + } + absPaths = append(absPaths, absPath) + } + + logger := config.Logger + if logger == nil { + logger = noopLogger{} + } + + return &Loader[E, A]{ + config: config, + Sources: NewSourceCollection(absPaths...), + loadedSources: map[string]bool{}, + logger: logger, + } +} + +// RegisterEventHandler adds a function that will be called for every +// successfully processed record. This should be called before Start. +func (l *Loader[E, A]) RegisterEventHandler(fn LoaderEventHandler[E, A]) { + l.handlers = append(l.handlers, fn) +} + +// Start processes the sources YAML files and loads data. +// Background goroutines will be stopped when the context is canceled. +func (l *Loader[E, A]) Start(ctx context.Context) error { + // Phase 1: Parse all config files and merge sources + for _, path := range l.config.Paths { + if err := l.parseAndMerge(path); err != nil { + return fmt.Errorf("%s: %w", path, err) + } + } + + // Delete entities from unknown or disabled sources + if err := l.removeEntitiesFromMissingSources(); err != nil { + return fmt.Errorf("failed to remove entities from missing sources: %w", err) + } + + // Phase 2: Load entities from merged sources + if err := l.loadAllEntities(ctx); err != nil { + return err + } + + // Phase 3: Watch config files for hot-reload + for _, path := range l.config.Paths { + watcher := NewFileWatcher(path, 5*time.Second) + changes := watcher.Watch(ctx) + go func(p string) { + for range changes { + l.logger.Infof("Config file changed: %s, reloading...", p) + if err := l.Reload(ctx); err != nil { + l.logger.Errorf("Failed to reload after config change %s: %v", p, err) + } + } + }(path) + } + + return nil +} + +// Stop gracefully shuts down the loader by canceling any background operations. +func (l *Loader[E, A]) Stop(ctx context.Context) error { + l.closersMu.Lock() + defer l.closersMu.Unlock() + if l.closer != nil { + l.closer() + l.closer = nil + } + return nil +} + +// Reload re-parses config files, cleans up missing sources, and reloads all entities. +func (l *Loader[E, A]) Reload(ctx context.Context) error { + for _, path := range l.config.Paths { + if err := l.parseAndMerge(path); err != nil { + l.logger.Errorf("failed to reload config %s: %v", path, err) + } + } + _ = l.removeEntitiesFromMissingSources() + return l.loadAllEntities(ctx) +} + +// parseAndMerge parses a config file and merges its sources into the collection. +func (l *Loader[E, A]) parseAndMerge(path string) error { + path, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("failed to get absolute path for %s: %v", path, err) + } + + config, err := l.readConfig(path) + if err != nil { + return err + } + + sources := make(map[string]Source, len(config.Catalogs)) + for _, entry := range config.Catalogs { + l.logger.Infof("reading config type %s...", entry.Type) + if entry.ID == "" { + return fmt.Errorf("invalid source: missing id") + } + if _, exists := sources[entry.ID]; exists { + return fmt.Errorf("invalid source: duplicate id %s", entry.ID) + } + + // Validate include/exclude patterns early + if err := ValidatePatterns(entry.IncludedModels, entry.ExcludedModels); err != nil { + return fmt.Errorf("invalid source %s: %w", entry.ID, err) + } + + sources[entry.ID] = entry.ToSource(path) + l.logger.Infof("loaded source %s of type %s", entry.ID, entry.Type) + } + + return l.Sources.Merge(path, sources) +} + +func (l *Loader[E, A]) readConfig(path string) (*SourceConfig, error) { + config := &SourceConfig{} + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + if err = yaml.UnmarshalStrict(bytes, config); err != nil { + return nil, err + } + + return config, nil +} + +// loadAllEntities loads entities from all merged sources. +func (l *Loader[E, A]) loadAllEntities(ctx context.Context) error { + l.loadedSources = map[string]bool{} + return l.updateDatabase(ctx) +} + +func (l *Loader[E, A]) updateDatabase(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + + l.closersMu.Lock() + if l.closer != nil { + l.closer() + } + l.closer = cancel + l.closersMu.Unlock() + + records := l.readProviderRecords(ctx) + + go func() { + for record := range records { + if l.config.IsEntityNil != nil && l.config.IsEntityNil(record.Entity) { + continue + } + + name := "" + if l.config.GetEntityName != nil { + name = l.config.GetEntityName(record.Entity) + } + + l.logger.Infof("Loading entity %s with %d artifact(s)", name, len(record.Artifacts)) + + entity, err := l.config.SaveEntity(record.Entity) + if err != nil { + l.logger.Errorf("%s: unable to save: %v", name, err) + continue + } + + entityID := l.config.GetEntityID(entity) + if entityID == nil { + l.logger.Errorf("%s: entity has no ID after save", name) + continue + } + + // Remove old artifacts before adding new ones + if l.config.DeleteArtifactsByEntity != nil { + if err := l.config.DeleteArtifactsByEntity(*entityID); err != nil { + l.logger.Errorf("%s: unable to remove old artifacts: %v", name, err) + } + } + + // Save new artifacts + for i, artifact := range record.Artifacts { + if err := l.config.SaveArtifact(artifact, *entityID); err != nil { + l.logger.Errorf("%s, artifact %d: %v", name, i, err) + } + } + + // Call event handlers + for _, handler := range l.handlers { + if err := handler(ctx, record); err != nil { + l.logger.Errorf("%s: event handler error: %v", name, err) + } + } + } + }() + + return nil +} + +// readProviderRecords calls the provider for every merged source that hasn't +// been loaded yet, and merges the returned channels together. +func (l *Loader[E, A]) readProviderRecords(ctx context.Context) <-chan Record[E, A] { + ch := make(chan Record[E, A]) + var wg sync.WaitGroup + + mergedSources := l.Sources.AllSources() + + for _, source := range mergedSources { + // Skip disabled sources + if !source.IsEnabled() { + if l.config.SaveSourceStatus != nil { + l.config.SaveSourceStatus(source.ID, SourceStatusDisabled, "") + } + continue + } + + // Skip already loaded sources + if l.loadedSources[source.ID] { + continue + } + + if source.Type == "" { + l.logger.Errorf("source %s has no type defined, skipping", source.ID) + if l.config.SaveSourceStatus != nil { + l.config.SaveSourceStatus(source.ID, SourceStatusError, "source has no type defined") + } + continue + } + + l.loadedSources[source.ID] = true + + l.logger.Infof("Reading entities from %s source %s", source.Type, source.ID) + + providerFunc, ok := l.config.ProviderRegistry.Get(source.Type) + if !ok { + l.logger.Errorf("provider type %s not registered", source.Type) + if l.config.SaveSourceStatus != nil { + l.config.SaveSourceStatus(source.ID, SourceStatusError, fmt.Sprintf("provider type %q not registered", source.Type)) + } + continue + } + + sourceDir := filepath.Dir(source.Origin) + sourceCopy := source // capture for goroutine + + records, err := providerFunc(ctx, &sourceCopy, sourceDir) + if err != nil { + l.logger.Errorf("error reading provider type %s with id %s: %v", source.Type, source.ID, err) + if l.config.SaveSourceStatus != nil { + l.config.SaveSourceStatus(source.ID, SourceStatusError, err.Error()) + } + continue + } + + wg.Add(1) + go func(sourceID string) { + defer wg.Done() + + for record := range records { + // Set source ID on entity + if l.config.SetEntitySourceID != nil && !l.config.IsEntityNil(record.Entity) { + l.config.SetEntitySourceID(record.Entity, sourceID) + } + ch <- record + } + + // Mark source as available + if l.config.SaveSourceStatus != nil && ctx.Err() == nil { + l.config.SaveSourceStatus(sourceID, SourceStatusAvailable, "") + } + }(source.ID) + } + + go func() { + defer close(ch) + wg.Wait() + }() + + return ch +} + +func (l *Loader[E, A]) removeEntitiesFromMissingSources() error { + if l.config.DeleteEntitiesBySource == nil || l.config.GetDistinctSourceIDs == nil { + return nil + } + + enabledSourceIDs := make(map[string]bool) + for id, source := range l.Sources.AllSources() { + if source.IsEnabled() { + enabledSourceIDs[id] = true + } + } + + existingSourceIDs, err := l.config.GetDistinctSourceIDs() + if err != nil { + return fmt.Errorf("unable to retrieve existing source IDs: %w", err) + } + + for _, oldSource := range existingSourceIDs { + if !enabledSourceIDs[oldSource] { + l.logger.Infof("Removing entities from source %s", oldSource) + if err := l.config.DeleteEntitiesBySource(oldSource); err != nil { + return fmt.Errorf("unable to remove entities from source %q: %w", oldSource, err) + } + } + } + + return nil +} diff --git a/pkg/catalog/plugin/config.go b/pkg/catalog/plugin/config.go new file mode 100644 index 0000000000..9b45b5c8ea --- /dev/null +++ b/pkg/catalog/plugin/config.go @@ -0,0 +1,249 @@ +package plugin + +import ( + "fmt" + "maps" + "os" + + "k8s.io/apimachinery/pkg/util/yaml" +) + +// CatalogSourcesConfig is the root configuration structure for multi-catalog sources.yaml. +type CatalogSourcesConfig struct { + // APIVersion identifies the config format version (e.g., "catalog/v1alpha1"). + APIVersion string `json:"apiVersion" yaml:"apiVersion"` + + // Kind identifies the config type (e.g., "CatalogSources"). + Kind string `json:"kind" yaml:"kind"` + + // Catalogs maps plugin names to their configurations. + // The key is the plugin name (e.g., "models", "datasets"). + Catalogs map[string]CatalogSection `json:"catalogs" yaml:"catalogs"` +} + +// CatalogSection contains configuration for a single catalog plugin. +type CatalogSection struct { + // Sources is the list of data sources for this catalog. + Sources []SourceConfig `json:"sources" yaml:"sources"` + + // Labels defines custom labels available in this catalog. + Labels []map[string]any `json:"labels,omitempty" yaml:"labels,omitempty"` + + // NamedQueries defines preset filter queries. + NamedQueries map[string]map[string]FieldFilter `json:"namedQueries,omitempty" yaml:"namedQueries,omitempty"` +} + +// SourceConfig represents a single data source configuration. +// This is a unified structure that works across all catalog types. +type SourceConfig struct { + // ID is the unique identifier for this source. + ID string `json:"id" yaml:"id"` + + // Name is the human-readable display name. + Name string `json:"name" yaml:"name"` + + // Type identifies the provider type (e.g., "yaml", "http", "hf"). + Type string `json:"type" yaml:"type"` + + // Enabled indicates whether this source should be loaded. + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + + // Labels are tags for filtering and categorization. + Labels []string `json:"labels,omitempty" yaml:"labels,omitempty"` + + // Properties contains provider-specific configuration. + Properties map[string]any `json:"properties,omitempty" yaml:"properties,omitempty"` + + // IncludedItems are glob patterns for items to include. + IncludedItems []string `json:"includedItems,omitempty" yaml:"includedItems,omitempty"` + + // ExcludedItems are glob patterns for items to exclude. + ExcludedItems []string `json:"excludedItems,omitempty" yaml:"excludedItems,omitempty"` + + // Origin is set programmatically to the config file path. + Origin string `json:"-" yaml:"-"` +} + +// FieldFilter represents a filter condition for named queries. +type FieldFilter struct { + Operator string `json:"operator" yaml:"operator"` + Value any `json:"value" yaml:"value"` +} + +// IsEnabled returns true if this source is enabled (defaults to true if nil). +func (s SourceConfig) IsEnabled() bool { + return s.Enabled == nil || *s.Enabled +} + +// LoadConfig loads a CatalogSourcesConfig from a YAML file. +func LoadConfig(path string) (*CatalogSourcesConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read config file %s: %w", path, err) + } + + return ParseConfig(data, path) +} + +// ParseConfig parses a CatalogSourcesConfig from YAML bytes. +// The origin parameter is used to set the Origin field on sources. +func ParseConfig(data []byte, origin string) (*CatalogSourcesConfig, error) { + cfg := &CatalogSourcesConfig{} + if err := yaml.UnmarshalStrict(data, cfg); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + + // Set origin on all sources + for catalogName := range cfg.Catalogs { + section := cfg.Catalogs[catalogName] + for i := range section.Sources { + section.Sources[i].Origin = origin + } + cfg.Catalogs[catalogName] = section + } + + return cfg, nil +} + +// LoadConfigs loads and merges multiple config files. +// Later files take precedence over earlier ones for overlapping sources. +func LoadConfigs(paths []string) (*CatalogSourcesConfig, error) { + if len(paths) == 0 { + return &CatalogSourcesConfig{ + Catalogs: make(map[string]CatalogSection), + }, nil + } + + // Load first config as base + result, err := LoadConfig(paths[0]) + if err != nil { + return nil, err + } + + // Merge subsequent configs + for _, path := range paths[1:] { + cfg, err := LoadConfig(path) + if err != nil { + return nil, err + } + + result = MergeConfigs(result, cfg) + } + + return result, nil +} + +// MergeConfigs merges two configs, with override taking precedence. +func MergeConfigs(base, override *CatalogSourcesConfig) *CatalogSourcesConfig { + result := &CatalogSourcesConfig{ + APIVersion: override.APIVersion, + Kind: override.Kind, + Catalogs: make(map[string]CatalogSection), + } + + if result.APIVersion == "" { + result.APIVersion = base.APIVersion + } + if result.Kind == "" { + result.Kind = base.Kind + } + + // Copy base catalogs + maps.Copy(result.Catalogs, base.Catalogs) + + // Merge override catalogs + for name, overrideSection := range override.Catalogs { + if baseSection, exists := result.Catalogs[name]; exists { + result.Catalogs[name] = mergeCatalogSections(baseSection, overrideSection) + } else { + result.Catalogs[name] = overrideSection + } + } + + return result +} + +// mergeCatalogSections merges two CatalogSections. +func mergeCatalogSections(base, override CatalogSection) CatalogSection { + result := CatalogSection{ + Sources: make([]SourceConfig, 0), + Labels: override.Labels, + NamedQueries: make(map[string]map[string]FieldFilter), + } + + if result.Labels == nil { + result.Labels = base.Labels + } + + // Build source map from base + sourceMap := make(map[string]SourceConfig) + for _, s := range base.Sources { + sourceMap[s.ID] = s + } + + // Merge override sources + for _, s := range override.Sources { + if existing, ok := sourceMap[s.ID]; ok { + sourceMap[s.ID] = mergeSourceConfigs(existing, s) + } else { + sourceMap[s.ID] = s + } + } + + // Convert back to slice + for _, s := range sourceMap { + result.Sources = append(result.Sources, s) + } + + // Merge named queries + maps.Copy(result.NamedQueries, base.NamedQueries) + for name, filters := range override.NamedQueries { + if existing, ok := result.NamedQueries[name]; ok { + // Merge field filters + maps.Copy(existing, filters) + result.NamedQueries[name] = existing + } else { + result.NamedQueries[name] = filters + } + } + + return result +} + +// mergeSourceConfigs merges two SourceConfigs with field-level merging. +func mergeSourceConfigs(base, override SourceConfig) SourceConfig { + result := base + + result.ID = override.ID + + if override.Name != "" { + result.Name = override.Name + } + + if override.Type != "" { + result.Type = override.Type + } + + if override.Enabled != nil { + result.Enabled = override.Enabled + } + + if override.Labels != nil { + result.Labels = override.Labels + } + + if override.Properties != nil { + result.Properties = override.Properties + result.Origin = override.Origin + } + + if override.IncludedItems != nil { + result.IncludedItems = override.IncludedItems + } + + if override.ExcludedItems != nil { + result.ExcludedItems = override.ExcludedItems + } + + return result +} diff --git a/pkg/catalog/plugin/config_test.go b/pkg/catalog/plugin/config_test.go new file mode 100644 index 0000000000..85cb58dc3b --- /dev/null +++ b/pkg/catalog/plugin/config_test.go @@ -0,0 +1,172 @@ +package plugin + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + yaml := ` +apiVersion: catalog/v1alpha1 +kind: CatalogSources +catalogs: + models: + sources: + - id: "source-1" + name: "Source One" + type: "yaml" + enabled: true + properties: + yamlCatalogPath: "./data/models.yaml" + - id: "source-2" + name: "Source Two" + type: "hf" + properties: + allowedOrganization: "redhat" + labels: + - name: "production" + color: "green" + namedQueries: + large-models: + parameters: + operator: "gt" + value: 7000000000 + datasets: + sources: + - id: "internal-datasets" + type: "yaml" + properties: + yamlCatalogPath: "./data/datasets.yaml" +` + + cfg, err := ParseConfig([]byte(yaml), "/test/path/sources.yaml") + require.NoError(t, err) + + assert.Equal(t, "catalog/v1alpha1", cfg.APIVersion) + assert.Equal(t, "CatalogSources", cfg.Kind) + + // Check models catalog + models, ok := cfg.Catalogs["models"] + assert.True(t, ok) + assert.Equal(t, 2, len(models.Sources)) + assert.Equal(t, "source-1", models.Sources[0].ID) + assert.Equal(t, "Source One", models.Sources[0].Name) + assert.Equal(t, "yaml", models.Sources[0].Type) + assert.True(t, *models.Sources[0].Enabled) + assert.Equal(t, "/test/path/sources.yaml", models.Sources[0].Origin) + + assert.Equal(t, 1, len(models.Labels)) + assert.Equal(t, "production", models.Labels[0]["name"]) + + assert.NotNil(t, models.NamedQueries) + assert.Contains(t, models.NamedQueries, "large-models") + + // Check datasets catalog + datasets, ok := cfg.Catalogs["datasets"] + assert.True(t, ok) + assert.Equal(t, 1, len(datasets.Sources)) +} + +func TestLoadConfig(t *testing.T) { + // Create a temp config file + yaml := ` +apiVersion: catalog/v1alpha1 +kind: CatalogSources +catalogs: + models: + sources: + - id: "test-source" + name: "Test Source" + type: "yaml" +` + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "sources.yaml") + err := os.WriteFile(configPath, []byte(yaml), 0644) + require.NoError(t, err) + + cfg, err := LoadConfig(configPath) + require.NoError(t, err) + + assert.Equal(t, "catalog/v1alpha1", cfg.APIVersion) + assert.Contains(t, cfg.Catalogs, "models") + assert.Equal(t, configPath, cfg.Catalogs["models"].Sources[0].Origin) +} + +func TestMergeConfigs(t *testing.T) { + base := &CatalogSourcesConfig{ + APIVersion: "v1", + Kind: "CatalogSources", + Catalogs: map[string]CatalogSection{ + "models": { + Sources: []SourceConfig{ + {ID: "source-1", Name: "Base Source", Type: "yaml", Enabled: boolPtr(false)}, + }, + }, + }, + } + + override := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{ + "models": { + Sources: []SourceConfig{ + {ID: "source-1", Enabled: boolPtr(true)}, // Enable the source + {ID: "source-2", Name: "New Source", Type: "hf"}, + }, + }, + "datasets": { + Sources: []SourceConfig{ + {ID: "ds-1", Name: "Dataset Source", Type: "yaml"}, + }, + }, + }, + } + + result := MergeConfigs(base, override) + + assert.Equal(t, "v1", result.APIVersion) + + // Models should have merged sources + models := result.Catalogs["models"] + assert.Equal(t, 2, len(models.Sources)) + + // Find source-1 and verify it was merged + var source1 *SourceConfig + for i := range models.Sources { + if models.Sources[i].ID == "source-1" { + source1 = &models.Sources[i] + break + } + } + require.NotNil(t, source1) + assert.Equal(t, "Base Source", source1.Name) // Inherited from base + assert.Equal(t, "yaml", source1.Type) // Inherited from base + assert.True(t, *source1.Enabled) // Overridden + + // Datasets should be added + datasets := result.Catalogs["datasets"] + assert.Equal(t, 1, len(datasets.Sources)) + assert.Equal(t, "ds-1", datasets.Sources[0].ID) +} + +func TestSourceConfigIsEnabled(t *testing.T) { + // Default (nil) should be enabled + s := SourceConfig{} + assert.True(t, s.IsEnabled()) + + // Explicitly enabled + s.Enabled = boolPtr(true) + assert.True(t, s.IsEnabled()) + + // Explicitly disabled + s.Enabled = boolPtr(false) + assert.False(t, s.IsEnabled()) +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/pkg/catalog/plugin/plugin.go b/pkg/catalog/plugin/plugin.go new file mode 100644 index 0000000000..52f3128bc8 --- /dev/null +++ b/pkg/catalog/plugin/plugin.go @@ -0,0 +1,119 @@ +// Package plugin provides a plugin-based architecture for catalog services. +// Catalog types (models, datasets, etc.) register as plugins via init() and +// are mounted under a unified HTTP server. +package plugin + +import ( + "context" + "flag" + "log/slog" + + "github.com/go-chi/chi/v5" + "gorm.io/gorm" +) + +// CatalogPlugin defines the interface that all catalog plugins must implement. +// Plugins register themselves via init() using the Register function. +type CatalogPlugin interface { + // Identity returns the plugin name (e.g., "models", "datasets"). + // This name is used for routing and configuration lookup. + Name() string + + // Version returns the API version (e.g., "v1alpha1"). + Version() string + + // Description returns a human-readable description of the plugin. + Description() string + + // Init initializes the plugin with its configuration. + // Called once during server startup before Start. + Init(ctx context.Context, cfg Config) error + + // Start begins background operations (hot-reload, watchers, etc.). + // Called after Init and after database migrations. + Start(ctx context.Context) error + + // Stop gracefully shuts down the plugin. + // Called during server shutdown. + Stop(ctx context.Context) error + + // Healthy returns true if the plugin is functioning correctly. + // Used for health check endpoints. + Healthy() bool + + // RegisterRoutes mounts the plugin's HTTP routes on the provided router. + // The router is already scoped to the plugin's base path. + RegisterRoutes(router chi.Router) error + + // Migrations returns database migrations for this plugin. + // Migrations are applied in order during server initialization. + Migrations() []Migration +} + +// BasePathProvider is an optional interface that plugins can implement +// to specify their own API base path. If not implemented, the server +// computes it as /api/{name}_catalog/{version}. +type BasePathProvider interface { + BasePath() string +} + +// SourceKeyProvider is an optional interface that plugins can implement +// to specify which key in the sources.yaml "catalogs" map they respond to. +// If not implemented, the plugin name is used as the config key. +// This allows the plugin name and config key to differ (e.g., plugin "model" +// can read from the "models" config section). +type SourceKeyProvider interface { + SourceKey() string +} + +// CatalogLoader defines the interface for data loading strategies. +// The core Loader[E, A] implements this by default. Plugins can +// register multiple loaders (e.g., core + custom). +type CatalogLoader interface { + // Start begins loading data and sets up any watchers/background operations. + Start(ctx context.Context) error + + // Stop gracefully shuts down the loader. + Stop(ctx context.Context) error +} + +// FlagProvider is an optional interface that plugins can implement +// to register custom CLI flags before flag parsing. +type FlagProvider interface { + // RegisterFlags registers custom CLI flags for this plugin. + // Called before flag.Parse() during server startup. + RegisterFlags(fs *flag.FlagSet) +} + +// Migration represents a database migration for a plugin. +type Migration struct { + // Version is a unique identifier for this migration (e.g., "001", "20240101_initial"). + Version string + + // Description provides a human-readable description of what this migration does. + Description string + + // Up applies the migration. + Up func(db *gorm.DB) error + + // Down reverts the migration. + Down func(db *gorm.DB) error +} + +// Config is passed to each plugin during Init. +type Config struct { + // Section contains the plugin-specific configuration from sources.yaml. + Section CatalogSection + + // DB is the shared database connection. + DB *gorm.DB + + // Logger is a namespaced logger for this plugin. + Logger *slog.Logger + + // BasePath is the API base path for this plugin (e.g., "/api/models_catalog/v1alpha1"). + BasePath string + + // ConfigPaths are the paths to all sources.yaml files being used. + ConfigPaths []string +} diff --git a/pkg/catalog/plugin/registry.go b/pkg/catalog/plugin/registry.go new file mode 100644 index 0000000000..54da8c329d --- /dev/null +++ b/pkg/catalog/plugin/registry.go @@ -0,0 +1,93 @@ +package plugin + +import ( + "flag" + "fmt" + "sync" +) + +// globalRegistry is the singleton registry for all catalog plugins. +var globalRegistry = &Registry{ + plugins: make(map[string]CatalogPlugin), +} + +// Registry holds all registered catalog plugins. +type Registry struct { + mu sync.RWMutex + plugins map[string]CatalogPlugin + order []string // preserves registration order +} + +// Register adds a plugin to the global registry. +// This is typically called from a plugin's init() function. +// Panics if a plugin with the same name is already registered. +func Register(p CatalogPlugin) { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + name := p.Name() + if _, exists := globalRegistry.plugins[name]; exists { + panic(fmt.Sprintf("plugin %q already registered", name)) + } + + globalRegistry.plugins[name] = p + globalRegistry.order = append(globalRegistry.order, name) +} + +// All returns all registered plugins in registration order. +func All() []CatalogPlugin { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + result := make([]CatalogPlugin, 0, len(globalRegistry.order)) + for _, name := range globalRegistry.order { + result = append(result, globalRegistry.plugins[name]) + } + return result +} + +// Get returns a plugin by name, or nil if not found. +func Get(name string) (CatalogPlugin, bool) { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + p, ok := globalRegistry.plugins[name] + return p, ok +} + +// Names returns all registered plugin names in registration order. +func Names() []string { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + result := make([]string, len(globalRegistry.order)) + copy(result, globalRegistry.order) + return result +} + +// Count returns the number of registered plugins. +func Count() int { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + return len(globalRegistry.plugins) +} + +// RegisterAllFlags iterates all registered plugins and calls RegisterFlags +// on those implementing FlagProvider. Call this before flag.Parse(). +func RegisterAllFlags(fs *flag.FlagSet) { + for _, p := range All() { + if fp, ok := p.(FlagProvider); ok { + fp.RegisterFlags(fs) + } + } +} + +// Reset clears the global registry. For testing only. +func Reset() { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + globalRegistry.plugins = make(map[string]CatalogPlugin) + globalRegistry.order = nil +} diff --git a/pkg/catalog/plugin/registry_test.go b/pkg/catalog/plugin/registry_test.go new file mode 100644 index 0000000000..6d04623fe2 --- /dev/null +++ b/pkg/catalog/plugin/registry_test.go @@ -0,0 +1,110 @@ +package plugin + +import ( + "context" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" +) + +// mockPlugin is a minimal CatalogPlugin implementation for testing. +type mockPlugin struct { + name string + version string + description string + healthy bool +} + +func (p *mockPlugin) Name() string { return p.name } +func (p *mockPlugin) Version() string { return p.version } +func (p *mockPlugin) Description() string { return p.description } +func (p *mockPlugin) Init(ctx context.Context, cfg Config) error { return nil } +func (p *mockPlugin) Start(ctx context.Context) error { return nil } +func (p *mockPlugin) Stop(ctx context.Context) error { return nil } +func (p *mockPlugin) Healthy() bool { return p.healthy } +func (p *mockPlugin) RegisterRoutes(router chi.Router) error { return nil } +func (p *mockPlugin) Migrations() []Migration { return nil } + +func TestRegister(t *testing.T) { + Reset() // Clear any existing plugins + + plugin := &mockPlugin{ + name: "test-plugin", + version: "v1", + description: "Test plugin", + healthy: true, + } + + Register(plugin) + + // Verify plugin was registered + names := Names() + assert.Equal(t, 1, len(names)) + assert.Equal(t, "test-plugin", names[0]) + + // Verify Get works + p, ok := Get("test-plugin") + assert.True(t, ok) + assert.Equal(t, "test-plugin", p.Name()) + assert.Equal(t, "v1", p.Version()) + + // Verify non-existent plugin + _, ok = Get("non-existent") + assert.False(t, ok) + + Reset() +} + +func TestRegisterDuplicate(t *testing.T) { + Reset() + + plugin1 := &mockPlugin{name: "duplicate"} + plugin2 := &mockPlugin{name: "duplicate"} + + Register(plugin1) + + // Should panic on duplicate registration + assert.Panics(t, func() { + Register(plugin2) + }) + + Reset() +} + +func TestAll(t *testing.T) { + Reset() + + plugin1 := &mockPlugin{name: "plugin-a"} + plugin2 := &mockPlugin{name: "plugin-b"} + plugin3 := &mockPlugin{name: "plugin-c"} + + Register(plugin1) + Register(plugin2) + Register(plugin3) + + all := All() + assert.Equal(t, 3, len(all)) + + // Verify registration order is preserved + assert.Equal(t, "plugin-a", all[0].Name()) + assert.Equal(t, "plugin-b", all[1].Name()) + assert.Equal(t, "plugin-c", all[2].Name()) + + Reset() +} + +func TestCount(t *testing.T) { + Reset() + + assert.Equal(t, 0, Count()) + + Register(&mockPlugin{name: "p1"}) + assert.Equal(t, 1, Count()) + + Register(&mockPlugin{name: "p2"}) + assert.Equal(t, 2, Count()) + + Reset() + assert.Equal(t, 0, Count()) +} diff --git a/pkg/catalog/plugin/server.go b/pkg/catalog/plugin/server.go new file mode 100644 index 0000000000..78e8043d95 --- /dev/null +++ b/pkg/catalog/plugin/server.go @@ -0,0 +1,274 @@ +package plugin + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "sync" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" + "gorm.io/gorm" +) + +// Server manages the lifecycle of catalog plugins and provides a unified HTTP server. +type Server struct { + router chi.Router + db *gorm.DB + config *CatalogSourcesConfig + configPaths []string + logger *slog.Logger + plugins []CatalogPlugin + mu sync.RWMutex +} + +// NewServer creates a new plugin server. +func NewServer(cfg *CatalogSourcesConfig, configPaths []string, db *gorm.DB, logger *slog.Logger) *Server { + if logger == nil { + logger = slog.Default() + } + + return &Server{ + db: db, + config: cfg, + configPaths: configPaths, + logger: logger, + plugins: make([]CatalogPlugin, 0), + } +} + +// Init initializes all registered plugins that have configuration. +func (s *Server) Init(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + for _, p := range All() { + // Use SourceKey if the plugin provides one, otherwise fall back to plugin name + configKey := p.Name() + if skp, ok := p.(SourceKeyProvider); ok { + configKey = skp.SourceKey() + } + + section, ok := s.config.Catalogs[configKey] + if !ok { + s.logger.Info("plugin has no sources configured", "plugin", p.Name(), "configKey", configKey) + section = CatalogSection{} + } + + // Use plugin's BasePath if it implements BasePathProvider, otherwise compute it. + var basePath string + if bp, ok := p.(BasePathProvider); ok { + basePath = bp.BasePath() + } else { + basePath = fmt.Sprintf("/api/%s_catalog/%s", p.Name(), p.Version()) + } + + // Only pass config paths to plugins that have sources configured. + // Unconfigured plugins should not try to parse the server config file. + var configPaths []string + if ok { + configPaths = s.configPaths + } + + pluginCfg := Config{ + Section: section, + DB: s.db, + Logger: s.logger.With("plugin", p.Name()), + BasePath: basePath, + ConfigPaths: configPaths, + } + + s.logger.Info("initializing plugin", "plugin", p.Name(), "version", p.Version(), "basePath", basePath) + + if err := p.Init(ctx, pluginCfg); err != nil { + return fmt.Errorf("plugin %s init failed: %w", p.Name(), err) + } + + s.plugins = append(s.plugins, p) + } + + return nil +} + +// MountRoutes creates the HTTP router with all plugin routes mounted. +func (s *Server) MountRoutes() chi.Router { + s.mu.RLock() + defer s.mu.RUnlock() + + s.router = chi.NewRouter() + + // Add common middleware + s.router.Use(middleware.RequestID) + s.router.Use(middleware.RealIP) + s.router.Use(middleware.Recoverer) + s.router.Use(cors.Handler(cors.Options{ + AllowedOrigins: []string{"https://*", "http://*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "X-PINGOTHER"}, + ExposedHeaders: []string{"Link"}, + AllowCredentials: false, + MaxAge: 300, + })) + + // Mount plugin routes + for _, p := range s.plugins { + var basePath string + if bp, ok := p.(BasePathProvider); ok { + basePath = bp.BasePath() + } else { + basePath = fmt.Sprintf("/api/%s_catalog/%s", p.Name(), p.Version()) + } + s.logger.Info("mounting plugin routes", "plugin", p.Name(), "basePath", basePath) + + s.router.Route(basePath, func(r chi.Router) { + if err := p.RegisterRoutes(r); err != nil { + s.logger.Error("failed to register routes", "plugin", p.Name(), "error", err) + } + }) + } + + // Add health endpoint + s.router.Get("/healthz", s.healthHandler) + s.router.Get("/readyz", s.readyHandler) + + // Add plugin info endpoint + s.router.Get("/api/plugins", s.pluginsHandler) + + return s.router +} + +// Start starts all plugins' background operations. +func (s *Server) Start(ctx context.Context) error { + s.mu.RLock() + defer s.mu.RUnlock() + + for _, p := range s.plugins { + s.logger.Info("starting plugin", "plugin", p.Name()) + if err := p.Start(ctx); err != nil { + return fmt.Errorf("plugin %s start failed: %w", p.Name(), err) + } + } + + return nil +} + +// Stop gracefully shuts down all plugins. +func (s *Server) Stop(ctx context.Context) error { + s.mu.RLock() + defer s.mu.RUnlock() + + var lastErr error + for _, p := range s.plugins { + s.logger.Info("stopping plugin", "plugin", p.Name()) + if err := p.Stop(ctx); err != nil { + s.logger.Error("plugin stop failed", "plugin", p.Name(), "error", err) + lastErr = err + } + } + + return lastErr +} + +// Router returns the underlying chi.Router. +func (s *Server) Router() chi.Router { + return s.router +} + +// Plugins returns the list of initialized plugins. +func (s *Server) Plugins() []CatalogPlugin { + s.mu.RLock() + defer s.mu.RUnlock() + + result := make([]CatalogPlugin, len(s.plugins)) + copy(result, s.plugins) + return result +} + +// healthHandler returns the health status of the server. +func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + response := map[string]string{ + "status": "ok", + } + + _ = json.NewEncoder(w).Encode(response) +} + +// readyHandler checks if all plugins are healthy. +func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { + s.mu.RLock() + defer s.mu.RUnlock() + + allHealthy := true + pluginStatus := make(map[string]bool) + + for _, p := range s.plugins { + healthy := p.Healthy() + pluginStatus[p.Name()] = healthy + if !healthy { + allHealthy = false + } + } + + w.Header().Set("Content-Type", "application/json") + + response := map[string]any{ + "plugins": pluginStatus, + } + + if allHealthy { + response["status"] = "ready" + w.WriteHeader(http.StatusOK) + } else { + response["status"] = "not_ready" + w.WriteHeader(http.StatusServiceUnavailable) + } + + _ = json.NewEncoder(w).Encode(response) +} + +// pluginsHandler returns information about registered plugins. +func (s *Server) pluginsHandler(w http.ResponseWriter, r *http.Request) { + s.mu.RLock() + defer s.mu.RUnlock() + + type pluginInfo struct { + Name string `json:"name"` + Version string `json:"version"` + Description string `json:"description"` + BasePath string `json:"basePath"` + Healthy bool `json:"healthy"` + } + + plugins := make([]pluginInfo, 0, len(s.plugins)) + for _, p := range s.plugins { + var basePath string + if bp, ok := p.(BasePathProvider); ok { + basePath = bp.BasePath() + } else { + basePath = fmt.Sprintf("/api/%s_catalog/%s", p.Name(), p.Version()) + } + plugins = append(plugins, pluginInfo{ + Name: p.Name(), + Version: p.Version(), + Description: p.Description(), + BasePath: basePath, + Healthy: p.Healthy(), + }) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + response := map[string]any{ + "plugins": plugins, + "count": len(plugins), + } + + _ = json.NewEncoder(w).Encode(response) +} diff --git a/pkg/catalog/plugin/server_test.go b/pkg/catalog/plugin/server_test.go new file mode 100644 index 0000000000..c2400ab52f --- /dev/null +++ b/pkg/catalog/plugin/server_test.go @@ -0,0 +1,234 @@ +package plugin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testPlugin is a CatalogPlugin implementation for testing. +type testPlugin struct { + name string + version string + description string + healthy bool + initCalled bool + startCalled bool + stopCalled bool +} + +func (p *testPlugin) Name() string { return p.name } +func (p *testPlugin) Version() string { return p.version } +func (p *testPlugin) Description() string { return p.description } + +func (p *testPlugin) Init(ctx context.Context, cfg Config) error { + p.initCalled = true + return nil +} + +func (p *testPlugin) Start(ctx context.Context) error { + p.startCalled = true + return nil +} + +func (p *testPlugin) Stop(ctx context.Context) error { + p.stopCalled = true + return nil +} + +func (p *testPlugin) Healthy() bool { return p.healthy } + +func (p *testPlugin) RegisterRoutes(router chi.Router) error { + router.Get("/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) + }) + return nil +} + +func (p *testPlugin) Migrations() []Migration { return nil } + +func TestServerInit(t *testing.T) { + Reset() + + plugin := &testPlugin{ + name: "test", + version: "v1", + healthy: true, + } + Register(plugin) + + cfg := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{ + "test": { + Sources: []SourceConfig{}, + }, + }, + } + + server := NewServer(cfg, []string{}, nil, nil) + err := server.Init(context.Background()) + require.NoError(t, err) + + assert.True(t, plugin.initCalled) + assert.Equal(t, 1, len(server.Plugins())) + + Reset() +} + +func TestServerInitializesUnconfiguredPlugins(t *testing.T) { + Reset() + + plugin := &testPlugin{ + name: "test", + version: "v1", + healthy: true, + } + Register(plugin) + + // Empty config - no catalogs configured + cfg := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{}, + } + + server := NewServer(cfg, []string{}, nil, nil) + err := server.Init(context.Background()) + require.NoError(t, err) + + // Plugin should still be initialized even without config + assert.True(t, plugin.initCalled) + assert.Equal(t, 1, len(server.Plugins())) + + Reset() +} + +func TestServerHealthEndpoint(t *testing.T) { + Reset() + + cfg := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{}, + } + + server := NewServer(cfg, []string{}, nil, nil) + err := server.Init(context.Background()) + require.NoError(t, err) + + router := server.MountRoutes() + + // Test /healthz + req := httptest.NewRequest("GET", "/healthz", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "ok") + + Reset() +} + +func TestServerReadyEndpoint(t *testing.T) { + Reset() + + healthyPlugin := &testPlugin{ + name: "healthy", + version: "v1", + healthy: true, + } + Register(healthyPlugin) + + cfg := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{ + "healthy": {}, + }, + } + + server := NewServer(cfg, []string{}, nil, nil) + err := server.Init(context.Background()) + require.NoError(t, err) + + router := server.MountRoutes() + + // Test /readyz when all plugins are healthy + req := httptest.NewRequest("GET", "/readyz", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "ready") + + Reset() +} + +func TestServerPluginsEndpoint(t *testing.T) { + Reset() + + plugin := &testPlugin{ + name: "test", + version: "v1", + description: "Test plugin", + healthy: true, + } + Register(plugin) + + cfg := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{ + "test": {}, + }, + } + + server := NewServer(cfg, []string{}, nil, nil) + err := server.Init(context.Background()) + require.NoError(t, err) + + router := server.MountRoutes() + + // Test /api/plugins + req := httptest.NewRequest("GET", "/api/plugins", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.Contains(t, body, "test") + assert.Contains(t, body, "v1") + + Reset() +} + +func TestServerStartStop(t *testing.T) { + Reset() + + plugin := &testPlugin{ + name: "test", + version: "v1", + healthy: true, + } + Register(plugin) + + cfg := &CatalogSourcesConfig{ + Catalogs: map[string]CatalogSection{ + "test": {}, + }, + } + + server := NewServer(cfg, []string{}, nil, nil) + err := server.Init(context.Background()) + require.NoError(t, err) + + // Start + err = server.Start(context.Background()) + require.NoError(t, err) + assert.True(t, plugin.startCalled) + + // Stop + err = server.Stop(context.Background()) + require.NoError(t, err) + assert.True(t, plugin.stopCalled) + + Reset() +} diff --git a/pkg/catalog/provider.go b/pkg/catalog/provider.go new file mode 100644 index 0000000000..e487a5d8aa --- /dev/null +++ b/pkg/catalog/provider.go @@ -0,0 +1,101 @@ +// Package catalog provides reusable abstractions for building catalog-style +// read-only aggregation services. It extracts common patterns from the Model Catalog +// to enable rapid creation of new catalog components (e.g., MCP Catalog). +package catalog + +import ( + "context" + "fmt" + "sync" +) + +// Record represents a single entity record with its associated artifacts. +// This is generic to support different entity types. +type Record[E any, A any] struct { + // Entity is the main entity being loaded (e.g., a catalog model). + // A nil Entity signals that a batch of records has been fully sent. + Entity E + + // Artifacts are the associated artifacts for this entity. + Artifacts []A +} + +// ProviderFunc emits records in a channel and is expected to spawn a goroutine +// and return immediately. The returned channel must close when the goroutine ends. +// The goroutine should end when the context is canceled, but may end sooner. +// +// The function may emit a record with a nil Entity to indicate that the +// complete set of entities has been sent (batch completion marker). +// +// Parameters: +// - ctx: Context for cancellation +// - source: The source configuration for this provider +// - reldir: The directory to resolve relative paths from (typically the config file's directory) +type ProviderFunc[E any, A any] func(ctx context.Context, source *Source, reldir string) (<-chan Record[E, A], error) + +// ProviderRegistry manages provider type registrations. +// It allows registering provider functions by name (e.g., "yaml", "http") +// and retrieving them for use when loading from sources. +type ProviderRegistry[E any, A any] struct { + mu sync.RWMutex + providers map[string]ProviderFunc[E, A] +} + +// NewProviderRegistry creates a new empty provider registry. +func NewProviderRegistry[E any, A any]() *ProviderRegistry[E, A] { + return &ProviderRegistry[E, A]{ + providers: make(map[string]ProviderFunc[E, A]), + } +} + +// Register adds a provider function with the given name. +// Returns an error if a provider with that name already exists. +func (r *ProviderRegistry[E, A]) Register(name string, fn ProviderFunc[E, A]) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.providers[name]; exists { + return fmt.Errorf("provider type %q already exists", name) + } + r.providers[name] = fn + return nil +} + +// MustRegister is like Register but panics on error. +// This is useful for init() functions. +func (r *ProviderRegistry[E, A]) MustRegister(name string, fn ProviderFunc[E, A]) { + if err := r.Register(name, fn); err != nil { + panic(err) + } +} + +// Get retrieves a provider function by name. +// Returns the function and true if found, or nil and false if not. +func (r *ProviderRegistry[E, A]) Get(name string) (ProviderFunc[E, A], bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + fn, ok := r.providers[name] + return fn, ok +} + +// Names returns a list of all registered provider names. +func (r *ProviderRegistry[E, A]) Names() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.providers)) + for name := range r.providers { + names = append(names, name) + } + return names +} + +// Has returns true if a provider with the given name is registered. +func (r *ProviderRegistry[E, A]) Has(name string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, ok := r.providers[name] + return ok +} diff --git a/pkg/catalog/provider_test.go b/pkg/catalog/provider_test.go new file mode 100644 index 0000000000..404a2ff373 --- /dev/null +++ b/pkg/catalog/provider_test.go @@ -0,0 +1,87 @@ +package catalog + +import ( + "context" + "testing" +) + +func TestProviderRegistry(t *testing.T) { + type Entity struct { + Name string + } + type Artifact struct { + URI string + } + + registry := NewProviderRegistry[*Entity, *Artifact]() + + // Test Register + providerFunc := func(ctx context.Context, source *Source, reldir string) (<-chan Record[*Entity, *Artifact], error) { + ch := make(chan Record[*Entity, *Artifact]) + close(ch) + return ch, nil + } + + err := registry.Register("test", providerFunc) + if err != nil { + t.Errorf("Register failed: %v", err) + } + + // Test duplicate registration + err = registry.Register("test", providerFunc) + if err == nil { + t.Error("Expected error for duplicate registration, got nil") + } + + // Test Get + fn, ok := registry.Get("test") + if !ok { + t.Error("Expected to find registered provider") + } + if fn == nil { + t.Error("Expected non-nil provider function") + } + + // Test Get non-existent + _, ok = registry.Get("nonexistent") + if ok { + t.Error("Expected not to find non-existent provider") + } + + // Test Has + if !registry.Has("test") { + t.Error("Expected Has to return true for registered provider") + } + if registry.Has("nonexistent") { + t.Error("Expected Has to return false for non-existent provider") + } + + // Test Names + names := registry.Names() + if len(names) != 1 || names[0] != "test" { + t.Errorf("Expected names to be [test], got %v", names) + } +} + +func TestMustRegister(t *testing.T) { + type Entity struct{} + type Artifact struct{} + + registry := NewProviderRegistry[*Entity, *Artifact]() + + // Should not panic + registry.MustRegister("test", func(ctx context.Context, source *Source, reldir string) (<-chan Record[*Entity, *Artifact], error) { + return nil, nil + }) + + // Should panic on duplicate + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic on duplicate MustRegister") + } + }() + + registry.MustRegister("test", func(ctx context.Context, source *Source, reldir string) (<-chan Record[*Entity, *Artifact], error) { + return nil, nil + }) +} diff --git a/pkg/catalog/providers/http/provider.go b/pkg/catalog/providers/http/provider.go new file mode 100644 index 0000000000..5710783f85 --- /dev/null +++ b/pkg/catalog/providers/http/provider.go @@ -0,0 +1,389 @@ +// Package http provides a base HTTP provider for catalog data. +// It handles HTTP fetching, polling, authentication, and rate limiting, +// while delegating entity-specific conversion to user-provided functions. +package http + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/kubeflow/model-registry/pkg/catalog" +) + +// Config configures an HTTP provider. +type Config[E any, A any] struct { + // BaseURLKey is the property key in source.Properties for the base URL. + // Defaults to "url" if empty. + BaseURLKey string + + // DefaultBaseURL is used if no URL is provided in source.Properties. + DefaultBaseURL string + + // SyncIntervalKey is the property key for sync interval. + // Defaults to "syncInterval" if empty. + SyncIntervalKey string + + // DefaultSyncInterval is the default polling interval. + // Defaults to 24 hours if zero. + DefaultSyncInterval time.Duration + + // HTTPClient is the HTTP client to use. + // Defaults to a client with 30 second timeout if nil. + HTTPClient *http.Client + + // FetchRecords fetches records from the HTTP API. + // This is the main function that catalog-specific providers must implement. + FetchRecords func(ctx context.Context, client *http.Client, baseURL string, source *catalog.Source) ([]catalog.Record[E, A], error) + + // GetAuthHeader returns the authorization header name and value. + // If nil or returns empty strings, no auth header is added. + GetAuthHeader func(source *catalog.Source) (name, value string) + + // ValidateCredentials validates the API credentials before fetching. + // If nil, no validation is performed. + ValidateCredentials func(ctx context.Context, client *http.Client, baseURL string, source *catalog.Source) error + + // Logger for logging messages (optional). + Logger Logger + + // UserAgent is the User-Agent header value. + // Defaults to "model-registry-catalog" if empty. + UserAgent string +} + +// Logger is an interface for logging. +type Logger interface { + Infof(format string, args ...any) + Errorf(format string, args ...any) + Warningf(format string, args ...any) +} + +type noopLogger struct{} + +func (noopLogger) Infof(format string, args ...any) {} +func (noopLogger) Errorf(format string, args ...any) {} +func (noopLogger) Warningf(format string, args ...any) {} + +// Provider is an HTTP-based data provider with periodic polling. +type Provider[E any, A any] struct { + config Config[E, A] + client *http.Client + baseURL string + syncInterval time.Duration + source *catalog.Source + filter *catalog.ItemFilter + logger Logger +} + +// NewProvider creates a new HTTP provider with the given configuration. +func NewProvider[E any, A any](config Config[E, A], source *catalog.Source) (*Provider[E, A], error) { + // Parse base URL + baseURLKey := config.BaseURLKey + if baseURLKey == "" { + baseURLKey = "url" + } + + baseURL := config.DefaultBaseURL + if url, ok := source.Properties[baseURLKey].(string); ok && url != "" { + baseURL = url + } + if baseURL == "" { + return nil, fmt.Errorf("missing base URL (property %s or default)", baseURLKey) + } + + // Parse sync interval + syncIntervalKey := config.SyncIntervalKey + if syncIntervalKey == "" { + syncIntervalKey = "syncInterval" + } + + syncInterval := config.DefaultSyncInterval + if syncInterval == 0 { + syncInterval = 24 * time.Hour + } + if intervalStr, ok := source.Properties[syncIntervalKey].(string); ok && intervalStr != "" { + if parsed, err := time.ParseDuration(intervalStr); err == nil { + syncInterval = parsed + } + } + + // HTTP client + client := config.HTTPClient + if client == nil { + client = &http.Client{Timeout: 30 * time.Second} + } + + // Build filter from source configuration + filter, err := catalog.NewItemFilterFromSource(source, nil, nil) + if err != nil { + return nil, err + } + + logger := config.Logger + if logger == nil { + logger = noopLogger{} + } + + return &Provider[E, A]{ + config: config, + client: client, + baseURL: baseURL, + syncInterval: syncInterval, + source: source, + filter: filter, + logger: logger, + }, nil +} + +// Records starts fetching data and returns a channel of records. +// The channel is closed when the context is canceled. +// The provider polls periodically based on syncInterval. +func (p *Provider[E, A]) Records(ctx context.Context) (<-chan catalog.Record[E, A], error) { + // Validate credentials if configured + if p.config.ValidateCredentials != nil { + if err := p.config.ValidateCredentials(ctx, p.client, p.baseURL, p.source); err != nil { + return nil, fmt.Errorf("credential validation failed: %w", err) + } + } + + // Fetch initial data to catch errors early + records, err := p.fetch(ctx) + if err != nil { + return nil, err + } + + ch := make(chan catalog.Record[E, A]) + go func() { + defer close(ch) + + // Send initial records + p.emit(ctx, records, ch) + + // Set up periodic polling + ticker := time.NewTicker(p.syncInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.logger.Infof("Periodic sync: fetching records for source %s", p.source.ID) + records, err := p.fetch(ctx) + if err != nil { + p.logger.Errorf("Failed to fetch records: %v", err) + continue + } + p.emit(ctx, records, ch) + } + } + }() + + return ch, nil +} + +func (p *Provider[E, A]) fetch(ctx context.Context) ([]catalog.Record[E, A], error) { + if p.config.FetchRecords == nil { + return nil, fmt.Errorf("FetchRecords function not configured") + } + return p.config.FetchRecords(ctx, p.client, p.baseURL, p.source) +} + +func (p *Provider[E, A]) emit(ctx context.Context, records []catalog.Record[E, A], out chan<- catalog.Record[E, A]) { + done := ctx.Done() + for _, record := range records { + select { + case out <- record: + case <-done: + return + } + } + + // Send an empty record to indicate batch completion + var zero catalog.Record[E, A] + select { + case out <- zero: + case <-done: + } +} + +// NewProviderFunc creates a ProviderFunc that can be registered with a ProviderRegistry. +func NewProviderFunc[E any, A any](config Config[E, A]) catalog.ProviderFunc[E, A] { + return func(ctx context.Context, source *catalog.Source, reldir string) (<-chan catalog.Record[E, A], error) { + provider, err := NewProvider(config, source) + if err != nil { + return nil, err + } + return provider.Records(ctx) + } +} + +// Request represents an HTTP request configuration. +type Request struct { + Method string + URL string + Headers map[string]string + Body io.Reader +} + +// DoRequest performs an HTTP request with standard error handling. +func DoRequest[T any](ctx context.Context, client *http.Client, req Request) (*T, error) { + method := req.Method + if method == "" { + method = "GET" + } + + httpReq, err := http.NewRequestWithContext(ctx, method, req.URL, req.Body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + for key, value := range req.Headers { + httpReq.Header.Set(key, value) + } + + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var result T + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &result, nil +} + +// DoRequestRaw performs an HTTP request and returns the raw response body. +func DoRequestRaw(ctx context.Context, client *http.Client, req Request) ([]byte, error) { + method := req.Method + if method == "" { + method = "GET" + } + + httpReq, err := http.NewRequestWithContext(ctx, method, req.URL, req.Body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + for key, value := range req.Headers { + httpReq.Header.Set(key, value) + } + + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes)) + } + + return io.ReadAll(resp.Body) +} + +// PaginatedFetcher handles paginated API responses. +type PaginatedFetcher[T any] struct { + Client *http.Client + Headers map[string]string + BuildURL func(cursor string) string + ParseNext func(response *http.Response, items []T) (cursor string, hasMore bool) + MaxItems int +} + +// FetchAll fetches all items from a paginated API. +func (f *PaginatedFetcher[T]) FetchAll(ctx context.Context) ([]T, error) { + var allItems []T + cursor := "" + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if f.MaxItems > 0 && len(allItems) >= f.MaxItems { + break + } + + url := f.BuildURL(cursor) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + for key, value := range f.Headers { + req.Header.Set(key, value) + } + + resp, err := f.Client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var items []T + if err := json.NewDecoder(resp.Body).Decode(&items); err != nil { + _ = resp.Body.Close() + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + nextCursor, hasMore := f.ParseNext(resp, items) + _ = resp.Body.Close() + + allItems = append(allItems, items...) + + if !hasMore || nextCursor == "" { + break + } + cursor = nextCursor + } + + return allItems, nil +} + +// RateLimiter provides simple rate limiting for API calls. +type RateLimiter struct { + interval time.Duration + lastCall time.Time +} + +// NewRateLimiter creates a rate limiter with the given minimum interval between calls. +func NewRateLimiter(interval time.Duration) *RateLimiter { + return &RateLimiter{interval: interval} +} + +// Wait waits until the next API call is allowed. +func (r *RateLimiter) Wait(ctx context.Context) error { + now := time.Now() + elapsed := now.Sub(r.lastCall) + if elapsed < r.interval { + select { + case <-time.After(r.interval - elapsed): + case <-ctx.Done(): + return ctx.Err() + } + } + r.lastCall = time.Now() + return nil +} diff --git a/pkg/catalog/providers/yaml/provider.go b/pkg/catalog/providers/yaml/provider.go new file mode 100644 index 0000000000..144829dbb6 --- /dev/null +++ b/pkg/catalog/providers/yaml/provider.go @@ -0,0 +1,231 @@ +// Package yaml provides a base YAML file provider for catalog data. +// It handles file reading, parsing, and hot-reloading, while delegating +// entity-specific conversion to user-provided functions. +package yaml + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/kubeflow/model-registry/pkg/catalog" + k8syaml "k8s.io/apimachinery/pkg/util/yaml" +) + +// Config configures a YAML provider. +type Config[E any, A any] struct { + // PathKey is the property key in source.Properties that contains the YAML file path. + // Defaults to "yamlCatalogPath" if empty. + PathKey string + + // Parse parses raw YAML bytes into a slice of entity records. + Parse func(data []byte) ([]catalog.Record[E, A], error) + + // Filter optionally filters records before emitting them. + // Return true to include the record, false to exclude it. + // If nil, all records are included. + Filter func(record catalog.Record[E, A]) bool + + // Logger for logging messages (optional). + Logger Logger + + // WatchInterval is the interval for checking file changes. + // Defaults to 5 seconds if zero. + WatchInterval time.Duration +} + +// Logger is an interface for logging. +type Logger interface { + Infof(format string, args ...any) + Errorf(format string, args ...any) +} + +type noopLogger struct{} + +func (noopLogger) Infof(format string, args ...any) {} +func (noopLogger) Errorf(format string, args ...any) {} + +// Provider is a YAML file-based data provider. +type Provider[E any, A any] struct { + config Config[E, A] + path string + filter *catalog.ItemFilter + logger Logger +} + +// NewProvider creates a new YAML provider with the given configuration. +// It reads the file path from source.Properties using the configured PathKey. +func NewProvider[E any, A any](config Config[E, A], source *catalog.Source, reldir string) (*Provider[E, A], error) { + pathKey := config.PathKey + if pathKey == "" { + pathKey = "yamlCatalogPath" + } + + path, ok := source.Properties[pathKey].(string) + if !ok || path == "" { + return nil, fmt.Errorf("missing %s string property", pathKey) + } + + // Resolve relative paths + if !filepath.IsAbs(path) { + path = filepath.Join(reldir, path) + } + + // Build filter from source configuration + filter, err := catalog.NewItemFilterFromSource(source, nil, nil) + if err != nil { + return nil, err + } + + logger := config.Logger + if logger == nil { + logger = noopLogger{} + } + + return &Provider[E, A]{ + config: config, + path: path, + filter: filter, + logger: logger, + }, nil +} + +// Records starts reading the YAML file and returns a channel of records. +// The channel is closed when the context is canceled. +// The provider watches for file changes and re-emits records when the file changes. +func (p *Provider[E, A]) Records(ctx context.Context) (<-chan catalog.Record[E, A], error) { + // Read initial data to catch errors early + records, err := p.read() + if err != nil { + return nil, err + } + + ch := make(chan catalog.Record[E, A]) + go func() { + defer close(ch) + + // Send initial records + p.emit(ctx, records, ch) + + // Watch for changes + p.watchAndReload(ctx, ch) + }() + + return ch, nil +} + +func (p *Provider[E, A]) read() ([]catalog.Record[E, A], error) { + data, err := os.ReadFile(p.path) + if err != nil { + return nil, fmt.Errorf("failed to read YAML file %s: %w", p.path, err) + } + + records, err := p.config.Parse(data) + if err != nil { + return nil, fmt.Errorf("failed to parse YAML file %s: %w", p.path, err) + } + + return records, nil +} + +func (p *Provider[E, A]) emit(ctx context.Context, records []catalog.Record[E, A], out chan<- catalog.Record[E, A]) { + done := ctx.Done() + for _, record := range records { + // Apply custom filter if provided + if p.config.Filter != nil && !p.config.Filter(record) { + continue + } + + select { + case out <- record: + case <-done: + return + } + } + + // Send an empty record to indicate batch completion + var zero catalog.Record[E, A] + select { + case out <- zero: + case <-done: + } +} + +func (p *Provider[E, A]) watchAndReload(ctx context.Context, ch chan<- catalog.Record[E, A]) { + interval := p.config.WatchInterval + if interval == 0 { + interval = 5 * time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + var lastModTime time.Time + if info, err := os.Stat(p.path); err == nil { + lastModTime = info.ModTime() + } + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + info, err := os.Stat(p.path) + if err != nil { + continue + } + + if info.ModTime().After(lastModTime) { + lastModTime = info.ModTime() + p.logger.Infof("Reloading YAML file %s", p.path) + + records, err := p.read() + if err != nil { + p.logger.Errorf("Failed to reload YAML file: %v", err) + continue + } + + p.emit(ctx, records, ch) + } + } + } +} + +// NewProviderFunc creates a ProviderFunc that can be registered with a ProviderRegistry. +// This is a convenience function for creating providers using the standard pattern. +func NewProviderFunc[E any, A any](config Config[E, A]) catalog.ProviderFunc[E, A] { + return func(ctx context.Context, source *catalog.Source, reldir string) (<-chan catalog.Record[E, A], error) { + provider, err := NewProvider(config, source, reldir) + if err != nil { + return nil, err + } + return provider.Records(ctx) + } +} + +// SimpleCatalog is a generic structure for simple YAML catalogs. +type SimpleCatalog[E any] struct { + Source string `json:"source" yaml:"source"` + Entities []E `json:"entities" yaml:"entities"` +} + +// ParseSimpleCatalog creates a parser for SimpleCatalog format. +// The toRecord function converts each entity to a Record. +func ParseSimpleCatalog[E any, A any](toRecord func(E) catalog.Record[E, A]) func([]byte) ([]catalog.Record[E, A], error) { + return func(data []byte) ([]catalog.Record[E, A], error) { + var cat SimpleCatalog[E] + if err := k8syaml.UnmarshalStrict(data, &cat); err != nil { + return nil, err + } + + records := make([]catalog.Record[E, A], 0, len(cat.Entities)) + for _, entity := range cat.Entities { + records = append(records, toRecord(entity)) + } + + return records, nil + } +} + diff --git a/pkg/catalog/source.go b/pkg/catalog/source.go new file mode 100644 index 0000000000..c4360599e8 --- /dev/null +++ b/pkg/catalog/source.go @@ -0,0 +1,291 @@ +package catalog + +import ( + "maps" + "slices" + "strings" + "sync" +) + +// Source represents a catalog data source configuration. +// A source defines where to fetch entities from and how to configure the provider. +type Source struct { + // ID is the unique identifier for this source. + ID string + + // Name is the human-readable display name for this source. + Name string + + // Type identifies the provider type to use (e.g., "yaml", "http"). + Type string + + // Enabled indicates whether this source should be loaded. + // Defaults to true if nil. + Enabled *bool + + // Labels are tags used for filtering and categorization. + Labels []string + + // Properties contains provider-specific configuration. + Properties map[string]any + + // IncludedItems are glob patterns for items to include. + // If non-empty, only items matching at least one pattern are included. + IncludedItems []string + + // ExcludedItems are glob patterns for items to exclude. + // Items matching any pattern are excluded, even if they match an include pattern. + ExcludedItems []string + + // Origin is the absolute path of the config file this source was loaded from. + // This is used for resolving relative paths in Properties. + Origin string +} + +// IsEnabled returns true if this source is enabled (defaults to true if nil). +func (s Source) IsEnabled() bool { + return s.Enabled == nil || *s.Enabled +} + +// originEntry holds sources from a single origin (config file). +type originEntry struct { + origin string + sources map[string]Source +} + +// SourceCollection manages catalog sources from multiple origins with priority-based merging. +// Later entries in the origin order take precedence over earlier ones. +type SourceCollection struct { + mu sync.RWMutex + entries []originEntry +} + +// NewSourceCollection creates a new SourceCollection with the given origin order. +// Origins listed later in the order take precedence over earlier ones. +// For example, if originOrder is ["default.yaml", "user.yaml"], sources from +// "user.yaml" will override sources with the same ID from "default.yaml". +func NewSourceCollection(originOrder ...string) *SourceCollection { + entries := make([]originEntry, len(originOrder)) + for i, origin := range originOrder { + entries[i] = originEntry{origin: origin, sources: nil} + } + return &SourceCollection{ + entries: entries, + } +} + +// Merge adds sources from one origin, completely replacing anything that was +// previously from that origin. +// +// If a source with the same ID exists in multiple origins, fields from +// higher-priority origins (listed later in entries) override fields from +// lower-priority origins. Fields that are not set (zero value for strings, +// nil for pointers/slices/maps) in the override are inherited from the base. +func (sc *SourceCollection) Merge(origin string, sources map[string]Source) error { + sc.mu.Lock() + defer sc.mu.Unlock() + + // Find existing entry for this origin + for i := range sc.entries { + if sc.entries[i].origin == origin { + sc.entries[i].sources = sources + return nil + } + } + + // Origin not found, append it (dynamic registration) + sc.entries = append(sc.entries, originEntry{origin: origin, sources: sources}) + return nil +} + +// mergeSources performs field-level merging of two Source structs. +// Fields from 'override' take precedence over 'base' when they are explicitly set. +func mergeSources(base, override Source) Source { + result := base + + // ID is always taken from override (it's the key) + result.ID = override.ID + + // Name: override if non-empty + if override.Name != "" { + result.Name = override.Name + } + + // Enabled: override if non-nil + if override.Enabled != nil { + result.Enabled = override.Enabled + } + + // Labels: override if non-nil (empty slice means "explicitly no labels") + if override.Labels != nil { + result.Labels = override.Labels + } + + // IncludedItems: override if non-nil + if override.IncludedItems != nil { + result.IncludedItems = override.IncludedItems + } + + // ExcludedItems: override if non-nil + if override.ExcludedItems != nil { + result.ExcludedItems = override.ExcludedItems + } + + // Type: override if non-empty + if override.Type != "" { + result.Type = override.Type + } + + // Properties: override if non-nil (complete replacement, not deep merge) + if override.Properties != nil { + result.Properties = override.Properties + } + + // Origin: use override's origin if Properties are overridden + if override.Properties != nil && override.Origin != "" { + result.Origin = override.Origin + } + + return result +} + +// applyDefaults applies default values to a Source for fields that are not set. +func applyDefaults(source Source) Source { + // Default Enabled to true if not set + if source.Enabled == nil { + enabled := true + source.Enabled = &enabled + } + + // Default Labels to empty slice if not set + if source.Labels == nil { + source.Labels = []string{} + } + + return source +} + +// merged computes the merged view of all sources with field-level merging. +// Must be called with lock held. +func (sc *SourceCollection) merged() map[string]Source { + result := map[string]Source{} + + for _, entry := range sc.entries { + for id, source := range entry.sources { + if existing, ok := result[id]; ok { + // Field-level merge: existing is base, source is override + result[id] = mergeSources(existing, source) + } else { + result[id] = source + } + } + } + + // Apply defaults to all merged sources + for id, source := range result { + result[id] = applyDefaults(source) + } + + return result +} + +// AllSources returns all merged sources. +// All sources are returned regardless of enabled status. +func (sc *SourceCollection) AllSources() map[string]Source { + sc.mu.RLock() + defer sc.mu.RUnlock() + + result := map[string]Source{} + for id, source := range sc.merged() { + result[id] = source + } + return result +} + +// EnabledSources returns only sources that are enabled. +func (sc *SourceCollection) EnabledSources() map[string]Source { + sc.mu.RLock() + defer sc.mu.RUnlock() + + result := map[string]Source{} + for id, source := range sc.merged() { + if source.IsEnabled() { + result[id] = source + } + } + return result +} + +// Get returns a source by ID if it exists and is enabled. +func (sc *SourceCollection) Get(id string) (Source, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + + merged := sc.merged() + source, exists := merged[id] + if !exists { + return Source{}, false + } + + // Only return if enabled + if source.IsEnabled() { + return source, true + } + return Source{}, false +} + +// ByLabel returns enabled sources that have any of the labels provided. +// The matching is case insensitive. +// +// If a label is "null", every source without a label is returned. +func (sc *SourceCollection) ByLabel(labels []string) []Source { + sc.mu.RLock() + defer sc.mu.RUnlock() + + labelMap := make(map[string]struct{}, len(labels)) + for _, label := range labels { + labelMap[strings.ToLower(label)] = struct{}{} + } + + matches := map[string]Source{} + sources := sc.merged() + + if _, hasNull := labelMap["null"]; hasNull { + for _, source := range sources { + if !source.IsEnabled() { + continue + } + if len(source.Labels) == 0 { + matches[source.ID] = source + } + } + } + +OUTER: + for _, source := range sources { + if !source.IsEnabled() { + continue + } + for _, label := range source.Labels { + if _, match := labelMap[strings.ToLower(label)]; match { + matches[source.ID] = source + continue OUTER + } + } + } + + return slices.Collect(maps.Values(matches)) +} + +// IDs returns all source IDs (both enabled and disabled). +func (sc *SourceCollection) IDs() []string { + sc.mu.RLock() + defer sc.mu.RUnlock() + + merged := sc.merged() + ids := make([]string, 0, len(merged)) + for id := range merged { + ids = append(ids, id) + } + return ids +} diff --git a/pkg/catalog/source_test.go b/pkg/catalog/source_test.go new file mode 100644 index 0000000000..680931fba4 --- /dev/null +++ b/pkg/catalog/source_test.go @@ -0,0 +1,175 @@ +package catalog + +import ( + "testing" +) + +func TestSourceIsEnabled(t *testing.T) { + tests := []struct { + name string + enabled *bool + expected bool + }{ + { + name: "nil enabled defaults to true", + enabled: nil, + expected: true, + }, + { + name: "explicitly enabled", + enabled: boolPtr(true), + expected: true, + }, + { + name: "explicitly disabled", + enabled: boolPtr(false), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := Source{Enabled: tt.enabled} + if got := s.IsEnabled(); got != tt.expected { + t.Errorf("IsEnabled() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestSourceCollection(t *testing.T) { + sc := NewSourceCollection("default.yaml", "user.yaml") + + // Add default sources + defaultSources := map[string]Source{ + "source1": { + ID: "source1", + Name: "Default Source 1", + Type: "yaml", + Enabled: boolPtr(false), + }, + "source2": { + ID: "source2", + Name: "Default Source 2", + Type: "http", + }, + } + if err := sc.Merge("default.yaml", defaultSources); err != nil { + t.Fatalf("Merge default failed: %v", err) + } + + // Add user sources (should override) + userSources := map[string]Source{ + "source1": { + ID: "source1", + Enabled: boolPtr(true), // Enable the disabled source + }, + } + if err := sc.Merge("user.yaml", userSources); err != nil { + t.Fatalf("Merge user failed: %v", err) + } + + // Test AllSources + all := sc.AllSources() + if len(all) != 2 { + t.Errorf("Expected 2 sources, got %d", len(all)) + } + + // Test field-level merging: source1 should have merged fields + s1, ok := all["source1"] + if !ok { + t.Fatal("Expected source1 in AllSources") + } + if s1.Name != "Default Source 1" { + t.Errorf("Expected name 'Default Source 1', got '%s'", s1.Name) + } + if s1.Type != "yaml" { + t.Errorf("Expected type 'yaml', got '%s'", s1.Type) + } + if !s1.IsEnabled() { + t.Error("Expected source1 to be enabled after merge") + } + + // Test Get (only returns enabled sources) + _, ok = sc.Get("source1") + if !ok { + t.Error("Expected to find enabled source1") + } + + // Test EnabledSources + enabled := sc.EnabledSources() + if len(enabled) != 2 { + t.Errorf("Expected 2 enabled sources, got %d", len(enabled)) + } + + // Test IDs + ids := sc.IDs() + if len(ids) != 2 { + t.Errorf("Expected 2 IDs, got %d", len(ids)) + } +} + +func TestSourceCollectionByLabel(t *testing.T) { + sc := NewSourceCollection() + + sources := map[string]Source{ + "source1": { + ID: "source1", + Labels: []string{"prod", "ml"}, + }, + "source2": { + ID: "source2", + Labels: []string{"dev"}, + }, + "source3": { + ID: "source3", + Labels: nil, // No labels + }, + } + _ = sc.Merge("test.yaml", sources) + + // Test finding by label + result := sc.ByLabel([]string{"prod"}) + if len(result) != 1 { + t.Errorf("Expected 1 source with 'prod' label, got %d", len(result)) + } + + // Test case insensitive + result = sc.ByLabel([]string{"PROD"}) + if len(result) != 1 { + t.Errorf("Expected 1 source with 'PROD' label (case insensitive), got %d", len(result)) + } + + // Test "null" for sources without labels + result = sc.ByLabel([]string{"null"}) + if len(result) != 1 { + t.Errorf("Expected 1 source without labels, got %d", len(result)) + } + + // Test multiple labels (OR logic) + result = sc.ByLabel([]string{"prod", "dev"}) + if len(result) != 2 { + t.Errorf("Expected 2 sources with 'prod' or 'dev' labels, got %d", len(result)) + } +} + +func TestSourceCollectionDynamicOrigin(t *testing.T) { + // Test that origins can be added dynamically + sc := NewSourceCollection() + + sources := map[string]Source{ + "source1": {ID: "source1", Type: "yaml"}, + } + if err := sc.Merge("dynamic.yaml", sources); err != nil { + t.Fatalf("Dynamic merge failed: %v", err) + } + + all := sc.AllSources() + if len(all) != 1 { + t.Errorf("Expected 1 source after dynamic add, got %d", len(all)) + } +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/pkg/catalog/watcher.go b/pkg/catalog/watcher.go new file mode 100644 index 0000000000..ac855c76d3 --- /dev/null +++ b/pkg/catalog/watcher.go @@ -0,0 +1,83 @@ +package catalog + +import ( + "context" + "os" + "sync" + "time" +) + +// FileWatcher watches a file for changes using polling. +// This is a simple implementation that can be used if more sophisticated +// file watching (like fsnotify) is not available. +type FileWatcher struct { + mu sync.Mutex + path string + lastModTime time.Time + interval time.Duration +} + +// NewFileWatcher creates a new file watcher. +func NewFileWatcher(path string, interval time.Duration) *FileWatcher { + if interval == 0 { + interval = 5 * time.Second + } + + w := &FileWatcher{ + path: path, + interval: interval, + } + + if info, err := os.Stat(path); err == nil { + w.lastModTime = info.ModTime() + } + + return w +} + +// Watch returns a channel that receives a value whenever the file changes. +// The channel is closed when the context is canceled. +func (w *FileWatcher) Watch(ctx context.Context) <-chan struct{} { + ch := make(chan struct{}) + + go func() { + defer close(ch) + + ticker := time.NewTicker(w.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if w.hasChanged() { + select { + case ch <- struct{}{}: + case <-ctx.Done(): + return + } + } + } + } + }() + + return ch +} + +func (w *FileWatcher) hasChanged() bool { + w.mu.Lock() + defer w.mu.Unlock() + + info, err := os.Stat(w.path) + if err != nil { + return false + } + + if info.ModTime().After(w.lastModTime) { + w.lastModTime = info.ModTime() + return true + } + + return false +} diff --git a/scripts/merge_catalog_specs.sh b/scripts/merge_catalog_specs.sh new file mode 100755 index 0000000000..53054205a4 --- /dev/null +++ b/scripts/merge_catalog_specs.sh @@ -0,0 +1,284 @@ +#!/bin/bash + +set -e + +cd "$(dirname "$(readlink -f "$0")")/.." + +if [ -z "$YQ" ]; then + if [ -e "bin/yq" ]; then + YQ="$(realpath "bin/yq")" + else + echo "Error: YQ is not set and bin/yq does not exist" >&2 + exit 1 + fi +fi + +# Temporary files tracked for cleanup +TEMP_FILES=() + +cleanup() { + rm -f "${TEMP_FILES[@]}" 2>/dev/null || true +} +trap cleanup EXIT + +# Register a temporary file for cleanup on exit +register_temp() { + TEMP_FILES+=("$1") +} + +usage() { + echo "Usage: $0 [--check] " + echo " --check: Check for differences in the generated merged catalog specification." + echo "" + echo "This script merges the main catalog API with all plugin APIs to create" + echo "a unified OpenAPI specification for documentation purposes." + echo "" + echo "Example: $0 catalog-spec.yaml" + exit 0 +} + +# Load common schema names once (pipe-delimited string, e.g. "BaseResource|BaseResourceList|Error") +COMMON_SCHEMAS="" +if [[ -f "api/openapi/src/lib/common.yaml" ]]; then + COMMON_SCHEMAS=$($YQ eval '.components.schemas | keys | join("|")' api/openapi/src/lib/common.yaml 2>/dev/null || echo "") +fi + +# Check if a schema name is in the common schemas list. +is_common_schema() { + local name="$1" + [[ -n "$COMMON_SCHEMAS" && "|${COMMON_SCHEMAS}|" == *"|${name}|"* ]] +} + +# Function to preprocess a plugin spec to avoid conflicts +preprocess_plugin_spec() { + local spec_file="$1" + local plugin_name="$2" + local temp_file="$3" + + # Create capitalized prefix for schemas (e.g., mcp -> Mcp_) + local schema_prefix="${plugin_name^}_" + + # Create lowercase prefix for operation IDs (e.g., mcp -> mcp_) + local op_prefix="${plugin_name}_" + + # Start with original spec + cp "$spec_file" "$temp_file" + + # 1. Prefix plugin-specific schema definitions (exclude common schemas) + local schema_names + schema_names=$($YQ eval '.components.schemas | keys | .[]' "$temp_file" 2>/dev/null || echo "") + + if [[ -n "$COMMON_SCHEMAS" ]]; then + # Build a single yq expression to rename all plugin-specific schemas at once + local yq_expr="" + while IFS= read -r schema_name; do + [[ -z "$schema_name" ]] && continue + if ! is_common_schema "$schema_name"; then + local new_name="${schema_prefix}${schema_name}" + if [[ -n "$yq_expr" ]]; then + yq_expr="${yq_expr} | " + fi + yq_expr="${yq_expr}(.components.schemas[\"${new_name}\"] = .components.schemas[\"${schema_name}\"]) | del(.components.schemas[\"${schema_name}\"])" + fi + done <<< "$schema_names" + + if [[ -n "$yq_expr" ]]; then + if ! $YQ eval -i "$yq_expr" "$temp_file" 2>/dev/null; then + echo "Warning: failed to rename schemas for plugin $plugin_name" >&2 + fi + fi + else + # Fallback: manually prefix non-BaseResource schemas + local yq_expr="" + while IFS= read -r schema_name; do + if [[ -n "$schema_name" && "$schema_name" != "BaseResource" && "$schema_name" != "BaseResourceList" && "$schema_name" != "BaseResourceDates" ]]; then + if [[ -n "$yq_expr" ]]; then + yq_expr="${yq_expr} | " + fi + yq_expr="${yq_expr}(.components.schemas[\"${schema_prefix}${schema_name}\"] = .components.schemas[\"${schema_name}\"]) | del(.components.schemas[\"${schema_name}\"])" + fi + done <<< "$schema_names" + + if [[ -n "$yq_expr" ]]; then + if ! $YQ eval -i "$yq_expr" "$temp_file" 2>/dev/null; then + echo "Warning: failed to rename schemas for plugin $plugin_name (fallback)" >&2 + fi + fi + fi + + # 2. Update all $ref pointers to use prefixed schema names (except for common schemas) + # First, prefix all schema references throughout the document + sed -i 's|#/components/schemas/\([A-Za-z][A-Za-z0-9]*\)|#/components/schemas/'"${schema_prefix}"'\1|g' "$temp_file" + # Then, un-prefix each common schema individually + if [[ -n "$COMMON_SCHEMAS" ]]; then + IFS='|' read -ra SCHEMA_ARRAY <<< "$COMMON_SCHEMAS" + for schema in "${SCHEMA_ARRAY[@]}"; do + sed -i 's|#/components/schemas/'"${schema_prefix}${schema}"'|#/components/schemas/'"${schema}"'|g' "$temp_file" + done + else + sed -i 's|#/components/schemas/'"${schema_prefix}"'BaseResource|#/components/schemas/BaseResource|g' "$temp_file" + sed -i 's|#/components/schemas/'"${schema_prefix}"'BaseResourceList|#/components/schemas/BaseResourceList|g' "$temp_file" + sed -i 's|#/components/schemas/'"${schema_prefix}"'BaseResourceDates|#/components/schemas/BaseResourceDates|g' "$temp_file" + fi + + # 3. Resolve external references to common schemas before merging + sed -i 's|lib/common.yaml#/components/schemas/|#/components/schemas/|g' "$temp_file" + + # 4. Prefix operation IDs to ensure uniqueness + sed -i 's/operationId: \(.*\)/operationId: '"${op_prefix}"'\1/' "$temp_file" + + # 5. Convert relative paths to absolute paths using the plugin's server base URL + local plugin_base_url + plugin_base_url=$($YQ eval '.servers[0].url' "$temp_file" 2>/dev/null || echo "") + if [[ -n "$plugin_base_url" && "$plugin_base_url" != "null" ]]; then + local yq_path_expr="" + local paths + paths=$($YQ eval '.paths | keys | .[]' "$temp_file" 2>/dev/null || echo "") + while IFS= read -r path; do + if [[ -n "$path" && "$path" != "null" ]]; then + local absolute_path="${plugin_base_url}${path}" + if [[ -n "$yq_path_expr" ]]; then + yq_path_expr="${yq_path_expr} | " + fi + yq_path_expr="${yq_path_expr}(.paths[\"${absolute_path}\"] = .paths[\"${path}\"]) | del(.paths[\"${path}\"])" + fi + done <<< "$paths" + + if [[ -n "$yq_path_expr" ]]; then + if ! $YQ eval -i "$yq_path_expr" "$temp_file" 2>/dev/null; then + echo "Warning: failed to rewrite paths for plugin $plugin_name" >&2 + fi + fi + + # Remove the server configuration since paths are now absolute + $YQ eval -i 'del(.servers)' "$temp_file" 2>/dev/null || true + fi + + # 6. Remove the info section to avoid overwriting main catalog's info during merge + $YQ eval -i 'del(.info)' "$temp_file" 2>/dev/null || true +} + +CHECK=false +BASENAME="" +while [[ $# -gt 0 ]]; do + case "$1" in + --check) + CHECK=true + shift + ;; + -h|--help) + usage + ;; + *) + if [[ "${1#-}" != "$1" ]]; then + echo "Unknown option: $1" + usage + fi + if [[ "$BASENAME" != "" ]]; then + usage + fi + + BASENAME=$1 + shift + ;; + esac +done + +if [[ "$BASENAME" == "" ]]; then + usage +fi + +BASENAME=$(basename "$BASENAME") +MAIN_CATALOG="api/openapi/catalog.yaml" + +if [[ ! -f "$MAIN_CATALOG" ]]; then + echo "Main catalog specification not found at $MAIN_CATALOG" + exit 1 +fi + +OUT_FILE="api/openapi/$BASENAME" +if [[ "$CHECK" == "true" ]]; then + OUT_FILE="$(mktemp -t modelregistry_catalog_spec_tempXXXXXX).yaml" + register_temp "$OUT_FILE" +fi + +# Auto-discover plugin OpenAPI specifications (exclude src subdirectory) +PLUGIN_SPECS=() +while IFS= read -r spec; do + PLUGIN_SPECS+=("$spec") +done < <(find catalog/plugins/*/api/openapi -maxdepth 1 -name "openapi.yaml" -type f 2>/dev/null | sort || true) + +echo "Merging catalog specifications..." +echo " Main catalog: $MAIN_CATALOG" + +# Start with the main catalog specification +cp "$MAIN_CATALOG" "$OUT_FILE" + +# Track which plugins we're merging for enhanced description +PLUGIN_NAMES=() + +# Process each plugin spec +for plugin_spec in "${PLUGIN_SPECS[@]}"; do + # Extract plugin name from path (e.g., catalog/plugins/mcp/... -> mcp) + plugin_name=${plugin_spec#catalog/plugins/} + plugin_name=${plugin_name%%/*} + + if [[ -z "$plugin_name" ]]; then + echo "Warning: Could not extract plugin name from $plugin_spec, skipping..." + continue + fi + + echo " Plugin: $plugin_name ($plugin_spec)" + PLUGIN_NAMES+=("$plugin_name") + + # Preprocess the plugin spec to avoid conflicts + temp_preprocessed="$(mktemp -t "preprocessed_${plugin_name}_XXXXXX").yaml" + register_temp "$temp_preprocessed" + preprocess_plugin_spec "$plugin_spec" "$plugin_name" "$temp_preprocessed" + + # Merge the preprocessed plugin spec with the main spec + temp_merged="$(mktemp -t merged_tempXXXXXX).yaml" + register_temp "$temp_merged" + $YQ eval-all '. as $item ireduce ({}; . * $item)' "$OUT_FILE" "$temp_preprocessed" > "$temp_merged" + mv "$temp_merged" "$OUT_FILE" +done + +# Update the main spec's info section to reflect the merged content +if [[ ${#PLUGIN_NAMES[@]} -gt 0 ]]; then + plugin_list=$(IFS=', '; echo "${PLUGIN_NAMES[*]}") + $YQ eval -i '.info.description = .info.description + "\n\nThis unified specification includes APIs from the following plugins: '"$plugin_list"'."' "$OUT_FILE" + + # Add a custom extension to track included plugins + plugins_json="[$(printf '"%s",' "${PLUGIN_NAMES[@]}" | sed 's/,$//')]" + $YQ eval -i '.["x-catalog-plugins"] = '"$plugins_json" "$OUT_FILE" +fi + +# Re-order the keys in the generated file (following merge_openapi.sh pattern) +$YQ eval -i ' + { + "openapi": .openapi, + "info": .info, + "servers": .servers, + "paths": .paths, + "components": .components, + "security": .security, + "tags": .tags, + "x-catalog-plugins": .["x-catalog-plugins"] + } | + sort_keys(.paths) | + sort_keys(.components.schemas) | + sort_keys(.components.responses) | + sort_keys(.components.parameters) +' "$OUT_FILE" + +if [[ "$CHECK" == "true" ]]; then + diff -u "api/openapi/$BASENAME" "$OUT_FILE" + exit $? +fi + +echo "Merged catalog specification generated: $OUT_FILE" +if [[ ${#PLUGIN_NAMES[@]} -gt 0 ]]; then + echo "Included plugins: $(IFS=', '; echo "${PLUGIN_NAMES[*]}")" +else + echo "No plugin specifications found - using main catalog only" +fi