diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs index 76445bdb5e..4f9f434e37 100644 --- a/candle-examples/examples/marian-mt/main.rs +++ b/candle-examples/examples/marian-mt/main.rs @@ -34,6 +34,8 @@ enum LanguagePair { EnFr, #[value(name = "en-ru")] EnRu, + #[value(name = "zh-en")] + ZhEn, } // TODO: Maybe add support for the conditional prompt. @@ -81,6 +83,7 @@ pub fn main() -> anyhow::Result<()> { (Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(), (Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(), (Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(), + (Which::Base, LanguagePair::ZhEn) => marian::Config::opus_mt_zh_en(), (Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"), }; let tokenizer_default_repo = match args.language_pair { @@ -90,6 +93,7 @@ pub fn main() -> anyhow::Result<()> { | LanguagePair::EnEs | LanguagePair::EnFr | LanguagePair::EnRu => "KeighBee/candle-marian", + LanguagePair::ZhEn => "crlf0710/candle-marian", }; let tokenizer = { let tokenizer = match args.tokenizer { @@ -103,6 +107,7 @@ pub fn main() -> anyhow::Result<()> { (Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json", (Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json", (Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json", + (Which::Base, LanguagePair::ZhEn) => "tokenizer-marian-base-zh-en-zh.json", (Which::Big, lp) => { anyhow::bail!("big is not supported for language pair {lp:?}") } @@ -127,6 +132,7 @@ pub fn main() -> anyhow::Result<()> { (Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json", (Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json", (Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json", + (Which::Base, LanguagePair::ZhEn) => "tokenizer-marian-base-zh-en-en.json", (Which::Big, lp) => { anyhow::bail!("big is not supported for language pair {lp:?}") } @@ -180,6 +186,11 @@ pub fn main() -> anyhow::Result<()> { hf_hub::RepoType::Model, "refs/pr/7".to_string(), )), + (Which::Base, LanguagePair::ZhEn) => api.repo(hf_hub::Repo::with_revision( + "Helsinki-NLP/opus-mt-zh-en".to_string(), + hf_hub::RepoType::Model, + "refs/pr/12".to_string(), + )), (Which::Big, lp) => { anyhow::bail!("big is not supported for language pair {lp:?}") } diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index ad57b876e1..cc459e64a6 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -106,6 +106,30 @@ impl Config { } } + pub fn opus_mt_zh_en() -> Self { + Self { + activation_function: candle_nn::Activation::Swish, + d_model: 512, + decoder_attention_heads: 8, + decoder_ffn_dim: 2048, + decoder_layers: 6, + decoder_start_token_id: 65000, + decoder_vocab_size: Some(65001), + encoder_attention_heads: 8, + encoder_ffn_dim: 2048, + encoder_layers: 6, + eos_token_id: 0, + forced_eos_token_id: 0, + is_encoder_decoder: true, + max_position_embeddings: 512, + pad_token_id: 65000, + scale_embedding: true, + share_encoder_decoder_embeddings: true, + use_cache: true, + vocab_size: 65001, + } + } + pub fn opus_mt_en_hi() -> Self { Self { activation_function: candle_nn::Activation::Swish,