Skip to content

Commit

Permalink
Add support for saturation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
mfasi committed May 30, 2024
1 parent f38d309 commit 3edb25d
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 107 deletions.
68 changes: 52 additions & 16 deletions mex/cpfloat.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ void mexFunction(int nlhs,
fpopts->precision = 11;
fpopts->emin = -14;
fpopts->emax = 15;
fpopts->subnormal = CPFLOAT_SUBN_USE;
fpopts->explim = CPFLOAT_EXPRANGE_TARG;
fpopts->round = CPFLOAT_RND_NE;
fpopts->subnormal = CPFLOAT_SAT_NO;
fpopts->subnormal = CPFLOAT_SUBN_USE;

fpopts->flip = CPFLOAT_SOFTERR_NO;
fpopts->p = 0.5;

fpopts->bitseed = NULL;
fpopts->randseedf = NULL;
fpopts->randseed = NULL;
Expand All @@ -54,6 +57,7 @@ void mexFunction(int nlhs,
/* Parse second argument and populate fpopts structure. */
if (nrhs > 1) {
bool is_subn_rnd_default = false;
bool is_saturation_default = false;
if(!mxIsEmpty(prhs[1]) && !mxIsStruct(prhs[1])) {
mexErrMsgIdAndTxt("cpfloat:invalidstruct",
"Second argument must be a struct.");
Expand All @@ -62,7 +66,7 @@ void mexFunction(int nlhs,

if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
/* Use default format, for compatibility with chop. */
/* Set default format, for compatibility with chop. */
strcpy(fpopts->format, "h");
else if (mxGetClassID(tmp) == mxCHAR_CLASS)
strcpy(fpopts->format, mxArrayToString(tmp));
Expand All @@ -80,6 +84,7 @@ void mexFunction(int nlhs,
fpopts->precision = 4;
fpopts->emin = -6;
fpopts->emax = 8;
is_saturation_default = true;
} else if (!strcmp(fpopts->format, "q52") ||
!strcmp(fpopts->format, "fp8-e5m2") ||
!strcmp(fpopts->format, "E5M2")) {
Expand Down Expand Up @@ -161,6 +166,31 @@ void mexFunction(int nlhs,
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->round = *((double *)mxGetData(tmp));
}
tmp = mxGetField(prhs[1], 0, "saturation");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->saturation = CPFLOAT_SAT_NO;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->saturation = *((double *)mxGetData(tmp));
} else {
if (is_saturation_default)
fpopts->saturation = CPFLOAT_SAT_USE; /* Default for E4M3. */
else
fpopts->saturation = CPFLOAT_SAT_NO;
}
tmp = mxGetField(prhs[1], 0, "subnormal");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->subnormal = CPFLOAT_SUBN_USE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->subnormal = *((double *)mxGetData(tmp));
} else {
if (is_subn_rnd_default)
fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */
else
fpopts->subnormal = CPFLOAT_SUBN_USE;
}

tmp = mxGetField(prhs[1], 0, "flip");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
Expand Down Expand Up @@ -288,10 +318,11 @@ void mexFunction(int nlhs,

/* Allocate and return second output. */
if (nlhs > 1) {
const char* field_names[] = {"format", "params", "subnormal", "round",
"flip", "p", "explim"};
const char* field_names[] = {"format", "params", "explim",
"round", "saturation", "subnormal",
"flip", "p"};
mwSize dims[2] = {1, 1};
plhs[1] = mxCreateStructArray(2, dims, 7, field_names);
plhs[1] = mxCreateStructArray(2, dims, 8, field_names);
mxSetFieldByNumber(plhs[1], 0, 0, mxCreateString(fpopts->format));

mxArray *outparams = mxCreateDoubleMatrix(1,3,mxREAL);
Expand All @@ -301,30 +332,35 @@ void mexFunction(int nlhs,
outparamsptr[2] = fpopts->emax;
mxSetFieldByNumber(plhs[1], 0, 1, outparams);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 2, outsubnormal);
mxArray *outexplim = mxCreateDoubleMatrix(1, 1, mxREAL);
double *outexplimptr = mxGetData(outexplim);
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 2, outexplim);

mxArray *outround = mxCreateDoubleMatrix(1,1,mxREAL);
double *outroundptr = mxGetData(outround);
outroundptr[0] = fpopts->round;
mxSetFieldByNumber(plhs[1], 0, 3, outround);

mxArray *outsaturation = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsaturationptr = mxGetData(outsaturation);
outsaturationptr[0] = fpopts->saturation;
mxSetFieldByNumber(plhs[1], 0, 4, outsaturation);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 5, outsubnormal);

mxArray *outflip = mxCreateDoubleMatrix(1,1,mxREAL);
double *outflipptr = mxGetData(outflip);
outflipptr[0] = fpopts->flip;
mxSetFieldByNumber(plhs[1], 0, 4, outflip);
mxSetFieldByNumber(plhs[1], 0, 6, outflip);

mxArray *outp = mxCreateDoubleMatrix(1,1,mxREAL);
double *outpptr = mxGetData(outp);
outpptr[0] = fpopts->p;
mxSetFieldByNumber(plhs[1], 0, 5, outp);

mxArray *outexplim = mxCreateDoubleMatrix(1,1,mxREAL);
double *outexplimptr = mxGetData(outexplim);
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 6, outexplim);
mxSetFieldByNumber(plhs[1], 0, 7, outp);

}
if (nlhs > 2)
Expand Down
16 changes: 11 additions & 5 deletions mex/cpfloat.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
% the target format, respectively. The default value of this field is
% the vector [11,-14,15].
%
% * The scalar FPOPTS.subnormal specifies the support for subnormal numbers.
% The target floating-point format will not support subnormal numbers if
% this field is set to 0, and will support them otherwise. The default value
% for this field is 0 if the target format is 'bfloat16' and 1 otherwise.
%
% * The scalar FPOPTS.explim specifies the support for an extended exponent
% range. The target floating-point format will have the exponent range of
% the storage format ('single' or 'double', depending on the class of X) if
Expand All @@ -63,6 +58,17 @@
% Any other value results in no rounding. The default value for this field
% is 1.
%
% * The scalar FPOPTS.saturation specifies whether saturation arithmetic is in
% use. On overflow, the target floating-point format will use the largest
% representable floating-point if this field is set to 0, and infinity
% otherwise. The default value for this field is 1 if the target format is
% 'E4M3' and 1 otherwise.

% * The scalar FPOPTS.subnormal specifies the support for subnormal numbers.
% The target floating-point format will not support subnormal numbers if
% this field is set to 0, and will support them otherwise. The default value
% for this field is 0 if the target format is 'bfloat16' and 1 otherwise.
%
% * The scalar FPOPTS.flip specifies whether the function should simulate the
% occurrence of a single bit flip striking the floating-point representation
% of elements of Y. Possible values are:
Expand Down
39 changes: 31 additions & 8 deletions src/cpfloat_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*
* + @ref cpfloat_explim_t,
* + @ref cpfloat_rounding_t,
* + @ref cpfloat_saturation_t,
* + @ref cpfloat_softerr_t,
* + @ref cpfloat_subnormal_t,
*
Expand Down Expand Up @@ -88,6 +89,16 @@ typedef enum {
CPFLOAT_NO_RND = 8,
} cpfloat_rounding_t;

/**
* @brief Saturation modes available in CPFloat.
*/
typedef enum {
/** Use standard arithmetic. */
CPFLOAT_SAT_NO = 0,
/** Use saturation arithmetic. */
CPFLOAT_SAT_USE = 1,
} cpfloat_saturation_t;

/**
* @brief Soft fault simulation modes available in CPFloat.
*/
Expand Down Expand Up @@ -214,14 +225,6 @@ typedef struct {
* exponent is larger than the maximum allowed by the storage format.
*/
cpfloat_exponent_t emax;
/**
* @brief Support for subnormal numbers in target format.
*
* @details Subnormal numbers are supported if this field is set to
* `CPFLOAT_SUBN_USE` and rounded to a normal number using the current
* rounding mode if it is set to `CPFLOAT_SUBN_RND`.
*/
cpfloat_subnormal_t subnormal;
/**
* @brief Support for extended exponents in target format.
*
Expand Down Expand Up @@ -256,6 +259,24 @@ typedef struct {
* those in the list above is specified.
*/
cpfloat_rounding_t round;
/**
* @brief Support for subnormal numbers in target format.
*
* @details Subnormal numbers are supported if this field is set to
* `CPFLOAT_SUBN_USE` and rounded to a normal number using the current
* rounding mode if it is set to `CPFLOAT_SUBN_RND`.
*/
cpfloat_saturation_t saturation;
/**
* @brief Support for subnormal numbers in target format.
*
* @details Subnormal numbers are supported if this field is set to
* `CPFLOAT_SUBN_USE` and rounded to a normal number using the current
* rounding mode if it is set to `CPFLOAT_SUBN_RND`.
*/
cpfloat_subnormal_t subnormal;

/* Bit flips. */
/**
* @brief Support for soft errors.
*
Expand All @@ -281,6 +302,8 @@ typedef struct {
* contain a number in the interval [0,1].
*/
double p;

/* Internal: state of pseudo-random number generator. */
/**
* @brief Internal state of pseudo-random number generator for single bits.
*
Expand Down
33 changes: 18 additions & 15 deletions src/cpfloat_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ typedef struct {
cpfloat_subnormal_t subnormal;
cpfloat_rounding_t round;
FPTYPE ftzthreshold;
FPTYPE ofvalue;
FPTYPE xmin;
FPTYPE xmax;
FPTYPE xbnd;
Expand Down Expand Up @@ -317,15 +318,17 @@ static inline FPPARAMS COMPUTE_GLOBAL_PARAMS(const optstruct *fpopts,
FPTYPE xmin = ldexp(1., emin); /* Smallest pos. normal. */
FPTYPE xmins = ldexp(1., emin-precision+1); /* Smallest pos. subnormal. */
FPTYPE ftzthreshold = (fpopts->subnormal == CPFLOAT_SUBN_USE) ? xmins : xmin;

FPTYPE xmax = ldexp(1., emax) * (2-ldexp(1., 1-precision));
FPTYPE xbnd = ldexp(1., emax) * (2-ldexp(1., -precision));
FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? xmax : INFINITY;

/* Bitmasks. */
INTTYPE leadmask = FULLMASK << (DEFPREC-precision); /* To keep. */
INTTYPE trailmask = leadmask ^ FULLMASK; /* To discard. */

FPPARAMS params = {precision, emax, emin, fpopts->subnormal, fpopts->round,
ftzthreshold, xmin, xmax, xbnd,
ftzthreshold, ofvalue, xmin, xmax, xbnd,
leadmask, trailmask, NULL, NULL};

return params;
Expand Down Expand Up @@ -401,7 +404,7 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
else \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ftzthreshold)); \
} else if (ABS(y) >= p->xbnd) { /* Overflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(INFINITY)); \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ofvalue)); \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
*(x) = FPOF((INTOF(y) + \
Expand All @@ -425,7 +428,7 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
else \
*(x) = FPOF(SIGN(y)); \
} else if (ABS(y) > p->xbnd) { /* Overflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(INFINITY)); \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ofvalue)); \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
*(x) = FPOF((INTOF(y) + (lp->trailmask>>1)) & lp->leadmask); \
Expand All @@ -450,7 +453,7 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
else \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ftzthreshold)); \
} else if (ABS(y) >= p->xbnd) { /* Overflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(INFINITY)); \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ofvalue)); \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
INTTYPE LSB = ((INTOF(y) >> (DEFPREC-lp->precision)) & INTCONST(1)) \
Expand All @@ -477,11 +480,11 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
*(x) = *(y) > 0 ? p->ftzthreshold : 0; \
} else if (ABS(y) > p->xmax) { /* Overflow */ \
if (*(y) > p->xmax) \
*(x) = INFINITY; \
else if (*(y) < -p->xmax && *(y) != -INFINITY) \
*(x) = p->ofvalue; \
else if (*(y) < -p->xmax && *(y) != -p->ofvalue) \
*(x) = -p->xmax; \
else /* *(y) == -INFINITY */ \
*(x) = -INFINITY; \
*(x) = -p->ofvalue; \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
if (SIGN(y) == 0) /* Add ulp if x is positive. */ \
Expand Down Expand Up @@ -509,11 +512,11 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
*(x) = *(y) >= 0 ? 0 : -p->ftzthreshold; \
} else if (ABS(y) > p->xmax) { /* Overflow */ \
if (*(y) < -p->xmax) \
*(x) = -INFINITY; \
else if (*(y) > p->xmax && *(y) != INFINITY) \
*(x) = -p->ofvalue; \
else if (*(y) > p->xmax && *(y) != p->ofvalue) \
*(x) = p->xmax; \
else /* *(y) == INFINITY */ \
*(x) = INFINITY; \
*(x) = p->ofvalue; \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
if (SIGN(y)) /* Subtract ulp if x is positive. */ \
Expand All @@ -532,7 +535,7 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
#define RD_TWD_ZERO_SCALAR_OTHER_EXP(x, y, p, lp) \
if (ABS(y) < p->ftzthreshold) { /* Underflow */ \
*(x) = FPOF(SIGN(y)); \
} else if (ABS(y) > p->xmax && ABS(y) != INFINITY) { /* Overflow */ \
} else if (ABS(y) > p->xmax && ABS(y) != p->ofvalue) { /* Overflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->xmax)); \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
Expand Down Expand Up @@ -581,17 +584,17 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
*(x) = *(y); \
} \
if (ABS(x) >= p->xbnd) /* Overflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(INFINITY));
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ofvalue));

/* Stochastic rounding with equal probabilities. */
#define RS_EQUI_SCALAR(x, y, p, lp) \
UPDATE_LOCAL_PARAMS(y, p, lp); \
if (ABS(y) < p->ftzthreshold && *(y) != 0) { /* Underflow */ \
randombit = GENBIT(p->BITSEED); \
*(x) = FPOF(SIGN(y) | INTOFCONST(randombit ? p->ftzthreshold : 0)); \
} else if (ABS(y) > p->xmax && ABS(y) != INFINITY) { /* Overflow */ \
} else if (ABS(y) > p->xmax && ABS(y) != p->ofvalue) { /* Overflow */ \
randombit = GENBIT(p->BITSEED); \
*(x) = FPOF(SIGN(y) | INTOFCONST(randombit ? INFINITY : p->xmax)); \
*(x) = FPOF(SIGN(y) | INTOFCONST(randombit ? p->ofvalue : p->xmax)); \
} else if ((INTOF(y) & lp->trailmask)) { /* Not exactly representable. */ \
randombit = GENBIT(p->BITSEED); \
*(x) = FPOF(INTOF(y) & lp->leadmask); \
Expand All @@ -604,7 +607,7 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
#define RO_SCALAR(x, y, p, lp) \
if (ABS(y) < p->ftzthreshold && *(y) != 0) { /* Underflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->ftzthreshold)); \
} else if (ABS(y) > p->xmax && ABS(y) != INFINITY) { /* Overflow */ \
} else if (ABS(y) > p->xmax && ABS(y) != p->ofvalue) { /* Overflow */ \
*(x) = FPOF(SIGN(y) | INTOFCONST(p->xmax)); \
} else { \
UPDATE_LOCAL_PARAMS(y, p, lp); \
Expand Down
Loading

0 comments on commit 3edb25d

Please sign in to comment.