import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FutureTimeoutError
from typing import List, Optional, Collection, Any, Dict
import pinecone
from sentry_sdk import start_span
class PineconeProvider(SearchProvider):
QUERY_BATCH_SIZE = 20
MAX_CONCURRENT_QUERIES = 4
def __init__(self, api_key: str, index_name: str):
self._api_key: Optional[str] = None
self._client: Optional[pinecone.GRPCIndex] = None
self._executor = ThreadPoolExecutor(self.QUERY_BATCH_SIZE * self.MAX_CONCURRENT_QUERIES)
self.index_name = index_name
if api_key is not None:
self.set_api_key(api_key)
def set_api_key(self, api_key: str):
self._api_key = api_key
if self._api_key:
pinecone.init(api_key=self._api_key)
self._client = pinecone.GRPCIndex(self.index_name)
else:
self._client = None
@instrument
def _query(
self,
queries: List[List[float]],
product_sets: Collection[str],
num_neighbors: int,
return_metadata: bool,
return_features: bool
) -> List[List[SearchMatch]]:
if self._client is None:
raise RuntimeError('API key is not set')
with start_span(op='pinecone.query') as span:
span.set_tag('pinecone.queries', len(queries))
span.set_tag('pinecone.top_k', num_neighbors)
span.set_tag('pinecone.include_metadata', return_metadata)
span.set_tag('pinecone.include_values', return_features)
response = self._client.query(
queries=queries,
top_k=num_neighbors,
include_metadata=return_metadata,
include_values=return_features
)
transformed_response = []
for result in response.results:
matches = []
for match in result.matches:
result = SearchMatch(
crop_id=match.id,
score=match.score
)
if return_metadata:
result.metadata = match.metadata
if return_features and match.values:
result.features = match.values
matches.append(result)
transformed_response.append(matches)
return transformed_response
@instrument
def query(
self,
queries: List[List[float]],
product_sets: Collection[str],
num_neighbors: int,
return_metadata: bool,
return_features: bool
):
response = []
futures = {
i: self._executor.submit(
self._query,
[query_vector],
num_neighbors=num_neighbors,
return_metadata=return_metadata,
return_features=return_features
)
for i, query_vector in enumerate(queries)
}
chunk_responses = {}
while len(chunk_responses) != len(queries):
for i in range(len(queries)):
if i in chunk_responses:
continue
future = futures[i]
try:
chunk_responses[i] = future.result(timeout=1)[0]
except FutureTimeoutError:
continue
time.sleep(1)
for i in range(len(queries)):
response.append(chunk_responses[i])
return response
Was this page helpful?