Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
591 changes: 590 additions & 1 deletion Cargo.lock

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ ndarray = "0.16.1"
serde = { version = "1.0.228", features = ["serde_derive"] }
tracing = "0.1.41"
thiserror = "2.0.17"
image-ndarray = "0.1.5"

[workspace.dependencies.bytes]
version = "1.11.1"
features = ["serde"]

[workspace.dependencies.image]
version = "0.25.10"
features = ["serde"]

[workspace.dependencies.parking_lot]
version = "0.12.5"
Expand Down
74 changes: 50 additions & 24 deletions encoderfile-runtime/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
use parking_lot::Mutex;
use std::{
fs::File,
io::{BufReader, Read, Seek},
sync::Arc,
error::Error, fs::File, io::{BufReader, Read, Seek, Error as IOError, ErrorKind}, sync::Arc,
};

use anyhow::Result;
use clap::Parser;
use encoderfile::{
common::{
ModelType,
model_type::{Embedding, SentenceEmbedding, SequenceClassification, TokenClassification},
},
runtime::{EncoderfileLoader, EncoderfileState, load_assets},
transport::cli::Cli,
common::{ModelConfig, model_type::{
Embedding, ImageClassification, ModelType, SentenceEmbedding, SequenceClassification, TokenClassification
}},
runtime::{ClassifierState, EncoderfileLoader, EncoderfileState, FeatureExtractorState, ImageInputState, TextInputState, load_assets},
transport::cli::{TextCli, ImageCli},
};

#[tokio::main]
Expand All @@ -30,49 +27,78 @@ async fn main() -> Result<()> {
}

