# type: ignore import os import shutil import sys import tempfile import traceback from datetime import datetime, timedelta from typing import Any import httpx import numpy as np import pytest import chromadb import chromadb.server.fastapi from chromadb.api.fastapi import FastAPI from chromadb.api.types import ( Document, EmbeddingFunction, QueryResult, TYPE_KEY, SPARSE_VECTOR_TYPE_VALUE, ) from chromadb.config import Settings from chromadb.errors import ( ChromaError, NotFoundError, InvalidArgumentError, ) from chromadb.utils.embedding_functions import DefaultEmbeddingFunction @pytest.fixture def persist_dir(): return tempfile.mkdtemp() @pytest.fixture def local_persist_api(persist_dir): client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", allow_reset=True, is_persistent=True, persist_directory=persist_dir, ), ) yield client client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) # https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached @pytest.fixture def local_persist_api_cache_bust(persist_dir): client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", allow_reset=True, is_persistent=True, persist_directory=persist_dir, ), ) yield client client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) def approx_equal(a, b, tolerance=1e-6) -> bool: return abs(a - b) < tolerance def vector_approx_equal(a, b, tolerance: float = 1e-6) -> bool: if len(a) != len(b): return False return all([approx_equal(a, b, tolerance) for a, b in zip(a, b)]) @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_loading(api_fixture, request): client = request.getfixturevalue("local_persist_api") client.reset() collection = client.create_collection("test") collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_collection("test") includes = ["embeddings", "documents", "metadatas", "distances"] nn = collection.query( query_texts="hello", n_results=1, include=["embeddings", "documents", "metadatas", "distances"], ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_loading_embedding_function(api_fixture, request): class TestEF(EmbeddingFunction[Document]): def __call__(self, input): return [np.array([1, 2, 3]) for _ in range(len(input))] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) def name(self) -> str: return "test" def build_from_config(self, config: dict[str, Any]) -> None: pass def get_config(self) -> dict[str, Any]: return {} client = request.getfixturevalue("local_persist_api") client.reset() collection = client.create_collection("test", embedding_function=TestEF()) collection.add(ids="id1", documents="hello") client2 = request.getfixturevalue("local_persist_api_cache_bust") collection = client2.get_collection("test", embedding_function=TestEF()) includes = ["embeddings", "documents", "metadatas", "distances"] nn = collection.query( query_texts="hello", n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_get_or_create_embedding_function(api_fixture, request): class TestEF(EmbeddingFunction[Document]): def __call__(self, input): return [np.array([1, 2, 3]) for _ in range(len(input))] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) def name(self) -> str: return "test" def build_from_config(self, config: dict[str, Any]) -> None: pass def get_config(self) -> dict[str, Any]: return {} api = request.getfixturevalue("local_persist_api") api.reset() collection = api.get_or_create_collection("test", embedding_function=TestEF()) collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_or_create_collection("test", embedding_function=TestEF()) includes = ["embeddings", "documents", "metadatas", "distances"] nn = collection.query( query_texts="hello", n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None assert nn["ids"] == [["id1"]] assert nn["embeddings"][0][0].tolist() == [1, 2, 3] assert nn["documents"] == [["hello"]] assert nn["distances"] == [[0]] @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist(api_fixture, request): client = request.getfixturevalue(api_fixture.__name__) client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 client = request.getfixturevalue(api_fixture.__name__) collection = client.get_collection("testspace") assert collection.count() == 2 client.delete_collection("testspace") client = request.getfixturevalue(api_fixture.__name__) assert client.list_collections() == [] def test_heartbeat(client): heartbeat_ns = client.heartbeat() assert isinstance(heartbeat_ns, int) heartbeat_s = heartbeat_ns // 10**9 heartbeat = datetime.fromtimestamp(heartbeat_s) assert heartbeat > datetime.now() - timedelta(seconds=10) def test_max_batch_size(client): batch_size = client.get_max_batch_size() assert batch_size > 0 def test_supports_base64_encoding(client): if not isinstance(client, FastAPI): pytest.skip("Not a FastAPI instance") client.reset() supports_base64_encoding = client.supports_base64_encoding() assert supports_base64_encoding is True def test_supports_base64_encoding_legacy(client): if not isinstance(client, FastAPI): pytest.skip("Not a FastAPI instance") client.reset() # legacy server does not give back supports_base64_encoding client.pre_flight_checks = { "max_batch_size": 100, } assert client.supports_base64_encoding() is False assert client.get_max_batch_size() == 100 def test_pre_flight_checks(client): if not isinstance(client, FastAPI): pytest.skip("Not a FastAPI instance") resp = httpx.get(f"{client._api_url}/pre-flight-checks") assert resp.status_code == 200 assert resp.json() is not None assert "max_batch_size" in resp.json().keys() assert "supports_base64_encoding" in resp.json().keys() batch_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["https://example.com/1", "https://example.com/2"], } def test_add(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 def test_collection_add_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.add(**batch_records) def test_get_or_create(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 with pytest.raises(Exception): collection = client.create_collection("testspace") collection = client.get_or_create_collection("testspace") assert collection.count() == 2 minimal_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["https://example.com/1", "https://example.com/2"], } def test_add_minimal(client): client.reset() collection = client.create_collection("testspace") collection.add(**minimal_records) assert collection.count() == 2 def test_get_from_db(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) includes = ["embeddings", "documents", "metadatas"] records = collection.get(include=includes) for key in records.keys(): if (key in includes) or (key == "ids"): assert len(records[key]) == 2 elif key == "included": assert set(records[key]) == set(includes) else: assert records[key] is None def test_collection_get_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.get() def test_reset_db(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 client.reset() assert len(client.list_collections()) == 0 def test_get_nearest_neighbors(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) includes = ["embeddings", "documents", "metadatas", "distances"] nn = collection.query( query_embeddings=[1.1, 2.3, 3.2], n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None nn = collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None nn = collection.query( query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]], n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 2 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None def test_delete(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 with pytest.raises(Exception): collection.delete() def test_delete_returns_none(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 assert collection.delete(ids=batch_records["ids"]) is None def test_delete_with_index(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1) def test_collection_delete_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.delete(ids=["id1"]) def test_count(client): client.reset() collection = client.create_collection("testspace") assert collection.count() == 0 collection.add(**batch_records) assert collection.count() == 2 def test_collection_count_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.count() def test_modify(client): client.reset() collection = client.create_collection("testspace") collection.modify(name="testspace2") # collection name is modify assert collection.name == "testspace2" def test_collection_modify_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.modify(name="test2") def test_modify_error_on_existing_name(client): client.reset() client.create_collection("testspace") c2 = client.create_collection("testspace2") with pytest.raises(Exception): c2.modify(name="testspace") def test_modify_warn_on_DF_change(client, caplog): client.reset() collection = client.create_collection("testspace") with pytest.raises(Exception, match="not supported"): collection.modify(metadata={"hnsw:space": "cosine"}) def test_metadata_cru(client): client.reset() metadata_a = {"a": 1, "b": 2} # Test create metadata collection = client.create_collection("testspace", metadata=metadata_a) assert collection.metadata is not None assert collection.metadata["a"] == 1 assert collection.metadata["b"] == 2 # Test get metadata collection = client.get_collection("testspace") assert collection.metadata is not None assert collection.metadata["a"] == 1 assert collection.metadata["b"] == 2 # Test modify metadata collection.modify(metadata={"a": 2, "c": 3}) assert collection.metadata["a"] == 2 assert collection.metadata["c"] == 3 assert "b" not in collection.metadata # Test get after modify metadata collection = client.get_collection("testspace") assert collection.metadata is not None assert collection.metadata["a"] == 2 assert collection.metadata["c"] == 3 assert "b" not in collection.metadata # Test name exists get_or_create_metadata collection = client.get_or_create_collection("testspace") assert collection.metadata is not None assert collection.metadata["a"] == 2 assert collection.metadata["c"] == 3 # Test name exists create metadata collection = client.get_or_create_collection("testspace2") assert collection.metadata is None # Test list collections collections = client.list_collections() for collection in collections: if collection.name == "testspace": assert collection.metadata is not None assert collection.metadata["a"] == 2 assert collection.metadata["c"] == 3 elif collection.name == "testspace2": assert collection.metadata is None def test_increment_index_on(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 includes = ["embeddings", "documents", "metadatas", "distances"] # increment index nn = collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None def test_add_a_collection(client): client.reset() client.create_collection("testspace") # get collection does not throw an error collection = client.get_collection("testspace") assert collection.name == "testspace" # get collection should throw an error if collection does not exist with pytest.raises(Exception): collection = client.get_collection("testspace2") def test_error_includes_trace_id(http_client): http_client.reset() with pytest.raises(ChromaError) as error: http_client.get_collection("testspace2") assert error.value.trace_id is not None def test_list_collections(client): client.reset() client.create_collection("testspace") client.create_collection("testspace2") # get collection does not throw an error collections = client.list_collections() assert len(collections) == 2 def test_reset(client): client.reset() client.create_collection("testspace") client.create_collection("testspace2") # get collection does not throw an error collections = client.list_collections() assert len(collections) == 2 client.reset() collections = client.list_collections() assert len(collections) == 0 def test_peek(client): client.reset() collection = client.create_collection("testspace") collection.add(**batch_records) assert collection.count() == 2 # peek peek = collection.peek() print(peek) for key in peek.keys(): if key in ["embeddings", "documents", "metadatas"] or key == "ids": assert len(peek[key]) == 2 elif key == "included": assert set(peek[key]) == set(["embeddings", "metadatas", "documents"]) else: assert peek[key] is None def test_collection_peek_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.peek() def test_collection_query_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.query(query_texts=["test"]) def test_collection_update_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.update(ids=["id1"], documents=["test"]) # TEST METADATA AND METADATA FILTERING # region metadata_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["id1", "id2"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, ], } def test_metadata_add_get_int_float(client): client.reset() collection = client.create_collection("test_int") collection.add(**metadata_records) items = collection.get(ids=["id1", "id2"]) assert items["metadatas"][0]["int_value"] == 1 assert items["metadatas"][0]["float_value"] == 1.001 assert items["metadatas"][1]["int_value"] == 2 assert isinstance(items["metadatas"][0]["int_value"], int) assert isinstance(items["metadatas"][0]["float_value"], float) def test_metadata_add_query_int_float(client): client.reset() collection = client.create_collection("test_int") collection.add(**metadata_records) items: QueryResult = collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results=1 ) assert items["metadatas"] is not None assert items["metadatas"][0][0]["int_value"] == 1 assert items["metadatas"][0][0]["float_value"] == 1.001 assert isinstance(items["metadatas"][0][0]["int_value"], int) assert isinstance(items["metadatas"][0][0]["float_value"], float) def test_metadata_get_where_string(client): client.reset() collection = client.create_collection("test_int") collection.add(**metadata_records) items = collection.get(where={"string_value": "one"}) assert items["metadatas"][0]["int_value"] == 1 assert items["metadatas"][0]["string_value"] == "one" def test_metadata_get_where_int(client): client.reset() collection = client.create_collection("test_int") collection.add(**metadata_records) items = collection.get(where={"int_value": 1}) assert items["metadatas"][0]["int_value"] == 1 assert items["metadatas"][0]["string_value"] == "one" def test_metadata_get_where_float(client): client.reset() collection = client.create_collection("test_int") collection.add(**metadata_records) items = collection.get(where={"float_value": 1.001}) assert items["metadatas"][0]["int_value"] == 1 assert items["metadatas"][0]["string_value"] == "one" assert items["metadatas"][0]["float_value"] == 1.001 def test_metadata_update_get_int_float(client): client.reset() collection = client.create_collection("test_int") collection.add(**metadata_records) collection.update( ids=["id1"], metadatas=[{"int_value": 2, "string_value": "two", "float_value": 2.002}], ) items = collection.get(ids=["id1"]) assert items["metadatas"][0]["int_value"] == 2 assert items["metadatas"][0]["string_value"] == "two" assert items["metadatas"][0]["float_value"] == 2.002 bad_metadata_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["id1", "id2"], "metadatas": [{"value": {"nested": "5"}}, {"value": [1, 2, 3]}], } def test_metadata_validation_add(client): client.reset() collection = client.create_collection("test_metadata_validation") with pytest.raises(ValueError, match="metadata"): collection.add(**bad_metadata_records) def test_metadata_validation_update(client): client.reset() collection = client.create_collection("test_metadata_validation") collection.add(**metadata_records) with pytest.raises(ValueError, match="metadata"): collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}}) def test_where_validation_get(client): client.reset() collection = client.create_collection("test_where_validation") with pytest.raises(ValueError, match="where"): collection.get(where={"value": {"nested": "5"}}) def test_where_validation_query(client): client.reset() collection = client.create_collection("test_where_validation") with pytest.raises(ValueError, match="where"): collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}}) operator_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["id1", "id2"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2, "float_value": 2.002, "string_value": "two"}, ], } def test_where_lt(client): client.reset() collection = client.create_collection("test_where_lt") collection.add(**operator_records) items = collection.get(where={"int_value": {"$lt": 2}}) assert len(items["metadatas"]) == 1 def test_where_lte(client): client.reset() collection = client.create_collection("test_where_lte") collection.add(**operator_records) items = collection.get(where={"int_value": {"$lte": 2.0}}) assert len(items["metadatas"]) == 2 def test_where_gt(client): client.reset() collection = client.create_collection("test_where_lte") collection.add(**operator_records) items = collection.get(where={"float_value": {"$gt": -1.4}}) assert len(items["metadatas"]) == 2 def test_where_gte(client): client.reset() collection = client.create_collection("test_where_lte") collection.add(**operator_records) items = collection.get(where={"float_value": {"$gte": 2.002}}) assert len(items["metadatas"]) == 1 def test_where_ne_string(client): client.reset() collection = client.create_collection("test_where_lte") collection.add(**operator_records) items = collection.get(where={"string_value": {"$ne": "two"}}) assert len(items["metadatas"]) == 1 def test_where_ne_eq_number(client): client.reset() collection = client.create_collection("test_where_lte") collection.add(**operator_records) items = collection.get(where={"int_value": {"$ne": 1}}) assert len(items["metadatas"]) == 1 items = collection.get(where={"float_value": {"$eq": 2.002}}) assert len(items["metadatas"]) == 1 def test_where_valid_operators(client): client.reset() collection = client.create_collection("test_where_valid_operators") collection.add(**operator_records) with pytest.raises(ValueError): collection.get(where={"int_value": {"$invalid": 2}}) with pytest.raises(ValueError): collection.get(where={"int_value": {"$lt": "2"}}) with pytest.raises(ValueError): collection.get(where={"int_value": {"$lt": 2, "$gt": 1}}) # Test invalid $and, $or with pytest.raises(ValueError): collection.get(where={"$and": {"int_value": {"$lt": 2}}}) with pytest.raises(ValueError): collection.get( where={"int_value": {"$lt": 2}, "$or": {"int_value": {"$gt": 1}}} ) with pytest.raises(ValueError): collection.get( where={"$gt": [{"int_value": {"$lt": 2}}, {"int_value": {"$gt": 1}}]} ) with pytest.raises(ValueError): collection.get(where={"$or": [{"int_value": {"$lt": 2}}]}) with pytest.raises(ValueError): collection.get(where={"$or": []}) with pytest.raises(ValueError): collection.get(where={"a": {"$contains": "test"}}) with pytest.raises(ValueError): collection.get( where={ "$or": [ {"a": {"$contains": "first"}}, # invalid {"$contains": "second"}, # valid ] } ) # TODO: Define the dimensionality of these embeddingds in terms of the default record bad_dimensionality_records = { "embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]], "ids": ["id1", "id2"], } bad_dimensionality_query = { "query_embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]], } bad_number_of_results_query = { "query_embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "n_results": 100, } def test_dimensionality_validation_add(client): client.reset() collection = client.create_collection("test_dimensionality_validation") collection.add(**minimal_records) with pytest.raises(Exception) as e: collection.add(**bad_dimensionality_records) assert "dimension" in str(e.value) def test_dimensionality_validation_query(client): client.reset() collection = client.create_collection("test_dimensionality_validation_query") collection.add(**minimal_records) with pytest.raises(Exception) as e: collection.query(**bad_dimensionality_query) assert "dimension" in str(e.value) def test_query_document_valid_operators(client): client.reset() collection = client.create_collection("test_where_valid_operators") collection.add(**operator_records) with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$lt": {"$nested": 2}}) with pytest.raises(ValueError, match="where document"): collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2}) with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$contains": []}) # Test invalid $contains with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$contains": {"text": "hello"}}) # Test invalid $not_contains with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$not_contains": {"text": "hello"}}) # Test invalid $and, $or with pytest.raises(ValueError): collection.get(where_document={"$and": {"$unsupported": "doc"}}) with pytest.raises(ValueError): collection.get( where_document={"$or": [{"$unsupported": "doc"}, {"$unsupported": "doc"}]} ) with pytest.raises(ValueError): collection.get(where_document={"$or": [{"$contains": "doc"}]}) with pytest.raises(ValueError): collection.get(where_document={"$or": []}) with pytest.raises(ValueError): collection.get( where_document={ "$or": [{"$and": [{"$contains": "doc"}]}, {"$contains": "doc"}] } ) contains_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "documents": ["this is doc1 and it's great!", "doc2 is also great!"], "ids": ["id1", "id2"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2, "float_value": 2.002, "string_value": "two"}, ], } def test_get_where_document(client): client.reset() collection = client.create_collection("test_get_where_document") collection.add(**contains_records) items = collection.get(where_document={"$contains": "doc1"}) assert len(items["metadatas"]) == 1 items = collection.get(where_document={"$contains": "great"}) assert len(items["metadatas"]) == 2 items = collection.get(where_document={"$contains": "bad"}) assert len(items["metadatas"]) == 0 def test_query_where_document(client): client.reset() collection = client.create_collection("test_query_where_document") collection.add(**contains_records) items = collection.query( query_embeddings=[1, 0, 0], where_document={"$contains": "doc1"}, n_results=1 ) assert len(items["metadatas"][0]) == 1 items = collection.query( query_embeddings=[0, 0, 0], where_document={"$contains": "great"}, n_results=2 ) assert len(items["metadatas"][0]) == 2 with pytest.raises(Exception) as e: items = collection.query( query_embeddings=[0, 0, 0], where_document={"$contains": "bad"}, n_results=1 ) assert "datapoints" in str(e.value) def test_delete_where_document(client): client.reset() collection = client.create_collection("test_delete_where_document") collection.add(**contains_records) collection.delete(where_document={"$contains": "doc1"}) assert collection.count() == 1 collection.delete(where_document={"$contains": "bad"}) assert collection.count() == 1 collection.delete(where_document={"$contains": "great"}) assert collection.count() == 0 logical_operator_records = { "embeddings": [ [1.1, 2.3, 3.2], [1.2, 2.24, 3.2], [1.3, 2.25, 3.2], [1.4, 2.26, 3.2], ], "ids": ["id1", "id2", "id3", "id4"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001, "is": "doc"}, {"int_value": 2, "float_value": 2.002, "string_value": "two", "is": "doc"}, {"int_value": 3, "float_value": 3.003, "string_value": "three", "is": "doc"}, {"int_value": 4, "float_value": 4.004, "string_value": "four", "is": "doc"}, ], "documents": [ "this document is first and great", "this document is second and great", "this document is third and great", "this document is fourth and great", ], } def test_where_logical_operators(client): client.reset() collection = client.create_collection("test_logical_operators") collection.add(**logical_operator_records) items = collection.get( where={ "$and": [ {"$or": [{"int_value": {"$gte": 3}}, {"float_value": {"$lt": 1.9}}]}, {"is": "doc"}, ] } ) assert len(items["metadatas"]) == 3 items = collection.get( where={ "$or": [ { "$and": [ {"int_value": {"$eq": 3}}, {"string_value": {"$eq": "three"}}, ] }, { "$and": [ {"int_value": {"$eq": 4}}, {"string_value": {"$eq": "four"}}, ] }, ] } ) assert len(items["metadatas"]) == 2 items = collection.get( where={ "$and": [ { "$or": [ {"int_value": {"$eq": 1}}, {"string_value": {"$eq": "two"}}, ] }, { "$or": [ {"int_value": {"$eq": 2}}, {"string_value": {"$eq": "one"}}, ] }, ] } ) assert len(items["metadatas"]) == 2 def test_where_document_logical_operators(client): client.reset() collection = client.create_collection("test_document_logical_operators") collection.add(**logical_operator_records) items = collection.get( where_document={ "$and": [ {"$contains": "first"}, {"$contains": "doc"}, ] } ) assert len(items["metadatas"]) == 1 items = collection.get( where_document={ "$or": [ {"$contains": "first"}, {"$contains": "second"}, ] } ) assert len(items["metadatas"]) == 2 items = collection.get( where_document={ "$or": [ {"$contains": "first"}, {"$contains": "second"}, ] }, where={ "int_value": {"$ne": 2}, }, ) assert len(items["metadatas"]) == 1 # endregion records = { "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2]], "ids": ["id1", "id2"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, ], "documents": ["this document is first", "this document is second"], } def test_query_include(client): client.reset() collection = client.create_collection("test_query_include") collection.add(**records) include = ["metadatas", "documents", "distances"] items = collection.query( query_embeddings=[0, 0, 0], include=include, n_results=1, ) assert items["embeddings"] is None assert items["ids"][0][0] == "id1" assert items["metadatas"][0][0]["int_value"] == 1 assert set(items["included"]) == set(include) include = ["embeddings", "documents", "distances"] items = collection.query( query_embeddings=[0, 0, 0], include=include, n_results=1, ) assert items["metadatas"] is None assert items["ids"][0][0] == "id1" assert set(items["included"]) == set(include) items = collection.query( query_embeddings=[[0, 0, 0], [1, 2, 1.2]], include=[], n_results=2, ) assert items["documents"] is None assert items["metadatas"] is None assert items["embeddings"] is None assert items["distances"] is None assert items["ids"][0][0] == "id1" assert items["ids"][0][1] == "id2" def test_get_include(client): client.reset() collection = client.create_collection("test_get_include") collection.add(**records) include = ["metadatas", "documents"] items = collection.get(include=include, where={"int_value": 1}) assert items["embeddings"] is None assert items["ids"][0] == "id1" assert items["metadatas"][0]["int_value"] == 1 assert items["documents"][0] == "this document is first" assert set(items["included"]) == set(include) include = ["embeddings", "documents"] items = collection.get(include=include) assert items["metadatas"] is None assert items["ids"][0] == "id1" assert approx_equal(items["embeddings"][1][0], 1.2) assert set(items["included"]) == set(include) items = collection.get(include=[]) assert items["documents"] is None assert items["metadatas"] is None assert items["embeddings"] is None assert items["ids"][0] == "id1" assert items["included"] == [] with pytest.raises(ValueError, match="include"): items = collection.get(include=["metadatas", "undefined"]) with pytest.raises(ValueError, match="include"): items = collection.get(include=None) # make sure query results are returned in the right order def test_query_order(client): client.reset() collection = client.create_collection("test_query_order") collection.add(**records) items = collection.query( query_embeddings=[1.2, 2.24, 3.2], include=["metadatas", "documents", "distances"], n_results=2, ) assert items["documents"][0][0] == "this document is second" assert items["documents"][0][1] == "this document is first" # test to make sure add, get, delete error on invalid id input def test_invalid_id(client): client.reset() collection = client.create_collection("test_invalid_id") # Add with non-string id with pytest.raises(ValueError) as e: collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}]) assert "ID" in str(e.value) # Get with non-list id with pytest.raises(ValueError) as e: collection.get(ids=1) assert "ID" in str(e.value) # Delete with malformed ids with pytest.raises(ValueError) as e: collection.delete(ids=["valid", 0]) assert "ID" in str(e.value) def test_index_params(client): EPS = 1e-12 # first standard add client.reset() collection = client.create_collection(name="test_index_params") collection.add(**records) items = collection.query( query_embeddings=[0.6, 1.12, 1.6], n_results=1, ) assert items["distances"][0][0] > 4 # cosine client.reset() collection = client.create_collection( name="test_index_params", metadata={"hnsw:space": "cosine", "hnsw:construction_ef": 20, "hnsw:M": 5}, ) collection.add(**records) items = collection.query( query_embeddings=[0.6, 1.12, 1.6], n_results=1, ) assert items["distances"][0][0] > 0 - EPS assert items["distances"][0][0] < 1 + EPS # ip client.reset() collection = client.create_collection( name="test_index_params", metadata={"hnsw:space": "ip"} ) collection.add(**records) items = collection.query( query_embeddings=[0.6, 1.12, 1.6], n_results=1, ) assert items["distances"][0][0] < -5 def test_invalid_index_params(client): client.reset() with pytest.raises(InvalidArgumentError): collection = client.create_collection( name="test_index_params", metadata={"hnsw:space": "foobar"} ) collection.add(**records) def test_persist_index_loading_params(client, request): client = request.getfixturevalue("local_persist_api") client.reset() collection = client.create_collection( "test", metadata={"hnsw:space": "ip"}, ) collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_collection( "test", ) assert collection.metadata["hnsw:space"] == "ip" includes = ["embeddings", "documents", "metadatas", "distances"] nn = collection.query( query_texts="hello", n_results=1, include=includes, ) for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 elif key == "included": assert set(nn[key]) == set(includes) else: assert nn[key] is None def test_add_large(client): client.reset() collection = client.create_collection("testspace") # Test adding a large number of records large_records = np.random.rand(2000, 512).astype(np.float32).tolist() collection.add( embeddings=large_records, ids=[f"http://example.com/{i}" for i in range(len(large_records))], ) assert collection.count() == len(large_records) # test get_version def test_get_version(client): client.reset() version = client.get_version() # assert version matches the pattern x.y.z import re assert re.match(r"\d+\.\d+\.\d+", version) # test delete_collection def test_delete_collection(client): client.reset() collection = client.create_collection("test_delete_collection") collection.add(**records) assert len(client.list_collections()) == 1 client.delete_collection("test_delete_collection") assert len(client.list_collections()) == 0 # test default embedding function def test_default_embedding(): embedding_function = DefaultEmbeddingFunction() docs = ["this is a test" for _ in range(64)] embeddings = embedding_function(docs) assert len(embeddings) == 64 def test_multiple_collections(client): embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist() embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist() ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))] ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))] client.reset() coll1 = client.create_collection("coll1") coll1.add(embeddings=embeddings1, ids=ids1) coll2 = client.create_collection("coll2") coll2.add(embeddings=embeddings2, ids=ids2) assert len(client.list_collections()) == 2 assert coll1.count() == len(embeddings1) assert coll2.count() == len(embeddings2) results1 = coll1.query(query_embeddings=embeddings1[0], n_results=1) results2 = coll2.query(query_embeddings=embeddings2[0], n_results=1) # progressively check the results are what we expect so we can debug when/if flakes happen assert len(results1["ids"]) > 0 assert len(results2["ids"]) > 0 assert len(results1["ids"][0]) > 0 assert len(results2["ids"][0]) > 0 assert results1["ids"][0][0] == ids1[0] assert results2["ids"][0][0] == ids2[0] def test_update_query(client): client.reset() collection = client.create_collection("test_update_query") collection.add(**records) updated_records = { "ids": [records["ids"][0]], "embeddings": [[0.1, 0.2, 0.3]], "documents": ["updated document"], "metadatas": [{"foo": "bar"}], } collection.update(**updated_records) # test query results = collection.query( query_embeddings=updated_records["embeddings"], n_results=1, include=["embeddings", "documents", "metadatas"], ) assert len(results["ids"][0]) == 1 assert results["ids"][0][0] == updated_records["ids"][0] assert results["documents"][0][0] == updated_records["documents"][0] assert results["metadatas"][0][0]["foo"] == "bar" assert vector_approx_equal( results["embeddings"][0][0], updated_records["embeddings"][0] ) def test_get_nearest_neighbors_where_n_results_more_than_element(client): client.reset() collection = client.create_collection("testspace") collection.add(**records) includes = ["embeddings", "documents", "metadatas", "distances"] results = collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results=5, include=includes, ) for key in results.keys(): if key in includes or key == "ids": assert len(results[key][0]) == 2 elif key == "included": assert set(results[key]) == set(includes) else: assert results[key] is None def test_invalid_n_results_param(client): client.reset() collection = client.create_collection("testspace") collection.add(**records) with pytest.raises(TypeError) as exc: collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results=-1, include=["embeddings", "documents", "metadatas", "distances"], ) assert "Number of requested results -1, cannot be negative, or zero." in str( exc.value ) assert exc.type == TypeError with pytest.raises(ValueError) as exc: collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results="one", include=["embeddings", "documents", "metadatas", "distances"], ) assert "int" in str(exc.value) assert exc.type == ValueError initial_records = { "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]], "ids": ["id1", "id2", "id3"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, {"string_value": "three"}, ], "documents": [ "this document is first", "this document is second", "this document is third", ], } new_records = { "embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]], "ids": ["id1", "id4"], "metadatas": [ {"int_value": 1, "string_value": "one_of_one", "float_value": 1.001}, {"int_value": 4}, ], "documents": [ "this document is even more first", "this document is new and fourth", ], } def test_upsert(client): client.reset() collection = client.create_collection("test") collection.add(**initial_records) assert collection.count() == 3 collection.upsert(**new_records) assert collection.count() == 4 get_result = collection.get( include=["embeddings", "metadatas", "documents"], ids=new_records["ids"][0] ) assert vector_approx_equal( get_result["embeddings"][0], new_records["embeddings"][0] ) assert get_result["metadatas"][0] == new_records["metadatas"][0] assert get_result["documents"][0] == new_records["documents"][0] query_result = collection.query( query_embeddings=get_result["embeddings"], n_results=1, include=["embeddings", "metadatas", "documents"], ) assert vector_approx_equal( query_result["embeddings"][0][0], new_records["embeddings"][0] ) assert query_result["metadatas"][0][0] == new_records["metadatas"][0] assert query_result["documents"][0][0] == new_records["documents"][0] collection.delete(ids=initial_records["ids"][2]) collection.upsert( ids=initial_records["ids"][2], embeddings=[[1.1, 0.99, 2.21]], metadatas=[{"string_value": "a new string value"}], ) assert collection.count() == 4 get_result = collection.get( include=["embeddings", "metadatas", "documents"], ids=["id3"] ) assert vector_approx_equal(get_result["embeddings"][0], [1.1, 0.99, 2.21]) assert get_result["metadatas"][0] == {"string_value": "a new string value"} assert get_result["documents"][0] is None def test_collection_upsert_with_invalid_collection_throws(client): client.reset() collection = client.create_collection("test") client.delete_collection("test") with pytest.raises(NotFoundError, match=r"Collection .* does not exist"): collection.upsert(**initial_records) # test to make sure add, query, update, upsert error on invalid embeddings input def test_invalid_embeddings(client): client.reset() collection = client.create_collection("test_invalid_embeddings") # Add with string embeddings invalid_records = { "embeddings": [["0", "0", "0"], ["1.2", "2.24", "3.2"]], "ids": ["id1", "id2"], } with pytest.raises(ValueError) as e: collection.add(**invalid_records) assert "embedding" in str(e.value) # Query with invalid embeddings with pytest.raises(ValueError) as e: collection.query( query_embeddings=[["1.1", "2.3", "3.2"]], n_results=1, ) assert "embedding" in str(e.value) # Update with invalid embeddings invalid_records = { "embeddings": [[[0], [0], [0]], [[1.2], [2.24], [3.2]]], "ids": ["id1", "id2"], } with pytest.raises(ValueError) as e: collection.update(**invalid_records) assert "embedding" in str(e.value) # Upsert with invalid embeddings invalid_records = { "embeddings": [[[1.1, 2.3, 3.2]], [[1.2, 2.24, 3.2]]], "ids": ["id1", "id2"], } with pytest.raises(ValueError) as e: collection.upsert(**invalid_records) assert "embedding" in str(e.value) # test to make sure update shows exception for bad dimensionality def test_dimensionality_exception_update(client): client.reset() collection = client.create_collection("test_dimensionality_update_exception") collection.add(**minimal_records) with pytest.raises(Exception) as e: collection.update(**bad_dimensionality_records) assert "dimension" in str(e.value) # test to make sure upsert shows exception for bad dimensionality def test_dimensionality_exception_upsert(client): client.reset() collection = client.create_collection("test_dimensionality_upsert_exception") collection.add(**minimal_records) with pytest.raises(Exception) as e: collection.upsert(**bad_dimensionality_records) assert "dimension" in str(e.value) # this may be flaky on windows, so we rerun it @pytest.mark.flaky(reruns=3, condition=sys.platform.startswith("win32")) def test_ssl_self_signed(client_ssl): if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"): pytest.skip("Skipping test for integration test") client_ssl.heartbeat() # this may be flaky on windows, so we rerun it @pytest.mark.flaky(reruns=3, condition=sys.platform.startswith("win32")) def test_ssl_self_signed_without_ssl_verify(client_ssl): if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"): pytest.skip("Skipping test for integration test") client_ssl.heartbeat() _port = client_ssl._server._settings.chroma_server_http_port with pytest.raises(ValueError) as e: chromadb.HttpClient(ssl=True, port=_port) stack_trace = traceback.format_exception( type(e.value), e.value, e.value.__traceback__ ) client_ssl.clear_system_cache() assert "CERTIFICATE_VERIFY_FAILED" in "".join(stack_trace) def test_query_id_filtering_small_dataset(client): client.reset() collection = client.create_collection("test_query_id_filtering_small") num_vectors = 100 dim = 512 small_records = np.random.rand(100, 512).astype(np.float32).tolist() ids = [f"{i}" for i in range(num_vectors)] collection.add( embeddings=small_records, ids=ids, ) query_ids = [f"{i}" for i in range(0, num_vectors, 10)] query_embedding = np.random.rand(dim).astype(np.float32).tolist() results = collection.query( query_embeddings=query_embedding, ids=query_ids, n_results=num_vectors, include=[], ) all_returned_ids = [item for sublist in results["ids"] for item in sublist] assert all(id in query_ids for id in all_returned_ids) def test_query_id_filtering_medium_dataset(client): client.reset() collection = client.create_collection("test_query_id_filtering_medium") num_vectors = 1000 dim = 512 medium_records = np.random.rand(num_vectors, dim).astype(np.float32).tolist() ids = [f"{i}" for i in range(num_vectors)] collection.add( embeddings=medium_records, ids=ids, ) query_ids = [f"{i}" for i in range(0, num_vectors, 10)] query_embedding = np.random.rand(dim).astype(np.float32).tolist() results = collection.query( query_embeddings=query_embedding, ids=query_ids, n_results=num_vectors, include=[], ) all_returned_ids = [item for sublist in results["ids"] for item in sublist] assert all(id in query_ids for id in all_returned_ids) multi_query_embeddings = [ np.random.rand(dim).astype(np.float32).tolist() for _ in range(3) ] multi_results = collection.query( query_embeddings=multi_query_embeddings, ids=query_ids, n_results=10, include=[], ) for result_set in multi_results["ids"]: assert all(id in query_ids for id in result_set) def test_query_id_filtering_e2e(client): client.reset() collection = client.create_collection("test_query_id_filtering_e2e") dim = 512 num_vectors = 100 embeddings = np.random.rand(num_vectors, dim).astype(np.float32).tolist() ids = [f"{i}" for i in range(num_vectors)] metadatas = [{"index": i} for i in range(num_vectors)] collection.add( embeddings=embeddings, ids=ids, metadatas=metadatas, ) ids_to_delete = [f"{i}" for i in range(10, 30)] collection.delete(ids=ids_to_delete) # modify some existing ids, and add some new ones to check query returns updated metadata ids_to_upsert_existing = [f"{i}" for i in range(30, 50)] new_num_vectors = num_vectors + 20 ids_to_upsert_new = [f"{i}" for i in range(num_vectors, new_num_vectors)] upsert_embeddings = ( np.random.rand(len(ids_to_upsert_existing) + len(ids_to_upsert_new), dim) .astype(np.float32) .tolist() ) upsert_metadatas = [ {"index": i, "upserted": True} for i in range(len(upsert_embeddings)) ] collection.upsert( embeddings=upsert_embeddings, ids=ids_to_upsert_existing + ids_to_upsert_new, metadatas=upsert_metadatas, ) valid_query_ids = ( [f"{i}" for i in range(5, 10)] # subset of existing ids + [f"{i}" for i in range(35, 45)] # subset of existing, but upserted + [ f"{i}" for i in range(num_vectors + 5, num_vectors + 15) ] # subset of new upserted ids ) includes = ["metadatas"] query_embedding = np.random.rand(dim).astype(np.float32).tolist() results = collection.query( query_embeddings=query_embedding, ids=valid_query_ids, n_results=new_num_vectors, include=includes, ) all_returned_ids = [item for sublist in results["ids"] for item in sublist] assert all(id in valid_query_ids for id in all_returned_ids) for result_index, id_list in enumerate(results["ids"]): for item_index, item_id in enumerate(id_list): if item_id in ids_to_upsert_existing or item_id in ids_to_upsert_new: # checks if metadata correctly has upserted flag assert results["metadatas"][result_index][item_index]["upserted"] upserted_id = ids_to_upsert_existing[0] # test single id filtering results = collection.query( query_embeddings=query_embedding, ids=upserted_id, n_results=1, include=includes, ) assert results["metadatas"][0][0]["upserted"] deleted_id = ids_to_delete[0] # test deleted id filter raises with pytest.raises(Exception) as error: collection.query( query_embeddings=query_embedding, ids=deleted_id, n_results=1, include=includes, ) assert "Error finding id" in str(error.value) def test_validate_sparse_vector(): """Test SparseVector validation in __post_init__.""" from chromadb.base_types import SparseVector # Test 1: Valid sparse vector - should not raise SparseVector(indices=[0, 2, 5], values=[0.1, 0.5, 0.9]) # Test 2: Valid sparse vector with empty lists - should not raise SparseVector(indices=[], values=[]) # Test 4: Invalid - indices not a list with pytest.raises(ValueError, match="Expected SparseVector indices to be a list"): SparseVector(indices="not_a_list", values=[0.1, 0.2]) # type: ignore # Test 5: Invalid - values not a list with pytest.raises(ValueError, match="Expected SparseVector values to be a list"): SparseVector(indices=[0, 1], values="not_a_list") # type: ignore # Test 6: Invalid - mismatched lengths with pytest.raises( ValueError, match="indices and values must have the same length" ): SparseVector(indices=[0, 1, 2], values=[0.1, 0.2]) # Test 7: Invalid - non-integer index with pytest.raises(ValueError, match="SparseVector indices must be integers"): SparseVector(indices=[0, "not_int", 2], values=[0.1, 0.2, 0.3]) # type: ignore # Test 8: Invalid - negative index with pytest.raises(ValueError, match="SparseVector indices must be non-negative"): SparseVector(indices=[0, -1, 2], values=[0.1, 0.2, 0.3]) # Test 9: Invalid - non-numeric value with pytest.raises(ValueError, match="SparseVector values must be numbers"): SparseVector(indices=[0, 1, 2], values=[0.1, "not_number", 0.3]) # type: ignore # Test 10: Invalid - float indices (not integers) with pytest.raises(ValueError, match="SparseVector indices must be integers"): SparseVector(indices=[0.0, 1.0, 2.0], values=[0.1, 0.2, 0.3]) # type: ignore # Test 11: Valid - integer values (not just floats) SparseVector(indices=[0, 1, 2], values=[1, 2, 3]) # Test 12: Valid - mixed int and float values SparseVector(indices=[0, 1, 2], values=[1, 2.5, 3]) # Test 13: Valid - large indices SparseVector(indices=[100, 1000, 10000], values=[0.1, 0.2, 0.3]) # Test 14: Invalid - None as value with pytest.raises(ValueError, match="SparseVector values must be numbers"): SparseVector(indices=[0, 1], values=[0.1, None]) # type: ignore # Test 15: Invalid - None as index with pytest.raises(ValueError, match="SparseVector indices must be integers"): SparseVector(indices=[0, None], values=[0.1, 0.2]) # type: ignore # Test 16: Valid - single element SparseVector(indices=[42], values=[3.14]) # Test 17: Boolean values are actually valid (bool is subclass of int in Python) SparseVector(indices=[0, 1], values=[True, False]) # True=1, False=0 # Test 18: Invalid - unsorted indices with pytest.raises( ValueError, match="indices must be sorted in strictly ascending order" ): SparseVector(indices=[0, 2, 1], values=[0.1, 0.2, 0.3]) # Test 19: Invalid - duplicate indices (not strictly ascending) with pytest.raises( ValueError, match="indices must be sorted in strictly ascending order" ): SparseVector(indices=[0, 1, 1, 2], values=[0.1, 0.2, 0.3, 0.4]) # Test 20: Invalid - descending order with pytest.raises( ValueError, match="indices must be sorted in strictly ascending order" ): SparseVector(indices=[5, 3, 1], values=[0.5, 0.3, 0.1]) def test_sparse_vector_in_metadata_validation(): """Test that sparse vectors are properly validated in metadata.""" from chromadb.api.types import validate_metadata from chromadb.base_types import SparseVector # Test 1: Valid metadata with sparse vectors sparse_vector_1 = SparseVector(indices=[0, 2, 5], values=[0.1, 0.5, 0.9]) sparse_vector_2 = SparseVector(indices=[1, 3, 4], values=[0.2, 0.4, 0.6]) metadata_1 = { "text": "document 1", "sparse_embedding": sparse_vector_1, "score": 0.5, } metadata_2 = { "text": "document 2", "sparse_embedding": sparse_vector_2, "score": 0.8, } validate_metadata(metadata_1) validate_metadata(metadata_2) # Test 2: Valid metadata with empty sparse vector metadata_empty = { "text": "empty sparse", "sparse_vec": SparseVector(indices=[], values=[]), } validate_metadata(metadata_empty) # Test 3: Invalid sparse vector in metadata (construction fails) with pytest.raises( ValueError, match="indices and values must have the same length" ): invalid_metadata = { "text": "invalid", "sparse_embedding": SparseVector(indices=[0, 1], values=[0.1]), } # Test 4: Invalid dict in metadata (not a SparseVector dataclass) invalid_metadata_2 = { "text": "missing indices", "sparse_embedding": {"values": [0.1, 0.2]}, } with pytest.raises( ValueError, match="Expected metadata value to be a str, int, float, bool, SparseVector, or None", ): validate_metadata(invalid_metadata_2) # Test 5: Invalid sparse vector - negative index (construction fails) with pytest.raises(ValueError, match="SparseVector indices must be non-negative"): invalid_metadata_3 = { "text": "negative index", "sparse_embedding": SparseVector( indices=[0, -1, 2], values=[0.1, 0.2, 0.3] ), } # Test 6: Invalid sparse vector - non-numeric value (construction fails) with pytest.raises(ValueError, match="SparseVector values must be numbers"): invalid_metadata_4 = { "text": "non-numeric value", "sparse_embedding": SparseVector( indices=[0, 1], values=[0.1, "not_a_number"] ), # type: ignore } # Test 7: Multiple sparse vectors in metadata metadata_multiple = { "text": "multiple sparse vectors", "sparse_1": SparseVector(indices=[0, 1], values=[0.1, 0.2]), "sparse_2": SparseVector(indices=[2, 3, 4], values=[0.3, 0.4, 0.5]), "regular_field": 42, } validate_metadata(metadata_multiple) # Test 8: Regular dict (not SparseVector) should be rejected metadata_nested = { "config": "some_config", "sparse_vector": {"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}, } with pytest.raises( ValueError, match="Expected metadata value to be a str, int, float, bool, SparseVector, or None", ): validate_metadata(metadata_nested) # Test 9: Large sparse vector large_sparse = SparseVector( indices=list(range(1000)), values=[float(i) * 0.001 for i in range(1000)], ) metadata_large = {"text": "large sparse", "large_sparse_vec": large_sparse} validate_metadata(metadata_large) def test_sparse_vector_dict_format_normalization(): """Test that dict-format sparse vectors are normalized to SparseVector instances.""" from chromadb.api.types import normalize_metadata, validate_metadata from chromadb.base_types import SparseVector # Test 1: Dict format with #type='sparse_vector' should be converted metadata_dict_format = { "text": "test document", "sparse": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [0, 2, 5], "values": [1.0, 2.0, 3.0], }, } normalized = normalize_metadata(metadata_dict_format) assert isinstance(normalized["sparse"], SparseVector) assert normalized["sparse"].indices == [0, 2, 5] assert normalized["sparse"].values == [1.0, 2.0, 3.0] # Should pass validation after normalization validate_metadata(normalized) # Test 2: SparseVector instance should pass through unchanged sparse_instance = SparseVector(indices=[1, 3, 4], values=[0.5, 1.5, 2.5]) metadata_instance_format = { "text": "test document", "sparse": sparse_instance, } normalized2 = normalize_metadata(metadata_instance_format) assert normalized2["sparse"] is sparse_instance # Same object validate_metadata(normalized2) # Test 3: Dict format with unsorted indices should be rejected during normalization metadata_unsorted = { "text": "unsorted", "sparse": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [5, 0, 2], "values": [3.0, 1.0, 2.0], }, } with pytest.raises( ValueError, match="indices must be sorted in strictly ascending order" ): normalize_metadata(metadata_unsorted) # Test 4: Dict format with duplicate indices should be rejected metadata_duplicates = { "text": "duplicates", "sparse": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [0, 2, 2], "values": [1.0, 2.0, 3.0], }, } with pytest.raises( ValueError, match="indices must be sorted in strictly ascending order" ): normalize_metadata(metadata_duplicates) # Test 5: Dict format with negative indices should be rejected metadata_negative = { "text": "negative", "sparse": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [-1, 0, 2], "values": [1.0, 2.0, 3.0], }, } with pytest.raises(ValueError, match="indices must be non-negative"): normalize_metadata(metadata_negative) # Test 6: Dict format with length mismatch should be rejected metadata_mismatch = { "text": "mismatch", "sparse": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [0, 2], "values": [1.0, 2.0, 3.0], }, } with pytest.raises( ValueError, match="indices and values must have the same length" ): normalize_metadata(metadata_mismatch) # Test 7: Regular dict without #type should not be converted metadata_regular_dict = { "text": "regular", "config": {"key": "value"}, } normalized3 = normalize_metadata(metadata_regular_dict) assert isinstance(normalized3["config"], dict) assert normalized3["config"]["key"] == "value" # Test 8: Empty sparse vector in dict format metadata_empty = { "text": "empty", "sparse": {TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [], "values": []}, } normalized4 = normalize_metadata(metadata_empty) assert isinstance(normalized4["sparse"], SparseVector) assert normalized4["sparse"].indices == [] assert normalized4["sparse"].values == [] # Test 9: Multiple sparse vectors in dict format metadata_multiple = { "sparse1": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [0, 1], "values": [1.0, 2.0], }, "sparse2": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [2, 3], "values": [3.0, 4.0], }, "regular": 42, } normalized5 = normalize_metadata(metadata_multiple) assert isinstance(normalized5["sparse1"], SparseVector) assert isinstance(normalized5["sparse2"], SparseVector) assert normalized5["regular"] == 42 def test_sparse_vector_dict_format_in_record_set(): """Test that dict-format sparse vectors work in normalize_insert_record_set.""" from chromadb.api.types import ( normalize_insert_record_set, validate_insert_record_set, ) from chromadb.base_types import SparseVector # Test 1: Mix of dict format and SparseVector instances record_set = normalize_insert_record_set( ids=["doc1", "doc2", "doc3"], embeddings=None, metadatas=[ { "text": "test1", "sparse": { TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [0, 2], "values": [1.0, 2.0], }, }, { "text": "test2", "sparse": SparseVector(indices=[1, 3], values=[1.5, 2.5]), }, {"text": "test3"}, # No sparse vector ], documents=["doc one", "doc two", "doc three"], ) # Both should be converted to SparseVector instances assert isinstance(record_set["metadatas"][0]["sparse"], SparseVector) assert isinstance(record_set["metadatas"][1]["sparse"], SparseVector) assert "sparse" not in record_set["metadatas"][2] # Validation should pass validate_insert_record_set(record_set) # Test 2: Verify values are correct after normalization assert record_set["metadatas"][0]["sparse"].indices == [0, 2] assert record_set["metadatas"][0]["sparse"].values == [1.0, 2.0] assert record_set["metadatas"][1]["sparse"].indices == [1, 3] assert record_set["metadatas"][1]["sparse"].values == [1.5, 2.5] def test_search_result_rows() -> None: """Test the SearchResult.rows() method for converting column-major to row-major format.""" from chromadb.api.types import SearchResult # Test 1: Basic single payload with all fields result = SearchResult( { "ids": [["id1", "id2", "id3"]], "documents": [["doc1", "doc2", "doc3"]], "embeddings": [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]], "metadatas": [[{"key": "a"}, {"key": "b"}, {"key": "c"}]], "scores": [[0.9, 0.8, 0.7]], "select": [["document", "score", "metadata"]], } ) rows = result.rows() assert len(rows) == 1 # One payload assert len(rows[0]) == 3 # Three results # Check first row assert rows[0][0]["id"] == "id1" assert rows[0][0]["document"] == "doc1" assert rows[0][0]["embedding"] == [1.0, 2.0] assert rows[0][0]["metadata"] == {"key": "a"} assert rows[0][0]["score"] == 0.9 # Check all rows have all fields for row in rows[0]: assert "id" in row assert "document" in row assert "embedding" in row assert "metadata" in row assert "score" in row # Test 2: Multiple payloads result = SearchResult( { "ids": [["a1", "a2"], ["b1", "b2", "b3"]], "documents": [["doc_a1", "doc_a2"], ["doc_b1", "doc_b2", "doc_b3"]], "embeddings": [ None, [[1.0], [2.0], [3.0]], ], # First payload has no embeddings "metadatas": [[{"x": 1}, {"x": 2}], None], # Second payload has no metadata "scores": [[0.5, 0.4], [0.9, 0.8, 0.7]], "select": [["document", "score"], ["embedding", "score"]], } ) rows = result.rows() assert len(rows) == 2 # Two payloads assert len(rows[0]) == 2 # First payload has 2 results assert len(rows[1]) == 3 # Second payload has 3 results # First payload - has docs, metadata, scores but no embeddings assert rows[0][0] == { "id": "a1", "document": "doc_a1", "metadata": {"x": 1}, "score": 0.5, } assert rows[0][1] == { "id": "a2", "document": "doc_a2", "metadata": {"x": 2}, "score": 0.4, } # Second payload - has docs, embeddings, scores but no metadata assert rows[1][0] == { "id": "b1", "document": "doc_b1", "embedding": [1.0], "score": 0.9, } assert rows[1][1] == { "id": "b2", "document": "doc_b2", "embedding": [2.0], "score": 0.8, } assert rows[1][2] == { "id": "b3", "document": "doc_b3", "embedding": [3.0], "score": 0.7, } # Test 3: Empty result result = SearchResult( { "ids": [], "documents": [], "embeddings": [], "metadatas": [], "scores": [], "select": [], } ) rows = result.rows() assert rows == [] # Test 4: Sparse data with None values in lists result = SearchResult( { "ids": [["id1", "id2", "id3"]], "documents": [[None, "doc2", None]], # Sparse documents "embeddings": None, # No embeddings at all "metadatas": [[{"a": 1}, None, {"c": 3}]], # Sparse metadata "scores": [[0.9, None, 0.7]], # Sparse scores "select": [["document", "metadata", "score"]], } ) rows = result.rows() assert len(rows) == 1 assert len(rows[0]) == 3 # First row - only has metadata and score assert rows[0][0] == {"id": "id1", "metadata": {"a": 1}, "score": 0.9} # Second row - only has document assert rows[0][1] == {"id": "id2", "document": "doc2"} # Third row - has metadata and score assert rows[0][2] == {"id": "id3", "metadata": {"c": 3}, "score": 0.7} # Test 5: Only IDs (minimal result) result = SearchResult( { "ids": [["id1", "id2"]], "documents": None, "embeddings": None, "metadatas": None, "scores": None, "select": [[]], } ) rows = result.rows() assert len(rows) == 1 assert len(rows[0]) == 2 assert rows[0][0] == {"id": "id1"} assert rows[0][1] == {"id": "id2"} # Test 6: SearchResult works as dict (backward compatibility) result = SearchResult( { "ids": [["test"]], "documents": [["test doc"]], "metadatas": [[{"test": True}]], "embeddings": [[[0.1, 0.2]]], "scores": [[0.99]], "select": [["all"]], } ) # Should work as dict assert result["ids"] == [["test"]] assert result.get("documents") == [["test doc"]] assert "metadatas" in result assert len(result) == 6 # Should have 6 keys # Should also have rows() method rows = result.rows() assert len(rows[0]) == 1 assert rows[0][0]["id"] == "test" print("All SearchResult.rows() tests passed!") def test_rrf_to_dict() -> None: """Test the Rrf (Reciprocal Rank Fusion) to_dict conversion.""" # Note: In these tests, "sparse_embedding" is just an example metadata field name. # Users can store any data in metadata fields and reference them by name (without # prefix). # The "#embedding" key refers to the special main embedding field. import pytest from chromadb.execution.expression.operator import Rrf, Knn, Val # Test 1: Basic RRF with two KNN rankings (equal weight) rrf = Rrf( [ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True), ] ) result = rrf.to_dict() # RRF formula: -sum(weight_i / (k + rank_i)) # With default k=60 and equal weights (1.0 each) # Expected: -(1.0/(60 + knn1) + 1.0/(60 + knn2)) expected = { "$mul": [ {"$val": -1}, { "$sum": [ { "$div": { "left": {"$val": 1.0}, "right": { "$sum": [ {"$val": 60}, { "$knn": { "query": [0.1, 0.2], "key": "#embedding", "limit": 16, "return_rank": True, } }, ] }, } }, { "$div": { "left": {"$val": 1.0}, "right": { "$sum": [ {"$val": 60}, { "$knn": { "query": [0.3, 0.4], "key": "sparse_embedding", "limit": 16, "return_rank": True, } }, ] }, } }, ] }, ] } assert result == expected # Test 2: RRF with custom weights and k rrf_weighted = Rrf( ranks=[ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True), ], weights=[2.0, 1.0], # Dense is 2x more important k=100, ) result_weighted = rrf_weighted.to_dict() # Expected: -(2.0/(100 + knn1) + 1.0/(100 + knn2)) expected_weighted = { "$mul": [ {"$val": -1}, { "$sum": [ { "$div": { "left": {"$val": 2.0}, "right": { "$sum": [ {"$val": 100}, { "$knn": { "query": [0.1, 0.2], "key": "#embedding", "limit": 16, "return_rank": True, } }, ] }, } }, { "$div": { "left": {"$val": 1.0}, "right": { "$sum": [ {"$val": 100}, { "$knn": { "query": [0.3, 0.4], "key": "sparse_embedding", "limit": 16, "return_rank": True, } }, ] }, } }, ] }, ] } assert result_weighted == expected_weighted # Test 3: RRF with three rankings rrf_three = Rrf( [ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True), Val(5.0), # Can also include constant rank ] ) result_three = rrf_three.to_dict() # Verify it has three terms in the sum assert "$mul" in result_three assert "$sum" in result_three["$mul"][1] terms = result_three["$mul"][1]["$sum"] assert len(terms) == 3 # Three ranking strategies # Test 4: Error case - mismatched weights with pytest.raises( ValueError, match="Number of weights .* must match number of ranks" ): rrf_bad = Rrf( ranks=[ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], return_rank=True), ], weights=[1.0], # Only one weight for two ranks ) rrf_bad.to_dict() # Test 5: Error case - negative weights with pytest.raises(ValueError, match="All weights must be non-negative"): rrf_negative = Rrf( ranks=[ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], return_rank=True), ], weights=[1.0, -1.0], # Negative weight ) rrf_negative.to_dict() # Test 6: Error case - empty ranks list with pytest.raises(ValueError, match="RRF requires at least one rank"): rrf_empty = Rrf([]) rrf_empty.to_dict() # Validation happens in to_dict() # Test 7: Error case - negative k value with pytest.raises(ValueError, match="k must be positive"): rrf_neg_k = Rrf([Val(1.0)], k=-5) rrf_neg_k.to_dict() # Validation happens in to_dict() # Test 8: Error case - zero k value with pytest.raises(ValueError, match="k must be positive"): rrf_zero_k = Rrf([Val(1.0)], k=0) rrf_zero_k.to_dict() # Validation happens in to_dict() # Test 9: Normalize flag with weights rrf_normalized = Rrf( ranks=[ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True), ], weights=[3.0, 1.0], # Will be normalized to [0.75, 0.25] normalize=True, k=100, ) result_normalized = rrf_normalized.to_dict() # Expected: -(0.75/(100 + knn1) + 0.25/(100 + knn2)) expected_normalized = { "$mul": [ {"$val": -1}, { "$sum": [ { "$div": { "left": {"$val": 0.75}, "right": { "$sum": [ {"$val": 100}, { "$knn": { "query": [0.1, 0.2], "key": "#embedding", "limit": 16, "return_rank": True, } }, ] }, } }, { "$div": { "left": {"$val": 0.25}, "right": { "$sum": [ {"$val": 100}, { "$knn": { "query": [0.3, 0.4], "key": "sparse_embedding", "limit": 16, "return_rank": True, } }, ] }, } }, ] }, ] } assert result_normalized == expected_normalized # Test 10: Normalize flag without weights (should work with defaults) rrf_normalize_defaults = Rrf( ranks=[ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], return_rank=True), ], normalize=True, # Will normalize [1.0, 1.0] to [0.5, 0.5] ) result_defaults = rrf_normalize_defaults.to_dict() # Both weights should be 0.5 after normalization expected_defaults = { "$mul": [ {"$val": -1}, { "$sum": [ { "$div": { "left": {"$val": 0.5}, "right": { "$sum": [ {"$val": 60}, # Default k=60 { "$knn": { "query": [0.1, 0.2], "key": "#embedding", "limit": 16, "return_rank": True, } }, ] }, } }, { "$div": { "left": {"$val": 0.5}, "right": { "$sum": [ {"$val": 60}, { "$knn": { "query": [0.3, 0.4], "key": "#embedding", "limit": 16, "return_rank": True, } }, ] }, } }, ] }, ] } assert result_defaults == expected_defaults # Test 11: Error case - normalize with all zero weights with pytest.raises(ValueError, match="Sum of weights must be positive"): rrf_zero_weights = Rrf( ranks=[ Knn(query=[0.1, 0.2], return_rank=True), Knn(query=[0.3, 0.4], return_rank=True), ], weights=[0.0, 0.0], normalize=True, ) rrf_zero_weights.to_dict() print("All RRF tests passed!") def test_group_by_serialization() -> None: """Test GroupBy, MinK, and MaxK serialization and deserialization.""" import pytest from chromadb.execution.expression.operator import ( GroupBy, MinK, MaxK, Key, Aggregate, ) # to_dict with OneOrMany keys group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3)) assert group_by.to_dict() == { "keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}, } # to_dict with multiple keys and MaxK group_by = GroupBy( keys=[Key("year"), Key("category")], aggregate=MaxK(keys=[Key.SCORE, Key("priority")], k=5), ) assert group_by.to_dict() == { "keys": ["year", "category"], "aggregate": {"$max_k": {"keys": ["#score", "priority"], "k": 5}}, } # Round-trip original = GroupBy(keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3)) assert GroupBy.from_dict(original.to_dict()).to_dict() == original.to_dict() # Empty GroupBy serializes to {} and from_dict({}) returns default GroupBy empty_group_by = GroupBy() assert empty_group_by.to_dict() == {} assert GroupBy.from_dict({}).to_dict() == {} # Error cases with pytest.raises(ValueError, match="requires 'keys' field"): GroupBy.from_dict({"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}) with pytest.raises(ValueError, match="requires 'aggregate' field"): GroupBy.from_dict({"keys": ["category"]}) with pytest.raises(ValueError, match="keys cannot be empty"): GroupBy.from_dict( {"keys": [], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}} ) with pytest.raises(ValueError, match="Unknown aggregate operator"): Aggregate.from_dict({"$unknown": {"keys": ["#score"], "k": 3}}) # Expression API Tests - Testing dict support and from_dict methods class TestSearchDictSupport: """Test Search class dict input support.""" def test_search_with_dict_where(self): """Test Search accepts dict for where parameter.""" from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import Where # Simple equality search = Search(where={"status": "active"}) assert search._where is not None assert isinstance(search._where, Where) # Complex where with operators search = Search(where={"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]}) assert search._where is not None def test_search_with_dict_rank(self): """Test Search accepts dict for rank parameter.""" from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import Rank # KNN ranking search = Search(rank={"$knn": {"query": [0.1, 0.2]}}) assert search._rank is not None assert isinstance(search._rank, Rank) # Val ranking search = Search(rank={"$val": 0.5}) assert search._rank is not None def test_search_with_dict_limit(self): """Test Search accepts dict and int for limit parameter.""" from chromadb.execution.expression.plan import Search # Dict limit search = Search(limit={"limit": 10, "offset": 5}) assert search._limit.limit == 10 assert search._limit.offset == 5 # Int limit (creates Limit with offset=0) search = Search(limit=10) assert search._limit.limit == 10 assert search._limit.offset == 0 def test_search_with_dict_select(self): """Test Search accepts dict, list, and set for select parameter.""" from chromadb.execution.expression.plan import Search # Dict select search = Search(select={"keys": ["#document", "#score"]}) assert search._select is not None # List select search = Search(select=["#document", "#metadata"]) assert search._select is not None # Set select search = Search(select={"#document", "#embedding"}) assert search._select is not None def test_search_mixed_inputs(self): """Test Search with mixed expression and dict inputs.""" from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import Key search = Search( where=Key("status") == "active", # Expression rank={"$knn": {"query": [0.1, 0.2]}}, # Dict limit=10, # Int select=["#document"], # List ) assert search._where is not None assert search._rank is not None assert search._limit.limit == 10 assert search._select is not None def test_search_builder_methods_with_dicts(self): """Test Search builder methods accept dicts.""" from chromadb.execution.expression.plan import Search search = Search().where({"status": "active"}).rank({"$val": 0.5}) assert search._where is not None assert search._rank is not None def test_search_invalid_inputs(self): """Test Search rejects invalid input types.""" import pytest from chromadb.execution.expression.plan import Search with pytest.raises(TypeError, match="where must be"): Search(where="invalid") with pytest.raises(TypeError, match="rank must be"): Search(rank=0.5) # Primitive numbers not allowed with pytest.raises(TypeError, match="limit must be"): Search(limit="10") with pytest.raises(TypeError, match="select must be"): Search(select=123) def test_search_with_group_by(self): """Test Search accepts group_by as dict, object, and builder method.""" import pytest from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import GroupBy, MinK, Key # Dict input search = Search( group_by={ "keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}, } ) assert isinstance(search._group_by, GroupBy) # Object input and builder method group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3)) assert Search(group_by=group_by)._group_by is group_by assert Search().group_by(group_by)._group_by.aggregate is not None # Invalid inputs with pytest.raises(TypeError, match="group_by must be"): Search(group_by="invalid") with pytest.raises(ValueError, match="requires 'aggregate' field"): Search(group_by={"keys": ["category"]}) def test_search_group_by_serialization(self): """Test Search serializes group_by correctly.""" from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import GroupBy, MinK, Key, Knn # Without group_by - empty dict search = Search().rank(Knn(query=[0.1, 0.2])).limit(10) assert search.to_dict()["group_by"] == {} # With group_by - has keys and aggregate search = Search().group_by( GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3)) ) result = search.to_dict()["group_by"] assert result["keys"] == ["category"] assert result["aggregate"] == {"$min_k": {"keys": ["#score"], "k": 3}} class TestWhereFromDict: """Test Where.from_dict() conversion.""" def test_simple_equality(self): """Test simple equality conversion.""" from chromadb.execution.expression.operator import Where, Eq # Shorthand for equality where = Where.from_dict({"status": "active"}) assert isinstance(where, Eq) # Explicit $eq where = Where.from_dict({"status": {"$eq": "active"}}) assert isinstance(where, Eq) def test_comparison_operators(self): """Test comparison operator conversions.""" from chromadb.execution.expression.operator import Where, Ne, Gt, Gte, Lt, Lte # $ne where = Where.from_dict({"status": {"$ne": "inactive"}}) assert isinstance(where, Ne) # $gt where = Where.from_dict({"score": {"$gt": 0.5}}) assert isinstance(where, Gt) # $gte where = Where.from_dict({"score": {"$gte": 0.5}}) assert isinstance(where, Gte) # $lt where = Where.from_dict({"score": {"$lt": 1.0}}) assert isinstance(where, Lt) # $lte where = Where.from_dict({"score": {"$lte": 1.0}}) assert isinstance(where, Lte) def test_membership_operators(self): """Test membership operator conversions.""" from chromadb.execution.expression.operator import Where, In, Nin # $in where = Where.from_dict({"status": {"$in": ["active", "pending"]}}) assert isinstance(where, In) # $nin (not in) where = Where.from_dict({"status": {"$nin": ["deleted", "archived"]}}) assert isinstance(where, Nin) def test_string_operators(self): """Test string operator conversions.""" from chromadb.execution.expression.operator import ( Where, Contains, NotContains, Regex, NotRegex, ) # $contains where = Where.from_dict({"text": {"$contains": "hello"}}) assert isinstance(where, Contains) # $not_contains where = Where.from_dict({"text": {"$not_contains": "spam"}}) assert isinstance(where, NotContains) # $regex where = Where.from_dict({"text": {"$regex": "^test.*"}}) assert isinstance(where, Regex) # $not_regex where = Where.from_dict({"text": {"$not_regex": r"\d+"}}) assert isinstance(where, NotRegex) def test_logical_operators(self): """Test logical operator conversions.""" from chromadb.execution.expression.operator import Where, And, Or # $and where = Where.from_dict( {"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]} ) assert isinstance(where, And) # $or where = Where.from_dict({"$or": [{"status": "active"}, {"status": "pending"}]}) assert isinstance(where, Or) def test_nested_logical_operators(self): """Test nested logical operations.""" from chromadb.execution.expression.operator import Where, And where = Where.from_dict( { "$and": [ {"$or": [{"status": "active"}, {"status": "pending"}]}, {"score": {"$gte": 0.5}}, ] } ) assert isinstance(where, And) def test_special_keys(self): """Test special key handling.""" from chromadb.execution.expression.operator import Where, In # ID key where = Where.from_dict({"#id": {"$in": ["id1", "id2"]}}) assert isinstance(where, In) def test_invalid_where_dicts(self): """Test invalid Where dict inputs.""" import pytest from chromadb.execution.expression.operator import Where with pytest.raises(TypeError, match="Expected dict"): Where.from_dict("not a dict") with pytest.raises(ValueError, match="cannot be empty"): Where.from_dict({}) with pytest.raises(ValueError, match="requires at least one condition"): Where.from_dict({"$and": []}) class TestRankFromDict: """Test Rank.from_dict() conversion.""" def test_val_conversion(self): """Test Val conversion.""" from chromadb.execution.expression.operator import Rank, Val rank = Rank.from_dict({"$val": 0.5}) assert isinstance(rank, Val) assert rank.value == 0.5 def test_knn_conversion(self): """Test KNN conversion.""" import numpy as np from chromadb.execution.expression.operator import Rank, Knn # Basic KNN with defaults rank = Rank.from_dict({"$knn": {"query": [0.1, 0.2]}}) assert isinstance(rank, Knn) # Handle both list and numpy array cases if isinstance(rank.query, np.ndarray): # Use allclose for floating point comparison with dtype tolerance assert np.allclose(rank.query, np.array([0.1, 0.2])) else: assert rank.query == [0.1, 0.2] assert rank.key == "#embedding" # default assert rank.limit == 16 # default # KNN with custom parameters rank = Rank.from_dict( { "$knn": { "query": [0.1, 0.2], "key": "sparse_embedding", "limit": 256, "return_rank": True, } } ) assert rank.key == "sparse_embedding" assert rank.limit == 256 assert rank.return_rank def test_arithmetic_operators(self): """Test arithmetic operator conversions.""" from chromadb.execution.expression.operator import Rank, Sum, Sub, Mul, Div # $sum rank = Rank.from_dict({"$sum": [{"$val": 0.5}, {"$val": 0.3}]}) assert isinstance(rank, Sum) # $sub rank = Rank.from_dict({"$sub": {"left": {"$val": 1.0}, "right": {"$val": 0.3}}}) assert isinstance(rank, Sub) # $mul rank = Rank.from_dict({"$mul": [{"$val": 2.0}, {"$val": 0.5}]}) assert isinstance(rank, Mul) # $div rank = Rank.from_dict({"$div": {"left": {"$val": 1.0}, "right": {"$val": 2.0}}}) assert isinstance(rank, Div) def test_math_functions(self): """Test math function conversions.""" from chromadb.execution.expression.operator import Rank, Abs, Exp, Log # $abs rank = Rank.from_dict({"$abs": {"$val": -0.5}}) assert isinstance(rank, Abs) # $exp rank = Rank.from_dict({"$exp": {"$val": 1.0}}) assert isinstance(rank, Exp) # $log rank = Rank.from_dict({"$log": {"$val": 2.0}}) assert isinstance(rank, Log) def test_aggregation_functions(self): """Test min/max conversions.""" from chromadb.execution.expression.operator import Rank, Max, Min # $max rank = Rank.from_dict({"$max": [{"$val": 0.5}, {"$val": 0.8}]}) assert isinstance(rank, Max) # $min rank = Rank.from_dict({"$min": [{"$val": 0.5}, {"$val": 0.8}]}) assert isinstance(rank, Min) def test_complex_rank_expression(self): """Test complex nested rank expressions.""" from chromadb.execution.expression.operator import Rank, Sum rank = Rank.from_dict( { "$sum": [ {"$mul": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.8}]}, {"$mul": [{"$val": 0.5}, {"$val": 0.2}]}, ] } ) assert isinstance(rank, Sum) def test_invalid_rank_dicts(self): """Test invalid Rank dict inputs.""" import pytest from chromadb.execution.expression.operator import Rank with pytest.raises(TypeError, match="Expected dict"): Rank.from_dict("not a dict") with pytest.raises(ValueError, match="cannot be empty"): Rank.from_dict({}) with pytest.raises(ValueError, match="exactly one operator"): Rank.from_dict({"$val": 0.5, "$knn": {"query": [0.1]}}) with pytest.raises(TypeError, match="requires a number"): Rank.from_dict({"$val": "not a number"}) class TestLimitFromDict: """Test Limit.from_dict() conversion.""" def test_limit_only(self): """Test limit without offset.""" from chromadb.execution.expression.operator import Limit limit = Limit.from_dict({"limit": 20}) assert limit.limit == 20 assert limit.offset == 0 # default def test_offset_only(self): """Test offset without limit.""" from chromadb.execution.expression.operator import Limit limit = Limit.from_dict({"offset": 10}) assert limit.offset == 10 assert limit.limit is None def test_limit_and_offset(self): """Test both limit and offset.""" from chromadb.execution.expression.operator import Limit limit = Limit.from_dict({"limit": 20, "offset": 10}) assert limit.limit == 20 assert limit.offset == 10 def test_validation(self): """Test Limit validation.""" import pytest from chromadb.execution.expression.operator import Limit # Negative limit with pytest.raises(ValueError, match="must be positive"): Limit.from_dict({"limit": -1}) # Zero limit with pytest.raises(ValueError, match="must be positive"): Limit.from_dict({"limit": 0}) # Negative offset with pytest.raises(ValueError, match="must be non-negative"): Limit.from_dict({"offset": -1}) def test_invalid_types(self): """Test type validation.""" import pytest from chromadb.execution.expression.operator import Limit with pytest.raises(TypeError, match="Expected dict"): Limit.from_dict("not a dict") with pytest.raises(TypeError, match="must be an integer"): Limit.from_dict({"limit": "20"}) with pytest.raises(TypeError, match="must be an integer"): Limit.from_dict({"offset": 10.5}) def test_unexpected_keys(self): """Test rejection of unexpected keys.""" import pytest from chromadb.execution.expression.operator import Limit with pytest.raises(ValueError, match="Unexpected keys"): Limit.from_dict({"limit": 10, "invalid": "key"}) class TestSelectFromDict: """Test Select.from_dict() conversion.""" def test_special_keys(self): """Test special key conversion.""" from chromadb.execution.expression.operator import Select, Key select = Select.from_dict( {"keys": ["#document", "#embedding", "#metadata", "#score"]} ) assert Key.DOCUMENT in select.keys assert Key.EMBEDDING in select.keys assert Key.METADATA in select.keys assert Key.SCORE in select.keys def test_metadata_keys(self): """Test regular metadata field keys.""" from chromadb.execution.expression.operator import Select, Key select = Select.from_dict({"keys": ["title", "author", "date"]}) assert Key("title") in select.keys assert Key("author") in select.keys assert Key("date") in select.keys def test_mixed_keys(self): """Test mix of special and metadata keys.""" from chromadb.execution.expression.operator import Select, Key select = Select.from_dict({"keys": ["#document", "title", "#score"]}) assert Key.DOCUMENT in select.keys assert Key("title") in select.keys assert Key.SCORE in select.keys def test_empty_keys(self): """Test empty keys list.""" from chromadb.execution.expression.operator import Select select = Select.from_dict({"keys": []}) assert len(select.keys) == 0 def test_validation(self): """Test Select validation.""" import pytest from chromadb.execution.expression.operator import Select with pytest.raises(TypeError, match="Expected dict"): Select.from_dict("not a dict") with pytest.raises(TypeError, match="must be a list/tuple/set"): Select.from_dict({"keys": "not a list"}) with pytest.raises(TypeError, match="must be a string"): Select.from_dict({"keys": [123]}) def test_unexpected_keys(self): """Test rejection of unexpected keys.""" import pytest from chromadb.execution.expression.operator import Select with pytest.raises(ValueError, match="Unexpected keys"): Select.from_dict({"keys": [], "invalid": "key"}) class TestRoundTripConversion: """Test that to_dict() and from_dict() round-trip correctly.""" def test_where_round_trip(self): """Test Where round-trip conversion.""" from chromadb.execution.expression.operator import Where, And, Key original = And([Key("status") == "active", Key("score") > 0.5]) dict_form = original.to_dict() restored = Where.from_dict(dict_form) assert restored.to_dict() == dict_form def test_rank_round_trip(self): """Test Rank round-trip conversion.""" import numpy as np from chromadb.execution.expression.operator import Rank, Knn, Val original = Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2 dict_form = original.to_dict() restored = Rank.from_dict(dict_form) restored_dict = restored.to_dict() # Compare with float32 precision tolerance for KNN queries # The normalize_embeddings function converts to float32, causing precision differences def compare_dicts(d1, d2): if isinstance(d1, dict) and isinstance(d2, dict): if "$knn" in d1 and "$knn" in d2: # Special handling for KNN queries knn1, knn2 = d1["$knn"], d2["$knn"] if "query" in knn1 and "query" in knn2: # Compare queries with float32 precision q1 = np.array(knn1["query"], dtype=np.float32) q2 = np.array(knn2["query"], dtype=np.float32) if not np.allclose(q1, q2): return False # Compare other fields exactly for key in knn1: if key != "query" and knn1[key] != knn2.get(key): return False return True # Recursively compare other dict structures if set(d1.keys()) != set(d2.keys()): return False for key in d1: if not compare_dicts(d1[key], d2[key]): return False return True elif isinstance(d1, list) and isinstance(d2, list): if len(d1) != len(d2): return False return all(compare_dicts(a, b) for a, b in zip(d1, d2)) else: return d1 == d2 assert compare_dicts(restored_dict, dict_form) def test_limit_round_trip(self): """Test Limit round-trip conversion.""" from chromadb.execution.expression.operator import Limit original = Limit(limit=20, offset=10) dict_form = original.to_dict() restored = Limit.from_dict(dict_form) assert restored.to_dict() == dict_form def test_select_round_trip(self): """Test Select round-trip conversion.""" from chromadb.execution.expression.operator import Select, Key original = Select(keys={Key.DOCUMENT, Key("title"), Key.SCORE}) dict_form = original.to_dict() restored = Select.from_dict(dict_form) # Note: Set order might differ, so compare sets assert set(restored.to_dict()["keys"]) == set(dict_form["keys"]) def test_search_round_trip(self): """Test Search round-trip through dict inputs.""" import numpy as np from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import Key, Knn, Limit, Select original_search = Search( where=Key("status") == "active", rank=Knn(query=[0.1, 0.2]), limit=Limit(limit=10), select=Select(keys={Key.DOCUMENT}), ) # Convert to dict search_dict = original_search.to_dict() # Create new Search from dicts new_search = Search( where=search_dict["filter"] if search_dict["filter"] else None, rank=search_dict["rank"] if search_dict["rank"] else None, limit=search_dict["limit"], select=search_dict["select"], ) # Get new dict new_dict = new_search.to_dict() # Compare with float32 tolerance for KNN queries # Use the same comparison function as test_rank_round_trip def compare_search_dicts(d1, d2): if isinstance(d1, dict) and isinstance(d2, dict): # Special handling for rank field with KNN if "rank" in d1 and "rank" in d2: rank1, rank2 = d1["rank"], d2["rank"] if isinstance(rank1, dict) and isinstance(rank2, dict): if "$knn" in rank1 and "$knn" in rank2: knn1, knn2 = rank1["$knn"], rank2["$knn"] if "query" in knn1 and "query" in knn2: q1 = np.array(knn1["query"], dtype=np.float32) q2 = np.array(knn2["query"], dtype=np.float32) if not np.allclose(q1, q2): return False # Compare other KNN fields for key in knn1: if key != "query" and knn1[key] != knn2.get(key): return False # Compare other fields in the dict for key in d1: if key != "rank" and d1[key] != d2.get(key): return False return True # Normal dict comparison if set(d1.keys()) != set(d2.keys()): return False for key in d1: if isinstance(d1[key], dict) and isinstance(d2[key], dict): if not compare_search_dicts(d1[key], d2[key]): return False elif d1[key] != d2[key]: return False return True else: return d1 == d2 assert compare_search_dicts(new_dict, search_dict) def test_search_round_trip_with_group_by(self): """Test Search round-trip with group_by.""" from chromadb.execution.expression.plan import Search from chromadb.execution.expression.operator import Key, GroupBy, MinK original = Search( where=Key("status") == "active", group_by=GroupBy( keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3), ), ) # Verify group_by round-trip search_dict = original.to_dict() assert search_dict["group_by"]["keys"] == ["category"] assert search_dict["group_by"]["aggregate"] == { "$min_k": {"keys": ["#score"], "k": 3} } # Reconstruct and compare group_by restored = Search(group_by=GroupBy.from_dict(search_dict["group_by"])) assert restored.to_dict()["group_by"] == search_dict["group_by"]