diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index 09a9a62..622360d 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -328,20 +328,14 @@ def remove_category(category) # puts "#{progress.completed} documents processed" # end # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) - category = category.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.key?(category) - - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - train_batch_internal(category, batch) - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) + + pairs = category && io ? { category => io } : categories + pairs.each do |cat, stream| + stream_train_category(cat, stream, batch_size: batch_size, &) end end @@ -389,6 +383,25 @@ def self.load_checkpoint(storage:, checkpoint_id:) private + # Trains from an IO stream with a single category. + # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void + def stream_train_category(category, io, batch_size:) + category = category.prepare_category_name + raise ArgumentError, "No such category: #{category}" unless @categories.key?(category) + raise ArgumentError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + train_batch_internal(category, batch) + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + end + # Trains a batch of documents for a single category. # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE) diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb index 37ae817..0c06dfa 100644 --- a/lib/classifier/knn.rb +++ b/lib/classifier/knn.rb @@ -268,9 +268,10 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE, &block) - @lsi.train_from_stream(category, io, batch_size: batch_size, &block) + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + # @type var categories: untype + @lsi.train_from_stream(category, io, batch_size: batch_size, **categories, &) synchronize { @dirty = true } end diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index 72de317..0d1819e 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -390,28 +390,14 @@ def self.load_checkpoint(storage:, checkpoint_id:) # end # classifier.fit # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) - category = category.to_s.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.include?(category) - - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - synchronize do - batch.each do |text| - features = text.word_hash(@min_word_length) - features.each_key { |word| @vocabulary[word] = true } - @training_data << { category: category, features: features } - end - @fitted = false - @dirty = true - end - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) + + pairs = category && io ? { category => io } : categories + pairs.each do |cat, stream| + stream_train_category(cat, stream, batch_size:, &) end end @@ -440,6 +426,33 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_ private + # Trains from an IO stream with a single category. + # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void + def stream_train_category(category, io, batch_size:) + category = category.to_s.prepare_category_name + raise ArgumentError, "No such category: #{category}" unless @categories.include?(category) + raise ArgumentError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + synchronize do + batch.each do |text| + features = text.word_hash(@min_word_length) + features.each_key { |word| @vocabulary[word] = true } + @training_data << { category: category, features: features } + end + @fitted = false + @dirty = true + end + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + end + # Trains a batch of documents for a single category. # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE) diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index 5c27715..3476cd9 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -662,21 +662,22 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) - original_auto_rebuild = @auto_rebuild - @auto_rebuild = false - + # rubocop:disable Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + # rubocop:enable Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity + raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) + + pairs = category && io ? { category => io } : categories + pairs.each_value do |io| + raise ArgumentError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + end begin - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - batch.each { |text| add_item(text, category) } - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? + original_auto_rebuild = @auto_rebuild + @auto_rebuild = false + pairs.each do |cat, stream| + stream_train_category(cat, stream, batch_size:, &) end ensure @auto_rebuild = original_auto_rebuild @@ -729,6 +730,21 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_ private + # Trains from an IO stream with a single category. + # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void + def stream_train_category(category, io, batch_size:) + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + batch.each { |text| add_item(text, category) } + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + end + # Restores LSI state from a JSON string (used by reload) # @rbs (String) -> void def restore_from_json(json) diff --git a/lib/classifier/streaming.rb b/lib/classifier/streaming.rb index 3c228b4..7238267 100644 --- a/lib/classifier/streaming.rb +++ b/lib/classifier/streaming.rb @@ -26,8 +26,8 @@ module Streaming # Trains the classifier from an IO stream. # Each line in the stream is treated as a separate document. # - # @rbs (Symbol | String, IO, ?batch_size: Integer) { (Progress) -> void } -> void - def train_from_stream(category, io, batch_size: DEFAULT_BATCH_SIZE, &block) + # @rbs (?(Symbol | String | nil), ?IO?, ?batch_size: Integer, **IO) { (Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: DEFAULT_BATCH_SIZE, **categories, &block) raise NotImplementedError, "#{self.class} must implement train_from_stream" end diff --git a/test/bayes/streaming_test.rb b/test/bayes/streaming_test.rb index fd2f311..16b0a75 100644 --- a/test/bayes/streaming_test.rb +++ b/test/bayes/streaming_test.rb @@ -17,6 +17,23 @@ def test_train_from_stream_basic assert_equal 'Spam', @classifier.classify('buy cheap free') end + def test_train_from_stream_many_categories + classifier = Classifier::Bayes.new('Spam', 'Ham') + classifier.train_from_stream( + spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"), + ham: StringIO.new("hello friend\nmeeting tomorrow\n") + ) + + assert_equal 'Spam', classifier.classify('buy free') + assert_equal 'Ham', classifier.classify('hello meeting') + end + + def test_train_from_stream_invalid_io_type + assert_raises(ArgumentError) do + @classifier.train_from_stream(spam: Object.new) + end + end + def test_train_from_stream_empty_io io = StringIO.new('') @classifier.train_from_stream(:spam, io) diff --git a/test/knn/streaming_test.rb b/test/knn/streaming_test.rb new file mode 100644 index 0000000..f3cdd4d --- /dev/null +++ b/test/knn/streaming_test.rb @@ -0,0 +1,37 @@ +require_relative '../test_helper' +require 'stringio' + +class KNNStreamingTest < Minitest::Test + def test_train_from_stream_basic + knn = Classifier::KNN.new + knn.train_from_stream(:spam, StringIO.new("buy now cheap\nfree money\nlimited offer\n")) + + assert_equal 'spam', knn.classify('buy cheap free') + end + + def test_train_from_stream_many_categories + knn = Classifier::KNN.new + knn.train_from_stream( + spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"), + ham: StringIO.new("hello friend\nmeeting tomorrow\nhello fellow\n") + ) + + assert_equal 'spam', knn.classify('free offer') + assert_equal 'ham', knn.classify('hello') + end + + def test_train_from_stream_invalid_io_type + knn = Classifier::KNN.new + assert_raises(ArgumentError) { knn.train_from_stream(spam: Object.new) } + end + + def test_train_from_stream_raises_without_args + knn = Classifier::KNN.new + assert_raises(ArgumentError) { knn.train_from_stream } + end + + def test_train_from_stream_raises_with_partial_args + knn = Classifier::KNN.new + assert_raises(ArgumentError) { knn.train_from_stream(:spam) } + end +end diff --git a/test/logistic_regression/streaming_test.rb b/test/logistic_regression/streaming_test.rb new file mode 100644 index 0000000..2fd542b --- /dev/null +++ b/test/logistic_regression/streaming_test.rb @@ -0,0 +1,39 @@ +require_relative '../test_helper' +require 'stringio' + +class LogisticRegressionStreamingTest < Minitest::Test + def test_train_from_stream_basic + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + classifier.train_from_stream(:spam, StringIO.new("buy now cheap\nfree money\nlimited offer\n")) + classifier.fit + + assert_equal 'Spam', classifier.classify('buy cheap free') + end + + def test_train_from_stream_many_categories + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + classifier.train_from_stream( + spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"), + ham: StringIO.new("hello friend\nmeeting tomorrow\n") + ) + classifier.fit + + assert_equal 'Spam', classifier.classify('buy free') + assert_equal 'Ham', classifier.classify('hello meeting') + end + + def test_train_from_stream_invalid_io_type + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + assert_raises(ArgumentError) { classifier.train_from_stream(spam: Object.new) } + end + + def test_train_from_stream_raises_without_args + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + assert_raises(ArgumentError) { classifier.train_from_stream } + end + + def test_train_from_stream_raises_with_partial_args + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + assert_raises(ArgumentError) { classifier.train_from_stream(:spam) } + end +end diff --git a/test/lsi/streaming_test.rb b/test/lsi/streaming_test.rb index 94bfaaf..787e625 100644 --- a/test/lsi/streaming_test.rb +++ b/test/lsi/streaming_test.rb @@ -23,6 +23,39 @@ def test_train_from_stream_basic assert_equal 'dog', result.to_s end + def test_train_from_stream_many_categories + lsi = Classifier::LSI.new + lsi.train_from_stream( + dog: StringIO.new("dogs are loyal pets\npuppies are playful\ndogs bark at strangers\n"), + cat: StringIO.new("cats are independent\nkittens are curious\ncats meow softly\n") + ) + + assert_equal :dog, lsi.classify('loyal pet that barks') + assert_equal :cat, lsi.classify('independent curious pet') + end + + def test_train_from_stream_raises_without_args + assert_raises(ArgumentError) { @lsi.train_from_stream } + end + + def test_train_from_stream_raises_with_partial_args + assert_raises(ArgumentError) { @lsi.train_from_stream(:spam) } + end + + def test_train_from_stream_invalid_io_type + assert_raises(ArgumentError) { @lsi.train_from_stream(category: Object.new) } + end + + def test_train_from_stream_with_invalid_io_type_does_not_modify_auto_rebuild_setting + @lsi = Classifier::LSI.new(auto_rebuild: true) + + assert_raises(ArgumentError) do + @lsi.train_from_stream(cat1: StringIO.new("one\ntwo\n"), cat2: Object.new) + end + + assert @lsi.auto_rebuild + end + def test_train_from_stream_empty_io @lsi.train_from_stream(:category, StringIO.new('')) @@ -82,6 +115,18 @@ def test_train_from_stream_rebuilds_index_when_auto_rebuild refute_predicate @lsi, :needs_rebuild? end + def test_train_from_stream_with_keyword_categories_rebuilds_index_when_auto_rebuild + @lsi = Classifier::LSI.new(auto_rebuild: true) + + @lsi.train_from_stream( + dog: StringIO.new("dogs are loyal\ndogs bark\n"), + cat: StringIO.new("cats are independent\ncats meow\n") + ) + + # Index should be built + refute_predicate @lsi, :needs_rebuild? + end + def test_train_from_stream_skips_rebuild_when_auto_rebuild_false @lsi = Classifier::LSI.new(auto_rebuild: false) @@ -91,6 +136,18 @@ def test_train_from_stream_skips_rebuild_when_auto_rebuild_false assert_predicate @lsi, :needs_rebuild? end + def test_train_from_stream_with_keyword_categories_skips_rebuild_when_auto_rebuild_false + @lsi = Classifier::LSI.new(auto_rebuild: false) + + @lsi.train_from_stream( + cat1: StringIO.new("document one\ndocument two\n"), + cat2: StringIO.new("document three\ndocument four\n") + ) + + # Index should need rebuild + assert_predicate @lsi, :needs_rebuild? + end + def test_train_from_stream_with_file Tempfile.create(['corpus', '.txt']) do |file| file.puts 'dogs are loyal pets'