diff --git a/src/accelerate/commands/estimate.py b/src/accelerate/commands/estimate.py index 2cd731b2221..52e18b55387 100644 --- a/src/accelerate/commands/estimate.py +++ b/src/accelerate/commands/estimate.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch from huggingface_hub import model_info from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError @@ -62,7 +63,8 @@ def check_has_model(error): def create_empty_model(model_name: str, library_name: str, trust_remote_code: bool = False, access_token: str = None): """ - Creates an empty model from its parent library on the `Hub` to calculate the overall memory consumption. + Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory + consumption. Args: model_name (`str`): @@ -120,7 +122,8 @@ def create_empty_model(model_name: str, library_name: str, trust_remote_code: bo break if value is not None: constructor = getattr(transformers, value) - model = constructor.from_config(config, trust_remote_code=trust_remote_code) + # we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config + model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code) elif library_name == "timm": if not is_timm_available(): raise ImportError(