Skip to content

Commit

Permalink
SBERT - enable batch embedding and fix sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Feb 28, 2023
1 parent 7b618aa commit 9e01a2e
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 26 deletions.
63 changes: 47 additions & 16 deletions orangecontrib/text/tests/test_sbert.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,42 @@
import base64
import json
import unittest
from unittest.mock import patch
from collections.abc import Iterator
import zlib
from unittest.mock import patch, ANY
import asyncio

from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM
from orangecontrib.text import Corpus

PATCH_METHOD = 'httpx.AsyncClient.post'
RESPONSE = [
f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
for i in range(9)
]

RESPONSES = {
t: [i] * EMB_DIM for i, t in enumerate(Corpus.from_file("deerwester").documents)
}
RESPONSE_NONE = RESPONSES.copy()
RESPONSE_NONE[list(RESPONSE_NONE.keys())[-1]] = None
IDEAL_RESPONSE = [[i] * EMB_DIM for i in range(9)]


class DummyResponse:

def __init__(self, content):
self.content = content


def make_dummy_post(response, sleep=0):
def _decompress_text(instance):
return zlib.decompress(base64.b64decode(instance.encode("utf-8"))).decode("utf-8")


def make_dummy_post(responses, sleep=0):
@staticmethod
async def dummy_post(url, headers, data=None, content=None):
assert data or content
await asyncio.sleep(sleep)
return DummyResponse(
content=next(response) if isinstance(response, Iterator) else response
)
data = json.loads(content.decode("utf-8", "replace"))
data_ = data if isinstance(data, list) else [data]
texts = [_decompress_text(instance) for instance in data_]
responses_ = [responses[t] for t in texts]
r = {"embedding": responses_ if isinstance(data, list) else responses_[0]}
return DummyResponse(content=json.dumps(r).encode("utf-8"))
return dummy_post


Expand All @@ -51,25 +59,25 @@ def test_empty_corpus(self, mock):
dict()
)

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_success(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE)

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_none_result(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_transform(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertIsNone(skipped)
self.assertEqual(len(self.corpus), len(res))
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_transform_skipped(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertEqual(len(self.corpus) - 1, len(res))
Expand All @@ -80,6 +88,29 @@ def test_transform_skipped(self):
self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas)
self.assertEqual(0, len(skipped.domain.attributes))

@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_batches_success(self):
for i in range(1, 11): # try different batch sizes
result = self.sbert.embed_batches(self.corpus.documents, i)
self.assertEqual(result, IDEAL_RESPONSE)

@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_batches_none_result(self):
for i in range(1, 11): # try different batch sizes
result = self.sbert.embed_batches(self.corpus.documents, i)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])

@patch("orangecontrib.text.vectorization.sbert._ServerCommunicator.embedd_data")
def test_reordered(self, mock):
"""Test that texts are reordered according to their length"""
self.sbert(self.corpus.documents)
mock.assert_called_with(
tuple(sorted(self.corpus.documents, key=len, reverse=True)), callback=ANY
)

self.sbert([["1", "2"], ["4", "5", "6"], ["0"]])
mock.assert_called_with((["4", "5", "6"], ["1", "2"], ["0"]), callback=ANY)


if __name__ == "__main__":
unittest.main()
70 changes: 60 additions & 10 deletions orangecontrib/text/vectorization/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import zlib
import sys
from threading import Thread
from typing import Any, List, Optional, Callable, Tuple
from typing import Any, List, Optional, Callable, Tuple, Union

import numpy as np
from Orange.misc.server_embedder import ServerEmbedderCommunicator
Expand All @@ -29,18 +29,20 @@ def __init__(self) -> None:
)

def __call__(
self, texts: List[str], callback: Callable = dummy_callback
) -> List[Optional[List[float]]]:
self, texts: List[Union[str, List[str]]], callback: Callable = dummy_callback
) -> List[Union[Optional[List[float]], List[Optional[List[float]]]]]:
"""Computes embeddings for given documents.
Parameters
----------
texts
A list of raw texts.
A list of texts or list of text batches (list with text)
Returns
-------
An array of embeddings.
List of embeddings for each document. Each item in the list can be either
list of numbers (embedding) or a None when embedding fails.
When texts is list of batches also responses are returned in batches.
"""
if len(texts) == 0:
return []
Expand All @@ -49,7 +51,7 @@ def __call__(
# at the end and thus add extra time to the complete embedding time
sorted_texts = sorted(
enumerate(texts),
key=lambda x: len(x[1][0]) if x[1] is not None else 0,
key=lambda x: len(x[1]) if x[1] is not None else 0,
reverse=True,
)
indices, sorted_texts = zip(*sorted_texts)
Expand Down Expand Up @@ -111,6 +113,44 @@ def _transform(

return new_corpus, skipped_corpus

def embed_batches(
self,
documents: List[str],
batch_size: int,
*,
callback: Callable = dummy_callback
) -> List[Optional[List[float]]]:
"""
Embed documents by sending batches of documents to the server instead of
sending one document per request. Using this method is suggested when
documents are words or extra short documents. Since they embed fast, the
bottleneck is sending requests to the server, and for those, it is
faster to send them in batches. In the case of documents with at least a
few sentences, the bottleneck is embedding itself. In this case, sending
them in separate requests can speed up embedding since the embedding
process can be more redistributed between workers.
Parameters
----------
documents
List of document that will be sent to the server
batch_size
Number of documents in one batch sent to the server
callback
Callback for reporting the progress
Returns
-------
List of embeddings for each document. Each item in the list can be either
list of numbers (embedding) or a None when embedding fails.
"""
batches = [
documents[ndx : ndx + batch_size]
for ndx in range(0, len(documents), batch_size)
]
embeddings_batches = self(batches)
return [emb for batch in embeddings_batches for emb in batch]

def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.
Expand Down Expand Up @@ -164,10 +204,20 @@ def embedd_data(
else:
return asyncio.run(self.embedd_batch(data, callback=callback))

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
data = base64.b64encode(
zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)
).decode("utf-8", "replace")
async def _encode_data_instance(
self, data_instance: Union[str, List[str]]
) -> Optional[bytes]:
def compress_text(text):
return base64.b64encode(
zlib.compress(text.encode("utf-8", "replace"), level=-1)
).decode("utf-8", "replace")

if isinstance(data_instance, str):
# single document in request
data = compress_text(data_instance)
else:
# request is batch (list of documents)
data = [compress_text(text) for text in data_instance]
if sys.getsizeof(data) > 500000:
# Document in corpus is too large. Size limit is 500 KB
# (after compression). - document skipped
Expand Down

0 comments on commit 9e01a2e

Please sign in to comment.