This commit is contained in:
Simon Alibert
2024-06-06 15:23:49 +00:00
parent bdc0ebd36a
commit a86f387554
3 changed files with 102 additions and 8 deletions

View File

@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg):
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
@@ -118,4 +118,6 @@ def make_transforms(cfg):
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
)
return v2.Compose([transforms, v2.ToDtype(torch.float32, scale=True)])
# return transforms
# return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)])
return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)])