macro_rules! run_cli {
($model_type:ident, $cli:expr, $config:expr, $session:expr, $tokenizer:expr, $model_config:expr) => {{
($model_type:ident, $cli:expr, $config:expr, $session:expr, $input_state:expr, $task_state:expr) => {{
let state = Arc::new(EncoderfileState::<$model_type>::new(
$config,
$session,
$tokenizer,
$model_config,
$input_state,
$task_state,
));
$cli.command.execute(state).await
}};
}

async fn entrypoint<'a, R: Read + Seek>(loader: &mut EncoderfileLoader<'a, R>) -> Result<()> {
let cli = Cli::parse();
let session = Mutex::new(loader.session()?);
let model_config = loader.model_config()?;
let tokenizer = loader.tokenizer()?;
let config = loader.encoderfile_config()?;
// TODO clear out lifetimes in state and loader to avoid

fn class_task_state(model_config: &ModelConfig) -> ClassifierState {
// if num_labels, make a vector of labels
// if id2label, make sure it's 0..n-1
ClassifierState {
id2label: model_config.id2label.clone(),
label2id: model_config.label2id.clone(),
num_labels: model_config.num_labels,
}
}

match loader.model_type() {
ModelType::Embedding => run_cli!(Embedding, cli, config, session, tokenizer, model_config),
ModelType::Embedding => run_cli!(
Embedding,
TextCli::parse(),
config,
session,
TextInputState { tokenizer: loader.tokenizer()?, model_config },
FeatureExtractorState {}
),
ModelType::SequenceClassification => run_cli!(
SequenceClassification,
cli,
TextCli::parse(),
config,
session,
tokenizer,
model_config
TextInputState { tokenizer: loader.tokenizer()?, model_config: model_config.clone() },
class_task_state(&model_config)
),
ModelType::TokenClassification => run_cli!(
TokenClassification,
cli,
TextCli::parse(),
config,
session,
tokenizer,
model_config
TextInputState { tokenizer: loader.tokenizer()?, model_config: model_config.clone() },
class_task_state(&model_config)
),
ModelType::SentenceEmbedding => run_cli!(
SentenceEmbedding,
cli,
TextCli::parse(),
config,
session,
TextInputState { tokenizer: loader.tokenizer()?, model_config },
FeatureExtractorState {}
),
ModelType::ImageClassification => run_cli!(
ImageClassification,
ImageCli::parse(),
config,
session,
tokenizer,
model_config
ImageInputState {
height: model_config.height(),
width: model_config.width(),
num_channels: model_config.num_channels().ok_or(IOError::new(ErrorKind::InvalidData, "Missing required configuration field"))?,
image_size: model_config.image_size,
},
class_task_state(&model_config)
),
}
}
10 changes: 10 additions & 0 deletions encoderfile/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,18 @@ workspace = true
[dependencies.serde_json]
workspace = true

[dependencies.bytes]
workspace = true

[dependencies.ndarray]
workspace = true

[dependencies.image]
workspace = true

[dependencies.image-ndarray]
workspace = true

[dependencies.figment]
version = "0.10.19"
features = ["env", "serde_yaml", "yaml"]
Expand Down Expand Up @@ -211,6 +220,7 @@ optional = true

[dependencies.axum]
version = "0.8.6"
features = ["multipart"]
optional = true

[dependencies.axum-server]
Expand Down
8 changes: 4 additions & 4 deletions encoderfile/benches/postprocessing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn main() {

#[divan::bench(args = [(8, 16, 384), (16, 128, 768), (64, 512, 1024)])]
fn embedding_postprocess(b: Bencher, dim: (usize, usize, usize)) {
let tokenizer = &embedding_state().tokenizer;
let tokenizer = &embedding_state().per_model_input_state.tokenizer;
let (batch, tokens, hidden) = dim;

// Random embeddings
Expand All @@ -35,7 +35,7 @@ fn embedding_postprocess(b: Bencher, dim: (usize, usize, usize)) {
#[divan::bench(args = [8, 16, 64])]
fn sequence_classification_postprocess(b: Bencher, batch: usize) {
let state = sequence_classification_state();
let config = &state.model_config;
let config = &state.per_task_state;
let n_labels = config.id2label.clone().unwrap().len();

let mut rng = rand::rng();
Expand All @@ -51,10 +51,10 @@ fn sequence_classification_postprocess(b: Bencher, batch: usize) {
#[divan::bench(args = [(8, 16), (16, 128), (64, 512)])]
fn token_classification_postprocess(b: Bencher, dim: (usize, usize)) {
let state = token_classification_state();
let config = &state.model_config;
let config = &state.per_task_state;
let n_labels = config.id2label.clone().unwrap().len();

let tokenizer = &embedding_state().tokenizer;
let tokenizer = &embedding_state().per_model_input_state.tokenizer;
let (batch, tokens) = dim;

// Random embeddings
Expand Down
4 changes: 4 additions & 0 deletions encoderfile/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"proto/sequence_classification.proto",
"proto/token_classification.proto",
"proto/sentence_embedding.proto",
"proto/image_classification.proto",
"proto/manifest.proto",
"proto/image_types.proto",
],
&[
"proto/embedding",
"proto/sequence_classification",
"proto/token_classification",
"proto/sentence_embedding",
"proto/image_classification",
"proto/manifest",
"proto/image_types",
],
)?;

Expand Down
21 changes: 21 additions & 0 deletions encoderfile/proto/image_classification.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
syntax = "proto3";

package encoderfile.image_classification;

import "proto/metadata.proto";
import "proto/image_types.proto";

service ImageClassificationInference {
rpc Predict(ImageClassificationRequest) returns (ImageClassificationResponse);
rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse);
}

message ImageClassificationRequest {
repeated encoderfile.image_types.ImageInput inputs = 1;
map<string, string> metadata = 11;
}

message ImageClassificationResponse {
repeated encoderfile.image_types.ImageLabels labels = 1;
map<string, string> metadata = 11;
}
30 changes: 30 additions & 0 deletions encoderfile/proto/image_segmentation.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
syntax = "proto3";

package encoderfile.image_segmentation;

import "proto/token.proto";
import "proto/metadata.proto";

service ImageSegmentation {
rpc Predict(ImageSegmentationRequest) returns (ImageSegmentationResponse);
rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse);
}

message ImageSegmentationRequest {
repeated encoderfile.image_types.ImageInput images = 1;
map<string, string> metadata = 11;
}

message ImageSegment {
encoderfile.image_types.ImageLabelScore label = 1;
bytes mask = 2;
}

message ImageSegments {
repeated ImageSegment segments = 1;
}

message ImageSegmentationResponse {
repeated ImageSegments segments_batch = 1;
map<string, string> metadata = 11;
}
16 changes: 16 additions & 0 deletions encoderfile/proto/image_types.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
syntax = "proto3";

package encoderfile.image_types;

message ImageInput {
bytes image = 1;
}

message ImageLabelScore {
string label = 1;
optional float score = 2;
}

message ImageLabels {
repeated ImageLabelScore labels = 1;
}
5 changes: 5 additions & 0 deletions encoderfile/proto/metadata.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ message GetModelMetadataRequest {}

message GetModelMetadataResponse {
string model_id = 1;
// TODO decide if we want a model family/area at a higher level
ModelType model_type = 2;
map<uint32, string> id2label = 3;
}
Expand All @@ -16,4 +17,8 @@ enum ModelType {
SEQUENCE_CLASSIFICATION = 2;
TOKEN_CLASSIFICATION = 3;
SENTENCE_EMBEDDING = 4;

IMAGE_CLASSIFICATION = 21;
// IMAGE_SEGMENTATION = 22;
// OBJECT_DETECTION = 23;
}
33 changes: 33 additions & 0 deletions encoderfile/proto/object_detection.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
syntax = "proto3";

package encoderfile.object_detection;

import "proto/token.proto";
import "proto/metadata.proto";

service ObjectDetection {
rpc Predict(ObjectDetectionRequest) returns (ObjectDetectionResponse);
rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse);
}

message ObjectDetectionRequest {
repeated encoderfile.image_types.ImageInput inputs = 1;
map<string, string> metadata = 11;
}

message ImageBoundingBox {
encoderfile.image_types.ImageLabelScore label = 1;
xmin int32 = 2;
xmax int32 = 3;
ymin int32 = 4;
ymax int32 = 5;
}

message ImageBoundingBoxes {
repeated ImageBoundingBox box = 1;
}

message ObjectDetectionResponse {
repeated ImageBoundingBoxes boxes = 1;
map<string, string> metadata = 11;
}
1 change: 0 additions & 1 deletion encoderfile/proto/sentence_embedding.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ syntax = "proto3";

package encoderfile.sentence_embedding;

import "proto/token.proto";
import "proto/metadata.proto";

service SentenceEmbeddingInference {
Expand Down
17 changes: 13 additions & 4 deletions encoderfile/src/builder/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ use crate::{
codec::EncoderfileCodec,
},
generated::manifest::Backend,
runtime::{InputType, Input}
};
use anyhow::{Context, Result};
use ort::session::input;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -27,6 +29,11 @@ pub struct EncoderfileBuilder {
pub config: BuildConfig,
}

pub fn validate(input: &Input) -> Result<()> {
Ok(())
}


impl EncoderfileBuilder {
pub fn new(config: BuildConfig) -> EncoderfileBuilder {
Self { config }
Expand Down Expand Up @@ -90,10 +97,12 @@ impl EncoderfileBuilder {
}

// validate tokenizer
let tokenizer_asset =
crate::builder::tokenizer::validate_tokenizer(&self.config.encoderfile)?;
planned_assets.push(tokenizer_asset);
terminal::success("Tokenizer validated");
if self.config.encoderfile.model_type.input_type() == crate::runtime::Input::Text {
let tokenizer_asset =
crate::builder::tokenizer::validate_tokenizer(&self.config.encoderfile)?;
planned_assets.push(tokenizer_asset);
terminal::success("Tokenizer validated");
}

// initialize final binary
terminal::info("Writing encoderfile...");
Expand Down
Loading