adjust unit tests for test_save_load_float16#12500
Conversation
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
|
@a-r-r-o-w @DN6 pls help review, thx! |
|
Not sure I understand the issue here. This specific T5 module is kept in fp32 on purpose, why forcing a fp16 cast in the test? |
|
@regisss Hi, the purpose of this test case is to compare the output of pipelines using fp16 dtype( |
| if hasattr(component, "half"): | ||
| # Although all components for pipe_loaded should be float16 now, some submodules still use fp32, like in https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/t5/modeling_t5.py#L783, so we need to do the conversion again manally to align with the datatype we use in pipe exactly | ||
| component = component.to(torch_device).half() |
There was a problem hiding this comment.
This doesn't seem right at all. torch_dtype should be able to take care of it. I just ran it on my GPU for SD and it worked fine.
There was a problem hiding this comment.
Hi @sayakpaul , I tested on A100, and when I print pipe_loaded.text_encoder.encoder.block[0].layer[1].DenseReluDense.wo.weight.dtype in L1455 , it returns torch.float32, not torch.float16, and the max_diff in L1456 is np.float16(0.0004883). When we apply this PR to align excatly with the behavior in pipe, the max_diff is 0. I think it's better to adjust the test case to make the output comparison of pipe and pipe_loaded apple to apple. WDYT?
There was a problem hiding this comment.
My point is torch_dtype in from_pretrained() should be enough for the model to be in fp16. Setting it with half() after loading the model in the FP16 torch_dtype seems erroneous to me.
I also ran the test on an A100, and it wasn't a problem. So, I am not sure if this test fix is correct at all.
There was a problem hiding this comment.
I printed pipe_loaded.text_encoder.encoder.block[0].layer[1].DenseReluDense.wo.weight.dtype after pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16), and it returns torch.float32, it is root caused in L783, so I manualy add .half() to pipe_loaded, although it looks a bit wierd... On A100, the tolerance value is OK, but I think from the fundamentals perspective, the output from pipelines loaded from former saved should be exactly the same, that is the max_diff should be 0, right?
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
|
@sayakpaul Hi, I adjusted the test code to pass |
| supports_dduf = False | ||
|
|
||
| def get_dummy_components(self): | ||
| def get_dummy_components(self, dtype=torch.float32): |
There was a problem hiding this comment.
pls refer to L246-L256 (Sorry I only found Chinese version for this explanation). Using torch.Tensor.to method will convert all weights, while using torch_dtype parameter with from_pretrained will preserve layers in _keep_in_fp32_modules. For wan models, all components of pipe will be fp16 dtype while it is not the case for pipe_loaded. Here I override test_save_load_float16 function seperately for wan models.
sayakpaul
left a comment
There was a problem hiding this comment.
I am honestly not sure about the changes introduced in this PR. We have gone over multiple comments and so far, I haven't been able to manually verify myself the failures this PR tries to solve.
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
…into wan-pipeline
| pass | ||
|
|
||
| @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") | ||
| def test_save_load_float16(self, expected_max_diff=1e-2): |
There was a problem hiding this comment.
I still don't know then how on my end the tests are passing.
There was a problem hiding this comment.
I think it should be related with the input. When I set all the seed in get_dummy_components to 1, the max_diff on A100 is np.float16(0.2366), and when set seed to 42, the output will be all nan value. After this PR, the max_diff will all be 0 for all the seed
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
| # Use from_pretrained with a tiny model to ensure proper dtype handling | ||
| # This ensures _keep_in_fp32_modules and _skip_layerwise_casting_patterns are respected | ||
| transformer = WanTransformer3DModel.from_pretrained( | ||
| "Kaixuanliu/tiny-random-wan-transformer", |
There was a problem hiding this comment.
pls replace my model space. We have to use from_pretrained here to make all the submodules' dtype correctly loaded.
| qk_norm="rms_norm_across_heads", | ||
| rope_max_seq_len=32, | ||
| transformer_2 = WanTransformer3DModel.from_pretrained( | ||
| "Kaixuanliu/tiny-random-wan-transformer", |
|
CC @yao-matrix |
|
I think it would be best to modify the def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
module = module.to(torch_device)
# Account for components with _keep_in_fp32_modules
if hasattr(module, "_keep_in_fp32_modules"):
for name, param in module.named_parameters():
if any(
module_to_keep_in_fp32 in name.split(".")
for module_to_keep_in_fp32 in module._keep_in_fp32_modules
):
param.data = param.data.to(torch.float32)
else:
param.data = param.data.to(torch.float16)
elif hasattr(module, "half"):
components[name] = module.to(torch_device).half() |
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
|
@DN6 Hi, I think this is a good advice, looks much better now. Have updated the code following your advice, thx! |
|
@sayakpaul @DN6 Hi, can this PR be merged now? |
|
@sayakpaul Hi, the failed CI cases should have nothing to do with this PR, can you help merge? |
|
Yeah will merge shortly. Thanks for your contributions! |

When we run unit test like
pytest -rA tests/pipelines/wan/test_wan_22.py::Wan22PipelineFastTests::test_save_load_float16, we found that the pipeline runs w/ all fp16 datatype, but after save and reload, some parts of text-encoder inpipe_loadeduses fp32, although we set torch_dtype to fp16 explicitly. Deep investigation found that the root cause is here: L783. Here we made an adjustment to the test case to manually add thecomponent = component.to(torch_device).half()operation to align excatly with the behavior inpipe