diff --git a/search_r1/search/retrieval_server.py b/search_r1/search/retrieval_server.py index b30b6ce..9839ebd 100644 --- a/search_r1/search/retrieval_server.py +++ b/search_r1/search/retrieval_server.py @@ -125,6 +125,10 @@ class Encoder: query_emb = query_emb.detach().cpu().numpy() query_emb = query_emb.astype(np.float32, order="C") + + del inputs, output + torch.cuda.empty_cache() + return query_emb class BaseRetriever: @@ -266,6 +270,10 @@ class DenseRetriever(BaseRetriever): results.extend(batch_results) scores.extend(batch_scores) + + del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results + torch.cuda.empty_cache() + if return_score: return results, scores else: