diff --git a/internal/simdgen/xed.go b/internal/simdgen/xed.go
index 004a815..4436043 100644
--- a/internal/simdgen/xed.go
+++ b/internal/simdgen/xed.go
@@ -190,6 +190,10 @@
// complicated.
action, ok := actionEncoding[op.Action]
if !ok {
+ if strings.HasPrefix(op.Name, "EMX_BROADCAST") {
+ // BROADCAST looks like to contain an obsolete operand.
+ return nil, nil
+ }
return nil, fmt.Errorf("unknown action %q", op.Action)
}
common := operandCommon{action: action}
@@ -249,7 +253,9 @@
if err != nil {
return unify.Tuple{}, unify.Tuple{}, err
}
- ops = append(ops, op)
+ if op != nil {
+ ops = append(ops, op)
+ }
}
// XED doesn't encode the size of mask operands. If there are mask operands,
@@ -272,6 +278,7 @@
var masks []int
var rSizes, wSizes, sizes []vecShape
allMasks := true
+ hasWMask := false
for i, op := range ops {
action := op.common().action
if _, ok := op.(operandMask); ok {
@@ -281,6 +288,9 @@
if action.r == r || action.w == w {
masks = append(masks, i)
}
+ if action.w {
+ hasWMask = true
+ }
} else {
allMasks = false
if reg, ok := op.(operandVReg); ok {
@@ -320,11 +330,17 @@
}
return nil
}
- return fmt.Errorf("cannot infer mask size: no register operands")
+ return fmt.Errorf("cannot infer mask size: no register operands: %+v", operands)
}
shape, ok := singular(sizes)
if !ok {
- return fmt.Errorf("cannot infer mask size: multiple register sizes %v", sizes)
+ if !hasWMask && len(wSizes) == 1 && len(masks) == 1 {
+ // This pattern looks like predicate mask, so its shape should align with the
+ // output. TODO: verify this is a safe assumption.
+ shape = wSizes[0]
+ } else {
+ return fmt.Errorf("cannot infer mask size: multiple register sizes %v", sizes)
+ }
}
for _, i := range masks {
m := ops[i].(operandMask)
@@ -407,6 +423,10 @@
return 256, true
case strings.HasPrefix(rhs, "ZMM_"):
return 512, true
+ case strings.HasPrefix(rhs, "GPR64_"), strings.HasPrefix(rhs, "VGPR64_"):
+ return 64, true
+ case strings.HasPrefix(rhs, "GPR32_"), strings.HasPrefix(rhs, "VGPR32_"):
+ return 32, true
}
return 0, false
}
@@ -475,6 +495,19 @@
// These just use the lower INT8 in each 16 bit field.
// As far as I can tell, "2I8" is a typo.
return scalarBaseInt, 8, true
+ case "2u16", "2U16":
+ // some VPDP* has it
+ // TODO: does "z" means it has zeroing?
+ return scalarBaseUint, 16, true
+ case "2i16", "2I16":
+ // some VPDP* has it
+ return scalarBaseInt, 16, true
+ case "4u8", "4U8":
+ // some VPDP* has it
+ return scalarBaseUint, 8, true
+ case "4i8", "4I8":
+ // some VPDP* has it
+ return scalarBaseInt, 8, true
}
// The rest follow a simple pattern.