Sentence Similarity
sentence-transformers
ONNX
Safetensors
Japanese
English
loss:MatryoshkaLoss
loss:MultipleNegativesRankingLoss
Instructions to use hotchpotch/static-embedding-japanese with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use hotchpotch/static-embedding-japanese with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("hotchpotch/static-embedding-japanese") sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Notebooks
- Google Colab
- Kaggle
| # static-embedding-japanese trainer.py | |
| # base: https://huggingface.co/blog/static-embeddings | |
| # MIT License | |
| import logging | |
| import os | |
| import random | |
| from pathlib import Path | |
| from sentence_transformers import ( | |
| SentenceTransformer, | |
| SentenceTransformerModelCardData, | |
| SentenceTransformerTrainer, | |
| SentenceTransformerTrainingArguments, | |
| ) | |
| from sentence_transformers.evaluation import NanoBEIREvaluator | |
| from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss | |
| from sentence_transformers.models.StaticEmbedding import StaticEmbedding | |
| from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers | |
| from transformers import AutoTokenizer | |
| from datasets import Dataset, DatasetDict, load_dataset | |
| EXP = "030" | |
| print("EXP:", EXP) | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| print(PROJECT_ROOT) | |
| EN_TARGET_DATASETS = [ | |
| # "gooaq", # non-commarical | |
| "msmarco", | |
| "squad", | |
| # "s2orc", # large | |
| "allnli", | |
| # "paq", # large | |
| "trivia_qa", | |
| # "msmarco_10m", | |
| "swim_ir", | |
| # "pubmedqa", | |
| "miracl", | |
| # "mldr", # non-commarical | |
| "mr_tydi", | |
| ] | |
| JA_TARGET_DATASETS = [ | |
| "hpprc_emb__auto-wiki-nli-triplet", | |
| "hpprc_emb__auto-wiki-qa", | |
| "hpprc_emb__auto-wiki-qa-nemotron", | |
| "hpprc_emb__auto-wiki-qa-pair", | |
| "hpprc_emb__baobab-wiki-retrieval", | |
| # "hpprc_emb__jagovfaqs", JMTEB task のtestに正解が含まれている | |
| "hpprc_emb__janli-triplet", | |
| "hpprc_emb__jaquad", | |
| "hpprc_emb__jqara", # JMTEB task のドメイン | |
| "hpprc_emb__jsnli-triplet", | |
| "hpprc_emb__jsquad", | |
| "hpprc_emb__miracl", # JMTEB task のドメイン | |
| "hpprc_emb__mkqa", | |
| "hpprc_emb__mkqa-triplet", | |
| # "hpprc_emb__mmarco", 文字化け等が含みノイジー | |
| "hpprc_emb__mr-tydi", # JMTEB task のドメイン | |
| "hpprc_emb__nu-mnli-triplet", | |
| "hpprc_emb__nu-snli-triplet", | |
| # "hpprc_emb__paws-x-triplet", JMTEB task のtestに含まれている? | |
| "hpprc_emb__quiz-no-mori", | |
| "hpprc_emb__quiz-works", | |
| "hpprc_emb__snow-triplet", | |
| "hpprc_llmjp-kaken", | |
| "hpprc_llmjp_warp_html", | |
| "hpprc_mqa_ja", | |
| "hpprc_msmarco_ja", | |
| ] | |
| AUG_FACTOR_DATASETS = { | |
| "hpprc_emb__miracl": 20, | |
| "hpprc_emb__mr-tydi": 20, | |
| "hpprc_emb__jqara": 10, | |
| "hpprc_emb__baobab-wiki-retrieval": 5, | |
| "hpprc_emb__mkqa": 5, | |
| "hpprc_emb__auto-wiki-qa-nemotron": 2, | |
| "hpprc_msmarco_ja": 2, | |
| } | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| logging.basicConfig( | |
| format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO | |
| ) | |
| random.seed(12) | |
| def _load_train_eval_datasets_en(): | |
| """ | |
| Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. | |
| Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. | |
| """ | |
| en_train_dataset_dir = PROJECT_ROOT / "datasets" / "en_train_dataset" | |
| en_eval_dataset_dir = PROJECT_ROOT / "datasets" / "en_eval_dataset" | |
| try: | |
| train_dataset = DatasetDict.load_from_disk(en_train_dataset_dir) | |
| eval_dataset = DatasetDict.load_from_disk(en_eval_dataset_dir) | |
| return train_dataset, eval_dataset | |
| except FileNotFoundError: | |
| print("Loading gooaq dataset...") | |
| gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train") | |
| gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12) | |
| gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"] | |
| gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"] | |
| print("Loaded gooaq dataset.") | |
| print("Loading msmarco dataset...") | |
| msmarco_dataset = load_dataset( | |
| "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", | |
| "triplet", | |
| split="train", | |
| ) | |
| msmarco_dataset_dict = msmarco_dataset.train_test_split( | |
| test_size=10_000, seed=12 | |
| ) | |
| msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"] | |
| msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"] | |
| print("Loaded msmarco dataset.") | |
| print("Loading squad dataset...") | |
| squad_dataset = load_dataset("sentence-transformers/squad", split="train") | |
| squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12) | |
| squad_train_dataset: Dataset = squad_dataset_dict["train"] | |
| squad_eval_dataset: Dataset = squad_dataset_dict["test"] | |
| print("Loaded squad dataset.") | |
| print("Loading s2orc dataset...") | |
| s2orc_dataset = load_dataset( | |
| "sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]" | |
| ) | |
| s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12) | |
| s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"] | |
| s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"] | |
| print("Loaded s2orc dataset.") | |
| print("Loading allnli dataset...") | |
| allnli_train_dataset = load_dataset( | |
| "sentence-transformers/all-nli", "triplet", split="train" | |
| ) | |
| allnli_eval_dataset = load_dataset( | |
| "sentence-transformers/all-nli", "triplet", split="dev" | |
| ) | |
| print("Loaded allnli dataset.") | |
| print("Loading paq dataset...") | |
| paq_dataset = load_dataset("sentence-transformers/paq", split="train") | |
| paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12) | |
| paq_train_dataset: Dataset = paq_dataset_dict["train"] | |
| paq_eval_dataset: Dataset = paq_dataset_dict["test"] | |
| print("Loaded paq dataset.") | |
| print("Loading trivia_qa dataset...") | |
| trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train") | |
| trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12) | |
| trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"] | |
| trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"] | |
| print("Loaded trivia_qa dataset.") | |
| print("Loading msmarco_10m dataset...") | |
| msmarco_10m_dataset = load_dataset( | |
| "bclavie/msmarco-10m-triplets", split="train" | |
| ) | |
| msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split( | |
| test_size=10_000, seed=12 | |
| ) | |
| msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"] | |
| msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"] | |
| print("Loaded msmarco_10m dataset.") | |
| print("Loading swim_ir dataset...") | |
| swim_ir_dataset = load_dataset( | |
| "nthakur/swim-ir-monolingual", "en", split="train" | |
| ).select_columns(["query", "text"]) | |
| swim_ir_dataset_dict = swim_ir_dataset.train_test_split( | |
| test_size=10_000, seed=12 | |
| ) | |
| swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"] | |
| swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"] | |
| print("Loaded swim_ir dataset.") | |
| # NOTE: 20 negatives | |
| print("Loading pubmedqa dataset...") | |
| pubmedqa_dataset = load_dataset( | |
| "sentence-transformers/pubmedqa", "triplet-20", split="train" | |
| ) | |
| pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split( | |
| test_size=100, seed=12 | |
| ) | |
| pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"] | |
| pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"] | |
| print("Loaded pubmedqa dataset.") | |
| # NOTE: A lot of overlap with anchor/positives | |
| print("Loading miracl dataset...") | |
| miracl_dataset = load_dataset( | |
| "sentence-transformers/miracl", "en-triplet-all", split="train" | |
| ) | |
| miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12) | |
| miracl_train_dataset: Dataset = miracl_dataset_dict["train"] | |
| miracl_eval_dataset: Dataset = miracl_dataset_dict["test"] | |
| print("Loaded miracl dataset.") | |
| # NOTE: A lot of overlap with anchor/positives | |
| print("Loading mldr dataset...") | |
| mldr_dataset = load_dataset( | |
| "sentence-transformers/mldr", "en-triplet-all", split="train" | |
| ) | |
| mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12) | |
| mldr_train_dataset: Dataset = mldr_dataset_dict["train"] | |
| mldr_eval_dataset: Dataset = mldr_dataset_dict["test"] | |
| print("Loaded mldr dataset.") | |
| # NOTE: A lot of overlap with anchor/positives | |
| print("Loading mr_tydi dataset...") | |
| mr_tydi_dataset = load_dataset( | |
| "sentence-transformers/mr-tydi", "en-triplet-all", split="train" | |
| ) | |
| mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split( | |
| test_size=10_000, seed=12 | |
| ) | |
| mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"] | |
| mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"] | |
| print("Loaded mr_tydi dataset.") | |
| train_dataset = DatasetDict( | |
| { | |
| "gooaq": gooaq_train_dataset, | |
| "msmarco": msmarco_train_dataset, | |
| "squad": squad_train_dataset, | |
| "s2orc": s2orc_train_dataset, | |
| "allnli": allnli_train_dataset, | |
| "paq": paq_train_dataset, | |
| "trivia_qa": trivia_qa_train_dataset, | |
| "msmarco_10m": msmarco_10m_train_dataset, | |
| "swim_ir": swim_ir_train_dataset, | |
| "pubmedqa": pubmedqa_train_dataset, | |
| "miracl": miracl_train_dataset, | |
| "mldr": mldr_train_dataset, | |
| "mr_tydi": mr_tydi_train_dataset, | |
| } | |
| ) | |
| eval_dataset = DatasetDict( | |
| { | |
| "gooaq": gooaq_eval_dataset, | |
| "msmarco": msmarco_eval_dataset, | |
| "squad": squad_eval_dataset, | |
| "s2orc": s2orc_eval_dataset, | |
| "allnli": allnli_eval_dataset, | |
| "paq": paq_eval_dataset, | |
| "trivia_qa": trivia_qa_eval_dataset, | |
| "msmarco_10m": msmarco_10m_eval_dataset, | |
| "swim_ir": swim_ir_eval_dataset, | |
| "pubmedqa": pubmedqa_eval_dataset, | |
| "miracl": miracl_eval_dataset, | |
| "mldr": mldr_eval_dataset, | |
| "mr_tydi": mr_tydi_eval_dataset, | |
| } | |
| ) | |
| train_dataset.save_to_disk(str(en_train_dataset_dir)) | |
| eval_dataset.save_to_disk(str(en_eval_dataset_dir)) | |
| return train_dataset, eval_dataset | |
| def load_train_eval_datasets_en(target_dataset_names: list[str] = []): | |
| print("Loading train and eval datasets...") | |
| if len(target_dataset_names) == 0: | |
| return DatasetDict(), DatasetDict() | |
| train_dataset, eval_dataset = _load_train_eval_datasets_en() | |
| ds_names = list(train_dataset.keys()) | |
| for ds_name in ds_names: | |
| if ds_name not in target_dataset_names: | |
| del train_dataset[ds_name] | |
| del eval_dataset[ds_name] | |
| else: | |
| print( | |
| "target en ds", | |
| ds_name, | |
| len(train_dataset[ds_name]), | |
| len(eval_dataset[ds_name]), | |
| ) | |
| return train_dataset, eval_dataset | |
| def load_train_eval_datasets_jp(target_dataset_names: list[str] = []): | |
| print("Loading train and eval datasets...") | |
| jp_train_dataset_dir = PROJECT_ROOT / "datasets" / "jp_train_dataset" | |
| jp_eval_dataset_dir = PROJECT_ROOT / "datasets" / "jp_eval_dataset" | |
| train_dataset_dict = {} | |
| eval_dataset_dict = {} | |
| for ds_name in target_dataset_names: | |
| print("loading jp ds", ds_name) | |
| try: | |
| train_ds = Dataset.load_from_disk(f"{jp_train_dataset_dir}/{ds_name}") | |
| eval_ds = Dataset.load_from_disk(f"{jp_eval_dataset_dir}/{ds_name}") | |
| except FileNotFoundError: | |
| print(f"{ds_name} not found, loading from datasets library...") | |
| ds = load_dataset( | |
| "hotchpotch/sentence_transformer_japanese", ds_name, split="train" | |
| ) | |
| ds_size = len(ds) | |
| test_size = min(3000, ds_size // 100) | |
| splitted = ds.train_test_split(test_size=test_size, seed=12) | |
| train_ds = splitted["train"] | |
| eval_ds = splitted["test"] | |
| # save | |
| train_ds.save_to_disk(f"{jp_train_dataset_dir}/{ds_name}") | |
| eval_ds.save_to_disk(f"{jp_eval_dataset_dir}/{ds_name}") | |
| train_dataset_dict[ds_name] = train_ds | |
| eval_dataset_dict[ds_name] = eval_ds | |
| return DatasetDict(train_dataset_dict), DatasetDict(eval_dataset_dict) | |
| def main(): | |
| # 1. Load a model to finetune with 2. (Optional) model card data | |
| print("Loading model...") | |
| static_embedding = StaticEmbedding( | |
| AutoTokenizer.from_pretrained("hotchpotch/xlm-roberta-japanese-tokenizer"), | |
| embedding_dim=1024, | |
| ) | |
| model = SentenceTransformer( | |
| modules=[static_embedding], | |
| model_card_data=SentenceTransformerModelCardData( | |
| language="ja", | |
| license="mit", | |
| model_name="Static Embeddings with japanese tokenizer finetuned on various datasets", | |
| ), | |
| ) | |
| # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL) | |
| print("Loading datasets...") | |
| train_dataset_en, eval_dataset_en = load_train_eval_datasets_en(EN_TARGET_DATASETS) | |
| train_dataset_jp, eval_dataset_jp = load_train_eval_datasets_jp(JA_TARGET_DATASETS) | |
| # merge | |
| print("Merging datasets...") | |
| train_dataset = DatasetDict({**train_dataset_en, **train_dataset_jp}) | |
| eval_dataset = DatasetDict({**eval_dataset_en, **eval_dataset_jp}) | |
| for ds_name, aug_factor in AUG_FACTOR_DATASETS.items(): | |
| columns = train_dataset[ds_name].column_names | |
| def data_aug(example): | |
| result = {} | |
| for col in columns: | |
| result[col] = example[col] * aug_factor | |
| return result | |
| before_len = len(train_dataset[ds_name]) | |
| train_dataset[ds_name] = train_dataset[ds_name].map( | |
| data_aug, batched=True, num_proc=11 | |
| ) | |
| print("data augmented", ds_name, before_len, len(train_dataset[ds_name])) | |
| for train_ds_name in train_dataset.keys(): | |
| print( | |
| "train ds", | |
| train_ds_name, | |
| len(train_dataset[train_ds_name]), | |
| len(eval_dataset[train_ds_name]), | |
| ) | |
| # 4. Define a loss function | |
| loss = MultipleNegativesRankingLoss(model) | |
| loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) | |
| # 5. (Optional) Specify training arguments | |
| run_name = f"static-retrieval-mrl-jp-v1_{EXP}" | |
| args = SentenceTransformerTrainingArguments( | |
| # Required parameter: | |
| output_dir=f"models/{run_name}", | |
| # Optional training parameters: | |
| num_train_epochs=2, | |
| per_device_train_batch_size=2048 * 3, | |
| # gradient_accumulation_steps=4, | |
| per_device_eval_batch_size=2048, | |
| learning_rate=2e-1, | |
| lr_scheduler_type="cosine", | |
| # optim="adafactor", | |
| warmup_ratio=0.1, | |
| fp16=False, # Set to False if you get an error that your GPU can't run on FP16 | |
| bf16=True, # Set to True if you have a GPU that supports BF16 | |
| batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch | |
| multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, | |
| # Optional tracking/debugging parameters: | |
| eval_strategy="steps", | |
| eval_steps=200, | |
| save_strategy="steps", | |
| save_steps=200, | |
| save_total_limit=20, | |
| logging_steps=20, | |
| logging_first_step=True, | |
| dataloader_prefetch_factor=4, | |
| dataloader_num_workers=15, | |
| run_name=run_name, # Will be used in W&B if `wandb` is installed | |
| ) | |
| # 6. (Optional) Create an evaluator & evaluate the base model | |
| evaluator = NanoBEIREvaluator() | |
| evaluator(model) | |
| # 7. Create a trainer & train | |
| trainer = SentenceTransformerTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| loss=loss, | |
| evaluator=evaluator, | |
| ) | |
| trainer.train() | |
| # (Optional) Evaluate the trained model on the evaluator after training | |
| evaluator(model) | |
| # 8. Save the trained model | |
| model.save_pretrained(f"{PROJECT_ROOT}/models/{run_name}/final") | |
| # 9. (Optional) Push it to the Hugging Face Hub | |
| # model.push_to_hub(run_name, private=True) | |
| if __name__ == "__main__": | |
| main() | |