package main import ( "encoding/binary" "fmt" "net" "strings" ) func netIPToUint32(ip []byte) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) } return binary.BigEndian.Uint32(ip) } type IPRange struct { Start uint32 End uint32 } type IPMatcher struct { ranges []IPRange } func int2ip(nn uint32) net.IP { ip := make(net.IP, 4) binary.BigEndian.PutUint32(ip, nn) return ip } func (ipm *IPMatcher) Set(rules []string) error { var ranges []IPRange for _, rule := range rules { if strings.Contains(rule, "/") { // CIDR _, network, err := net.ParseCIDR(rule) if err != nil { return fmt.Errorf("error parsing rules for %s: %s", rule, err) } ranges = append(ranges, IPRange{ Start: netIPToUint32(network.IP), End: (netIPToUint32(network.IP) & netIPToUint32(network.Mask)) | (netIPToUint32(network.Mask) ^ 0xffffffff), }) } else if strings.Contains(rule, "-") { // Range parts := strings.Split(rule, "-") ranges = append(ranges, IPRange{ Start: netIPToUint32(net.ParseIP(parts[0])), End: netIPToUint32(net.ParseIP(parts[1])), }) } else { // Single IP ranges = append(ranges, IPRange{ Start: netIPToUint32(net.ParseIP(rule)), End: netIPToUint32(net.ParseIP(rule)), }) } } // TODO: Sort ranges // TODO: Find overlapping ranges ipm.ranges = ranges return nil } func (ipm *IPMatcher) Matches(ip net.IP) bool { incomingIp := netIPToUint32(ip) for _, ipRange := range ipm.ranges { if (ipRange.Start <= incomingIp) && (incomingIp <= ipRange.End) { return true } } return false } func main() { matcher := IPMatcher{} matcher.Set([]string{ "11.0.0.0/24", "10.0.0.1", "12.0.0.1-12.0.255.255", }) assert(matcher.Matches(net.ParseIP("10.0.0.2")), false, "10.0.0.2 should not match") assert(matcher.Matches(net.ParseIP("10.0.0.1")), true, "10.0.0.1 should match") assert(matcher.Matches(net.ParseIP("11.0.0.50")), true, "11.0.0.50 should match") assert(matcher.Matches(net.ParseIP("11.0.0.255")), true, "11.0.0.255 should match") assert(matcher.Matches(net.ParseIP("12.0.0.0")), false, "12.0.0.0 should not match") assert(matcher.Matches(net.ParseIP("12.0.50.1")), true, "12.0.50.1 should match") fmt.Println("done!") } func assert(result, expected bool, description string) { if result != expected { fmt.Println(description) } }