102 lines
2.3 KiB
Go
102 lines
2.3 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
func netIPToUint32(ip []byte) uint32 {
|
||
|
if len(ip) == 16 {
|
||
|
return binary.BigEndian.Uint32(ip[12:16])
|
||
|
}
|
||
|
return binary.BigEndian.Uint32(ip)
|
||
|
}
|
||
|
|
||
|
type IPRange struct {
|
||
|
Start uint32
|
||
|
End uint32
|
||
|
}
|
||
|
|
||
|
type IPMatcher struct {
|
||
|
ranges []IPRange
|
||
|
}
|
||
|
|
||
|
func int2ip(nn uint32) net.IP {
|
||
|
ip := make(net.IP, 4)
|
||
|
binary.BigEndian.PutUint32(ip, nn)
|
||
|
return ip
|
||
|
}
|
||
|
|
||
|
func (ipm *IPMatcher) Set(rules []string) error {
|
||
|
var ranges []IPRange
|
||
|
|
||
|
for _, rule := range rules {
|
||
|
if strings.Contains(rule, "/") {
|
||
|
// CIDR
|
||
|
_, network, err := net.ParseCIDR(rule)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("error parsing rules for %s: %s", rule, err)
|
||
|
}
|
||
|
ranges = append(ranges, IPRange{
|
||
|
Start: netIPToUint32(network.IP),
|
||
|
End: (netIPToUint32(network.IP) & netIPToUint32(network.Mask)) | (netIPToUint32(network.Mask) ^ 0xffffffff),
|
||
|
})
|
||
|
} else if strings.Contains(rule, "-") {
|
||
|
// Range
|
||
|
parts := strings.Split(rule, "-")
|
||
|
ranges = append(ranges, IPRange{
|
||
|
Start: netIPToUint32(net.ParseIP(parts[0])),
|
||
|
End: netIPToUint32(net.ParseIP(parts[1])),
|
||
|
})
|
||
|
} else {
|
||
|
// Single IP
|
||
|
ranges = append(ranges,
|
||
|
IPRange{
|
||
|
Start: netIPToUint32(net.ParseIP(rule)),
|
||
|
End: netIPToUint32(net.ParseIP(rule)),
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TODO: Sort ranges
|
||
|
// TODO: Find overlapping ranges
|
||
|
ipm.ranges = ranges
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (ipm *IPMatcher) Matches(ip net.IP) bool {
|
||
|
incomingIp := netIPToUint32(ip)
|
||
|
for _, ipRange := range ipm.ranges {
|
||
|
if (ipRange.Start <= incomingIp) && (incomingIp <= ipRange.End) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
matcher := IPMatcher{}
|
||
|
matcher.Set([]string{
|
||
|
"11.0.0.0/24",
|
||
|
"10.0.0.1",
|
||
|
"12.0.0.1-12.0.255.255",
|
||
|
})
|
||
|
|
||
|
assert(matcher.Matches(net.ParseIP("10.0.0.2")), false, "10.0.0.2 should not match")
|
||
|
assert(matcher.Matches(net.ParseIP("10.0.0.1")), true, "10.0.0.1 should match")
|
||
|
assert(matcher.Matches(net.ParseIP("11.0.0.50")), true, "11.0.0.50 should match")
|
||
|
assert(matcher.Matches(net.ParseIP("11.0.0.255")), true, "11.0.0.255 should match")
|
||
|
assert(matcher.Matches(net.ParseIP("12.0.0.0")), false, "12.0.0.0 should not match")
|
||
|
assert(matcher.Matches(net.ParseIP("12.0.50.1")), true, "12.0.50.1 should match")
|
||
|
fmt.Println("done!")
|
||
|
}
|
||
|
|
||
|
func assert(result, expected bool, description string) {
|
||
|
if result != expected {
|
||
|
fmt.Println(description)
|
||
|
}
|
||
|
}
|