【PyTorch】教程:torch.nn.Hardtanh
迪丽瓦拉
2024-05-29 15:11:38
0

torch.nn.Hardtanh

原型

CLASS torch.nn.Hardtanh(min_val=- 1.0, max_val=1.0, inplace=False, min_value=None, max_value=None)

参数

  • min_val ([float]) – 线性区域的最小值,默认为 -1
  • max_val ([float]) – 线性区域的最大值,默认为 1
  • inplace ([bool]) – 默认为 False

定义

HardTanh(x)={max_valif x>max_val min_valif x \text{ max\_val } \\ \text{min\_val} & \text{ if } x < \text{ min\_val } \\ x & \text{ otherwise } \\ \end{cases} HardTanh(x)=⎩⎧​max_valmin_valx​ if x> max_val  if x< min_val  otherwise ​

在这里插入图片描述

代码

import torch
import torch.nn as nnm = nn.Hardtanh(-2, 2)
input = torch.randn(2)
output = m(input)
print("input: ", input)    # input:  tensor([2.1926, 0.2211])
print("output: ", output)  # output:  tensor([2.0000, 0.2211])

【参考】

Hardtanh — PyTorch 1.13 documentation

相关内容