DopeorNope 개발일지

RuntimeError: expected scalar type Half but found Float 에러 해결방법 본문

에러 노트

RuntimeError: expected scalar type Half but found Float 에러 해결방법

DopeorNope 2024. 1. 2. 03:11

Axolotl에서 Pre-train 과정에서 다음과 같은 에러가 발생함

 

RuntimeError: expected scalar type Half but found Float

 

FP16으로 내가 불러와서 지금 Half(원래는 32비트 이기때문에, FP16은 Half임)로  불러왔지만,

 

데이터가 지금 float이기 때문에 문제가 발생함

 

 

이럴경우 이와 같이 문제를 해결하면 됨.

 

 

train.py에서 아래와 같이 trainer.train 하는 곳에서 'cuda'알아서 오토캐스트 해주면 데이터에 맞게 알아서 해결됨.

 

  if cfg.flash_optimum:
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_math=True, enable_mem_efficient=True
        ):
            trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    else:
        with torch.autocast("cuda"): # 추가함
            trainer.train(resume_from_checkpoint=resume_from_checkpoint)