clean up retrieval cache

This commit is contained in:
PeterGriffinJin
2025-03-23 14:33:14 +00:00
parent 6272082a64
commit f5204213d3

View File

@@ -125,6 +125,10 @@ class Encoder:
query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
del inputs, output
torch.cuda.empty_cache()
return query_emb
class BaseRetriever:
@@ -266,6 +270,10 @@ class DenseRetriever(BaseRetriever):
results.extend(batch_results)
scores.extend(batch_scores)
del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results
torch.cuda.empty_cache()
if return_score:
return results, scores
else: