diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 6574d49fb132..dbb51475ec21 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -80,7 +80,6 @@ def __init__(self, model, config): self.mp_group = config.tensor_parallel.tp_group self.mpu = config.tensor_parallel.mpu - #self._validate_args(self.mpu, config.replace_with_kernel_inject) self.quantize_merge_count = 1 self.quantization_scales = None @@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting): f"mlp_extra_grouping = {self.mlp_extra_grouping}, " f"quantize_groups = {self.quantize_groups}", [0]) - # TODO: remove this function and add this functionality to pydantic config checking - def _validate_args(self, mpu, replace_with_kernel_inject): - # TODO: to support SD pipeline we need to avoid this check for now - if replace_with_kernel_inject and not isinstance(self.module, Module): - raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") - if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1: - raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}") - - if mpu: - methods = ["get_model_parallel_group", "get_data_parallel_group"] - for method in methods: - if not hasattr(mpu, method): - raise ValueError(f"mpu is missing {method}") - if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)): - raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}") - - supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16] - if self._config.dtype not in supported_dtypes: - raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}") - - if self.injection_dict is not None and not isinstance(self.injection_dict, dict): - raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}") - def load_model_with_checkpoint(self, r_module): self.mp_replace = ReplaceWithTensorSlicing( mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)